├── .github
├── ISSUE_TEMPLATE
│ ├── bug_report.yaml
│ └── feature-request.yaml
└── PULL_REQUEST_TEMPLATE
│ └── pr_template.md
├── .gitignore
├── LICENSE
├── MODEL_LICENSE
├── README.md
├── README_en.md
├── finetune
├── README.md
├── README_en.md
├── configs
│ ├── ds_zero_3.json
│ ├── lora.yaml
│ └── sft.yaml
├── finetune.py
├── finetune_vision.py
└── vision_dataset.zip
├── inference
├── cli_demo.py
├── cli_demo_vision.py
├── ov_convert
│ ├── convert_chat.py
│ └── convert_v.py
└── web_demo.py
├── requirements.txt
└── resources
└── img.png
/.github/ISSUE_TEMPLATE/bug_report.yaml:
--------------------------------------------------------------------------------
1 | name: "\U0001F41B Bug Report"
2 | description: Submit a bug report to help us improve GLM-Edge / 提交一个 Bug 问题报告来帮助我们改进 GLM-Edge
3 | body:
4 | - type: textarea
5 | id: system-info
6 | attributes:
7 | label: System Info / 系統信息
8 | description: Your operating environment / 您的运行环境信息
9 | placeholder: Includes Cuda version, Transformers version, Python version, operating system, hardware information (if you suspect a hardware problem)... / 包括Cuda版本,Transformers版本,Python版本,操作系统,硬件信息(如果您怀疑是硬件方面的问题)...
10 | validations:
11 | required: true
12 |
13 | - type: textarea
14 | id: who-can-help
15 | attributes:
16 | label: Who can help? / 谁可以帮助到您?
17 | description: |
18 | Your issue will be replied to more quickly if you can figure out the right person to tag with @
19 | All issues are read by one of the maintainers, so if you don't know who to tag, just leave this blank and our maintainer will ping the right person.
20 |
21 | Please tag fewer than 3 people.
22 |
23 | 如果您能找到合适的标签 @,您的问题会更快得到回复。
24 | 所有问题都会由我们的维护者阅读,如果您不知道该标记谁,只需留空,我们的维护人员会找到合适的开发组成员来解决问题。
25 |
26 | 标记的人数应该不超过 3 个人。
27 |
28 | If it's not a bug in these three subsections, you may not specify the helper. Our maintainer will find the right person in the development group to solve the problem.
29 |
30 | 如果不是这三个子版块的bug,您可以不指明帮助者,我们的维护人员会找到合适的开发组成员来解决问题。
31 |
32 | placeholder: "@Username ..."
33 |
34 | - type: checkboxes
35 | id: information-scripts-examples
36 | attributes:
37 | label: Information / 问题信息
38 | description: 'The problem arises when using: / 问题出现在'
39 | options:
40 | - label: "The official example scripts / 官方的示例脚本"
41 | - label: "My own modified scripts / 我自己修改的脚本和任务"
42 |
43 | - type: textarea
44 | id: reproduction
45 | validations:
46 | required: true
47 | attributes:
48 | label: Reproduction / 复现过程
49 | description: |
50 | Please provide a code example that reproduces the problem you encountered, preferably with a minimal reproduction unit.
51 | If you have code snippets, error messages, stack traces, please provide them here as well.
52 | Please format your code correctly using code tags. See https://help.github.com/en/github/writing-on-github/creating-and-highlighting-code-blocks#syntax-highlighting
53 | Do not use screenshots, as they are difficult to read and (more importantly) do not allow others to copy and paste your code.
54 |
55 | 请提供能重现您遇到的问题的代码示例,最好是最小复现单元。
56 | 如果您有代码片段、错误信息、堆栈跟踪,也请在此提供。
57 | 请使用代码标签正确格式化您的代码。请参见 https://help.github.com/en/github/writing-on-github/creating-and-highlighting-code-blocks#syntax-highlighting
58 | 请勿使用截图,因为截图难以阅读,而且(更重要的是)不允许他人复制粘贴您的代码。
59 | placeholder: |
60 | Steps to reproduce the behavior/复现Bug的步骤:
61 |
62 | 1.
63 | 2.
64 | 3.
65 |
66 | - type: textarea
67 | id: expected-behavior
68 | validations:
69 | required: true
70 | attributes:
71 | label: Expected behavior / 期待表现
72 | description: "A clear and concise description of what you would expect to happen. /简单描述您期望发生的事情。"
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/feature-request.yaml:
--------------------------------------------------------------------------------
1 | name: "\U0001F680 Feature request"
2 | description: Submit a request for a new GLM-Edge feature / 提交一个新的 GLM-Edge 的功能建议
3 | labels: [ "feature" ]
4 | body:
5 | - type: textarea
6 | id: feature-request
7 | validations:
8 | required: true
9 | attributes:
10 | label: Feature request / 功能建议
11 | description: |
12 | A brief description of the functional proposal. Links to corresponding papers and code are desirable.
13 | 对功能建议的简述。最好提供对应的论文和代码链接
14 |
15 | - type: textarea
16 | id: motivation
17 | validations:
18 | required: true
19 | attributes:
20 | label: Motivation / 动机
21 | description: |
22 | Your motivation for making the suggestion. If that motivation is related to another GitHub issue, link to it here.
23 | 您提出建议的动机。如果该动机与另一个 GitHub 问题有关,请在此处提供对应的链接。
24 |
25 | - type: textarea
26 | id: contribution
27 | validations:
28 | required: true
29 | attributes:
30 | label: Your contribution / 您的贡献
31 | description: |
32 |
33 | Your PR link or any other link you can help with.
34 | 您的PR链接或者其他您能提供帮助的链接。
--------------------------------------------------------------------------------
/.github/PULL_REQUEST_TEMPLATE/pr_template.md:
--------------------------------------------------------------------------------
1 | # Raise valuable PR / 提出有价值的PR
2 |
3 | ## Caution/ 注意事项:
4 | Users should keep the following points in mind when submitting PRs:
5 |
6 | 1. The proposed PR should be about this project.
7 | 2. the proposed PR should be relevant, if there are multiple ideas and optimizations, they should be assigned to different PRs.
8 |
9 | 用户在提交PR时候应该注意以下几点:
10 |
11 | 1. 提出的PR应该是关于本项目的。
12 | 2. 提出的PR应该具有针对性,如果具有多个不同的想法和优化方案,应该分配到不同的PR中。
13 |
14 | ## 不应该提出的PR / PRs that should not be proposed
15 |
16 | If a developer proposes a PR about any of the following, it may be closed or Rejected.
17 |
18 | 1. those that don't describe improvement options.
19 | 2. multiple issues of different types combined in one PR.
20 | 3. The proposed PR is highly duplicative of already existing PRs.
21 |
22 | 如果开发者提出关于以下方面的PR,则可能会被直接关闭或拒绝通过。
23 |
24 | 1. 没有说明改进方案的。
25 | 2. 多个不同类型的问题合并在一个PR中的。
26 | 3. 提出的PR与已经存在的PR高度重复的。
27 |
28 |
29 | # 检查您的PR
30 | - [ ] Have you read the Contributor Guidelines, Pull Request section? / 您是否阅读了贡献者指南、Pull Request 部分?
31 | - [ ] Has this been discussed/approved via a Github issue or forum? If so, add a link. / 是否通过 Github 问题或论坛讨论/批准过?如果是,请添加链接。
32 | - [ ] Did you make sure you updated the documentation with your changes? Here are the Documentation Guidelines, and here are the Documentation Formatting Tips. /您是否确保根据您的更改更新了文档?这里是文档指南,这里是文档格式化技巧。
33 | - [ ] Did you write new required tests? / 您是否编写了新的必要测试?
34 | - [ ] Are your PRs for only one issue / 您的PR是否仅针对一个问题
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *__pycache__/
2 | samples*/
3 | runs/
4 | checkpoints/
5 | master_ip
6 | logs/
7 | *.DS_Store
8 | .idea/*
9 | output*
10 | test*
11 | pyproject.toml
12 | draft*
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright 2025 GLM-Edge Model Team @ Zhipu AI
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/MODEL_LICENSE:
--------------------------------------------------------------------------------
1 | The GLM-Edge License
2 |
3 | 1. 定义
4 |
5 | “许可方”是指分发其软件的 GLM-Edge 模型团队。
6 | “软件”是指根据本许可提供的 GLM-Edge 模型参数。
7 |
8 | 2. 许可授予
9 |
10 | 根据本许可的条款和条件,许可方特此授予您非排他性、全球性、不可转让、不可再许可、可撤销、免版税的版权许可。
11 | 本许可允许您免费使用本仓库中的所有开源模型进行学术研究,对于希望将模型用于商业目的的用户,需在[这里](https://open.bigmodel.cn/mla/form)完成登记。经过登记的用户可以免费使用本模型进行商业活动,但必须遵守本许可的所有条款和条件。
12 | 上述版权声明和本许可声明应包含在本软件的所有副本或重要部分中。
13 | 如果您分发或提供 THUDM / 智谱AI 关于 GLM-Edge 开源模型的材料(或其任何衍生作品),或使用其中任何材料(包括 GLM-Edge 系列的所有开源模型)的产品或服务,您应:
14 |
15 | (A) 随任何此类 THUDM / 智谱AI 材料提供本协议的副本;
16 | (B) 在相关网站、用户界面、博客文章、关于页面或产品文档上突出显示 “Built with GLM-Edge”。
17 | 如果您使用 THUDM / 智谱AI的 GLM-Edge 开源模型的材料来创建、训练、微调或以其他方式改进已分发或可用的 AI 模型,您还应在任何此类 AI 模型名称的开头添加 “GLM-Edge”。
18 |
19 | 3. 限制
20 |
21 | 您不得出于任何军事或非法目的使用、复制、修改、合并、发布、分发、复制或创建本软件的全部或部分衍生作品。
22 | 您不得利用本软件从事任何危害国家安全和国家统一,危害社会公共利益及公序良俗,侵犯他人商业秘密、知识产权、名誉权、肖像权、财产权等权益的行为。
23 | 您在使用中应遵循使用地所适用的法律法规政策、道德规范等要求。
24 |
25 | 4. 免责声明
26 |
27 | 本软件“按原样”提供,不提供任何明示或暗示的保证,包括但不限于对适销性、特定用途的适用性和非侵权性的保证。
28 | 在任何情况下,作者或版权持有人均不对任何索赔、损害或其他责任负责,无论是在合同诉讼、侵权行为还是其他方面,由软件或软件的使用或其他交易引起、由软件引起或与之相关
29 | 软件。
30 |
31 | 5. 责任限制
32 |
33 | 除适用法律禁止的范围外,在任何情况下且根据任何法律理论,无论是基于侵权行为、疏忽、合同、责任或其他原因,任何许可方均不对您承担任何直接、间接、特殊、偶然、示范性、
34 | 或间接损害,或任何其他商业损失,即使许可人已被告知此类损害的可能性。
35 |
36 | 6. 争议解决
37 |
38 | 本许可受中华人民共和国法律管辖并按其解释。 因本许可引起的或与本许可有关的任何争议应提交北京市海淀区人民法院。
39 | 请注意,许可证可能会更新到更全面的版本。 有关许可和版权的任何问题,请通过 license@zhipuai.cn 或 opensource@zhipuai.cn 与我们联系。
40 |
41 | 1. Definitions
42 |
43 | “Licensor” means the GLM-Edge Model Team that distributes its Software.
44 | “Software” means the GLM-Edge model parameters made available under this license.
45 |
46 | 2. License
47 |
48 | Under the terms and conditions of this license, the Licensor hereby grants you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty-free copyright license.
49 | This license allows you to use all open source models in this repository for free for academic research. For users who wish to use the models for commercial purposes, please do so [here](https://open.bigmodel.cn/mla/form)
50 | Complete registration. Registered users are free to use this model for commercial activities, but must comply with all terms and conditions of this license.
51 | The copyright notice and this license notice shall be included in all copies or substantial portions of the Software.
52 | If you distribute or provide THUDM / Zhipu AI materials on the GLM-Edge open source model (or any derivative works thereof), or products or services that use any materials therein (including all open source models of the GLM-Edge series), you should:
53 |
54 | (A) Provide a copy of this Agreement with any such THUDM/Zhipu AI Materials;
55 | (B) Prominently display "Built with GLM-Edge" on the relevant website, user interface, blog post, related page or product documentation.
56 | If you use materials from THUDM/ZHIPU's GLM-Edge model to create, train, operate, or otherwise improve assigned or available AI models, you should also add "GLM-Edge" to the beginning of any such AI model name.
57 |
58 | 3. Restrictions
59 |
60 | You are not allowed to use, copy, modify, merge, publish, distribute, copy or create all or part of the derivative works of this software for any military or illegal purposes.
61 | You are not allowed to use this software to engage in any behavior that endangers national security and unity, endangers social public interests and public order, infringes on the rights and interests of others such as trade secrets, intellectual property rights, reputation rights, portrait rights, and property rights.
62 | You should comply with the applicable laws, regulations, policies, ethical standards, and other requirements in the place of use during use.
63 |
64 | 4. Disclaimer
65 |
66 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE
67 | WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
68 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
69 | OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
70 |
71 | 5. Limitation of Liability
72 |
73 | EXCEPT TO THE EXTENT PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER BASED IN TORT,
74 | NEGLIGENCE, CONTRACT, LIABILITY, OR OTHERWISE WILL ANY LICENSOR BE LIABLE TO YOU FOR ANY DIRECT, INDIRECT, SPECIAL,
75 | INCIDENTAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES, OR ANY OTHER COMMERCIAL LOSSES, EVEN IF THE LICENSOR HAS BEEN ADVISED
76 | OF THE POSSIBILITY OF SUCH DAMAGES.
77 |
78 | 6. Dispute Resolution
79 |
80 | This license shall be governed and construed in accordance with the laws of People’s Republic of China. Any dispute
81 | arising from or in connection with this License shall be submitted to Haidian District People's Court in Beijing.
82 |
83 | Note that the license is subject to update to a more comprehensive version. For any questions related to the license and
84 | copyright, please contact us at license@zhipuai.cn.
85 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # GLM-Edge
2 |
3 | Read this in [English](README_en.md)
4 |
5 | 在 🤗 这里 体验 GLM-Edge-1.5B-Chat 端侧模型
6 |
7 | 在 🤗 这里 或者 🤖 这里 体验 GLM-Edge-V-5B 端侧模型
8 |
9 |
10 | ## 模型介绍
11 |
12 | **GLM-Edge** 系列是我们在面向端侧真实落地使用的场景下的一次尝试,由两种尺寸的大语言对话模型和多模态理解模型组成(
13 | `GLM-Edge-1.5B-Chat`,`GLM-Edge-4B-Chat`,`GLM-Edge-V-2B`,`GLM-Edge-V-5B`)。其中,`1.5B / 2B`模型主要面向手机、车机等平台,
14 | `4B / 5B` 模型主要面向PC等平台。
15 |
16 | 基于GLM-4系列的技术积累,我们针对端侧实际部署情况,对模型结构和尺寸做了针对性的调整,以求在模型表现、实机推理效果和落地便利度之间达到平衡。同时,通过与伙伴企业的深入合作和在推理优化上的不懈努力,在一些端侧平台上,GLM-Edge系列模型能以极快的速度运行。
17 |
18 | 例如,在高通骁龙8 Elite平台上,借助其强大的NPU算力,GLM-Edge通过混合量化方案,1.5B对话模型、2B多模态模型能实现每秒60
19 | tokens以上的解码速度。在应用投机采样技术之后,两个模型能以峰值每秒100 tokens以上的解码速度运行。**这些推理方案会由我们或合作伙伴后续放出。暂时不会在本仓库提供。**
20 | 模型下载地址:
21 |
22 | | Model | HuggingFace Model | GGUF Model |
23 | |:------------------:|:--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------:|:-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------:|
24 | | GLM-Edge-1.5B-Chat | [🤗 Huggingface](https://huggingface.co/THUDM/glm-edge-1.5b-chat)
[🤖 ModelScope](https://modelscope.cn/models/ZhipuAI/glm-edge-1.5b-chat)
[🟣 WiseModel](https://wisemodel.cn/models/ZhipuAI/glm-edge-1.5b-chat) | [🤗 Huggingface](https://huggingface.co/THUDM/glm-edge-1.5b-chat-gguf)
[🤖 ModelScope](https://modelscope.cn/models/ZhipuAI/glm-edge-1.5b-chat-gguf)
[🟣 WiseModel](https://wisemodel.cn/models/ZhipuAI/glm-edge-1.5b-chat-gguf) |
25 | | GLM-Edge-4B-Chat | [🤗 Huggingface](https://huggingface.co/THUDM/glm-edge-4b-chat)
[🤖 ModelScope](https://modelscope.cn/models/ZhipuAI/glm-edge-4b-chat)
[🟣 WiseModel](https://wisemodel.cn/models/ZhipuAI/glm-edge-4b-chat) | [🤗 Huggingface](https://huggingface.co/THUDM/glm-edge-4b-chat-gguf)
[🤖 ModelScope](https://modelscope.cn/models/ZhipuAI/glm-edge-4b-chat-gguf)
[🟣 WiseModel](https://wisemodel.cn/models/ZhipuAI/glm-edge-4b-chat-gguf) |
26 | | GLM-Edge-V-2B | [🤗 Huggingface](https://huggingface.co/THUDM/glm-edge-v-2b)
[🤖 ModelScope](https://modelscope.cn/models/ZhipuAI/glm-edge-v-2b)
[🟣 WiseModel](https://wisemodel.cn/models/ZhipuAI/glm-edge-v-2b) | [🤗 Huggingface](https://huggingface.co/THUDM/glm-edge-v-2b-gguf)
[🤖 ModelScope](https://modelscope.cn/models/ZhipuAI/glm-edge-v-2b-gguf)
[🟣 WiseModel](https://wisemodel.cn/models/ZhipuAI/glm-edge-v-2b-gguf) |
27 | | GLM-Edge-V-5B | [🤗 Huggingface](https://huggingface.co/THUDM/glm-edge-v-5b)
[🤖 ModelScope](https://modelscope.cn/models/ZhipuAI/glm-edge-v-5b)
[🟣 WiseModel](https://wisemodel.cn/models/ZhipuAI/glm-edge-v-5b) | [🤗 Huggingface](https://huggingface.co/THUDM/glm-edge-v-5b-gguf)
[🤖 ModelScope](https://modelscope.cn/models/ZhipuAI/glm-edge-v-5b-gguf)
[🟣 WiseModel](https://wisemodel.cn/models/ZhipuAI/glm-edge-v-5b-gguf) |
28 |
29 | ## 实机运行数据
30 |
31 | 数据采集日截止到2024年11月28日。我们还在积极地与合作伙伴们一道优化这些性能。
32 |
33 | ### 高通
34 |
35 | | 模型 | 任务 | 量化方案 | 框架 | 1st token latency (ms) | Token Rate (tokens/s) | Peak Memory Footprint (GB) |
36 | |--------------------|------------------------|------|-----|------------------------|-----------------------|----------------------------|
37 | | GLM-Edge-4B-Chat | (input/output=512/128) | INT4 | QNN | 660 | 24 | 2.9 |
38 | | GLM-Edge-1.5B-Chat | (input/output=512/128) | INT4 | QNN | 260 | 65 | 1.2 |
39 |
40 | * 在高通8 Elite(Gen4)平台上测试,模型全部运行在NPU上
41 | * 如运行V模型,另外需要单图890ms的处理时间和约660M的额外内存
42 | * 使用投机解码方案时,Token Rate还有最高50%的提升
43 |
44 | ### Intel
45 |
46 | | 模型 | 任务 | 量化方案 | 框架 | 1st token latency (ms) | Token Rate (tokens/s) | Peak Memory Footprint (GB) |
47 | |--------------------|--------------------------------------|------|----------|------------------------|-----------------------|----------------------------|
48 | | GLM-Edge-4B-Chat | (input/output=1024/128) | INT4 | OPENVINO | 541.2 | 27 | 3.9 |
49 | | GLM-Edge-1.5B-Chat | (input/output=1024/128) | INT4 | OPENVINO | 228.2 | 63 | 2.3 |
50 | | GLM-Edge-V-2B | Single image understanding (672x672) | INT4 | OPENVINO | 362.1 | 70 | 3.4 |
51 |
52 | * 在Intel LNL 288V (ARC 140V 8X@2.05GHz) 平台上测试。
53 | * 如运行V模型,另外需要单图1.7s的处理时间和约2G的额外内存。
54 |
55 | ## 安装依赖
56 |
57 | 请确保你的Python版本为`3.10`或更高版本。并按照如下方式安装依赖,安装以下依赖能确保正确运行本仓库的所有代码。
58 |
59 | ```shell
60 | pip install -r requirements.txt
61 | ```
62 |
63 | ## 模型推理
64 |
65 | ### Transformers / OpenVINO / vLLM Demo
66 |
67 | 我们提供了 vLLM, OpenVINO 和 transformers 三种后端推理方式,你可以通过运行以下命令来运行模型。这是一个命令行交互代码。
68 |
69 | ```shell
70 | python cli_demo.py --backend transformers --model_path THUDM/glm-edge-1.5b-chat --precision bfloat16
71 | python cli_demo.py --backend vllm --model_path THUDM/glm-edge-1.5b-chat --precision bfloat16
72 | python cli_demo.py --backend ov --model_path THUDM/glm-edge-1.5b-chat-ov --precision int4
73 | ```
74 |
75 | > 注意:
76 | >
77 | > OpenVINO 版本模型需要进行转换,请前往 [这里](inference/ov_convert) 运行转换代码。
78 | >
79 | > ```python convert_chat.py --model_path THUDM/glm-edge-1.5b-chat --precision int4 ``` 转换对话模型。
80 | >
81 | > ```python convert.py --model_path THUDM/glm-edge-v-2b --precision int4``` 转换视觉理解模型。
82 | >
83 | > 你也可以在 [这里](https://github.com/openvino-dev-samples/glm-edge.openvino) 查看原始的转换代码。
84 | >
85 | > vLLM 版本模型需要从 [这里](https://github.com/sixsixcoder/vllm/tree/glm-4) 源代码 安装 vLLM 以正常运行。
86 |
87 | 如果你想使用 glm-edge-v 系列模型,你可以运行以下命令行交互代码
88 |
89 | ```shell
90 | python cli_demo_vision.py --backend transformers --model_path THUDM/glm-edge-v-2b --precision bfloat16
91 | python cli_demo.py --backend ov --model_path THUDM/glm-edge-1.5b-chat-ov --precision int4
92 | ```
93 |
94 | 你也可以使用 Gradio 启动 WebUI。
95 |
96 | ```shell
97 | python cli_demo.py --backend transformers --model_path THUDM/glm-edge-1.5b-chat --precision bfloat16
98 | python cli_demo.py --backend vllm --model_path THUDM/glm-edge-1.5b-chat --precision int4 # For Int4 Inference
99 | ```
100 |
101 | ### XInference
102 |
103 | 如果你使用 XInference 进行推理,你可以通过运行以下命令来运行模型。这是一个命令行交互代码。
104 |
105 | ```shell
106 | xinference launch --model-engine Transformers --model-name glm-edge-v --size-in-billions 2 --model-format pytorch --quantization none
107 | ```
108 |
109 | 使用 OpenAI API进行推理:
110 |
111 | ```python
112 | import openai
113 |
114 | client = openai.Client(
115 | api_key="cannot be empty",
116 | base_url="http://:/v1"
117 | )
118 | output = client.chat.completions.create(
119 | model="glm-edge-v",
120 | messages=[
121 | {
122 | "role": "user",
123 | "content": [
124 | {
125 | 'type': 'text',
126 | 'text': 'describe this image',
127 | },
128 | {
129 | 'type': 'image_url',
130 | 'image_url': {
131 | "url": "img.png",
132 | }
133 | },
134 | ],
135 | }
136 | ],
137 | max_tokens=512,
138 | temperature=0.7
139 | )
140 |
141 | print(output)
142 | ```
143 |
144 | ## 微调模型
145 |
146 | 我们提供了微调模型的代码,请参考 [微调教程](finetune/README.md)。
147 |
148 | ## 协议
149 |
150 | 本 github 仓库代码的使用 [Apache2.0 LICENSE](LICENSE)。
151 |
152 | 模型权重的使用请遵循 [Model License](MODEL_LICENSE)。
153 |
--------------------------------------------------------------------------------
/README_en.md:
--------------------------------------------------------------------------------
1 | # GLM-Edge
2 |
3 | Experience the GLM-Edge-1.5B-Chat edge chat model at 🤗 [here](https://huggingface.co/spaces/THUDM-HF-SPACE/GLM-Edge-1.5B-Chat-Space)
4 |
5 | Experience the GLM-Edge-V-5B edge vision-chat model at 🤗 [here](https://huggingface.co/spaces/THUDM-HF-SPACE/GLM-Edge-V-5B-Space)
6 |
7 | ## Model Introduction
8 |
9 | The **GLM-Edge** series is our attempt to meet real-world deployment scenarios for edge devices. It consists of two sizes of
10 | large language dialogue models and multimodal understanding models (`GLM-Edge-1.5B-Chat`, `GLM-Edge-4B-Chat`,
11 | `GLM-Edge-V-2B`, `GLM-Edge-V-5B`). Among them, the `1.5B / 2B` models are mainly targeted at platforms like mobile
12 | phones and car machines, while the `4B / 5B` models are aimed at platforms like PCs.
13 |
14 | Based on the technological advancements of the GLM-4 series, we have made targeted adjustments to the model structure
15 | and size, balancing model performance, real-world inference efficiency, and deployment convenience. Through deep
16 | collaboration with partner enterprises and relentless efforts in inference optimization, the GLM-Edge series models can
17 | run at extremely high speeds on some edge platforms.
18 |
19 | For example, on the Qualcomm Snapdragon 8 Elite platform, leveraging its powerful NPU computing power and using a mixed
20 | quantization scheme, the 1.5B dialogue model and the 2B multimodal model can achieve decoding speeds of over 60 tokens
21 | per second. With speculative sampling techniques, these models can reach peak decoding speeds of over 100 tokens per
22 | second. These inference solutions will be released later by us or our partners.
23 |
24 | Download links for the models:
25 |
26 | | Model | HuggingFace Model | GGUF Model |
27 | |:------------------:|:--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------:|:-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------:|
28 | | GLM-Edge-1.5B-Chat | [🤗 HuggingFace](https://huggingface.co/THUDM/glm-edge-1.5b-chat)
[🤖 ModelScope](https://modelscope.cn/models/ZhipuAI/glm-edge-1.5b-chat)
[🟣 WiseModel](https://wisemodel.cn/models/ZhipuAI/glm-edge-1.5b-chat) | [🤗 HuggingFace](https://huggingface.co/THUDM/glm-edge-1.5b-chat-gguf)
[🤖 ModelScope](https://modelscope.cn/models/ZhipuAI/glm-edge-1.5b-chat-gguf)
[🟣 WiseModel](https://wisemodel.cn/models/ZhipuAI/glm-edge-1.5b-chat-gguf) |
29 | | GLM-Edge-4B-Chat | [🤗 HuggingFace](https://huggingface.co/THUDM/glm-edge-4b-chat)
[🤖 ModelScope](https://modelscope.cn/models/ZhipuAI/glm-edge-4b-chat)
[🟣 WiseModel](https://wisemodel.cn/models/ZhipuAI/glm-edge-4b-chat) | [🤗 HuggingFace](https://huggingface.co/THUDM/glm-edge-4b-chat-gguf)
[🤖 ModelScope](https://modelscope.cn/models/ZhipuAI/glm-edge-4b-chat-gguf)
[🟣 WiseModel](https://wisemodel.cn/models/ZhipuAI/glm-edge-4b-chat-gguf) |
30 | | GLM-Edge-V-2B | [🤗 HuggingFace](https://huggingface.co/THUDM/glm-edge-v-2b)
[🤖 ModelScope](https://modelscope.cn/models/ZhipuAI/glm-edge-v-2b)
[🟣 WiseModel](https://wisemodel.cn/models/ZhipuAI/glm-edge-v-2b) | [🤗 HuggingFace](https://huggingface.co/THUDM/glm-edge-v-2b-gguf)
[🤖 ModelScope](https://modelscope.cn/models/ZhipuAI/glm-edge-v-2b-gguf)
[🟣 WiseModel](https://wisemodel.cn/models/ZhipuAI/glm-edge-v-2b-gguf) |
31 | | GLM-Edge-V-5B | [🤗 HuggingFace](https://huggingface.co/THUDM/glm-edge-v-5b)
[🤖 ModelScope](https://modelscope.cn/models/ZhipuAI/glm-edge-v-5b)
[🟣 WiseModel](https://wisemodel.cn/models/ZhipuAI/glm-edge-v-5b) | [🤗 HuggingFace](https://huggingface.co/THUDM/glm-edge-v-5b-gguf)
[🤖 ModelScope](https://modelscope.cn/models/ZhipuAI/glm-edge-v-5b-gguf)
[🟣 WiseModel](https://wisemodel.cn/models/ZhipuAI/glm-edge-v-5b-gguf) |
32 |
33 | ## Performance Data
34 |
35 | Data collection is up to November 28, 2024. We are actively working with partners to optimize these performances.
36 |
37 | ### Qualcomm
38 |
39 | | Model | Task | Quantization | Framework | 1st Token Latency (ms) | Token Rate (tokens/s) | Peak Memory Footprint (GB) |
40 | |--------------------|------------------------|--------------|-----------|------------------------|-----------------------|----------------------------|
41 | | GLM-Edge-4B-Chat | (input/output=512/128) | INT4 | QNN | 260 | 65 | 2.9 |
42 | | GLM-Edge-1.5B-Chat | (input/output=512/128) | INT4 | QNN | 660 | 24 | 1.2 |
43 |
44 | - Tested on the Qualcomm 8 Elite (Gen4) platform with models fully running on the NPU.
45 | - For V models, an additional 890ms processing time per image and about 660MB extra memory is required.
46 | - With speculative decoding, the Token Rate can achieve up to 50% improvement.
47 |
48 | ### Intel
49 |
50 | | Model | Task | Quantization | Framework | 1st Token Latency (ms) | Token Rate (tokens/s) | Peak Memory Footprint (GB) |
51 | |--------------------|--------------------------------------|--------------|-----------|------------------------|-----------------------|----------------------------|
52 | | GLM-Edge-4B-Chat | (input/output=1024/128) | INT4 | OPENVINO | 541.2 | 27 | 3.9 |
53 | | GLM-Edge-1.5B-Chat | (input/output=1024/128) | INT4 | OPENVINO | 228.2 | 63 | 2.3 |
54 | | GLM-Edge-V-2B | Single image understanding (672x672) | INT4 | OPENVINO | 362.1 | 70 | 3.4 |
55 |
56 | - Tested on the Intel LNL 288V (ARC 140V 8X@2.05GHz) platform.
57 | - For V models, an additional 1.7s processing time per image and about 2GB extra memory is required.
58 |
59 | ## Install Dependencies
60 |
61 | Ensure your Python version is `3.10` or higher. Install dependencies as follows to ensure all code in this repository
62 | runs correctly:
63 |
64 | ```shell
65 | pip install -r requirements.txt
66 | ```
67 |
68 | ## Model Inference
69 |
70 | ### Transformers / OpenVINO / vLLM Demo
71 |
72 | We provide three backend inference options: vLLM, OpenVINO, and transformers. You can run the models using the following
73 | commands. This is a command-line interaction code.
74 |
75 | ```shell
76 | python cli_demo.py --backend transformers --model_path THUDM/glm-edge-1.5b-chat --precision bfloat16
77 | python cli_demo.py --backend vllm --model_path THUDM/glm-edge-1.5b-chat --precision bfloat16
78 | python cli_demo.py --backend ov --model_path THUDM/glm-edge-1.5b-chat-ov --precision int4
79 | ```
80 |
81 | > Note:
82 | >
83 | > OpenVINO version models need conversion. Please visit [here](inference/ov_convert) to run the conversion code.
84 | >
85 | > ```python convert_chat.py --model_path THUDM/glm-edge-1.5b-chat --precision int4 ``` to convert dialogue models.
86 | >
87 | > ```python convert.py --model_path THUDM/glm-edge-v-2b --precision int4``` to convert visual understanding models.
88 | >
89 | > You can also view the original conversion code [here](https://github.com/openvino-dev-samples/glm-edge.openvino).
90 | >
91 | > vLLM version models require installation of source code from [here](https://github.com/sixsixcoder/vllm/tree/glm-4) to
92 | > run properly.
93 |
94 | To use glm-edge-v series models, you can run the following command-line interaction code:
95 |
96 | ```shell
97 | python cli_demo_vision.py --backend transformers --model_path THUDM/glm-edge-v-2b --precision bfloat16
98 | python cli_demo.py --backend ov --model_path THUDM/glm-edge-1.5b-chat-ov --precision int4
99 | ```
100 |
101 | You can also use Gradio to launch a WebUI.
102 |
103 | ```shell
104 | python cli_demo.py --backend transformers --model_path THUDM/glm-edge-1.5b-chat --precision bfloat16
105 | python cli_demo.py --backend vllm --model_path THUDM/glm-edge-1.5b-chat --precision int4 # For Int4 Inference
106 | ```
107 |
108 | ### XInference
109 |
110 | If you use XInference for inference, you can run the model using the following commands. This is a command-line
111 | interaction code.
112 |
113 | ```shell
114 | xinference launch --model-engine Transformers --model-name glm-edge-v --size-in-billions 2 --model-format pytorch --quantization none
115 | ```
116 |
117 | Using OpenAI API for inference:
118 |
119 | ```python
120 | import openai
121 |
122 | client = openai.Client(
123 | api_key="cannot be empty",
124 | base_url="http://:/v1"
125 | )
126 | output = client.chat.completions.create(
127 | model="glm-edge-v",
128 | messages=[
129 | {
130 | "role": "user",
131 | "content": [
132 | {
133 | 'type': 'text',
134 | 'text': 'describe this image',
135 | },
136 | {
137 | 'type': 'image_url',
138 | 'image_url': {
139 | "url": "img.png",
140 | }
141 | },
142 | ],
143 | }
144 | ],
145 | max_tokens=512,
146 | temperature=0.7
147 | )
148 |
149 | print(output)
150 | ```
151 |
152 | ## Fine-Tuning Models
153 |
154 | We provide code for fine-tuning models. Please refer to the [Fine-Tuning Tutorial](finetune/README.md).
155 |
156 | ## License
157 |
158 | The code in this GitHub repository uses the [Apache2.0 LICENSE](LICENSE).
159 |
160 | Usage of model weights must follow the [Model License](MODEL_LICENSE).
161 |
--------------------------------------------------------------------------------
/finetune/README.md:
--------------------------------------------------------------------------------
1 | # GLM-Edge 对话模型微调
2 |
3 | Read this in [English](README_en.md)
4 |
5 | 本 demo 中,你将体验到如何微调 GLM-Edge 对话开源模型。 请严格按照文档的步骤进行操作,以避免不必要的错误。
6 |
7 | ## 多轮对话格式
8 |
9 | 多轮对话微调示例采用 GLM-Edge 对话格式约定,对不同角色添加不同 `loss_mask` 从而在一遍计算中为多轮回复计算 `loss`。
10 |
11 | 对于数据文件,样例采用如下格式:
12 |
13 | 对于glm-edge-chat系列模型,您应该按照以下格式整理数据。
14 |
15 | ```json
16 | [
17 | {
18 | "messages": [
19 | {
20 | "role": "system",
21 | "content": "",
22 | },
23 | {
24 | "role": "user",
25 | "content": ""
26 | },
27 | {
28 | "role": "assistant",
29 | "content": ""
30 | },
31 | // Multi_turns
32 | {
33 | "role": "user",
34 | "content": ""
35 | },
36 | {
37 | "role": "assistant",
38 | "content": ""
39 | },
40 | ]
41 | }
42 | ]
43 | ```
44 |
45 | 这里是一个单轮对话的例子:
46 |
47 | ```json
48 | {
49 | "messages": [
50 | {
51 | "role": "user",
52 | "content": "类型#裤*材质#牛仔布*风格#性感"
53 | },
54 | {
55 | "role": "assistant",
56 | "content": "3x1的这款牛仔裤采用浅白的牛仔面料为裤身材质,其柔然的手感和细腻的质地,在穿着舒适的同时,透露着清纯甜美的个性气质。除此之外,流畅的裤身剪裁将性感的>腿部曲线彰显的淋漓尽致,不失为一款随性出街的必备单品。"
57 | }
58 | ]
59 | }
60 | ```
61 |
62 | 对于glm-edge-v系列模型,您应该按照以下格式整理数据。
63 |
64 | ```json
65 | [
66 | {
67 | "messages": [
68 | {
69 | "role": "user",
70 | "content": [
71 | {
72 | "type": "image",
73 | "image": "path/to/image"
74 | },
75 | {
76 | "type": "text",
77 | "text": "图片中的狗在做什么?"
78 | }
79 | ]
80 | },
81 | {
82 | "role": "assistant",
83 | "content": [
84 | {
85 | "type": "text",
86 | "text": "zRzRzRzRzRzRzR!这只狗躺在公寓客厅的绿色沙发上。"
87 | }
88 | ]
89 | },
90 | {
91 | "role": "user",
92 | "content": [
93 | {
94 | "type": "text",
95 | "text": "这只狗是什么颜色的?"
96 | }
97 | ]
98 | },
99 | {
100 | "role": "assistant",
101 | "content": [
102 | {
103 | "type": "text",
104 | "text": "zRzRzRzRzRzRzR!这只狗是棕色和白色的。"
105 | }
106 | ]
107 | }
108 | ]
109 | }
110 | ]
111 | ```
112 |
113 | ## 配置文件
114 |
115 | 微调配置文件位于 `config` 目录下,包括以下文件:
116 |
117 | 1. `ds_zereo_2 / ds_zereo_3.json`: deepspeed 配置文件。
118 |
119 | 2. `lora.yaml / sft.yaml`: 模型不同方式的配置文件,包括模型参数、优化器参数、训练参数等。 部分重要参数解释如下:
120 | + data_config 部分
121 | + train_file: 训练数据集的文件路径。
122 | + val_file: 验证数据集的文件路径。
123 | + test_file: 测试数据集的文件路径。
124 | + num_proc: 在加载数据时使用的进程数量。
125 | + max_input_length: 输入序列的最大长度,对于glm-edge-v系列模型,由于图片占位token个数是584,因此值需要设置大些。
126 | + max_output_length: 输出序列的最大长度。
127 | + training_args 部分
128 | + output_dir: 用于保存模型和其他输出的目录。
129 | + max_steps: 训练的最大步数。
130 | + per_device_train_batch_size: 每个设备(如 GPU)的训练批次大小。
131 | + dataloader_num_workers: 加载数据时使用的工作线程数量。
132 | + remove_unused_columns: 是否移除数据中未使用的列。
133 | + save_strategy: 模型保存策略(例如,每隔多少步保存一次)。
134 | + save_steps: 每隔多少步保存一次模型。
135 | + log_level: 日志级别(如 info)。
136 | + logging_strategy: 日志记录策略。
137 | + logging_steps: 每隔多少步记录一次日志。
138 | + per_device_eval_batch_size: 每个设备的评估批次大小。
139 | + evaluation_strategy: 评估策略(例如,每隔多少步进行一次评估)。
140 | + eval_steps: 每隔多少步进行一次评估。
141 | + predict_with_generate: 是否使用生成模式进行预测。
142 | + generation_config 部分
143 | + max_new_tokens: 生成的最大新 token 数量。
144 | + peft_config 部分
145 | + peft_type: 使用的参数有效调整类型 (支持 LORA 和 PREFIX_TUNING)。
146 | + task_type: 任务类型,这里是因果语言模型 (不要改动)。
147 | + Lora 参数:
148 | + r: LoRA 的秩。
149 | + lora_alpha: LoRA 的缩放因子。
150 | + lora_dropout: 在 LoRA 层使用的 dropout 概率。
151 |
152 | ## 开始微调
153 |
154 | 通过以下代码执行 **单机多卡/多机多卡** 运行,这是使用 `deepspeed` 作为加速方案的,您需要安装 `deepspeed`。接着,按照此命令运行:
155 |
156 | ```shell
157 | OMP_NUM_THREADS=1 torchrun --standalone --nnodes=1 --nproc_per_node=8 finetune.py data/AdvertiseGen/ THUDM/glm-edge-4b-chat configs/lora.yaml # For Chat Fine-tune
158 | OMP_NUM_THREADS=1 torchrun --standalone --nnodes=1 --nproc_per_node=8 finetune_vision.py data/CogVLM-311K/ THUDM/glm-edge-v-5b configs/lora.yaml # For VQA Fine-tune
159 | ```
160 |
161 | 通过以下代码执行 **单机单卡** 运行。
162 |
163 | ```shell
164 | python finetune.py data/AdvertiseGen/ THUDM/glm-edge-4b-chat configs/lora.yaml # For Chat Fine-tune
165 | python finetune_vision.py data/CogVLM-311K/ THUDM/glm-edge-v-5b configs/lora.yaml # For VQA Fine-tune
166 | ```
167 |
168 | ## 从保存点进行微调
169 |
170 | 如果按照上述方式进行训练,每次微调都会从头开始,如果你想从训练一半的模型开始微调,你可以加入第四个参数,这个参数有两种传入方式:
171 |
172 | 1. `yes`, 自动从最后一个保存的 Checkpoint开始训练
173 | 2. `XX`, 断点号数字 例 `600` 则从序号600 Checkpoint开始训练
174 |
175 | 例如,这就是一个从最后一个保存点继续微调的示例代码
176 |
177 | ```shell
178 | python finetune.py data/AdvertiseGen/ THUDM/glm-edge-4b-chat configs/lora.yaml yes
179 | python finetune_vision.py data/CogVLM-311K/ THUDM/glm-edge-4b-chat configs/lora.yaml yes
180 | ```
181 |
--------------------------------------------------------------------------------
/finetune/README_en.md:
--------------------------------------------------------------------------------
1 | # GLM-Edge dialogue model fine-tuning
2 |
3 | Read this in [Chinese](README.md)
4 |
5 | In this demo, you will experience how to fine-tune the GLM-Edge-4B-Chat open source dialogue model. Please strictly follow the steps in the document to avoid unnecessary errors.
6 |
7 | ## Multi-turn dialogue format
8 |
9 | The multi-turn dialogue fine-tuning example uses the GLM-Edge dialogue format convention, adding different `loss_mask` to different roles to calculate `loss` for multiple rounds of replies in one calculation.
10 |
11 | For data files, the sample uses the following format:
12 |
13 | For the glm-edge-chat family of models, you should organize your data in the following format.
14 |
15 | ```json
16 | [
17 | {
18 | "messages": [
19 | {
20 | "role": "system",
21 | "content": "",
22 | },
23 | {
24 | "role": "user",
25 | "content": ""
26 | },
27 | {
28 | "role": "assistant",
29 | "content": ""
30 | },
31 | // Multi_turns
32 | {
33 | "role": "user",
34 | "content": ""
35 | },
36 | {
37 | "role": "assistant",
38 | "content": ""
39 | },
40 | ]
41 | }
42 | ]
43 | ```
44 |
45 | Here is an example of a single-turn conversation:
46 |
47 | ```json
48 | {
49 | "messages": [
50 | {
51 | "role": "user",
52 | "content": "Type#Pants*Material#Denim*Style#Sexy"
53 | },
54 | {
55 | "role": "assistant",
56 | "content": "This pair of jeans from 3x1 is made of light white denim fabric. Its soft feel and delicate texture make it comfortable to wear while revealing a pure and sweet personality. In addition, the smooth cut of the pants fully highlights the sexy leg curves, making it a must-have item for casual outings."
57 | }
58 | ]
59 | }
60 | ```
61 |
62 | For glm-edge-v family of models, you should organize your data in the following format.
63 |
64 | ```json
65 | [
66 | {
67 | "messages": [
68 | {
69 | "role": "user",
70 | "content": [
71 | {
72 | "type": "image",
73 | "image": "path/to/image"
74 | },
75 | {
76 | "type": "text",
77 | "text": "图片中的狗在做什么?"
78 | }
79 | ]
80 | },
81 | {
82 | "role": "assistant",
83 | "content": [
84 | {
85 | "type": "text",
86 | "text": "zRzRzRzRzRzRzR!这只狗躺在公寓客厅的绿色沙发上。"
87 | }
88 | ]
89 | },
90 | {
91 | "role": "user",
92 | "content": [
93 | {
94 | "type": "text",
95 | "text": "这只狗是什么颜色的?"
96 | }
97 | ]
98 | },
99 | {
100 | "role": "assistant",
101 | "content": [
102 | {
103 | "type": "text",
104 | "text": "zRzRzRzRzRzRzR!这只狗是棕色和白色的。"
105 | }
106 | ]
107 | }
108 | ]
109 | }
110 | ]
111 | ```
112 |
113 | ## Configuration Files
114 |
115 | The fine-tuning configuration files are located in the `config` directory and include the following files:
116 |
117 | 1. `ds_zereo_2 / ds_zereo_3.json`: deepspeed configuration file
118 |
119 | 2. `lora.yaml / sft.yaml`: Configuration files of different models, including model parameters, optimizer parameters, training parameters, etc. Some important parameters are explained as follows:
120 |
121 | + train_file: File path of training dataset.
122 | + val_file: File path of validation dataset.
123 | + test_file: File path of test dataset.
124 | + num_proc: Number of processes to use when loading data.
125 | + max_input_length: Maximum length of input sequence, since the number of image placeholder token is 584, the value needs to be set larger.
126 | + max_output_length: Maximum length of output sequence.
127 | + training_args section
128 | + output_dir: Directory for saving model and other outputs.
129 | + max_steps: Maximum number of training steps.
130 | + per_device_train_batch_size: Training batch size per device (such as GPU).
131 | + dataloader_num_workers: Number of worker threads to use when loading data.
132 | + remove_unused_columns: Whether to remove unused columns in data.
133 | + save_strategy: Model saving strategy (for example, how many steps to save).
134 | + save_steps: How many steps to save the model.
135 | + log_level: Log level (such as info).
136 | + logging_strategy: logging strategy.
137 | + logging_steps: how many steps to log at.
138 | + per_device_eval_batch_size: per-device evaluation batch size.
139 | + evaluation_strategy: evaluation strategy (e.g. how many steps to evaluate at).
140 | + eval_steps: how many steps to evaluate at.
141 | + predict_with_generate: whether to use generation mode for prediction.
142 | + generation_config section
143 | + max_new_tokens: maximum number of new tokens to generate.
144 | + peft_config section
145 | + peft_type: type of parameter tuning to use (supports LORA and PREFIX_TUNING).
146 | + task_type: task type, here is causal language model (don't change).
147 | + Lora parameters:
148 | + r: rank of LoRA.
149 | + lora_alpha: scaling factor of LoRA.
150 | + lora_dropout: dropout probability to use in LoRA layer.
151 |
152 | ## Start fine-tuning
153 |
154 | Execute **single machine multi-card/multi-machine multi-card** run through the following code, which uses `deepspeed` as
155 | the acceleration solution, and you need to install `deepspeed`.
156 |
157 | ```shell
158 | OMP_NUM_THREADS=1 torchrun --standalone --nnodes=1 --nproc_per_node=8 finetune.py data/AdvertiseGen/ THUDM/glm-edge-4b-chat configs/lora.yaml # For Chat Fine-tune
159 | OMP_NUM_THREADS=1 torchrun --standalone --nnodes=1 --nproc_per_node=8 finetune_vision.py data/CogVLM-311K/ THUDM/glm-edge-v-5b configs/lora.yaml # For VQA Fine-tune
160 | ```
161 |
162 | Execute **single machine single card** run through the following code.
163 |
164 | ```shell
165 | python finetune.py data/AdvertiseGen/ THUDM/glm-edge-4b-chat configs/lora.yaml # For Chat Fine-tune
166 | python finetune_vision.py data/CogVLM-311K/ THUDM/glm-edge-v-5b configs/lora.yaml # For VQA Fine-tune
167 | ```
168 |
169 | ## Fine-tune from a saved point
170 |
171 | If you train as described above, each fine-tuning will start from the beginning. If you want to fine-tune from a
172 | half-trained model, you can add a fourth parameter, which can be passed in two ways:
173 |
174 | 1. `yes`, automatically start training from the last saved Checkpoint
175 |
176 | 2. `XX`, breakpoint number, for example `600`, start training from Checkpoint 600
177 |
178 | For example, this is an example code to continue fine-tuning from the last saved point
179 |
180 | ```shell
181 | python finetune.py data/AdvertiseGen/ THUDM/glm-edge-4b-chat configs/lora.yaml yes
182 | python finetune_vision.py data/CogVLM-311K/ THUDM/glm-edge-4b-chat configs/lora.yaml yes
183 | ```
184 |
--------------------------------------------------------------------------------
/finetune/configs/ds_zero_3.json:
--------------------------------------------------------------------------------
1 | {
2 | "train_micro_batch_size_per_gpu": "auto",
3 | "zero_allow_untested_optimizer": true,
4 | "bf16": {
5 | "enabled": "auto"
6 | },
7 | "optimizer": {
8 | "type": "AdamW",
9 | "params": {
10 | "lr": "auto",
11 | "betas": "auto",
12 | "eps": "auto",
13 | "weight_decay": "auto"
14 | }
15 | },
16 | "zero_optimization": {
17 | "stage": 3,
18 | "allgather_partitions": true,
19 | "allgather_bucket_size": 5e8,
20 | "reduce_scatter": true,
21 | "contiguous_gradients": true,
22 | "overlap_comm": true,
23 | "sub_group_size": 1e9,
24 | "reduce_bucket_size": "auto",
25 | "stage3_prefetch_bucket_size": "auto",
26 | "stage3_param_persistence_threshold": "auto",
27 | "stage3_max_live_parameters": 1e9,
28 | "stage3_max_reuse_distance": 1e9,
29 | "stage3_gather_16bit_weights_on_model_save": true
30 | }
31 | }
--------------------------------------------------------------------------------
/finetune/configs/lora.yaml:
--------------------------------------------------------------------------------
1 | data_config:
2 | train_file: train.jsonl
3 | val_file: dev.jsonl
4 | test_file: dev.jsonl
5 | num_proc: 1
6 |
7 | freezeV: False
8 | max_input_length: 2048 # For Image Must larger than 578
9 | max_output_length: 1024
10 |
11 | training_args:
12 | bf16: True
13 | # see `transformers.Seq2SeqTrainingArguments`
14 | output_dir: ./output
15 | max_steps: 3000
16 | # needed to be fit for the dataset
17 | learning_rate: 5e-4
18 | # settings for data loading
19 | per_device_train_batch_size: 16
20 | dataloader_num_workers: 16
21 | remove_unused_columns: false
22 | # settings for saving checkpoints
23 | save_strategy: steps
24 | save_steps: 500
25 | # settings for logging
26 | log_level: info
27 | logging_strategy: steps
28 | logging_steps: 10
29 | # settings for evaluation
30 | per_device_eval_batch_size: 16
31 | eval_strategy: steps
32 | eval_steps: 1000
33 | # adam_epsilon: 1e-6
34 | predict_with_generate: true
35 | generation_config:
36 | max_new_tokens: 512
37 | peft_config:
38 | peft_type: LORA
39 | task_type: CAUSAL_LM
40 | r: 8
41 | lora_alpha: 32
42 | lora_dropout: 0.1
43 | target_modules: ["q_proj", "k_proj", "v_proj"]
44 |
--------------------------------------------------------------------------------
/finetune/configs/sft.yaml:
--------------------------------------------------------------------------------
1 | data_config:
2 | train_file: train.jsonl
3 | val_file: dev.jsonl
4 | test_file: dev.jsonl
5 | num_proc: 1
6 |
7 | combine: True
8 | max_input_length: 2048 # For Image Must larger than 578
9 | max_output_length: 1024
10 |
11 | training_args:
12 | bf16: True
13 | # see `transformers.Seq2SeqTrainingArguments`
14 | output_dir: ./output
15 | max_steps: 3000
16 | # needed to be fit for the dataset
17 | learning_rate: 5e-5
18 | # settings for data loading
19 | per_device_train_batch_size: 4
20 | dataloader_num_workers: 16
21 | remove_unused_columns: false
22 | # settings for saving checkpoints
23 | save_strategy: steps
24 | save_steps: 500
25 | # settings for logging
26 | log_level: info
27 | logging_strategy: steps
28 | logging_steps: 10
29 | # settings for evaluation
30 | per_device_eval_batch_size: 16
31 | eval_strategy: steps
32 | eval_steps: 1000
33 | # settings for optimizer
34 | adam_epsilon: 1e-6
35 | predict_with_generate: true
36 | generation_config:
37 | max_new_tokens: 512
38 | # set your absolute deepspeed path here
39 | deepspeed: configs/ds_zero_3.json
40 |
--------------------------------------------------------------------------------
/finetune/finetune.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import os
3 | import jieba
4 | import dataclasses as dc
5 | import functools
6 | from collections.abc import Callable, Mapping, Sequence
7 | from pathlib import Path
8 | from typing import Annotated, Any, Union
9 | import numpy as np
10 | import ruamel.yaml as yaml
11 | import torch
12 | import typer
13 | from datasets import Dataset, Split
14 | from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
15 | from peft import PeftConfig, get_peft_config, get_peft_model
16 | from rouge_chinese import Rouge
17 | from torch import nn
18 | from transformers import (
19 | AutoModelForCausalLM,
20 | AutoTokenizer,
21 | EvalPrediction,
22 | GenerationConfig,
23 | PreTrainedTokenizer,
24 | Seq2SeqTrainingArguments,
25 | )
26 | from transformers import DataCollatorForSeq2Seq as _DataCollatorForSeq2Seq
27 | from transformers import Seq2SeqTrainer as _Seq2SeqTrainer
28 | from datasets import load_dataset, DatasetDict, NamedSplit
29 | from typing import Optional
30 |
31 | app = typer.Typer(pretty_exceptions_show_locals=False)
32 |
33 |
34 | class DataCollatorForSeq2Seq(_DataCollatorForSeq2Seq):
35 | def __call__(self, features, return_tensors=None):
36 | output_ids = [feature["output_ids"] for feature in features] if "output_ids" in features[0].keys() else None
37 | if output_ids is not None:
38 | max_output_length = max(len(out) for out in output_ids)
39 | if self.pad_to_multiple_of is not None:
40 | max_output_length = (
41 | (max_output_length + self.pad_to_multiple_of - 1)
42 | // self.pad_to_multiple_of
43 | * self.pad_to_multiple_of
44 | )
45 | for feature in features:
46 | remainder = [self.tokenizer.pad_token_id] * (max_output_length - len(feature["output_ids"]))
47 | if isinstance(feature["output_ids"], list):
48 | feature["output_ids"] = feature["output_ids"] + remainder
49 | else:
50 | feature["output_ids"] = np.concatenate([feature["output_ids"], remainder]).astype(np.int64)
51 | return super().__call__(features, return_tensors)
52 |
53 |
54 | class Seq2SeqTrainer(_Seq2SeqTrainer):
55 | def prediction_step(
56 | self,
57 | model: nn.Module,
58 | inputs: dict[str, Any],
59 | prediction_loss_only: bool,
60 | ignore_keys=None,
61 | **gen_kwargs,
62 | ) -> tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
63 | with torch.no_grad(): # Ensure no gradient computation
64 | if self.args.predict_with_generate:
65 | output_ids = inputs.pop("output_ids")
66 | input_ids = inputs["input_ids"]
67 |
68 | if "labels" in inputs:
69 | del inputs["labels"]
70 |
71 | loss, generated_tokens, labels = super().prediction_step(
72 | model, inputs, prediction_loss_only, ignore_keys, **gen_kwargs
73 | )
74 |
75 | generated_tokens = generated_tokens[:, input_ids.size()[1] :]
76 | labels = output_ids
77 |
78 | del inputs, input_ids, output_ids
79 | torch.cuda.empty_cache()
80 |
81 | return loss, generated_tokens, labels
82 |
83 |
84 | @dc.dataclass
85 | class DataConfig(object):
86 | train_file: Optional[str] = None
87 | val_file: Optional[str] = None
88 | test_file: Optional[str] = None
89 | num_proc: Optional[int] = None
90 |
91 | @property
92 | def data_format(self) -> str:
93 | return Path(self.train_file).suffix
94 |
95 | @property
96 | def data_files(self) -> dict[NamedSplit, str]:
97 | return {
98 | split: data_file
99 | for split, data_file in zip(
100 | [Split.TRAIN, Split.VALIDATION, Split.TEST],
101 | [self.train_file, self.val_file, self.test_file],
102 | )
103 | if data_file is not None
104 | }
105 |
106 |
107 | @dc.dataclass
108 | class FinetuningConfig(object):
109 | data_config: DataConfig
110 |
111 | max_input_length: int
112 | max_output_length: int
113 | freezeV: bool
114 |
115 | training_args: Seq2SeqTrainingArguments = dc.field(
116 | default_factory=lambda: Seq2SeqTrainingArguments(output_dir="./output")
117 | )
118 | peft_config: Optional[PeftConfig] = None
119 |
120 | def __post_init__(self):
121 | if not self.training_args.do_eval or self.data_config.val_file is None:
122 | self.training_args.do_eval = False
123 | self.training_args.evaluation_strategy = "no"
124 | self.data_config.val_file = None
125 | else:
126 | self.training_args.per_device_eval_batch_size = (
127 | self.training_args.per_device_eval_batch_size or self.training_args.per_device_train_batch_size
128 | )
129 |
130 | @classmethod
131 | def from_dict(cls, **kwargs) -> "FinetuningConfig":
132 | training_args = kwargs.get("training_args", None)
133 | if training_args is not None and not isinstance(training_args, Seq2SeqTrainingArguments):
134 | gen_config = training_args.get("generation_config")
135 | if not isinstance(gen_config, GenerationConfig):
136 | training_args["generation_config"] = GenerationConfig(**gen_config)
137 | kwargs["training_args"] = Seq2SeqTrainingArguments(**training_args)
138 |
139 | data_config = kwargs.get("data_config")
140 | if not isinstance(data_config, DataConfig):
141 | kwargs["data_config"] = DataConfig(**data_config)
142 |
143 | peft_config = kwargs.get("peft_config", None)
144 | if peft_config is not None and not isinstance(peft_config, PeftConfig):
145 | kwargs["peft_config"] = get_peft_config(config_dict=peft_config)
146 | return cls(**kwargs)
147 |
148 | @classmethod
149 | def from_file(cls, path: Union[str, Path]) -> "FinetuningConfig":
150 | path = Path(path)
151 | parser = yaml.YAML(typ="safe", pure=True)
152 | parser.indent(mapping=2, offset=2, sequence=4)
153 | parser.default_flow_style = False
154 | kwargs = parser.load(path)
155 | return cls.from_dict(**kwargs)
156 |
157 |
158 | def _load_datasets(
159 | data_dir: str,
160 | data_format: str,
161 | data_files: dict[NamedSplit, str],
162 | num_proc: Optional[int],
163 | ) -> DatasetDict:
164 | if data_format == ".jsonl":
165 | dataset_dct = load_dataset(
166 | data_dir,
167 | data_files=data_files,
168 | split=None,
169 | num_proc=num_proc,
170 | )
171 | else:
172 | raise NotImplementedError(f"Cannot load dataset in the '{data_format}' format.")
173 | return dataset_dct
174 |
175 |
176 | class DataManager(object):
177 | def __init__(self, data_dir: str, data_config: DataConfig):
178 | self._num_proc = data_config.num_proc
179 |
180 | self._dataset_dct = _load_datasets(
181 | data_dir,
182 | data_config.data_format,
183 | data_config.data_files,
184 | self._num_proc,
185 | )
186 |
187 | def _get_dataset(self, split: NamedSplit) -> Optional[Dataset]:
188 | return self._dataset_dct.get(split, None)
189 |
190 | def get_dataset(
191 | self,
192 | split: NamedSplit,
193 | process_fn: Callable[[dict[str, Any]], dict[str, Any]],
194 | batched: bool = True,
195 | remove_orig_columns: bool = True,
196 | ) -> Optional[Dataset]:
197 | orig_dataset = self._get_dataset(split)
198 | if orig_dataset is None:
199 | return
200 |
201 | if remove_orig_columns:
202 | remove_columns = orig_dataset.column_names
203 | else:
204 | remove_columns = None
205 | return orig_dataset.map(
206 | process_fn,
207 | batched=batched,
208 | remove_columns=remove_columns,
209 | num_proc=self._num_proc,
210 | )
211 |
212 |
213 | def process_message(message):
214 | if "tools" in message and message["role"] == "system":
215 | for tool in message["tools"]:
216 | parameters = tool["function"]["parameters"]["properties"]
217 | tool["function"]["parameters"]["properties"] = {k: v for k, v in parameters.items() if v is not None}
218 | elif "tools" in message:
219 | del message["tools"]
220 | return message
221 |
222 |
223 | def process_batch(
224 | batch: Mapping[str, Sequence],
225 | tokenizer: PreTrainedTokenizer,
226 | max_input_length: int,
227 | max_output_length: int,
228 | ) -> dict[str, list]:
229 | batched_conv = batch["messages"]
230 | batched_input_ids = []
231 | batched_labels = []
232 | for conv in batched_conv:
233 | new_input_ids = tokenizer.apply_chat_template(
234 | conv, tokenize=True, return_dict=False, add_generation_prompt=False
235 | )
236 | input_ids = new_input_ids
237 | loss_masks = [False] * len(input_ids)
238 | last_assistant_index = len(input_ids) - input_ids[::-1].index(59254) - 1 # <|assistant|>
239 | for j in range(last_assistant_index + 1, len(input_ids)):
240 | loss_masks[j] = True
241 |
242 | input_ids.append(59253) # EOS for chat <|user|>
243 | loss_masks = [False, *loss_masks]
244 | labels = []
245 | for input_id, mask in zip(input_ids, loss_masks):
246 | if mask:
247 | labels.append(input_id)
248 | else:
249 | labels.append(-100)
250 | max_length = max_input_length + max_output_length + 1
251 | batched_input_ids.append(input_ids[:max_length])
252 | batched_labels.append(labels[:max_length])
253 |
254 | del batched_conv, conv, input_ids, loss_masks, new_input_ids, labels
255 | torch.cuda.empty_cache()
256 |
257 | return {"input_ids": batched_input_ids, "labels": batched_labels}
258 |
259 |
260 | def process_batch_eval(
261 | batch: Mapping[str, Sequence],
262 | tokenizer: PreTrainedTokenizer,
263 | max_input_length: int,
264 | max_output_length: int,
265 | ) -> dict[str, list]:
266 | batched_conv = batch["messages"]
267 | batched_input_ids = []
268 | batched_output_ids = []
269 |
270 | for conv in batched_conv:
271 | new_input_ids = tokenizer.apply_chat_template(conv, tokenize=True, return_dict=False)
272 | input_ids = new_input_ids
273 | last_assistant_index = len(input_ids) - input_ids[::-1].index(59254) - 1
274 | output_prompt, output_ids = (
275 | input_ids[:1],
276 | input_ids[last_assistant_index:],
277 | )
278 | output_ids.append(59253)
279 | batched_input_ids.append(input_ids[:max_input_length] + output_prompt[:1])
280 | batched_output_ids.append(output_ids[:max_output_length])
281 |
282 | del batched_conv, conv, input_ids, new_input_ids, output_prompt, output_ids
283 | torch.cuda.empty_cache()
284 |
285 | return {"input_ids": batched_input_ids, "output_ids": batched_output_ids}
286 |
287 |
288 | def load_tokenizer_and_model(
289 | model_dir: str,
290 | peft_config: Optional[PeftConfig] = None,
291 | ):
292 | tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True, padding_side="left")
293 |
294 | model = AutoModelForCausalLM.from_pretrained(
295 | model_dir,
296 | trust_remote_code=True,
297 | use_cache=False,
298 | torch_dtype=torch.bfloat16, # Must use BFloat 16
299 | )
300 |
301 | if peft_config is not None:
302 | model = get_peft_model(model, peft_config)
303 | model.print_trainable_parameters()
304 |
305 | return tokenizer, model
306 |
307 |
308 | def compute_metrics(eval_preds: EvalPrediction, tokenizer):
309 | batched_pred_ids, batched_label_ids = eval_preds
310 | batched_pred_ids[batched_pred_ids == -100] = tokenizer.pad_token_id
311 | batched_label_ids[batched_label_ids == -100] = tokenizer.pad_token_id
312 | metrics_dct = {"rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []}
313 | for pred_ids, label_ids in zip(batched_pred_ids, batched_label_ids):
314 | pred_txt = tokenizer.decode(pred_ids, skip_special_tokens=True).strip()
315 | label_txt = tokenizer.decode(label_ids, skip_special_tokens=True).strip()
316 | pred_tokens = list(jieba.cut(pred_txt))
317 | label_tokens = list(jieba.cut(label_txt))
318 | rouge = Rouge()
319 | scores = rouge.get_scores(" ".join(pred_tokens), " ".join(label_tokens))
320 | for k, v in scores[0].items():
321 | metrics_dct[k].append(round(v["f"] * 100, 4))
322 | metrics_dct["bleu-4"].append(
323 | sentence_bleu([label_tokens], pred_tokens, smoothing_function=SmoothingFunction().method3)
324 | )
325 | return {k: np.mean(v) for k, v in metrics_dct.items()}
326 |
327 |
328 | @app.command()
329 | def main(
330 | data_dir: Annotated[str, typer.Argument(help="")],
331 | model_dir: Annotated[
332 | str,
333 | typer.Argument(
334 | help="A string that specifies the model id of a pretrained model configuration hosted on huggingface.co, or a path to a directory containing a model configuration file."
335 | ),
336 | ],
337 | config_file: Annotated[str, typer.Argument(help="")],
338 | auto_resume_from_checkpoint: str = typer.Argument(
339 | default="",
340 | help="If entered as yes, automatically use the latest save checkpoint. If it is a numerical example 12 15, use the corresponding save checkpoint. If the input is no, restart training",
341 | ),
342 | ):
343 | ft_config = FinetuningConfig.from_file(config_file)
344 | tokenizer, model = load_tokenizer_and_model(model_dir, peft_config=ft_config.peft_config)
345 | data_manager = DataManager(data_dir, ft_config.data_config)
346 |
347 | train_dataset = data_manager.get_dataset(
348 | Split.TRAIN,
349 | functools.partial(
350 | process_batch,
351 | tokenizer=tokenizer,
352 | max_input_length=ft_config.max_input_length,
353 | max_output_length=ft_config.max_output_length,
354 | ),
355 | batched=True,
356 | )
357 |
358 | val_dataset = data_manager.get_dataset(
359 | Split.VALIDATION,
360 | functools.partial(
361 | process_batch_eval,
362 | tokenizer=tokenizer,
363 | max_input_length=ft_config.max_input_length,
364 | max_output_length=ft_config.max_output_length,
365 | ),
366 | batched=True,
367 | )
368 |
369 | test_dataset = data_manager.get_dataset(
370 | Split.TEST,
371 | functools.partial(
372 | process_batch_eval,
373 | tokenizer=tokenizer,
374 | max_input_length=ft_config.max_input_length,
375 | max_output_length=ft_config.max_output_length,
376 | ),
377 | batched=True,
378 | )
379 |
380 | trainer = Seq2SeqTrainer(
381 | model=model,
382 | args=ft_config.training_args,
383 | data_collator=DataCollatorForSeq2Seq(
384 | tokenizer=tokenizer,
385 | padding="longest",
386 | return_tensors="pt",
387 | ),
388 | train_dataset=train_dataset,
389 | eval_dataset=val_dataset,
390 | compute_metrics=functools.partial(compute_metrics, tokenizer=tokenizer),
391 | )
392 |
393 | if auto_resume_from_checkpoint.upper() == "" or auto_resume_from_checkpoint is None:
394 | trainer.train()
395 | else:
396 | output_dir = ft_config.training_args.output_dir
397 | dirlist = os.listdir(output_dir)
398 | checkpoint_sn = 0
399 | for checkpoint_str in dirlist:
400 | if checkpoint_str.find("eckpoint") > 0 and checkpoint_str.find("tmp") == -1:
401 | checkpoint = int(checkpoint_str.replace("checkpoint-", ""))
402 | if checkpoint > checkpoint_sn:
403 | checkpoint_sn = checkpoint
404 | if auto_resume_from_checkpoint.upper() == "YES":
405 | if checkpoint_sn > 0:
406 | model.gradient_checkpointing_enable()
407 | model.enable_input_require_grads()
408 | checkpoint_directory = os.path.join(output_dir, "checkpoint-" + str(checkpoint_sn))
409 | print("resume checkpoint from checkpoint-" + str(checkpoint_sn))
410 | trainer.train(resume_from_checkpoint=checkpoint_directory)
411 | else:
412 | trainer.train()
413 | else:
414 | if auto_resume_from_checkpoint.isdigit():
415 | if int(auto_resume_from_checkpoint) > 0:
416 | checkpoint_sn = int(auto_resume_from_checkpoint)
417 | model.gradient_checkpointing_enable()
418 | model.enable_input_require_grads()
419 | checkpoint_directory = os.path.join(output_dir, "checkpoint-" + str(checkpoint_sn))
420 | print("resume checkpoint from checkpoint-" + str(checkpoint_sn))
421 | trainer.train(resume_from_checkpoint=checkpoint_directory)
422 | else:
423 | print(
424 | auto_resume_from_checkpoint,
425 | "The specified checkpoint sn("
426 | + auto_resume_from_checkpoint
427 | + ") has not been saved. Please search for the correct checkpoint in the model output directory",
428 | )
429 |
430 | if test_dataset is not None:
431 | trainer.predict(test_dataset)
432 |
433 |
434 | if __name__ == "__main__":
435 | app()
436 |
--------------------------------------------------------------------------------
/finetune/finetune_vision.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import os
3 |
4 | os.environ["WANDB_DISABLED"] = "true"
5 | import jieba
6 | import dataclasses as dc
7 | import functools
8 | from collections.abc import Callable, Mapping, Sequence
9 | from pathlib import Path
10 | from typing import Annotated, Any, Union
11 | import numpy as np
12 | import ruamel.yaml as yaml
13 | import torch
14 | import typer
15 | from datasets import Dataset, Split
16 | from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
17 | from peft import PeftConfig, get_peft_config, get_peft_model
18 | from rouge_chinese import Rouge
19 | from torch import nn
20 | from transformers import (
21 | AutoModelForCausalLM,
22 | AutoImageProcessor,
23 | AutoTokenizer,
24 | EvalPrediction,
25 | GenerationConfig,
26 | PreTrainedTokenizer,
27 | Seq2SeqTrainingArguments,
28 | )
29 | from transformers import DataCollatorForSeq2Seq as _DataCollatorForSeq2Seq
30 | from transformers import Seq2SeqTrainer as _Seq2SeqTrainer
31 | from datasets import load_dataset, DatasetDict, NamedSplit
32 | from typing import Optional
33 | from PIL import Image
34 |
35 | app = typer.Typer(pretty_exceptions_show_locals=False)
36 |
37 |
38 | class DataCollatorForSeq2Seq(_DataCollatorForSeq2Seq):
39 | def __call__(self, features, return_tensors=None):
40 | output_ids = [feature["output_ids"] for feature in features] if "output_ids" in features[0].keys() else None
41 | if output_ids is not None:
42 | max_output_length = max(len(out) for out in output_ids)
43 | if self.pad_to_multiple_of is not None:
44 | max_output_length = (
45 | (max_output_length + self.pad_to_multiple_of - 1)
46 | // self.pad_to_multiple_of
47 | * self.pad_to_multiple_of
48 | )
49 | for feature in features:
50 | remainder = [self.tokenizer.pad_token_id] * (max_output_length - len(feature["output_ids"]))
51 | if isinstance(feature["output_ids"], list):
52 | feature["output_ids"] = feature["output_ids"] + remainder
53 | else:
54 | feature["output_ids"] = np.concatenate([feature["output_ids"], remainder]).astype(np.int64)
55 | return super().__call__(features, return_tensors)
56 |
57 |
58 | class Seq2SeqTrainer(_Seq2SeqTrainer):
59 | def prediction_step(
60 | self,
61 | model: nn.Module,
62 | inputs: dict,
63 | prediction_loss_only: bool,
64 | ignore_keys=None,
65 | **gen_kwargs,
66 | ) -> tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
67 | with torch.no_grad():
68 | if self.args.predict_with_generate:
69 | output_ids = inputs.pop("output_ids", None)
70 |
71 | if "labels" in inputs:
72 | del inputs["labels"]
73 |
74 | loss, generated_tokens, labels = super().prediction_step(
75 | model=model,
76 | inputs=inputs,
77 | prediction_loss_only=prediction_loss_only,
78 | ignore_keys=ignore_keys,
79 | **gen_kwargs,
80 | )
81 |
82 | if generated_tokens is not None:
83 | generated_tokens = generated_tokens[:, inputs["input_ids"].size()[1] :]
84 |
85 | if self.args.predict_with_generate:
86 | labels = output_ids
87 |
88 | del inputs, output_ids
89 | torch.cuda.empty_cache()
90 |
91 | return loss, generated_tokens, labels
92 |
93 |
94 | @dc.dataclass
95 | class DataConfig(object):
96 | train_file: Optional[str] = None
97 | val_file: Optional[str] = None
98 | test_file: Optional[str] = None
99 | num_proc: Optional[int] = None
100 |
101 | @property
102 | def data_format(self) -> str:
103 | return Path(self.train_file).suffix
104 |
105 | @property
106 | def data_files(self) -> dict[NamedSplit, str]:
107 | return {
108 | split: data_file
109 | for split, data_file in zip(
110 | [Split.TRAIN, Split.VALIDATION, Split.TEST],
111 | [self.train_file, self.val_file, self.test_file],
112 | )
113 | if data_file is not None
114 | }
115 |
116 |
117 | @dc.dataclass
118 | class FinetuningConfig(object):
119 | data_config: DataConfig
120 |
121 | max_input_length: int
122 | max_output_length: int
123 | freezeV: bool
124 |
125 | training_args: Seq2SeqTrainingArguments = dc.field(
126 | default_factory=lambda: Seq2SeqTrainingArguments(output_dir="./output")
127 | )
128 | peft_config: Optional[PeftConfig] = None
129 |
130 | def __post_init__(self):
131 | if not self.training_args.do_eval or self.data_config.val_file is None:
132 | self.training_args.do_eval = False
133 | self.training_args.evaluation_strategy = "no"
134 | self.data_config.val_file = None
135 | else:
136 | self.training_args.per_device_eval_batch_size = (
137 | self.training_args.per_device_eval_batch_size or self.training_args.per_device_train_batch_size
138 | )
139 |
140 | @classmethod
141 | def from_dict(cls, **kwargs) -> "FinetuningConfig":
142 | training_args = kwargs.get("training_args", None)
143 | if training_args is not None and not isinstance(training_args, Seq2SeqTrainingArguments):
144 | gen_config = training_args.get("generation_config")
145 | if not isinstance(gen_config, GenerationConfig):
146 | training_args["generation_config"] = GenerationConfig(**gen_config)
147 | kwargs["training_args"] = Seq2SeqTrainingArguments(**training_args)
148 |
149 | data_config = kwargs.get("data_config")
150 | if not isinstance(data_config, DataConfig):
151 | kwargs["data_config"] = DataConfig(**data_config)
152 |
153 | peft_config = kwargs.get("peft_config", None)
154 | if peft_config is not None and not isinstance(peft_config, PeftConfig):
155 | kwargs["peft_config"] = get_peft_config(config_dict=peft_config)
156 | return cls(**kwargs)
157 |
158 | @classmethod
159 | def from_file(cls, path: Union[str, Path]) -> "FinetuningConfig":
160 | path = Path(path)
161 | parser = yaml.YAML(typ="safe", pure=True)
162 | parser.indent(mapping=2, offset=2, sequence=4)
163 | parser.default_flow_style = False
164 | kwargs = parser.load(path)
165 | return cls.from_dict(**kwargs)
166 |
167 |
168 | def _load_datasets(
169 | data_dir: str,
170 | data_format: str,
171 | data_files: dict[NamedSplit, str],
172 | num_proc: Optional[int],
173 | ) -> DatasetDict:
174 | if data_format == ".jsonl":
175 | dataset_dct = load_dataset(
176 | data_dir,
177 | data_files=data_files,
178 | split=None,
179 | num_proc=num_proc,
180 | )
181 | else:
182 | raise NotImplementedError(f"Cannot load dataset in the '{data_format}' format.")
183 | return dataset_dct
184 |
185 |
186 | class DataManager(object):
187 | def __init__(self, data_dir: str, data_config: DataConfig):
188 | self._num_proc = data_config.num_proc
189 |
190 | self._dataset_dct = _load_datasets(
191 | data_dir,
192 | data_config.data_format,
193 | data_config.data_files,
194 | self._num_proc,
195 | )
196 |
197 | def _get_dataset(self, split: NamedSplit) -> Optional[Dataset]:
198 | return self._dataset_dct.get(split, None)
199 |
200 | def get_dataset(
201 | self,
202 | split: NamedSplit,
203 | process_fn: Callable[[dict[str, Any]], dict[str, Any]],
204 | batched: bool = True,
205 | remove_orig_columns: bool = True,
206 | ) -> Optional[Dataset]:
207 | orig_dataset = self._get_dataset(split)
208 | if orig_dataset is None:
209 | return
210 | if remove_orig_columns:
211 | remove_columns = orig_dataset.column_names
212 | else:
213 | remove_columns = None
214 | return orig_dataset.map(
215 | process_fn,
216 | batched=batched,
217 | remove_columns=remove_columns,
218 | num_proc=self._num_proc,
219 | # This is default params of orig_dataset.map, and you can change it smaller
220 | # https://github.com/THUDM/GLM-4/issues/277
221 | writer_batch_size=1000,
222 | batch_size=1000,
223 | )
224 |
225 |
226 | def process_batch(
227 | batch: Mapping[str, Sequence],
228 | tokenizer: PreTrainedTokenizer,
229 | processor,
230 | max_input_length: int,
231 | max_output_length: int,
232 | ) -> dict[str, list]:
233 | batched_conv = batch["messages"]
234 | batched_input_ids = []
235 | batched_attention_mask = []
236 | batched_position_ids = []
237 | batched_labels = []
238 | batched_images = []
239 |
240 | max_length = max_input_length + max_output_length
241 |
242 | for conv in batched_conv:
243 | input_ids = []
244 | attention_mask = []
245 | position_ids = []
246 | loss_masks = []
247 | pixel_values = []
248 |
249 | if conv[0]["content"][0].get("image"):
250 | image = Image.open(conv[0]["content"][0]["image"])
251 | pixel_values.append(torch.tensor(processor(image).pixel_values))
252 |
253 | for message in conv:
254 | loss_mask_val = False if message["role"] in ("system", "user") else True
255 | new_input_ids_all = tokenizer.apply_chat_template(
256 | [message],
257 | add_generation_prompt=False,
258 | tokenize=True,
259 | return_dict=True,
260 | return_tensors="pt",
261 | )
262 | new_input_ids = new_input_ids_all["input_ids"][0].tolist()
263 | new_attention_mask = new_input_ids_all["attention_mask"][0].tolist()
264 | new_position_ids = list(range(len(position_ids), len(position_ids) + len(new_input_ids)))
265 |
266 | new_loss_masks = [loss_mask_val] * len(new_input_ids)
267 | input_ids += new_input_ids
268 | attention_mask += new_attention_mask
269 | position_ids += new_position_ids
270 | loss_masks += new_loss_masks
271 |
272 | input_ids.append(59253) # EOS
273 | attention_mask.append(1)
274 | position_ids.append(len(position_ids))
275 | loss_masks.append(True)
276 |
277 | padding_length = max(0, max_length - len(input_ids))
278 |
279 | # Left padding with batch
280 | input_ids = [tokenizer.pad_token_id] * padding_length + input_ids[-max_length:]
281 | attention_mask = [0] * padding_length + attention_mask[-max_length:]
282 | position_ids = [0] * padding_length + position_ids[-max_length:]
283 | loss_masks = [False] * padding_length + loss_masks[-max_length:]
284 |
285 | labels = []
286 | for input_id, mask in zip(input_ids, loss_masks):
287 | if mask:
288 | labels.append(input_id)
289 | else:
290 | labels.append(-100)
291 |
292 | batched_input_ids.append(input_ids[:max_length])
293 | batched_attention_mask.append(attention_mask[:max_length])
294 | batched_position_ids.append(position_ids[:max_length])
295 | batched_labels.append(labels[:max_length])
296 | if len(pixel_values) > 0:
297 | batched_images.append(pixel_values[0][0])
298 | else:
299 | batched_images.append(torch.zeros([1, 1, 3, 672, 672]))
300 |
301 | del (
302 | batched_conv,
303 | conv,
304 | input_ids,
305 | attention_mask,
306 | position_ids,
307 | loss_masks,
308 | message,
309 | new_input_ids,
310 | new_loss_masks,
311 | labels,
312 | input_id,
313 | mask,
314 | )
315 | torch.cuda.empty_cache()
316 |
317 | return {
318 | "input_ids": batched_input_ids,
319 | "attention_mask": batched_attention_mask,
320 | "position_ids": batched_position_ids,
321 | "labels": batched_labels,
322 | "pixel_values": batched_images,
323 | }
324 |
325 |
326 | def process_batch_eval(
327 | batch: Mapping[str, Sequence],
328 | tokenizer: PreTrainedTokenizer,
329 | processor,
330 | max_input_length: int,
331 | max_output_length: int,
332 | ) -> dict[str, list]:
333 | batched_conv = batch["messages"]
334 | batched_input_ids = []
335 | batched_attention_mask = []
336 | batched_position_ids = []
337 | batched_output_ids = []
338 | batched_images = []
339 |
340 | for conv in batched_conv:
341 | if conv[0]["content"][0].get("image"):
342 | image = Image.open(conv[0]["content"][0]["image"])
343 |
344 | new_input_ids_all = tokenizer.apply_chat_template(
345 | conv,
346 | add_generation_prompt=False,
347 | tokenize=True,
348 | padding=True,
349 | return_dict=True,
350 | return_tensors="pt",
351 | )
352 |
353 | input_ids = new_input_ids_all["input_ids"][0].tolist()
354 | attention_mask = new_input_ids_all["attention_mask"][0].tolist()
355 | position_ids = list(range(len(input_ids)))
356 |
357 | dialogue_parts = [0]
358 | user_idx = []
359 | for idx, token_id in enumerate(input_ids):
360 | if token_id == 59254:
361 | dialogue_parts.append(idx + 1)
362 | elif token_id == 59253:
363 | user_idx.append(idx)
364 |
365 | if user_idx[-1] != len(input_ids):
366 | user_idx.append(len(input_ids))
367 |
368 | # Split the conversation into multiple dialogue segments
369 | for end_idx in range(1, len(dialogue_parts)):
370 | input_segment = input_ids[: dialogue_parts[end_idx]]
371 | attention_segment = attention_mask[: dialogue_parts[end_idx]]
372 | position_segment = position_ids[: dialogue_parts[end_idx]]
373 | output_segment = input_ids[dialogue_parts[end_idx] : user_idx[end_idx]]
374 | output_segment.append(59253) # Add EOS token
375 |
376 | # Left Padding
377 | padding_length = max(0, max_input_length - len(input_segment))
378 | input_segment = [tokenizer.pad_token_id] * padding_length + input_segment[:max_input_length]
379 | attention_segment = [0] * padding_length + attention_segment[:max_input_length]
380 | position_segment = [0] * padding_length + position_segment[:max_input_length]
381 | output_segment = [tokenizer.pad_token_id] * padding_length + output_segment[:max_output_length]
382 |
383 | batched_input_ids.append(input_segment[:max_input_length])
384 | batched_attention_mask.append(attention_segment[:max_input_length])
385 | batched_position_ids.append(position_segment[:max_input_length])
386 | batched_output_ids.append(output_segment[:max_output_length])
387 | if conv[0]["content"][0].get("image"):
388 | batched_images.append(torch.tensor(processor(image).pixel_values)[0])
389 | else:
390 | batched_images.append(torch.zeros([1, 1, 3, 672, 672]))
391 |
392 | del batched_conv, input_ids, attention_mask, position_ids, new_input_ids_all, output_segment
393 | torch.cuda.empty_cache()
394 |
395 | return {
396 | "input_ids": batched_input_ids,
397 | "attention_mask": batched_attention_mask,
398 | "position_ids": batched_position_ids,
399 | "output_ids": batched_output_ids,
400 | "pixel_values": batched_images,
401 | }
402 |
403 |
404 | def load_tokenizer_and_model(
405 | model_dir: str,
406 | peft_config: Optional[PeftConfig] = None,
407 | ):
408 | tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True, padding_side="left")
409 | processor = AutoImageProcessor.from_pretrained(model_dir, trust_remote_code=True, dtype=torch.bfloat16)
410 | if peft_config is not None:
411 | model = AutoModelForCausalLM.from_pretrained(model_dir, trust_remote_code=True, torch_dtype=torch.bfloat16)
412 | model = get_peft_model(model, peft_config)
413 | model.print_trainable_parameters()
414 | else:
415 | model = AutoModelForCausalLM.from_pretrained(model_dir, trust_remote_code=True, torch_dtype=torch.bfloat16)
416 | return tokenizer, model, processor
417 |
418 |
419 | def compute_metrics(eval_preds: EvalPrediction, tokenizer):
420 | batched_pred_ids, batched_label_ids = eval_preds
421 | batched_pred_ids[batched_pred_ids == -100] = tokenizer.pad_token_id
422 | batched_label_ids[batched_label_ids == -100] = tokenizer.pad_token_id
423 | metrics_dct = {"rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []}
424 | for pred_ids, label_ids in zip(batched_pred_ids, batched_label_ids):
425 | pred_txt = tokenizer.decode(pred_ids, skip_special_tokens=True).strip()
426 | label_txt = tokenizer.decode(label_ids, skip_special_tokens=True).strip()
427 | pred_tokens = list(jieba.cut(pred_txt))
428 | label_tokens = list(jieba.cut(label_txt))
429 | rouge = Rouge()
430 | scores = rouge.get_scores(" ".join(pred_tokens), " ".join(label_tokens))
431 | for k, v in scores[0].items():
432 | metrics_dct[k].append(round(v["f"] * 100, 4))
433 | metrics_dct["bleu-4"].append(
434 | sentence_bleu([label_tokens], pred_tokens, smoothing_function=SmoothingFunction().method3)
435 | )
436 | return {k: np.mean(v) for k, v in metrics_dct.items()}
437 |
438 |
439 | @app.command()
440 | def main(
441 | data_dir: Annotated[str, typer.Argument(help="")],
442 | model_dir: Annotated[
443 | str,
444 | typer.Argument(
445 | help="A string that specifies the model id of a pretrained model configuration hosted on huggingface.co, or a path to a directory containing a model configuration file."
446 | ),
447 | ],
448 | config_file: Annotated[str, typer.Argument(help="")],
449 | auto_resume_from_checkpoint: str = typer.Argument(
450 | default="",
451 | help="If entered as yes, automatically use the latest save checkpoint. If it is a numerical example 12 15, use the corresponding save checkpoint. If the input is no, restart training",
452 | ),
453 | ):
454 | ft_config = FinetuningConfig.from_file(config_file)
455 | tokenizer, model, processor = load_tokenizer_and_model(model_dir, peft_config=ft_config.peft_config)
456 |
457 | if ft_config.freezeV:
458 | for param in model.base_model.model.model.vision.parameters():
459 | param.requires_grad = False
460 | data_manager = DataManager(data_dir, ft_config.data_config)
461 |
462 | train_dataset = data_manager.get_dataset(
463 | Split.TRAIN,
464 | functools.partial(
465 | process_batch,
466 | tokenizer=tokenizer,
467 | processor=processor,
468 | max_input_length=ft_config.max_input_length,
469 | max_output_length=ft_config.max_output_length,
470 | ),
471 | batched=True,
472 | )
473 |
474 | val_dataset = data_manager.get_dataset(
475 | Split.VALIDATION,
476 | functools.partial(
477 | process_batch_eval,
478 | tokenizer=tokenizer,
479 | processor=processor,
480 | max_input_length=ft_config.max_input_length,
481 | max_output_length=ft_config.max_output_length,
482 | ),
483 | batched=True,
484 | )
485 |
486 | test_dataset = data_manager.get_dataset(
487 | Split.TEST,
488 | functools.partial(
489 | process_batch_eval,
490 | tokenizer=tokenizer,
491 | processor=processor,
492 | max_input_length=ft_config.max_input_length,
493 | max_output_length=ft_config.max_output_length,
494 | ),
495 | batched=True,
496 | )
497 |
498 | model.gradient_checkpointing_enable()
499 | model.enable_input_require_grads()
500 |
501 | trainer = Seq2SeqTrainer(
502 | model=model,
503 | args=ft_config.training_args,
504 | data_collator=DataCollatorForSeq2Seq(
505 | tokenizer=tokenizer,
506 | padding="longest",
507 | return_tensors="pt",
508 | ),
509 | train_dataset=train_dataset,
510 | eval_dataset=val_dataset,
511 | compute_metrics=functools.partial(compute_metrics, tokenizer=tokenizer),
512 | )
513 |
514 | if auto_resume_from_checkpoint.upper() == "" or auto_resume_from_checkpoint is None:
515 | trainer.train()
516 | else:
517 | output_dir = ft_config.training_args.output_dir
518 | dirlist = os.listdir(output_dir)
519 | checkpoint_sn = 0
520 | for checkpoint_str in dirlist:
521 | if checkpoint_str.find("eckpoint") > 0 and checkpoint_str.find("tmp") == -1:
522 | checkpoint = int(checkpoint_str.replace("checkpoint-", ""))
523 | if checkpoint > checkpoint_sn:
524 | checkpoint_sn = checkpoint
525 | if auto_resume_from_checkpoint.upper() == "YES":
526 | if checkpoint_sn > 0:
527 | model.gradient_checkpointing_enable()
528 | model.enable_input_require_grads()
529 | checkpoint_directory = os.path.join(output_dir, "checkpoint-" + str(checkpoint_sn))
530 | print("resume checkpoint from checkpoint-" + str(checkpoint_sn))
531 | trainer.train(resume_from_checkpoint=checkpoint_directory)
532 | else:
533 | trainer.train()
534 | else:
535 | if auto_resume_from_checkpoint.isdigit():
536 | if int(auto_resume_from_checkpoint) > 0:
537 | checkpoint_sn = int(auto_resume_from_checkpoint)
538 | model.gradient_checkpointing_enable()
539 | model.enable_input_require_grads()
540 | checkpoint_directory = os.path.join(output_dir, "checkpoint-" + str(checkpoint_sn))
541 | print("resume checkpoint from checkpoint-" + str(checkpoint_sn))
542 | trainer.train(resume_from_checkpoint=checkpoint_directory)
543 | else:
544 | print(
545 | auto_resume_from_checkpoint,
546 | "The specified checkpoint sn("
547 | + auto_resume_from_checkpoint
548 | + ") has not been saved. Please search for the correct checkpoint in the model output directory",
549 | )
550 |
551 | if test_dataset is not None:
552 | trainer.predict(test_dataset)
553 |
554 |
555 | if __name__ == "__main__":
556 | app()
557 |
--------------------------------------------------------------------------------
/finetune/vision_dataset.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THUDM/GLM-Edge/7a93ea5047e91cf27b74d1e64ad490f75d329b17/finetune/vision_dataset.zip
--------------------------------------------------------------------------------
/inference/cli_demo.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import time
3 | import asyncio
4 | from threading import Thread
5 | from transformers import (
6 | AutoTokenizer,
7 | AutoModelForCausalLM,
8 | TextIteratorStreamer,
9 | BitsAndBytesConfig,
10 | )
11 | from vllm import SamplingParams, AsyncEngineArgs, AsyncLLMEngine
12 | from vllm.lora.request import LoRARequest
13 | from optimum.intel.openvino import OVModelForCausalLM
14 | import torch
15 |
16 |
17 | # Load Model and Tokenizer for VLLM
18 | def load_vllm_model_and_tokenizer(model_dir: str, lora_path: str, precision: str):
19 | enable_lora = bool(lora_path)
20 | tokenizer = AutoTokenizer.from_pretrained(model_dir)
21 | engine_args = AsyncEngineArgs(
22 | model=model_dir,
23 | tokenizer=model_dir,
24 | enable_lora=enable_lora,
25 | tensor_parallel_size=1,
26 | dtype="bfloat16" if precision == "bfloat16" else "float16",
27 | gpu_memory_utilization=0.9,
28 | enforce_eager=True,
29 | worker_use_ray=True,
30 | disable_log_requests=True,
31 | )
32 | engine = AsyncLLMEngine.from_engine_args(engine_args)
33 | return engine, tokenizer, enable_lora
34 |
35 |
36 | async def vllm_gen(engine, tokenizer, lora_path, enable_lora, messages, top_p, temperature, max_length):
37 | inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
38 | sampling_params = SamplingParams(
39 | n=1,
40 | best_of=1,
41 | presence_penalty=1.0,
42 | frequency_penalty=0.0,
43 | temperature=temperature,
44 | top_p=top_p,
45 | max_tokens=max_length,
46 | )
47 | if enable_lora:
48 | async for output in engine.generate(
49 | inputs=inputs,
50 | sampling_params=sampling_params,
51 | request_id=f"{time.time()}",
52 | lora_request=LoRARequest("GLM-Edge-lora", 1, lora_path=lora_path),
53 | ):
54 | yield output.outputs[0].text
55 | else:
56 | async for output in engine.generate(
57 | prompt=inputs, sampling_params=sampling_params, request_id=f"{time.time()}"
58 | ):
59 | yield output.outputs[0].text
60 |
61 |
62 | # CLI Chat Function for Transformers and OpenVINO
63 | def generic_chat(tokenizer, model, temperature, top_p, max_length, backend="transformers"):
64 | history = []
65 | backend_label = "OpenVINO" if backend == "ov" else "Transformers"
66 | print(f"Welcome to the GLM-Edge CLI chat ({backend_label}). Type your messages below.")
67 | while True:
68 | user_input = input("\nYou: ")
69 | if user_input.lower() in ["exit", "quit"]:
70 | break
71 | history.append([user_input, ""])
72 |
73 | messages = [{"role": "user", "content": user_input}]
74 | model_inputs = tokenizer.apply_chat_template(
75 | messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt"
76 | )
77 | model_inputs = {k: v.to("cpu") for k, v in model_inputs.items()} # Ensure CPU for OpenVINO
78 |
79 | streamer = TextIteratorStreamer(tokenizer=tokenizer, timeout=60, skip_prompt=True, skip_special_tokens=True)
80 | generate_kwargs = {
81 | "input_ids": model_inputs["input_ids"],
82 | "attention_mask": model_inputs["attention_mask"],
83 | "streamer": streamer,
84 | "max_new_tokens": max_length,
85 | "do_sample": True,
86 | "top_p": top_p,
87 | "temperature": temperature,
88 | "repetition_penalty": 1.2,
89 | "eos_token_id": tokenizer.encode("<|user|>"),
90 | }
91 | t = Thread(target=model.generate, kwargs=generate_kwargs)
92 | t.start()
93 |
94 | print("GLM-Edge:", end="", flush=True)
95 | for new_token in streamer:
96 | print(new_token, end="", flush=True)
97 | history[-1][1] += new_token
98 |
99 | history[-1][1] = history[-1][1].strip()
100 |
101 |
102 | # Main Async Chat Function for VLLM
103 | async def vllm_chat(engine, tokenizer, lora_path, enable_lora, temperature, top_p, max_length):
104 | history = []
105 | print("Welcome to the GLM-Edge DEMO chat (VLLM). Type your messages below.")
106 | while True:
107 | user_input = input("\nYou: ")
108 | if user_input.lower() in ["exit", "quit"]:
109 | break
110 | history.append([user_input, ""])
111 |
112 | messages = [{"role": "user", "content": user_input}]
113 | print("\nGLM-Edge: ", end="")
114 | current_length = 0
115 | output = ""
116 | async for output in vllm_gen(
117 | engine, tokenizer, lora_path, enable_lora, messages, top_p, temperature, max_length
118 | ):
119 | print(output[current_length:], end="", flush=True)
120 | current_length = len(output)
121 | history[-1][1] = output
122 |
123 |
124 | def main():
125 | parser = argparse.ArgumentParser(description="Run GLM-Edge DEMO Chat with VLLM, Transformers, or OpenVINO backend")
126 | parser.add_argument(
127 | "--backend",
128 | type=str,
129 | choices=["vllm", "transformers", "ov"],
130 | required=True,
131 | help="Choose inference backend: vllm, transformers, or OpenVINO",
132 | )
133 | parser.add_argument("--model_path", type=str, required=True, help="Path to the model")
134 | parser.add_argument("--lora_path", type=str, default=None, help="Path to LoRA (leave empty to skip)")
135 | parser.add_argument(
136 | "--precision", type=str, default="bfloat16", choices=["float16", "bfloat16", "int4"], help="Model precision"
137 | )
138 | parser.add_argument("--temperature", type=float, default=0.6, help="Temperature for sampling")
139 | parser.add_argument("--top_p", type=float, default=0.8, help="Top-p (nucleus) sampling probability")
140 | parser.add_argument("--max_length", type=int, default=8192, help="Maximum token length for generation")
141 | args = parser.parse_args()
142 |
143 | if args.backend == "vllm":
144 | engine, tokenizer, enable_lora = load_vllm_model_and_tokenizer(args.model_path, args.lora_path, args.precision)
145 | asyncio.run(
146 | vllm_chat(engine, tokenizer, args.lora_path, enable_lora, args.temperature, args.top_p, args.max_length)
147 | )
148 | elif args.backend == "ov":
149 | tokenizer = AutoTokenizer.from_pretrained(args.model_path)
150 | model = OVModelForCausalLM.from_pretrained(args.model_path, device="CPU") # CPU,GPU and XPU are supported
151 | generic_chat(tokenizer, model, args.temperature, args.top_p, args.max_length, backend="ov")
152 | else:
153 | tokenizer = AutoTokenizer.from_pretrained(args.model_path)
154 | if args.precision == "int4":
155 | model = AutoModelForCausalLM.from_pretrained(
156 | args.model_path,
157 | trust_remote_code=True,
158 | quantization_config=BitsAndBytesConfig(load_in_4bit=True),
159 | torch_dtype=torch.bfloat16,
160 | low_cpu_mem_usage=True,
161 | ).eval()
162 | else:
163 | model = AutoModelForCausalLM.from_pretrained(
164 | args.model_path,
165 | torch_dtype=torch.bfloat16 if args.precision == "bfloat16" else torch.float16,
166 | trust_remote_code=True,
167 | device_map="auto",
168 | ).eval()
169 | generic_chat(tokenizer, model, args.temperature, args.top_p, args.max_length, backend="transformers")
170 |
171 |
172 | if __name__ == "__main__":
173 | main()
174 |
--------------------------------------------------------------------------------
/inference/cli_demo_vision.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from threading import Thread
3 |
4 | from PIL import Image
5 | from transformers import (
6 | AutoTokenizer,
7 | AutoModelForCausalLM,
8 | AutoImageProcessor,
9 | TextIteratorStreamer,
10 | BitsAndBytesConfig,
11 | )
12 | import torch
13 | from ov_convert.convert_v import OvGLMv
14 |
15 |
16 | def generic_chat(tokenizer, processor, model, temperature, top_p, max_length, backend="transformers"):
17 | history = []
18 | image = None
19 | backend_label = "OpenVINO" if backend == "ov" else "Transformers"
20 | print(f"Welcome to the GLM-Edge-v CLI chat ({backend_label}). Type your messages below.")
21 | image_path = input("Image Path:")
22 | try:
23 | image = Image.open(image_path).convert("RGB")
24 | pixel_values = torch.tensor(processor(image).pixel_values).to(model.device)
25 | except:
26 | print("Invalid image path. Continuing with text conversation.")
27 |
28 | while True:
29 | user_input = input("\nYou: ")
30 | if user_input.lower() in ["exit", "quit"]:
31 | break
32 | history.append([user_input, ""])
33 |
34 | messages = []
35 | for idx, (user_msg, model_msg) in enumerate(history):
36 | if idx == len(history) - 1 and not model_msg:
37 | messages.append({"role": "user", "content": [{"type": "text", "text": user_msg}, {"type": "image"}]})
38 | break
39 | if user_msg:
40 | messages.append({"role": "user", "content": [{"type": "text", "text": user_msg}]})
41 | if model_msg:
42 | messages.append({"role": "assistant", "content": [{"type": "text", "text": model_msg}]})
43 |
44 | model_inputs = tokenizer.apply_chat_template(
45 | messages, add_generation_prompt=True, tokenize=True, return_tensors="pt", return_dict=True
46 | )
47 |
48 | model_inputs = model_inputs.to(model.device)
49 |
50 | streamer = TextIteratorStreamer(tokenizer=tokenizer, timeout=60, skip_prompt=True, skip_special_tokens=True)
51 | generate_kwargs = {
52 | **model_inputs,
53 | "streamer": streamer,
54 | "max_new_tokens": max_length,
55 | "do_sample": True,
56 | "top_p": top_p,
57 | "temperature": temperature,
58 | "repetition_penalty": 1.2,
59 | "pixel_values": pixel_values if image else None,
60 | }
61 |
62 | t = Thread(target=model.generate, kwargs=generate_kwargs)
63 | t.start()
64 | print("GLM-Edge-v:", end="", flush=True)
65 | for new_token in streamer:
66 | if new_token:
67 | print(new_token, end="", flush=True)
68 | history[-1][1] += new_token
69 |
70 | history[-1][1] = history[-1][1].strip()
71 |
72 |
73 | def main():
74 | parser = argparse.ArgumentParser(description="Run GLM-Edge-v DEMO Chat with Transformers or OpenVINO backend")
75 |
76 | parser.add_argument(
77 | "--backend",
78 | type=str,
79 | choices=["transformers", "ov"],
80 | required=True,
81 | help="Choose inference backend: transformers or OpenVINO",
82 | )
83 | parser.add_argument("--model_path", type=str, required=True, help="Path to the model")
84 | parser.add_argument(
85 | "--precision", type=str, default="bfloat16", choices=["float16", "bfloat16", "int4"], help="Model precision"
86 | )
87 | parser.add_argument("--temperature", type=float, default=0.6, help="Temperature for sampling")
88 | parser.add_argument("--top_p", type=float, default=0.8, help="Top-p (nucleus) sampling probability")
89 | parser.add_argument("--max_length", type=int, default=8192, help="Maximum token length for generation")
90 | args = parser.parse_args()
91 |
92 | tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True, encode_special_tokens=True)
93 | processor = AutoImageProcessor.from_pretrained(args.model_path, trust_remote_code=True)
94 |
95 | if args.backend == "ov":
96 | model = OvGLMv(args.model_path, device="CPU") # Use OpenVINO
97 | else:
98 | if args.precision == "int4":
99 | model = AutoModelForCausalLM.from_pretrained(
100 | args.model_path,
101 | trust_remote_code=True,
102 | quantization_config=BitsAndBytesConfig(load_in_4bit=True),
103 | torch_dtype=torch.bfloat16,
104 | low_cpu_mem_usage=True,
105 | ).eval()
106 | else:
107 | model = AutoModelForCausalLM.from_pretrained(
108 | args.model_path,
109 | torch_dtype=torch.bfloat16 if args.precision == "bfloat16" else torch.float16,
110 | trust_remote_code=True,
111 | device_map="auto",
112 | ).eval()
113 |
114 | generic_chat(tokenizer, processor, model, args.temperature, args.top_p, args.max_length, backend=args.backend)
115 |
116 |
117 | if __name__ == "__main__":
118 | main()
119 |
--------------------------------------------------------------------------------
/inference/ov_convert/convert_chat.py:
--------------------------------------------------------------------------------
1 | from transformers import AutoTokenizer
2 | from optimum.intel import OVWeightQuantizationConfig
3 | from optimum.intel.openvino import OVModelForCausalLM
4 | from optimum.exporters.tasks import TasksManager
5 | import os
6 | import argparse
7 |
8 | TasksManager._SUPPORTED_MODEL_TYPE["glm"] = TasksManager._SUPPORTED_MODEL_TYPE[
9 | "llama"
10 | ] # Using with Llama Type in converting
11 |
12 | if __name__ == "__main__":
13 | parser = argparse.ArgumentParser(add_help=False)
14 | parser.add_argument("--model_path", default="THUDM/glm-edge-1.5b-chat", type=str, help="orignal model path")
15 | parser.add_argument(
16 | "--precision", default="int4", type=str, choices=["fp16", "int8", "int4"], help="fp16, int8 or int4"
17 | )
18 | parser.add_argument("--output_path", default="glm-edge-1.5b-chat-ov", type=str, help="path to save the IR model")
19 | args = parser.parse_args()
20 | os.makedirs(args.output_path, exist_ok=True)
21 |
22 | compression_configs = {
23 | "sym": True,
24 | "group_size": 128,
25 | "ratio": 0.8,
26 | }
27 |
28 | TasksManager._SUPPORTED_MODEL_TYPE["glm"] = TasksManager._SUPPORTED_MODEL_TYPE["llama"]
29 |
30 | print("====Exporting IR=====")
31 | if args.precision == "int4":
32 | ov_model = OVModelForCausalLM.from_pretrained(
33 | args.model_path,
34 | export=True,
35 | compile=False,
36 | quantization_config=OVWeightQuantizationConfig(bits=4, **compression_configs),
37 | trust_remote_code=True,
38 | )
39 | elif args.precision == "int8":
40 | ov_model = OVModelForCausalLM.from_pretrained(
41 | args.model_path, export=True, compile=False, load_in_8bit=True, trust_remote_code=True
42 | )
43 | else:
44 | ov_model = OVModelForCausalLM.from_pretrained(
45 | args.model_path, export=True, compile=False, load_in_8bit=False, trust_remote_code=True
46 | )
47 |
48 | print("====Saving IR=====")
49 | ov_model.save_pretrained(args.output_path)
50 |
51 | print("====Exporting tokenizer=====")
52 | tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True)
53 | tokenizer.save_pretrained(args.output_path)
54 |
55 | print("====Exporting IR tokenizer=====")
56 | from optimum.exporters.openvino.convert import export_tokenizer
57 |
58 | export_tokenizer(tokenizer, args.output_path)
59 | print("====Finished=====")
60 |
--------------------------------------------------------------------------------
/inference/ov_convert/convert_v.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | from pathlib import Path
4 | import types
5 | import gc
6 |
7 | import openvino as ov
8 | from openvino.runtime import opset13
9 | import nncf
10 | import numpy as np
11 | import torch
12 | from transformers.cache_utils import Cache
13 | from transformers import AutoModelForCausalLM, AutoImageProcessor, AutoConfig, AutoTokenizer
14 | from transformers.generation import GenerationMixin
15 | from transformers.modeling_outputs import CausalLMOutputWithPast, BaseModelOutputWithPast
16 | from typing import Optional, Tuple, Union, List, Dict, Any
17 | from transformers import __version__ as transformers_version
18 | from transformers.generation.utils import GenerationConfig, ModelOutput
19 |
20 |
21 | def _chatglm_transformer_forward(
22 | self,
23 | input_ids: torch.LongTensor = None,
24 | attention_mask: Optional[torch.Tensor] = None,
25 | position_ids: Optional[torch.LongTensor] = None,
26 | past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
27 | inputs_embeds: Optional[torch.FloatTensor] = None,
28 | labels: Optional[torch.LongTensor] = None,
29 | use_cache: Optional[bool] = None,
30 | output_attentions: Optional[bool] = None,
31 | output_hidden_states: Optional[bool] = None,
32 | return_dict: Optional[bool] = None,
33 | cache_position: Optional[torch.LongTensor] = None,
34 | num_logits_to_keep: int = 0,
35 | **loss_kwargs,
36 | ) -> Union[Tuple, BaseModelOutputWithPast]:
37 | """take care of image_encode, position_ids and (attention_mask = None is fine)"""
38 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
39 | output_hidden_states = (
40 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
41 | )
42 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
43 |
44 | outputs = self.model(
45 | input_ids=input_ids,
46 | attention_mask=attention_mask,
47 | position_ids=position_ids,
48 | past_key_values=past_key_values,
49 | inputs_embeds=inputs_embeds,
50 | use_cache=use_cache,
51 | output_attentions=output_attentions,
52 | output_hidden_states=output_hidden_states,
53 | return_dict=return_dict,
54 | cache_position=cache_position,
55 | )
56 |
57 | hidden_states = outputs[0]
58 | logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
59 | logits = logits.to(torch.float32)
60 | output = (logits,) + outputs[1:]
61 | return output
62 |
63 |
64 | def model_has_state(ov_model: ov.Model):
65 | return len(ov_model.get_sinks()) > 0
66 |
67 |
68 | def model_has_input_output_name(ov_model: ov.Model, name: str):
69 | """
70 | Helper function for checking that model has specified input or output name
71 |
72 | Parameters:
73 | ov_model (ov.Model):
74 | name (str):
75 | name of input or output
76 |
77 | Returns:
78 | True if input or output with requested name exists else False
79 | """
80 | return name in sum([list(t.get_names()) for t in ov_model.inputs + ov_model.outputs], [])
81 |
82 |
83 | def fuse_cache_reorder(
84 | ov_model: ov.Model,
85 | not_kv_inputs: List[str],
86 | key_value_input_names: List[str],
87 | gather_dim: int,
88 | ):
89 | """
90 | Fuses reored_cache during generate cycle into ov.Model. Used with stateful models, because we can not modify model state directly.
91 |
92 | Adds a new beam_idx parameter and Gather op per each kv-cache input in a given model.
93 | Should be run before make_stateful. Implements optimumum's _reorder_cache
94 | inside the model in the beginning of each iteration.
95 | Gather works along given gather_dim dimension that may vary from model to model.
96 | KV-cache inputs are identified based on names in key_value_input_names.
97 | Append the new beam_idx parameter to not_kv_inputs.
98 |
99 | Parameters:
100 | ov_model (`ov.Model`):
101 | openvino model for processing
102 | not_kv_inputs (`List[str]`):
103 | list of input nodes in model that not related to past key values
104 | key_value_input_names (`List[str]`):
105 | list of names for key value input layers
106 | gather_dim (int):
107 | dimension for gathering cache during reorder pass
108 | """
109 |
110 | if model_has_input_output_name(ov_model, "beam_idx"):
111 | raise ValueError("Model already has fused cache")
112 | input_batch = ov_model.input("inputs_embeds").get_partial_shape()[0]
113 | beam_idx = opset13.parameter(name="beam_idx", dtype=ov.Type.i32, shape=ov.PartialShape([input_batch]))
114 | beam_idx.output(0).get_tensor().add_names({"beam_idx"}) # why list is not accepted?
115 | ov_model.add_parameters([beam_idx])
116 | not_kv_inputs.append(ov_model.inputs[-1])
117 | # Go over all cache parameters and fuse _reorder_cache with indices provided by the new parameter beam_idx
118 | for input_name in key_value_input_names:
119 | parameter_output_port = ov_model.input(input_name)
120 | consumers = parameter_output_port.get_target_inputs()
121 | gather = opset13.gather(parameter_output_port, beam_idx, opset13.constant(gather_dim))
122 | for consumer in consumers:
123 | consumer.replace_source_output(gather.output(0))
124 | ov_model.validate_nodes_and_infer_types()
125 |
126 |
127 | def build_state_initializer(ov_model: ov.Model, batch_dim: int):
128 | """
129 | Build initialization ShapeOf Expression for all ReadValue ops
130 |
131 | Parameters:
132 | ov_model (ov.Model):
133 | openvino model
134 | batch_dim (int):
135 | index of dimension corresponding to batch size
136 | """
137 | input_ids = ov_model.input("inputs_embeds")
138 | batch = opset13.gather(
139 | opset13.shape_of(input_ids, output_type="i64"),
140 | opset13.constant([0]),
141 | opset13.constant(0),
142 | )
143 | for op in ov_model.get_ops():
144 | if op.get_type_name() == "ReadValue":
145 | dims = [dim.min_length for dim in list(op.get_output_partial_shape(0))]
146 | dims[batch_dim] = batch
147 | dims = [
148 | (opset13.constant(np.array([dim], dtype=np.int64)) if isinstance(dim, int) else dim) for dim in dims
149 | ]
150 | shape = opset13.concat(dims, axis=0)
151 | broadcast = opset13.broadcast(opset13.constant(0.0, dtype=op.get_output_element_type(0)), shape)
152 | op.set_arguments([broadcast])
153 | ov_model.validate_nodes_and_infer_types()
154 |
155 |
156 | def make_stateful(
157 | ov_model: ov.Model,
158 | not_kv_inputs: List[str],
159 | key_value_input_names: List[str],
160 | key_value_output_names: List[str],
161 | batch_dim: int,
162 | num_attention_heads: int,
163 | num_beams_and_batch: int = None,
164 | ):
165 | """
166 | Hides kv-cache inputs and outputs inside the model as variables.
167 |
168 | Parameters:
169 | ov_model (ov.Model):
170 | openvino model
171 | not_kv_inputs (`List[str]`):
172 | list of input nodes in model that not related to past key values
173 | key_value_input_names (`List[str]`):
174 | list of names for key value input layers
175 | key_value_output_names (`List[str]`):
176 | list of names for key value input layers
177 | batch_dim (int):
178 | index of batch dimension in key value layers
179 | num_attention_heads (int):
180 | number of attention heads for batch dimension initialization
181 | num_beams_an_batch (int):
182 | precalculated number of beams and batch for shapes initialization
183 | """
184 | from openvino._offline_transformations import apply_make_stateful_transformation
185 |
186 | input_output_map = {}
187 |
188 | if num_beams_and_batch is not None:
189 | # Set batch size for input_ids and attention mask to avoid dynamic dimension got propagated from the end of the model back to ReadValue
190 | for input in not_kv_inputs:
191 | shape = input.get_partial_shape()
192 | if shape.rank.get_length() <= 2: # == 1 for beam_index
193 | shape[0] = num_beams_and_batch
194 | input.get_node().set_partial_shape(shape)
195 | for kv_name_pair in zip(key_value_input_names, key_value_output_names):
196 | input_output_map[kv_name_pair[0]] = kv_name_pair[1]
197 | if num_beams_and_batch is not None:
198 | input = ov_model.input(kv_name_pair[0])
199 | shape = input.get_partial_shape()
200 | shape[batch_dim] = num_beams_and_batch * num_attention_heads
201 | input.get_node().set_partial_shape(shape)
202 |
203 | if num_beams_and_batch is not None:
204 | # Re-validation model if shapes are altered above
205 | ov_model.validate_nodes_and_infer_types()
206 |
207 | apply_make_stateful_transformation(ov_model, input_output_map)
208 | if num_beams_and_batch is None:
209 | build_state_initializer(ov_model, batch_dim)
210 |
211 |
212 | def patch_stateful(ov_model):
213 | key_value_input_names = [key.get_any_name() for key in ov_model.inputs[2:-1]]
214 | key_value_output_names = [key.get_any_name() for key in ov_model.outputs[1:]]
215 | not_kv_inputs = [
216 | input for input in ov_model.inputs if not any(name in key_value_input_names for name in input.get_names())
217 | ]
218 | if not key_value_input_names or not key_value_output_names:
219 | return
220 | batch_dim = 0
221 | num_attention_heads = 1
222 |
223 | fuse_cache_reorder(ov_model, not_kv_inputs, key_value_input_names, batch_dim)
224 | make_stateful(
225 | ov_model,
226 | not_kv_inputs,
227 | key_value_input_names,
228 | key_value_output_names,
229 | batch_dim,
230 | num_attention_heads,
231 | None,
232 | )
233 |
234 |
235 | core = ov.Core()
236 |
237 |
238 | def cleanup_torchscript_cache():
239 | """
240 | Helper for removing cached model representation
241 | """
242 | torch._C._jit_clear_class_registry()
243 | torch.jit._recursive.concrete_type_store = torch.jit._recursive.ConcreteTypeStore()
244 | torch.jit._state._clear_class_state()
245 |
246 |
247 | def convert_glmv_model(model_id, output_dir, quantization_config):
248 | model_name = Path(model_id).name
249 | output_dir = Path(output_dir)
250 |
251 | lang_model_path = output_dir / "openvino_language_model.xml"
252 | image_embed_path = output_dir / "openvino_vision.xml"
253 | embed_token_path = output_dir / "openvino_embedding.xml"
254 | config = AutoConfig.from_pretrained(model_id, trust_remote_code=True)
255 | image_size = config.vision_config["image_size"]
256 |
257 | if all(
258 | [
259 | lang_model_path.exists(),
260 | image_embed_path.exists(),
261 | embed_token_path.exists(),
262 | ]
263 | ):
264 | print(f"✅ {model_name} model already converted. You can find results in {output_dir}")
265 | return
266 | print(f"⌛ {model_name} conversion started. Be patient, it may takes some time.")
267 | print("⌛ Load Original model")
268 | model = AutoModelForCausalLM.from_pretrained(
269 | model_id, trust_remote_code=True, torch_dtype=torch.float32, _attn_implementation="eager"
270 | )
271 | model.config.save_pretrained(output_dir)
272 | tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
273 | tokenizer.save_pretrained(output_dir)
274 | processor = AutoImageProcessor.from_pretrained(model_id, trust_remote_code=True)
275 | processor.save_pretrained(output_dir)
276 | # shutil.copy2(ori_token_config_path, ov_token_config_path)
277 |
278 | print("✅ Original model successfully loaded")
279 |
280 | if not embed_token_path.exists():
281 | print("⌛ Convert Input embedding model")
282 | ov_model = ov.convert_model(
283 | model.model.embed_tokens,
284 | example_input=torch.ones([1, 10], dtype=torch.int64),
285 | )
286 | ov.save_model(ov_model, embed_token_path)
287 | del ov_model
288 | cleanup_torchscript_cache()
289 | gc.collect()
290 | print("✅ Input embedding model successfully converted")
291 |
292 | if not image_embed_path.exists():
293 | print("⌛ Convert Image embedding model")
294 | # vision_embed_tokens.forward = vision_embed_tokens.vit
295 | ov_model = ov.convert_model(model.model.vision, example_input=torch.ones([1, 3, image_size, image_size]))
296 | ov.save_model(ov_model, image_embed_path)
297 | del ov_model
298 | cleanup_torchscript_cache()
299 | gc.collect()
300 | print("✅ Image embedding model successfully converted")
301 |
302 | if not lang_model_path.exists():
303 | print("⌛ Convert Language model")
304 |
305 | input_ids = torch.zeros([2, 2], dtype=torch.int64)
306 | inputs_embeds = torch.zeros([2, 2, config.hidden_size], dtype=torch.float32)
307 |
308 | pkv = model.model(
309 | input_ids=input_ids,
310 | attention_mask=torch.ones((2, 2), dtype=torch.int64),
311 | mages=torch.zeros([1, 3, image_size, image_size])
312 | )[1]
313 | model.forward = types.MethodType(_chatglm_transformer_forward, model)
314 |
315 | model.config.torchscript = True
316 | model_inputs = ["attention_mask", "position_ids"]
317 | model_outputs = ["logits"]
318 | for idx in range(len(pkv)):
319 | model_inputs.extend([f"past_key_values.{idx}.key", f"past_key_values.{idx}.value"])
320 | model_outputs.extend([f"present.{idx}.key", f"present.{idx}.value"])
321 | model_inputs.append("inputs_embeds")
322 | position_ids = torch.tensor([[2, 3], [2, 3]])
323 | ov_model = ov.convert_model(
324 | model,
325 | example_input={
326 | "position_ids": position_ids,
327 | "inputs_embeds": inputs_embeds,
328 | "attention_mask": torch.ones([2, 4], dtype=torch.int64),
329 | "past_key_values": pkv,
330 | },
331 | )
332 |
333 | for input, input_name in zip(ov_model.inputs, model_inputs):
334 | input.get_tensor().set_names({input_name})
335 |
336 | for output, output_name in zip(ov_model.outputs, model_outputs):
337 | output.get_tensor().set_names({output_name})
338 | patch_stateful(ov_model)
339 | print("✅ Language model successfully converted")
340 |
341 | if quantization_config is not None:
342 | print(f"⌛ Weights compression with {quantization_config['mode']} mode started")
343 | ov_model = nncf.compress_weights(ov_model, **quantization_config)
344 | print("✅ Weights compression finished")
345 |
346 | ov.save_model(ov_model, lang_model_path)
347 | del ov_model
348 | cleanup_torchscript_cache()
349 | del model
350 | gc.collect()
351 | print(f"✅ {model_name} model conversion finished. You can find results in {output_dir}")
352 |
353 |
354 | def is_empty(images_list: Optional[List[List[torch.Tensor]]]):
355 | if images_list is None or len(images_list) == 0:
356 | return True
357 | for image_list in images_list:
358 | if image_list is not None:
359 | return False
360 | return True
361 |
362 |
363 | class OvGLMv(GenerationMixin):
364 | def __init__(self, model_dir, device):
365 | model_dir = Path(model_dir)
366 | self.model = core.read_model(model_dir / "openvino_language_model.xml")
367 | self.vision = core.compile_model(model_dir / "openvino_vision.xml", "CPU")
368 | self.embedding = core.compile_model(model_dir / "openvino_embedding.xml", "CPU")
369 | self.input_names = {key.get_any_name(): idx for idx, key in enumerate(self.model.inputs)}
370 | self.output_names = {key.get_any_name(): idx for idx, key in enumerate(self.model.outputs)}
371 | # compiled_model = core.compile_model(self.model, device, config={"GPU_ENABLE_SDPA_OPTIMIZATION": "NO", "INFERENCE_PRECISION_HINT": "FP32"})
372 | compiled_model = core.compile_model(self.model, device)
373 |
374 | self.request = compiled_model.create_infer_request()
375 | self.config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True)
376 | self.generation_config = GenerationConfig.from_model_config(self.config)
377 | self.main_input_name = "input_ids"
378 | self.device = torch.device("cpu")
379 | self.num_pkv = 2
380 | self._supports_cache_class = False
381 | self.next_beam_idx = None
382 | self.hd_transform_order = "glb_sub"
383 |
384 | def can_generate(self):
385 | """Returns True to validate the check that the model using `GenerationMixin.generate()` can indeed generate."""
386 | return True
387 |
388 | def __call__(
389 | self,
390 | input_ids: torch.LongTensor = None,
391 | pixel_values: torch.Tensor = None,
392 | position_ids: Optional[torch.Tensor] = None,
393 | attention_mask: Optional[torch.BoolTensor] = None,
394 | past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
395 | inputs_embeds: Optional[torch.Tensor] = None,
396 | **kwargs,
397 | ) -> CausalLMOutputWithPast:
398 | return self.forward(
399 | input_ids=input_ids,
400 | pixel_values=pixel_values,
401 | attention_mask=attention_mask,
402 | past_key_values=past_key_values,
403 | position_ids=position_ids,
404 | inputs_embeds=inputs_embeds,
405 | **kwargs,
406 | )
407 |
408 | def forward(
409 | self,
410 | input_ids: torch.LongTensor = None,
411 | pixel_values: torch.Tensor = None,
412 | position_ids: Optional[torch.Tensor] = None,
413 | attention_mask: Optional[torch.BoolTensor] = None,
414 | past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
415 | inputs_embeds: Optional[torch.Tensor] = None,
416 | return_dict: Optional[bool] = None,
417 | ) -> Union[Tuple, BaseModelOutputWithPast]:
418 | batch_size, num_concurrent_media, num_tiles, num_channels, height, width = pixel_values.shape
419 | pixel_values = pixel_values.reshape(batch_size * num_concurrent_media * num_tiles, num_channels, height, width)
420 | if not past_key_values:
421 | self.request.reset_state()
422 | self.next_beam_idx = np.arange(input_ids.shape[0], dtype=int)
423 | # not allow for inputs_embeds, because we want to process image feature
424 | assert input_ids is not None and inputs_embeds is None, f"{input_ids} {inputs_embeds}"
425 | inputs_embeds = torch.from_numpy(self.embedding(input_ids)[0])
426 | new_input_embeds = []
427 | multi_flags = [True if self.config.boi_token_id in input_id.tolist() else False for input_id in input_ids]
428 | images_features = None
429 | if not is_empty(pixel_values):
430 | images_features = torch.from_numpy(self.vision(pixel_values)[0])
431 | image_count = 0
432 | for i in range(len(input_ids)):
433 | input_id = input_ids[i].tolist()
434 | if multi_flags[i]:
435 | boi_token_pos = input_id.index(self.config.boi_token_id)
436 | assert boi_token_pos >= 0, "begin_of_image not found!"
437 | num_image_padding_tokens = input_id.count(self.config.boi_token_id)
438 | assert (
439 | num_image_padding_tokens == images_features[image_count].shape[0]
440 | ), f"Wrong image padding token number: {num_image_padding_tokens}"
441 | new_input_embeds.append(
442 | torch.cat(
443 | (
444 | inputs_embeds[i, :boi_token_pos],
445 | images_features[image_count].to(inputs_embeds.device),
446 | inputs_embeds[i, boi_token_pos + num_image_padding_tokens :],
447 | )
448 | )
449 | )
450 | image_count += 1
451 | else:
452 | new_input_embeds.append(inputs_embeds[i])
453 | inputs_embeds = torch.stack(new_input_embeds, dim=0)
454 |
455 | if inputs_embeds is None:
456 | inputs_embeds = self.embedding(input_ids)[0]
457 | inputs = {}
458 | inputs["inputs_embeds"] = inputs_embeds
459 | inputs["attention_mask"] = attention_mask
460 | inputs["position_ids"] = position_ids
461 | if "beam_idx" in self.input_names:
462 | inputs["beam_idx"] = (
463 | self.next_beam_idx if self.next_beam_idx is not None else np.arange(inputs_embeds.shape[0], dtype=int)
464 | )
465 | self.request.start_async(inputs, share_inputs=True)
466 | self.request.wait()
467 | logits = self.request.get_tensor("logits").data
468 | logits = torch.from_numpy(logits).to(self.device)
469 | past_key_values = ((),)
470 |
471 | return CausalLMOutputWithPast(logits=logits, past_key_values=past_key_values)
472 |
473 | def _reorder_cache(
474 | self, past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
475 | ) -> Tuple[Tuple[torch.Tensor]]:
476 | """
477 | This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
478 | [`~PreTrainedModel.beam_sample`] is called.
479 | This is required to match `past_key_values` with the correct beam_idx at every generation step.
480 | """
481 | self.next_beam_idx = np.array(beam_idx) # save beam_idx to be used as an input in the next iteration
482 | return past_key_values
483 |
484 | def _update_model_kwargs_for_generation(
485 | self,
486 | outputs: ModelOutput,
487 | model_kwargs: Dict[str, Any],
488 | is_encoder_decoder: bool = False,
489 | standardize_cache_format: bool = False,
490 | ) -> Dict[str, Any]:
491 | # update past_key_values
492 | if int(transformers_version.split(".")[1]) >= 44:
493 | assert not standardize_cache_format
494 | _, cache = self._extract_past_from_model_output(outputs)
495 | model_kwargs["past_key_values"] = cache
496 | else:
497 | cache = self._extract_past_from_model_output(outputs, standardize_cache_format)
498 |
499 | # update attention mask
500 | if "attention_mask" in model_kwargs:
501 | attention_mask = model_kwargs["attention_mask"]
502 | model_kwargs["attention_mask"] = torch.cat(
503 | [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
504 | )
505 |
506 | # update position ids
507 | if "position_ids" in model_kwargs:
508 | position_ids = model_kwargs["position_ids"]
509 | new_position_id = position_ids[..., -1:].clone()
510 | new_position_id += 1
511 | model_kwargs["position_ids"] = torch.cat([position_ids, new_position_id], dim=-1)
512 |
513 | model_kwargs["is_first_forward"] = False
514 | return model_kwargs
515 |
516 | def prepare_inputs_for_generation(
517 | self,
518 | input_ids: torch.LongTensor,
519 | pixel_values: Optional[torch.Tensor] = torch.zeros([1, 1, 1, 3, 672, 672]),
520 | past_key_values: Optional[torch.Tensor] = None,
521 | attention_mask: Optional[torch.Tensor] = None,
522 | position_ids: Optional[torch.Tensor] = None,
523 | use_cache: Optional[bool] = None,
524 | is_first_forward: bool = True,
525 | **kwargs,
526 | ) -> dict:
527 | if position_ids is None:
528 | if attention_mask is None:
529 | # Can only build sequential ids. Raise error right now
530 | raise ValueError("Cannot create position ids when attention mask is None")
531 | else:
532 | position_ids = self._create_position_ids_from_attention_mask(attention_mask)
533 | if not is_first_forward:
534 | if past_key_values is not None:
535 | position_ids = position_ids[..., -1:]
536 | input_ids = input_ids[:, -1:]
537 | return {
538 | "input_ids": input_ids,
539 | "pixel_values": pixel_values,
540 | "past_key_values": past_key_values,
541 | "position_ids": position_ids,
542 | "attention_mask": attention_mask,
543 | }
544 |
545 | def _create_position_ids_from_attention_mask(self, attention_mask):
546 | # Initialize a tensor of the same shape as attention_mask to hold position IDs
547 | position_ids = torch.zeros_like(attention_mask, dtype=torch.long, device=attention_mask.device)
548 | # Iterate over the batch
549 | for i, mask in enumerate(attention_mask):
550 | # Find the positions where the mask is 1
551 | positions = torch.nonzero(mask, as_tuple=False).squeeze(1).to(attention_mask.device)
552 | # Assign position IDs to those positions
553 | position_ids[i, positions] = torch.arange(start=0, end=positions.size(0), dtype=torch.long).to(
554 | attention_mask.device
555 | )
556 | return position_ids
557 |
558 |
559 | if __name__ == "__main__":
560 | parser = argparse.ArgumentParser()
561 | parser.add_argument("--model_path", default="THUDM/glm-edge-v-2b", type=str, help="orignal model path")
562 | parser.add_argument(
563 | "--precision", default="int4", type=str, choices=["fp16", "int8", "int4"], help="fp16, int8 or int4"
564 | )
565 | parser.add_argument("--output_path", default="glm-edge-v-2b-ov", help="path to save the ir model")
566 | args = parser.parse_args()
567 | os.makedirs(args.output_path, exist_ok=True)
568 | if args.precision == "int4":
569 | compression_configuration = {
570 | "mode": nncf.CompressWeightsMode.INT4_SYM,
571 | "group_size": 64,
572 | "ratio": 0.6,
573 | }
574 | elif args.precision == "int8":
575 | compression_configuration = {
576 | "mode": nncf.CompressWeightsMode.INT8,
577 | "group_size": 64,
578 | "ratio": 0.6,
579 | }
580 | else:
581 | compression_configuration = None
582 | convert_glmv_model(args.model_path, args.output_path, compression_configuration)
583 |
--------------------------------------------------------------------------------
/inference/web_demo.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from pathlib import Path
3 | from threading import Thread
4 | from typing import Union
5 | import requests
6 | from io import BytesIO
7 | from PIL import Image
8 | import re
9 | import gradio as gr
10 | import torch
11 | from peft import AutoPeftModelForCausalLM
12 | from transformers import (
13 | AutoTokenizer,
14 | AutoModelForCausalLM,
15 | AutoImageProcessor,
16 | TextIteratorStreamer,
17 | BitsAndBytesConfig,
18 | )
19 |
20 | # Parse arguments
21 | def parse_args():
22 | parser = argparse.ArgumentParser(description="GLM-Edge-Chat Gradio Demo with adjustable parameters")
23 | parser.add_argument("--model_path", type=str, required=True, help="Path to the model")
24 | parser.add_argument("--server_name", type=str, default="127.0.0.1", help="Server name")
25 | parser.add_argument("--server_port", type=int, default=7860, help="Server port")
26 | parser.add_argument("--lora_path", type=str, default=None, help="Path to LoRA model if available")
27 | parser.add_argument(
28 | "--precision",
29 | type=str,
30 | choices=["float16", "bfloat16", "int4"],
31 | default="bfloat16",
32 | help="Precision for model",
33 | )
34 | return parser.parse_args()
35 |
36 | args = parse_args()
37 |
38 | def _resolve_path(path: Union[str, Path]) -> Path:
39 | return Path(path).expanduser().resolve()
40 |
41 | # Load Model and Tokenizer for transformers
42 | def load_model_and_tokenizer(model_dir: Union[str, Path], precision: str, trust_remote_code: bool = True):
43 | model_dir = _resolve_path(model_dir)
44 | if precision == "int4":
45 | model = AutoModelForCausalLM.from_pretrained(
46 | model_dir,
47 | trust_remote_code=trust_remote_code,
48 | quantization_config=BitsAndBytesConfig(load_in_4bit=True),
49 | torch_dtype=torch.bfloat16,
50 | low_cpu_mem_usage=True,
51 | ).eval()
52 | elif (model_dir / "adapter_config.json").exists():
53 | model = AutoPeftModelForCausalLM.from_pretrained(
54 | model_dir, trust_remote_code=trust_remote_code, device_map="auto"
55 | )
56 | else:
57 | model = AutoModelForCausalLM.from_pretrained(model_dir, trust_remote_code=trust_remote_code, device_map="auto")
58 |
59 | tokenizer_dir = (
60 | model.peft_config["default"].base_model_name_or_path if hasattr(model, "peft_config") else model_dir
61 | )
62 | tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir, trust_remote_code=trust_remote_code, use_fast=False)
63 | return model, tokenizer
64 |
65 | model, tokenizer = load_model_and_tokenizer(args.model_path, args.precision, trust_remote_code=True)
66 |
67 | def is_url(s):
68 | if re.match(r'^(?:http|ftp)s?://', s):
69 | return True
70 | return False
71 |
72 | def get_image(image):
73 | if is_url(image):
74 | response = requests.get(image)
75 | return Image.open(BytesIO(response.content)).convert("RGB")
76 | elif image:
77 | return Image.open(image).convert("RGB")
78 | return None
79 |
80 | def preprocess_messages(history, prompt, image):
81 | messages = []
82 | pixel_values = None
83 |
84 | if prompt:
85 | messages.append({"role": "system", "content": prompt})
86 | for idx, (user_msg, model_msg) in enumerate(history):
87 | if prompt and idx == 0:
88 | continue
89 | if idx == len(history) - 1 and not messages:
90 | messages.append({"role": "user", "content": user_msg})
91 | break
92 | if user_msg:
93 | messages.append({"role": "user", "content": user_msg})
94 | if model_msg:
95 | messages.append({"role": "assistant", "content": messages})
96 |
97 | if hasattr(model.config, "vision_config"):
98 | for item in messages:
99 | msg = item['content']
100 | item['content'] = [{"type": "text", "text": msg}]
101 | if image:
102 | messages[-1]['content'].append({"type": "image"})
103 | try:
104 | image_input = get_image(image)
105 |
106 | processor = AutoImageProcessor.from_pretrained(
107 | args.model_path,
108 | trust_remote_code=True
109 | )
110 | pixel_values = torch.tensor(
111 | processor(image_input).pixel_values).to(model.device)
112 | except:
113 | print("Invalid image path. Continuing with text conversation.")
114 |
115 | return messages, pixel_values
116 |
117 | def predict(history, prompt, max_length, top_p, temperature, image=None):
118 | messages, pixel_values = preprocess_messages(history, prompt, image)
119 | model_inputs = tokenizer.apply_chat_template(
120 | messages, add_generation_prompt=True, tokenize=True, return_tensors="pt", return_dict=True
121 | )
122 |
123 | streamer = TextIteratorStreamer(tokenizer, timeout=60, skip_prompt=True, skip_special_tokens=True)
124 | generate_kwargs = {
125 | "input_ids": model_inputs["input_ids"].to(model.device),
126 | "attention_mask": model_inputs["attention_mask"].to(model.device),
127 | "streamer": streamer,
128 | "max_new_tokens": max_length,
129 | "do_sample": True,
130 | "top_p": top_p,
131 | "temperature": temperature,
132 | "repetition_penalty": 1.2,
133 | }
134 | if hasattr(model.config, "vision_config"):
135 | generate_kwargs['eos_token_id'] = [59246, 59253, 59255]
136 | if image and isinstance(pixel_values, torch.Tensor):
137 | generate_kwargs['pixel_values'] = pixel_values
138 | else:
139 | generate_kwargs['eos_token_id'] = tokenizer.encode("<|user|>")
140 | t = Thread(target=model.generate, kwargs=generate_kwargs)
141 | t.start()
142 | for new_token in streamer:
143 | if new_token:
144 | history[-1][1] += new_token
145 | yield history
146 |
147 | def main():
148 | with gr.Blocks() as demo:
149 | gr.HTML("""GLM-Edge-Chat Gradio Chat Demo
""")
150 |
151 | # Top row: Chatbot and Image upload
152 | with gr.Row():
153 | with gr.Column(scale=3):
154 | chatbot = gr.Chatbot()
155 | with gr.Column(scale=1):
156 | image_input = gr.Image(label="Upload an Image", type="filepath")
157 |
158 | # Bottom row: System prompt, user input, and controls
159 | with gr.Row():
160 | with gr.Column(scale=2):
161 | prompt_input = gr.Textbox(show_label=True, placeholder="System Prompt", label="System Prompt")
162 | user_input = gr.Textbox(show_label=True, placeholder="Input...", label="User Input")
163 | submitBtn = gr.Button("Submit")
164 | pBtn = gr.Button("Set System prompt")
165 | emptyBtn = gr.Button("Clear History")
166 | with gr.Column(scale=1):
167 | max_length = gr.Slider(0, 8192, value=4096, step=1.0, label="Maximum length", interactive=True)
168 | top_p = gr.Slider(0, 1, value=0.8, step=0.01, label="Top P", interactive=True)
169 | temperature = gr.Slider(0.01, 1, value=0.6, step=0.01, label="Temperature", interactive=True)
170 |
171 | # Define functions for button actions
172 | def user(query, history):
173 | return "", history + [[query, ""]]
174 |
175 | def set_prompt(prompt_text):
176 | return [[prompt_text, "Prompt set successfully"]]
177 |
178 | # Button actions and callbacks
179 | pBtn.click(set_prompt, inputs=[prompt_input], outputs=chatbot)
180 | submitBtn.click(user, [user_input, chatbot], [user_input, chatbot], queue=False).then(
181 | predict, [chatbot, prompt_input, max_length, top_p, temperature, image_input], chatbot
182 | )
183 | emptyBtn.click(lambda: (None, None), None, [chatbot, prompt_input], queue=False)
184 |
185 | demo.queue()
186 | demo.launch(server_name=args.server_name, server_port=args.server_port)
187 |
188 |
189 | if __name__ == "__main__":
190 | main()
191 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | # Install transformers for source:
2 | git+https://github.com/huggingface/transformers.git
3 | torch>=2.5.1
4 | torchvision>=0.20.0
5 | huggingface-hub>=0.25.1
6 | sentencepiece>=0.2.0
7 | jinja2>=3.1.4
8 | pydantic>=2.9.2
9 | timm>=1.0.9
10 | tiktoken>=0.8.0
11 | numpy==1.26.4
12 | accelerate>=1.1.1
13 | sentence_transformers>=3.1.1
14 | gradio>=5.6.0
15 | openai>=1.55.0
16 | einops>=0.8.0
17 | pillow>=10.4.0
18 | sse-starlette>=2.1.3
19 | bitsandbytes>=0.43.2 # INT4 Loading
20 |
21 | # For Intel OpenVINNO convert
22 | # optimum-intel>=1.20.1
23 | # openvino>=1.26.4
24 | # nncf>=2.14.0
25 |
26 | # vllm>=0.6.4.post1 # using with VLLM Framework
27 | # peft>=0.14.0 # Using with finetune model
28 |
--------------------------------------------------------------------------------
/resources/img.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THUDM/GLM-Edge/7a93ea5047e91cf27b74d1e64ad490f75d329b17/resources/img.png
--------------------------------------------------------------------------------