├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.yaml │ └── feature-request.yaml └── PULL_REQUEST_TEMPLATE │ └── pr_template.md ├── .gitignore ├── LICENSE ├── MODEL_LICENSE ├── README.md ├── README_en.md ├── finetune ├── README.md ├── README_en.md ├── configs │ ├── ds_zero_3.json │ ├── lora.yaml │ └── sft.yaml ├── finetune.py ├── finetune_vision.py └── vision_dataset.zip ├── inference ├── cli_demo.py ├── cli_demo_vision.py ├── ov_convert │ ├── convert_chat.py │ └── convert_v.py └── web_demo.py ├── requirements.txt └── resources └── img.png /.github/ISSUE_TEMPLATE/bug_report.yaml: -------------------------------------------------------------------------------- 1 | name: "\U0001F41B Bug Report" 2 | description: Submit a bug report to help us improve GLM-Edge / 提交一个 Bug 问题报告来帮助我们改进 GLM-Edge 3 | body: 4 | - type: textarea 5 | id: system-info 6 | attributes: 7 | label: System Info / 系統信息 8 | description: Your operating environment / 您的运行环境信息 9 | placeholder: Includes Cuda version, Transformers version, Python version, operating system, hardware information (if you suspect a hardware problem)... / 包括Cuda版本,Transformers版本,Python版本,操作系统,硬件信息(如果您怀疑是硬件方面的问题)... 10 | validations: 11 | required: true 12 | 13 | - type: textarea 14 | id: who-can-help 15 | attributes: 16 | label: Who can help? / 谁可以帮助到您? 17 | description: | 18 | Your issue will be replied to more quickly if you can figure out the right person to tag with @ 19 | All issues are read by one of the maintainers, so if you don't know who to tag, just leave this blank and our maintainer will ping the right person. 20 | 21 | Please tag fewer than 3 people. 22 | 23 | 如果您能找到合适的标签 @,您的问题会更快得到回复。 24 | 所有问题都会由我们的维护者阅读,如果您不知道该标记谁,只需留空,我们的维护人员会找到合适的开发组成员来解决问题。 25 | 26 | 标记的人数应该不超过 3 个人。 27 | 28 | If it's not a bug in these three subsections, you may not specify the helper. Our maintainer will find the right person in the development group to solve the problem. 29 | 30 | 如果不是这三个子版块的bug,您可以不指明帮助者,我们的维护人员会找到合适的开发组成员来解决问题。 31 | 32 | placeholder: "@Username ..." 33 | 34 | - type: checkboxes 35 | id: information-scripts-examples 36 | attributes: 37 | label: Information / 问题信息 38 | description: 'The problem arises when using: / 问题出现在' 39 | options: 40 | - label: "The official example scripts / 官方的示例脚本" 41 | - label: "My own modified scripts / 我自己修改的脚本和任务" 42 | 43 | - type: textarea 44 | id: reproduction 45 | validations: 46 | required: true 47 | attributes: 48 | label: Reproduction / 复现过程 49 | description: | 50 | Please provide a code example that reproduces the problem you encountered, preferably with a minimal reproduction unit. 51 | If you have code snippets, error messages, stack traces, please provide them here as well. 52 | Please format your code correctly using code tags. See https://help.github.com/en/github/writing-on-github/creating-and-highlighting-code-blocks#syntax-highlighting 53 | Do not use screenshots, as they are difficult to read and (more importantly) do not allow others to copy and paste your code. 54 | 55 | 请提供能重现您遇到的问题的代码示例,最好是最小复现单元。 56 | 如果您有代码片段、错误信息、堆栈跟踪,也请在此提供。 57 | 请使用代码标签正确格式化您的代码。请参见 https://help.github.com/en/github/writing-on-github/creating-and-highlighting-code-blocks#syntax-highlighting 58 | 请勿使用截图,因为截图难以阅读,而且(更重要的是)不允许他人复制粘贴您的代码。 59 | placeholder: | 60 | Steps to reproduce the behavior/复现Bug的步骤: 61 | 62 | 1. 63 | 2. 64 | 3. 65 | 66 | - type: textarea 67 | id: expected-behavior 68 | validations: 69 | required: true 70 | attributes: 71 | label: Expected behavior / 期待表现 72 | description: "A clear and concise description of what you would expect to happen. /简单描述您期望发生的事情。" -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature-request.yaml: -------------------------------------------------------------------------------- 1 | name: "\U0001F680 Feature request" 2 | description: Submit a request for a new GLM-Edge feature / 提交一个新的 GLM-Edge 的功能建议 3 | labels: [ "feature" ] 4 | body: 5 | - type: textarea 6 | id: feature-request 7 | validations: 8 | required: true 9 | attributes: 10 | label: Feature request / 功能建议 11 | description: | 12 | A brief description of the functional proposal. Links to corresponding papers and code are desirable. 13 | 对功能建议的简述。最好提供对应的论文和代码链接 14 | 15 | - type: textarea 16 | id: motivation 17 | validations: 18 | required: true 19 | attributes: 20 | label: Motivation / 动机 21 | description: | 22 | Your motivation for making the suggestion. If that motivation is related to another GitHub issue, link to it here. 23 | 您提出建议的动机。如果该动机与另一个 GitHub 问题有关,请在此处提供对应的链接。 24 | 25 | - type: textarea 26 | id: contribution 27 | validations: 28 | required: true 29 | attributes: 30 | label: Your contribution / 您的贡献 31 | description: | 32 | 33 | Your PR link or any other link you can help with. 34 | 您的PR链接或者其他您能提供帮助的链接。 -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE/pr_template.md: -------------------------------------------------------------------------------- 1 | # Raise valuable PR / 提出有价值的PR 2 | 3 | ## Caution/ 注意事项: 4 | Users should keep the following points in mind when submitting PRs: 5 | 6 | 1. The proposed PR should be about this project. 7 | 2. the proposed PR should be relevant, if there are multiple ideas and optimizations, they should be assigned to different PRs. 8 | 9 | 用户在提交PR时候应该注意以下几点: 10 | 11 | 1. 提出的PR应该是关于本项目的。 12 | 2. 提出的PR应该具有针对性,如果具有多个不同的想法和优化方案,应该分配到不同的PR中。 13 | 14 | ## 不应该提出的PR / PRs that should not be proposed 15 | 16 | If a developer proposes a PR about any of the following, it may be closed or Rejected. 17 | 18 | 1. those that don't describe improvement options. 19 | 2. multiple issues of different types combined in one PR. 20 | 3. The proposed PR is highly duplicative of already existing PRs. 21 | 22 | 如果开发者提出关于以下方面的PR,则可能会被直接关闭或拒绝通过。 23 | 24 | 1. 没有说明改进方案的。 25 | 2. 多个不同类型的问题合并在一个PR中的。 26 | 3. 提出的PR与已经存在的PR高度重复的。 27 | 28 | 29 | # 检查您的PR 30 | - [ ] Have you read the Contributor Guidelines, Pull Request section? / 您是否阅读了贡献者指南、Pull Request 部分? 31 | - [ ] Has this been discussed/approved via a Github issue or forum? If so, add a link. / 是否通过 Github 问题或论坛讨论/批准过?如果是,请添加链接。 32 | - [ ] Did you make sure you updated the documentation with your changes? Here are the Documentation Guidelines, and here are the Documentation Formatting Tips. /您是否确保根据您的更改更新了文档?这里是文档指南,这里是文档格式化技巧。 33 | - [ ] Did you write new required tests? / 您是否编写了新的必要测试? 34 | - [ ] Are your PRs for only one issue / 您的PR是否仅针对一个问题 -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *__pycache__/ 2 | samples*/ 3 | runs/ 4 | checkpoints/ 5 | master_ip 6 | logs/ 7 | *.DS_Store 8 | .idea/* 9 | output* 10 | test* 11 | pyproject.toml 12 | draft* -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2025 GLM-Edge Model Team @ Zhipu AI 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /MODEL_LICENSE: -------------------------------------------------------------------------------- 1 | The GLM-Edge License 2 | 3 | 1. 定义 4 | 5 | “许可方”是指分发其软件的 GLM-Edge 模型团队。 6 | “软件”是指根据本许可提供的 GLM-Edge 模型参数。 7 | 8 | 2. 许可授予 9 | 10 | 根据本许可的条款和条件,许可方特此授予您非排他性、全球性、不可转让、不可再许可、可撤销、免版税的版权许可。 11 | 本许可允许您免费使用本仓库中的所有开源模型进行学术研究,对于希望将模型用于商业目的的用户,需在[这里](https://open.bigmodel.cn/mla/form)完成登记。经过登记的用户可以免费使用本模型进行商业活动,但必须遵守本许可的所有条款和条件。 12 | 上述版权声明和本许可声明应包含在本软件的所有副本或重要部分中。 13 | 如果您分发或提供 THUDM / 智谱AI 关于 GLM-Edge 开源模型的材料(或其任何衍生作品),或使用其中任何材料(包括 GLM-Edge 系列的所有开源模型)的产品或服务,您应: 14 | 15 | (A) 随任何此类 THUDM / 智谱AI 材料提供本协议的副本; 16 | (B) 在相关网站、用户界面、博客文章、关于页面或产品文档上突出显示 “Built with GLM-Edge”。 17 | 如果您使用 THUDM / 智谱AI的 GLM-Edge 开源模型的材料来创建、训练、微调或以其他方式改进已分发或可用的 AI 模型,您还应在任何此类 AI 模型名称的开头添加 “GLM-Edge”。 18 | 19 | 3. 限制 20 | 21 | 您不得出于任何军事或非法目的使用、复制、修改、合并、发布、分发、复制或创建本软件的全部或部分衍生作品。 22 | 您不得利用本软件从事任何危害国家安全和国家统一,危害社会公共利益及公序良俗,侵犯他人商业秘密、知识产权、名誉权、肖像权、财产权等权益的行为。 23 | 您在使用中应遵循使用地所适用的法律法规政策、道德规范等要求。 24 | 25 | 4. 免责声明 26 | 27 | 本软件“按原样”提供,不提供任何明示或暗示的保证,包括但不限于对适销性、特定用途的适用性和非侵权性的保证。 28 | 在任何情况下,作者或版权持有人均不对任何索赔、损害或其他责任负责,无论是在合同诉讼、侵权行为还是其他方面,由软件或软件的使用或其他交易引起、由软件引起或与之相关 29 | 软件。 30 | 31 | 5. 责任限制 32 | 33 | 除适用法律禁止的范围外,在任何情况下且根据任何法律理论,无论是基于侵权行为、疏忽、合同、责任或其他原因,任何许可方均不对您承担任何直接、间接、特殊、偶然、示范性、 34 | 或间接损害,或任何其他商业损失,即使许可人已被告知此类损害的可能性。 35 | 36 | 6. 争议解决 37 | 38 | 本许可受中华人民共和国法律管辖并按其解释。 因本许可引起的或与本许可有关的任何争议应提交北京市海淀区人民法院。 39 | 请注意,许可证可能会更新到更全面的版本。 有关许可和版权的任何问题,请通过 license@zhipuai.cn 或 opensource@zhipuai.cn 与我们联系。 40 | 41 | 1. Definitions 42 | 43 | “Licensor” means the GLM-Edge Model Team that distributes its Software. 44 | “Software” means the GLM-Edge model parameters made available under this license. 45 | 46 | 2. License 47 | 48 | Under the terms and conditions of this license, the Licensor hereby grants you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty-free copyright license. 49 | This license allows you to use all open source models in this repository for free for academic research. For users who wish to use the models for commercial purposes, please do so [here](https://open.bigmodel.cn/mla/form) 50 | Complete registration. Registered users are free to use this model for commercial activities, but must comply with all terms and conditions of this license. 51 | The copyright notice and this license notice shall be included in all copies or substantial portions of the Software. 52 | If you distribute or provide THUDM / Zhipu AI materials on the GLM-Edge open source model (or any derivative works thereof), or products or services that use any materials therein (including all open source models of the GLM-Edge series), you should: 53 | 54 | (A) Provide a copy of this Agreement with any such THUDM/Zhipu AI Materials; 55 | (B) Prominently display "Built with GLM-Edge" on the relevant website, user interface, blog post, related page or product documentation. 56 | If you use materials from THUDM/ZHIPU's GLM-Edge model to create, train, operate, or otherwise improve assigned or available AI models, you should also add "GLM-Edge" to the beginning of any such AI model name. 57 | 58 | 3. Restrictions 59 | 60 | You are not allowed to use, copy, modify, merge, publish, distribute, copy or create all or part of the derivative works of this software for any military or illegal purposes. 61 | You are not allowed to use this software to engage in any behavior that endangers national security and unity, endangers social public interests and public order, infringes on the rights and interests of others such as trade secrets, intellectual property rights, reputation rights, portrait rights, and property rights. 62 | You should comply with the applicable laws, regulations, policies, ethical standards, and other requirements in the place of use during use. 63 | 64 | 4. Disclaimer 65 | 66 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE 67 | WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 68 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 69 | OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 70 | 71 | 5. Limitation of Liability 72 | 73 | EXCEPT TO THE EXTENT PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER BASED IN TORT, 74 | NEGLIGENCE, CONTRACT, LIABILITY, OR OTHERWISE WILL ANY LICENSOR BE LIABLE TO YOU FOR ANY DIRECT, INDIRECT, SPECIAL, 75 | INCIDENTAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES, OR ANY OTHER COMMERCIAL LOSSES, EVEN IF THE LICENSOR HAS BEEN ADVISED 76 | OF THE POSSIBILITY OF SUCH DAMAGES. 77 | 78 | 6. Dispute Resolution 79 | 80 | This license shall be governed and construed in accordance with the laws of People’s Republic of China. Any dispute 81 | arising from or in connection with this License shall be submitted to Haidian District People's Court in Beijing. 82 | 83 | Note that the license is subject to update to a more comprehensive version. For any questions related to the license and 84 | copyright, please contact us at license@zhipuai.cn. 85 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GLM-Edge 2 | 3 | Read this in [English](README_en.md) 4 | 5 | 在 🤗 这里 体验 GLM-Edge-1.5B-Chat 端侧模型 6 | 7 | 在 🤗 这里 或者 🤖 这里 体验 GLM-Edge-V-5B 端侧模型 8 | 9 | 10 | ## 模型介绍 11 | 12 | **GLM-Edge** 系列是我们在面向端侧真实落地使用的场景下的一次尝试,由两种尺寸的大语言对话模型和多模态理解模型组成( 13 | `GLM-Edge-1.5B-Chat`,`GLM-Edge-4B-Chat`,`GLM-Edge-V-2B`,`GLM-Edge-V-5B`)。其中,`1.5B / 2B`模型主要面向手机、车机等平台, 14 | `4B / 5B` 模型主要面向PC等平台。 15 | 16 | 基于GLM-4系列的技术积累,我们针对端侧实际部署情况,对模型结构和尺寸做了针对性的调整,以求在模型表现、实机推理效果和落地便利度之间达到平衡。同时,通过与伙伴企业的深入合作和在推理优化上的不懈努力,在一些端侧平台上,GLM-Edge系列模型能以极快的速度运行。 17 | 18 | 例如,在高通骁龙8 Elite平台上,借助其强大的NPU算力,GLM-Edge通过混合量化方案,1.5B对话模型、2B多模态模型能实现每秒60 19 | tokens以上的解码速度。在应用投机采样技术之后,两个模型能以峰值每秒100 tokens以上的解码速度运行。**这些推理方案会由我们或合作伙伴后续放出。暂时不会在本仓库提供。** 20 | 模型下载地址: 21 | 22 | | Model | HuggingFace Model | GGUF Model | 23 | |:------------------:|:--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------:|:-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------:| 24 | | GLM-Edge-1.5B-Chat | [🤗 Huggingface](https://huggingface.co/THUDM/glm-edge-1.5b-chat)
[🤖 ModelScope](https://modelscope.cn/models/ZhipuAI/glm-edge-1.5b-chat)
[🟣 WiseModel](https://wisemodel.cn/models/ZhipuAI/glm-edge-1.5b-chat) | [🤗 Huggingface](https://huggingface.co/THUDM/glm-edge-1.5b-chat-gguf)
[🤖 ModelScope](https://modelscope.cn/models/ZhipuAI/glm-edge-1.5b-chat-gguf)
[🟣 WiseModel](https://wisemodel.cn/models/ZhipuAI/glm-edge-1.5b-chat-gguf) | 25 | | GLM-Edge-4B-Chat | [🤗 Huggingface](https://huggingface.co/THUDM/glm-edge-4b-chat)
[🤖 ModelScope](https://modelscope.cn/models/ZhipuAI/glm-edge-4b-chat)
[🟣 WiseModel](https://wisemodel.cn/models/ZhipuAI/glm-edge-4b-chat) | [🤗 Huggingface](https://huggingface.co/THUDM/glm-edge-4b-chat-gguf)
[🤖 ModelScope](https://modelscope.cn/models/ZhipuAI/glm-edge-4b-chat-gguf)
[🟣 WiseModel](https://wisemodel.cn/models/ZhipuAI/glm-edge-4b-chat-gguf) | 26 | | GLM-Edge-V-2B | [🤗 Huggingface](https://huggingface.co/THUDM/glm-edge-v-2b)
[🤖 ModelScope](https://modelscope.cn/models/ZhipuAI/glm-edge-v-2b)
[🟣 WiseModel](https://wisemodel.cn/models/ZhipuAI/glm-edge-v-2b) | [🤗 Huggingface](https://huggingface.co/THUDM/glm-edge-v-2b-gguf)
[🤖 ModelScope](https://modelscope.cn/models/ZhipuAI/glm-edge-v-2b-gguf)
[🟣 WiseModel](https://wisemodel.cn/models/ZhipuAI/glm-edge-v-2b-gguf) | 27 | | GLM-Edge-V-5B | [🤗 Huggingface](https://huggingface.co/THUDM/glm-edge-v-5b)
[🤖 ModelScope](https://modelscope.cn/models/ZhipuAI/glm-edge-v-5b)
[🟣 WiseModel](https://wisemodel.cn/models/ZhipuAI/glm-edge-v-5b) | [🤗 Huggingface](https://huggingface.co/THUDM/glm-edge-v-5b-gguf)
[🤖 ModelScope](https://modelscope.cn/models/ZhipuAI/glm-edge-v-5b-gguf)
[🟣 WiseModel](https://wisemodel.cn/models/ZhipuAI/glm-edge-v-5b-gguf) | 28 | 29 | ## 实机运行数据 30 | 31 | 数据采集日截止到2024年11月28日。我们还在积极地与合作伙伴们一道优化这些性能。 32 | 33 | ### 高通 34 | 35 | | 模型 | 任务 | 量化方案 | 框架 | 1st token latency (ms) | Token Rate (tokens/s) | Peak Memory Footprint (GB) | 36 | |--------------------|------------------------|------|-----|------------------------|-----------------------|----------------------------| 37 | | GLM-Edge-4B-Chat | (input/output=512/128) | INT4 | QNN | 660 | 24 | 2.9 | 38 | | GLM-Edge-1.5B-Chat | (input/output=512/128) | INT4 | QNN | 260 | 65 | 1.2 | 39 | 40 | * 在高通8 Elite(Gen4)平台上测试,模型全部运行在NPU上 41 | * 如运行V模型,另外需要单图890ms的处理时间和约660M的额外内存 42 | * 使用投机解码方案时,Token Rate还有最高50%的提升 43 | 44 | ### Intel 45 | 46 | | 模型 | 任务 | 量化方案 | 框架 | 1st token latency (ms) | Token Rate (tokens/s) | Peak Memory Footprint (GB) | 47 | |--------------------|--------------------------------------|------|----------|------------------------|-----------------------|----------------------------| 48 | | GLM-Edge-4B-Chat | (input/output=1024/128) | INT4 | OPENVINO | 541.2 | 27 | 3.9 | 49 | | GLM-Edge-1.5B-Chat | (input/output=1024/128) | INT4 | OPENVINO | 228.2 | 63 | 2.3 | 50 | | GLM-Edge-V-2B | Single image understanding (672x672) | INT4 | OPENVINO | 362.1 | 70 | 3.4 | 51 | 52 | * 在Intel LNL 288V (ARC 140V 8X@2.05GHz) 平台上测试。 53 | * 如运行V模型,另外需要单图1.7s的处理时间和约2G的额外内存。 54 | 55 | ## 安装依赖 56 | 57 | 请确保你的Python版本为`3.10`或更高版本。并按照如下方式安装依赖,安装以下依赖能确保正确运行本仓库的所有代码。 58 | 59 | ```shell 60 | pip install -r requirements.txt 61 | ``` 62 | 63 | ## 模型推理 64 | 65 | ### Transformers / OpenVINO / vLLM Demo 66 | 67 | 我们提供了 vLLM, OpenVINO 和 transformers 三种后端推理方式,你可以通过运行以下命令来运行模型。这是一个命令行交互代码。 68 | 69 | ```shell 70 | python cli_demo.py --backend transformers --model_path THUDM/glm-edge-1.5b-chat --precision bfloat16 71 | python cli_demo.py --backend vllm --model_path THUDM/glm-edge-1.5b-chat --precision bfloat16 72 | python cli_demo.py --backend ov --model_path THUDM/glm-edge-1.5b-chat-ov --precision int4 73 | ``` 74 | 75 | > 注意: 76 | > 77 | > OpenVINO 版本模型需要进行转换,请前往 [这里](inference/ov_convert) 运行转换代码。 78 | > 79 | > ```python convert_chat.py --model_path THUDM/glm-edge-1.5b-chat --precision int4 ``` 转换对话模型。 80 | > 81 | > ```python convert.py --model_path THUDM/glm-edge-v-2b --precision int4``` 转换视觉理解模型。 82 | > 83 | > 你也可以在 [这里](https://github.com/openvino-dev-samples/glm-edge.openvino) 查看原始的转换代码。 84 | > 85 | > vLLM 版本模型需要从 [这里](https://github.com/sixsixcoder/vllm/tree/glm-4) 源代码 安装 vLLM 以正常运行。 86 | 87 | 如果你想使用 glm-edge-v 系列模型,你可以运行以下命令行交互代码 88 | 89 | ```shell 90 | python cli_demo_vision.py --backend transformers --model_path THUDM/glm-edge-v-2b --precision bfloat16 91 | python cli_demo.py --backend ov --model_path THUDM/glm-edge-1.5b-chat-ov --precision int4 92 | ``` 93 | 94 | 你也可以使用 Gradio 启动 WebUI。 95 | 96 | ```shell 97 | python cli_demo.py --backend transformers --model_path THUDM/glm-edge-1.5b-chat --precision bfloat16 98 | python cli_demo.py --backend vllm --model_path THUDM/glm-edge-1.5b-chat --precision int4 # For Int4 Inference 99 | ``` 100 | 101 | ### XInference 102 | 103 | 如果你使用 XInference 进行推理,你可以通过运行以下命令来运行模型。这是一个命令行交互代码。 104 | 105 | ```shell 106 | xinference launch --model-engine Transformers --model-name glm-edge-v --size-in-billions 2 --model-format pytorch --quantization none 107 | ``` 108 | 109 | 使用 OpenAI API进行推理: 110 | 111 | ```python 112 | import openai 113 | 114 | client = openai.Client( 115 | api_key="cannot be empty", 116 | base_url="http://:/v1" 117 | ) 118 | output = client.chat.completions.create( 119 | model="glm-edge-v", 120 | messages=[ 121 | { 122 | "role": "user", 123 | "content": [ 124 | { 125 | 'type': 'text', 126 | 'text': 'describe this image', 127 | }, 128 | { 129 | 'type': 'image_url', 130 | 'image_url': { 131 | "url": "img.png", 132 | } 133 | }, 134 | ], 135 | } 136 | ], 137 | max_tokens=512, 138 | temperature=0.7 139 | ) 140 | 141 | print(output) 142 | ``` 143 | 144 | ## 微调模型 145 | 146 | 我们提供了微调模型的代码,请参考 [微调教程](finetune/README.md)。 147 | 148 | ## 协议 149 | 150 | 本 github 仓库代码的使用 [Apache2.0 LICENSE](LICENSE)。 151 | 152 | 模型权重的使用请遵循 [Model License](MODEL_LICENSE)。 153 | -------------------------------------------------------------------------------- /README_en.md: -------------------------------------------------------------------------------- 1 | # GLM-Edge 2 | 3 | Experience the GLM-Edge-1.5B-Chat edge chat model at 🤗 [here](https://huggingface.co/spaces/THUDM-HF-SPACE/GLM-Edge-1.5B-Chat-Space) 4 | 5 | Experience the GLM-Edge-V-5B edge vision-chat model at 🤗 [here](https://huggingface.co/spaces/THUDM-HF-SPACE/GLM-Edge-V-5B-Space) 6 | 7 | ## Model Introduction 8 | 9 | The **GLM-Edge** series is our attempt to meet real-world deployment scenarios for edge devices. It consists of two sizes of 10 | large language dialogue models and multimodal understanding models (`GLM-Edge-1.5B-Chat`, `GLM-Edge-4B-Chat`, 11 | `GLM-Edge-V-2B`, `GLM-Edge-V-5B`). Among them, the `1.5B / 2B` models are mainly targeted at platforms like mobile 12 | phones and car machines, while the `4B / 5B` models are aimed at platforms like PCs. 13 | 14 | Based on the technological advancements of the GLM-4 series, we have made targeted adjustments to the model structure 15 | and size, balancing model performance, real-world inference efficiency, and deployment convenience. Through deep 16 | collaboration with partner enterprises and relentless efforts in inference optimization, the GLM-Edge series models can 17 | run at extremely high speeds on some edge platforms. 18 | 19 | For example, on the Qualcomm Snapdragon 8 Elite platform, leveraging its powerful NPU computing power and using a mixed 20 | quantization scheme, the 1.5B dialogue model and the 2B multimodal model can achieve decoding speeds of over 60 tokens 21 | per second. With speculative sampling techniques, these models can reach peak decoding speeds of over 100 tokens per 22 | second. These inference solutions will be released later by us or our partners. 23 | 24 | Download links for the models: 25 | 26 | | Model | HuggingFace Model | GGUF Model | 27 | |:------------------:|:--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------:|:-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------:| 28 | | GLM-Edge-1.5B-Chat | [🤗 HuggingFace](https://huggingface.co/THUDM/glm-edge-1.5b-chat)
[🤖 ModelScope](https://modelscope.cn/models/ZhipuAI/glm-edge-1.5b-chat)
[🟣 WiseModel](https://wisemodel.cn/models/ZhipuAI/glm-edge-1.5b-chat) | [🤗 HuggingFace](https://huggingface.co/THUDM/glm-edge-1.5b-chat-gguf)
[🤖 ModelScope](https://modelscope.cn/models/ZhipuAI/glm-edge-1.5b-chat-gguf)
[🟣 WiseModel](https://wisemodel.cn/models/ZhipuAI/glm-edge-1.5b-chat-gguf) | 29 | | GLM-Edge-4B-Chat | [🤗 HuggingFace](https://huggingface.co/THUDM/glm-edge-4b-chat)
[🤖 ModelScope](https://modelscope.cn/models/ZhipuAI/glm-edge-4b-chat)
[🟣 WiseModel](https://wisemodel.cn/models/ZhipuAI/glm-edge-4b-chat) | [🤗 HuggingFace](https://huggingface.co/THUDM/glm-edge-4b-chat-gguf)
[🤖 ModelScope](https://modelscope.cn/models/ZhipuAI/glm-edge-4b-chat-gguf)
[🟣 WiseModel](https://wisemodel.cn/models/ZhipuAI/glm-edge-4b-chat-gguf) | 30 | | GLM-Edge-V-2B | [🤗 HuggingFace](https://huggingface.co/THUDM/glm-edge-v-2b)
[🤖 ModelScope](https://modelscope.cn/models/ZhipuAI/glm-edge-v-2b)
[🟣 WiseModel](https://wisemodel.cn/models/ZhipuAI/glm-edge-v-2b) | [🤗 HuggingFace](https://huggingface.co/THUDM/glm-edge-v-2b-gguf)
[🤖 ModelScope](https://modelscope.cn/models/ZhipuAI/glm-edge-v-2b-gguf)
[🟣 WiseModel](https://wisemodel.cn/models/ZhipuAI/glm-edge-v-2b-gguf) | 31 | | GLM-Edge-V-5B | [🤗 HuggingFace](https://huggingface.co/THUDM/glm-edge-v-5b)
[🤖 ModelScope](https://modelscope.cn/models/ZhipuAI/glm-edge-v-5b)
[🟣 WiseModel](https://wisemodel.cn/models/ZhipuAI/glm-edge-v-5b) | [🤗 HuggingFace](https://huggingface.co/THUDM/glm-edge-v-5b-gguf)
[🤖 ModelScope](https://modelscope.cn/models/ZhipuAI/glm-edge-v-5b-gguf)
[🟣 WiseModel](https://wisemodel.cn/models/ZhipuAI/glm-edge-v-5b-gguf) | 32 | 33 | ## Performance Data 34 | 35 | Data collection is up to November 28, 2024. We are actively working with partners to optimize these performances. 36 | 37 | ### Qualcomm 38 | 39 | | Model | Task | Quantization | Framework | 1st Token Latency (ms) | Token Rate (tokens/s) | Peak Memory Footprint (GB) | 40 | |--------------------|------------------------|--------------|-----------|------------------------|-----------------------|----------------------------| 41 | | GLM-Edge-4B-Chat | (input/output=512/128) | INT4 | QNN | 260 | 65 | 2.9 | 42 | | GLM-Edge-1.5B-Chat | (input/output=512/128) | INT4 | QNN | 660 | 24 | 1.2 | 43 | 44 | - Tested on the Qualcomm 8 Elite (Gen4) platform with models fully running on the NPU. 45 | - For V models, an additional 890ms processing time per image and about 660MB extra memory is required. 46 | - With speculative decoding, the Token Rate can achieve up to 50% improvement. 47 | 48 | ### Intel 49 | 50 | | Model | Task | Quantization | Framework | 1st Token Latency (ms) | Token Rate (tokens/s) | Peak Memory Footprint (GB) | 51 | |--------------------|--------------------------------------|--------------|-----------|------------------------|-----------------------|----------------------------| 52 | | GLM-Edge-4B-Chat | (input/output=1024/128) | INT4 | OPENVINO | 541.2 | 27 | 3.9 | 53 | | GLM-Edge-1.5B-Chat | (input/output=1024/128) | INT4 | OPENVINO | 228.2 | 63 | 2.3 | 54 | | GLM-Edge-V-2B | Single image understanding (672x672) | INT4 | OPENVINO | 362.1 | 70 | 3.4 | 55 | 56 | - Tested on the Intel LNL 288V (ARC 140V 8X@2.05GHz) platform. 57 | - For V models, an additional 1.7s processing time per image and about 2GB extra memory is required. 58 | 59 | ## Install Dependencies 60 | 61 | Ensure your Python version is `3.10` or higher. Install dependencies as follows to ensure all code in this repository 62 | runs correctly: 63 | 64 | ```shell 65 | pip install -r requirements.txt 66 | ``` 67 | 68 | ## Model Inference 69 | 70 | ### Transformers / OpenVINO / vLLM Demo 71 | 72 | We provide three backend inference options: vLLM, OpenVINO, and transformers. You can run the models using the following 73 | commands. This is a command-line interaction code. 74 | 75 | ```shell 76 | python cli_demo.py --backend transformers --model_path THUDM/glm-edge-1.5b-chat --precision bfloat16 77 | python cli_demo.py --backend vllm --model_path THUDM/glm-edge-1.5b-chat --precision bfloat16 78 | python cli_demo.py --backend ov --model_path THUDM/glm-edge-1.5b-chat-ov --precision int4 79 | ``` 80 | 81 | > Note: 82 | > 83 | > OpenVINO version models need conversion. Please visit [here](inference/ov_convert) to run the conversion code. 84 | > 85 | > ```python convert_chat.py --model_path THUDM/glm-edge-1.5b-chat --precision int4 ``` to convert dialogue models. 86 | > 87 | > ```python convert.py --model_path THUDM/glm-edge-v-2b --precision int4``` to convert visual understanding models. 88 | > 89 | > You can also view the original conversion code [here](https://github.com/openvino-dev-samples/glm-edge.openvino). 90 | > 91 | > vLLM version models require installation of source code from [here](https://github.com/sixsixcoder/vllm/tree/glm-4) to 92 | > run properly. 93 | 94 | To use glm-edge-v series models, you can run the following command-line interaction code: 95 | 96 | ```shell 97 | python cli_demo_vision.py --backend transformers --model_path THUDM/glm-edge-v-2b --precision bfloat16 98 | python cli_demo.py --backend ov --model_path THUDM/glm-edge-1.5b-chat-ov --precision int4 99 | ``` 100 | 101 | You can also use Gradio to launch a WebUI. 102 | 103 | ```shell 104 | python cli_demo.py --backend transformers --model_path THUDM/glm-edge-1.5b-chat --precision bfloat16 105 | python cli_demo.py --backend vllm --model_path THUDM/glm-edge-1.5b-chat --precision int4 # For Int4 Inference 106 | ``` 107 | 108 | ### XInference 109 | 110 | If you use XInference for inference, you can run the model using the following commands. This is a command-line 111 | interaction code. 112 | 113 | ```shell 114 | xinference launch --model-engine Transformers --model-name glm-edge-v --size-in-billions 2 --model-format pytorch --quantization none 115 | ``` 116 | 117 | Using OpenAI API for inference: 118 | 119 | ```python 120 | import openai 121 | 122 | client = openai.Client( 123 | api_key="cannot be empty", 124 | base_url="http://:/v1" 125 | ) 126 | output = client.chat.completions.create( 127 | model="glm-edge-v", 128 | messages=[ 129 | { 130 | "role": "user", 131 | "content": [ 132 | { 133 | 'type': 'text', 134 | 'text': 'describe this image', 135 | }, 136 | { 137 | 'type': 'image_url', 138 | 'image_url': { 139 | "url": "img.png", 140 | } 141 | }, 142 | ], 143 | } 144 | ], 145 | max_tokens=512, 146 | temperature=0.7 147 | ) 148 | 149 | print(output) 150 | ``` 151 | 152 | ## Fine-Tuning Models 153 | 154 | We provide code for fine-tuning models. Please refer to the [Fine-Tuning Tutorial](finetune/README.md). 155 | 156 | ## License 157 | 158 | The code in this GitHub repository uses the [Apache2.0 LICENSE](LICENSE). 159 | 160 | Usage of model weights must follow the [Model License](MODEL_LICENSE). 161 | -------------------------------------------------------------------------------- /finetune/README.md: -------------------------------------------------------------------------------- 1 | # GLM-Edge 对话模型微调 2 | 3 | Read this in [English](README_en.md) 4 | 5 | 本 demo 中,你将体验到如何微调 GLM-Edge 对话开源模型。 请严格按照文档的步骤进行操作,以避免不必要的错误。 6 | 7 | ## 多轮对话格式 8 | 9 | 多轮对话微调示例采用 GLM-Edge 对话格式约定,对不同角色添加不同 `loss_mask` 从而在一遍计算中为多轮回复计算 `loss`。 10 | 11 | 对于数据文件,样例采用如下格式: 12 | 13 | 对于glm-edge-chat系列模型,您应该按照以下格式整理数据。 14 | 15 | ```json 16 | [ 17 | { 18 | "messages": [ 19 | { 20 | "role": "system", 21 | "content": "", 22 | }, 23 | { 24 | "role": "user", 25 | "content": "" 26 | }, 27 | { 28 | "role": "assistant", 29 | "content": "" 30 | }, 31 | // Multi_turns 32 | { 33 | "role": "user", 34 | "content": "" 35 | }, 36 | { 37 | "role": "assistant", 38 | "content": "" 39 | }, 40 | ] 41 | } 42 | ] 43 | ``` 44 | 45 | 这里是一个单轮对话的例子: 46 | 47 | ```json 48 | { 49 | "messages": [ 50 | { 51 | "role": "user", 52 | "content": "类型#裤*材质#牛仔布*风格#性感" 53 | }, 54 | { 55 | "role": "assistant", 56 | "content": "3x1的这款牛仔裤采用浅白的牛仔面料为裤身材质,其柔然的手感和细腻的质地,在穿着舒适的同时,透露着清纯甜美的个性气质。除此之外,流畅的裤身剪裁将性感的>腿部曲线彰显的淋漓尽致,不失为一款随性出街的必备单品。" 57 | } 58 | ] 59 | } 60 | ``` 61 | 62 | 对于glm-edge-v系列模型,您应该按照以下格式整理数据。 63 | 64 | ```json 65 | [ 66 | { 67 | "messages": [ 68 | { 69 | "role": "user", 70 | "content": [ 71 | { 72 | "type": "image", 73 | "image": "path/to/image" 74 | }, 75 | { 76 | "type": "text", 77 | "text": "图片中的狗在做什么?" 78 | } 79 | ] 80 | }, 81 | { 82 | "role": "assistant", 83 | "content": [ 84 | { 85 | "type": "text", 86 | "text": "zRzRzRzRzRzRzR!这只狗躺在公寓客厅的绿色沙发上。" 87 | } 88 | ] 89 | }, 90 | { 91 | "role": "user", 92 | "content": [ 93 | { 94 | "type": "text", 95 | "text": "这只狗是什么颜色的?" 96 | } 97 | ] 98 | }, 99 | { 100 | "role": "assistant", 101 | "content": [ 102 | { 103 | "type": "text", 104 | "text": "zRzRzRzRzRzRzR!这只狗是棕色和白色的。" 105 | } 106 | ] 107 | } 108 | ] 109 | } 110 | ] 111 | ``` 112 | 113 | ## 配置文件 114 | 115 | 微调配置文件位于 `config` 目录下,包括以下文件: 116 | 117 | 1. `ds_zereo_2 / ds_zereo_3.json`: deepspeed 配置文件。 118 | 119 | 2. `lora.yaml / sft.yaml`: 模型不同方式的配置文件,包括模型参数、优化器参数、训练参数等。 部分重要参数解释如下: 120 | + data_config 部分 121 | + train_file: 训练数据集的文件路径。 122 | + val_file: 验证数据集的文件路径。 123 | + test_file: 测试数据集的文件路径。 124 | + num_proc: 在加载数据时使用的进程数量。 125 | + max_input_length: 输入序列的最大长度,对于glm-edge-v系列模型,由于图片占位token个数是584,因此值需要设置大些。 126 | + max_output_length: 输出序列的最大长度。 127 | + training_args 部分 128 | + output_dir: 用于保存模型和其他输出的目录。 129 | + max_steps: 训练的最大步数。 130 | + per_device_train_batch_size: 每个设备(如 GPU)的训练批次大小。 131 | + dataloader_num_workers: 加载数据时使用的工作线程数量。 132 | + remove_unused_columns: 是否移除数据中未使用的列。 133 | + save_strategy: 模型保存策略(例如,每隔多少步保存一次)。 134 | + save_steps: 每隔多少步保存一次模型。 135 | + log_level: 日志级别(如 info)。 136 | + logging_strategy: 日志记录策略。 137 | + logging_steps: 每隔多少步记录一次日志。 138 | + per_device_eval_batch_size: 每个设备的评估批次大小。 139 | + evaluation_strategy: 评估策略(例如,每隔多少步进行一次评估)。 140 | + eval_steps: 每隔多少步进行一次评估。 141 | + predict_with_generate: 是否使用生成模式进行预测。 142 | + generation_config 部分 143 | + max_new_tokens: 生成的最大新 token 数量。 144 | + peft_config 部分 145 | + peft_type: 使用的参数有效调整类型 (支持 LORA 和 PREFIX_TUNING)。 146 | + task_type: 任务类型,这里是因果语言模型 (不要改动)。 147 | + Lora 参数: 148 | + r: LoRA 的秩。 149 | + lora_alpha: LoRA 的缩放因子。 150 | + lora_dropout: 在 LoRA 层使用的 dropout 概率。 151 | 152 | ## 开始微调 153 | 154 | 通过以下代码执行 **单机多卡/多机多卡** 运行,这是使用 `deepspeed` 作为加速方案的,您需要安装 `deepspeed`。接着,按照此命令运行: 155 | 156 | ```shell 157 | OMP_NUM_THREADS=1 torchrun --standalone --nnodes=1 --nproc_per_node=8 finetune.py data/AdvertiseGen/ THUDM/glm-edge-4b-chat configs/lora.yaml # For Chat Fine-tune 158 | OMP_NUM_THREADS=1 torchrun --standalone --nnodes=1 --nproc_per_node=8 finetune_vision.py data/CogVLM-311K/ THUDM/glm-edge-v-5b configs/lora.yaml # For VQA Fine-tune 159 | ``` 160 | 161 | 通过以下代码执行 **单机单卡** 运行。 162 | 163 | ```shell 164 | python finetune.py data/AdvertiseGen/ THUDM/glm-edge-4b-chat configs/lora.yaml # For Chat Fine-tune 165 | python finetune_vision.py data/CogVLM-311K/ THUDM/glm-edge-v-5b configs/lora.yaml # For VQA Fine-tune 166 | ``` 167 | 168 | ## 从保存点进行微调 169 | 170 | 如果按照上述方式进行训练,每次微调都会从头开始,如果你想从训练一半的模型开始微调,你可以加入第四个参数,这个参数有两种传入方式: 171 | 172 | 1. `yes`, 自动从最后一个保存的 Checkpoint开始训练 173 | 2. `XX`, 断点号数字 例 `600` 则从序号600 Checkpoint开始训练 174 | 175 | 例如,这就是一个从最后一个保存点继续微调的示例代码 176 | 177 | ```shell 178 | python finetune.py data/AdvertiseGen/ THUDM/glm-edge-4b-chat configs/lora.yaml yes 179 | python finetune_vision.py data/CogVLM-311K/ THUDM/glm-edge-4b-chat configs/lora.yaml yes 180 | ``` 181 | -------------------------------------------------------------------------------- /finetune/README_en.md: -------------------------------------------------------------------------------- 1 | # GLM-Edge dialogue model fine-tuning 2 | 3 | Read this in [Chinese](README.md) 4 | 5 | In this demo, you will experience how to fine-tune the GLM-Edge-4B-Chat open source dialogue model. Please strictly follow the steps in the document to avoid unnecessary errors. 6 | 7 | ## Multi-turn dialogue format 8 | 9 | The multi-turn dialogue fine-tuning example uses the GLM-Edge dialogue format convention, adding different `loss_mask` to different roles to calculate `loss` for multiple rounds of replies in one calculation. 10 | 11 | For data files, the sample uses the following format: 12 | 13 | For the glm-edge-chat family of models, you should organize your data in the following format. 14 | 15 | ```json 16 | [ 17 | { 18 | "messages": [ 19 | { 20 | "role": "system", 21 | "content": "", 22 | }, 23 | { 24 | "role": "user", 25 | "content": "" 26 | }, 27 | { 28 | "role": "assistant", 29 | "content": "" 30 | }, 31 | // Multi_turns 32 | { 33 | "role": "user", 34 | "content": "" 35 | }, 36 | { 37 | "role": "assistant", 38 | "content": "" 39 | }, 40 | ] 41 | } 42 | ] 43 | ``` 44 | 45 | Here is an example of a single-turn conversation: 46 | 47 | ```json 48 | { 49 | "messages": [ 50 | { 51 | "role": "user", 52 | "content": "Type#Pants*Material#Denim*Style#Sexy" 53 | }, 54 | { 55 | "role": "assistant", 56 | "content": "This pair of jeans from 3x1 is made of light white denim fabric. Its soft feel and delicate texture make it comfortable to wear while revealing a pure and sweet personality. In addition, the smooth cut of the pants fully highlights the sexy leg curves, making it a must-have item for casual outings." 57 | } 58 | ] 59 | } 60 | ``` 61 | 62 | For glm-edge-v family of models, you should organize your data in the following format. 63 | 64 | ```json 65 | [ 66 | { 67 | "messages": [ 68 | { 69 | "role": "user", 70 | "content": [ 71 | { 72 | "type": "image", 73 | "image": "path/to/image" 74 | }, 75 | { 76 | "type": "text", 77 | "text": "图片中的狗在做什么?" 78 | } 79 | ] 80 | }, 81 | { 82 | "role": "assistant", 83 | "content": [ 84 | { 85 | "type": "text", 86 | "text": "zRzRzRzRzRzRzR!这只狗躺在公寓客厅的绿色沙发上。" 87 | } 88 | ] 89 | }, 90 | { 91 | "role": "user", 92 | "content": [ 93 | { 94 | "type": "text", 95 | "text": "这只狗是什么颜色的?" 96 | } 97 | ] 98 | }, 99 | { 100 | "role": "assistant", 101 | "content": [ 102 | { 103 | "type": "text", 104 | "text": "zRzRzRzRzRzRzR!这只狗是棕色和白色的。" 105 | } 106 | ] 107 | } 108 | ] 109 | } 110 | ] 111 | ``` 112 | 113 | ## Configuration Files 114 | 115 | The fine-tuning configuration files are located in the `config` directory and include the following files: 116 | 117 | 1. `ds_zereo_2 / ds_zereo_3.json`: deepspeed configuration file 118 | 119 | 2. `lora.yaml / sft.yaml`: Configuration files of different models, including model parameters, optimizer parameters, training parameters, etc. Some important parameters are explained as follows: 120 | 121 | + train_file: File path of training dataset. 122 | + val_file: File path of validation dataset. 123 | + test_file: File path of test dataset. 124 | + num_proc: Number of processes to use when loading data. 125 | + max_input_length: Maximum length of input sequence, since the number of image placeholder token is 584, the value needs to be set larger. 126 | + max_output_length: Maximum length of output sequence. 127 | + training_args section 128 | + output_dir: Directory for saving model and other outputs. 129 | + max_steps: Maximum number of training steps. 130 | + per_device_train_batch_size: Training batch size per device (such as GPU). 131 | + dataloader_num_workers: Number of worker threads to use when loading data. 132 | + remove_unused_columns: Whether to remove unused columns in data. 133 | + save_strategy: Model saving strategy (for example, how many steps to save). 134 | + save_steps: How many steps to save the model. 135 | + log_level: Log level (such as info). 136 | + logging_strategy: logging strategy. 137 | + logging_steps: how many steps to log at. 138 | + per_device_eval_batch_size: per-device evaluation batch size. 139 | + evaluation_strategy: evaluation strategy (e.g. how many steps to evaluate at). 140 | + eval_steps: how many steps to evaluate at. 141 | + predict_with_generate: whether to use generation mode for prediction. 142 | + generation_config section 143 | + max_new_tokens: maximum number of new tokens to generate. 144 | + peft_config section 145 | + peft_type: type of parameter tuning to use (supports LORA and PREFIX_TUNING). 146 | + task_type: task type, here is causal language model (don't change). 147 | + Lora parameters: 148 | + r: rank of LoRA. 149 | + lora_alpha: scaling factor of LoRA. 150 | + lora_dropout: dropout probability to use in LoRA layer. 151 | 152 | ## Start fine-tuning 153 | 154 | Execute **single machine multi-card/multi-machine multi-card** run through the following code, which uses `deepspeed` as 155 | the acceleration solution, and you need to install `deepspeed`. 156 | 157 | ```shell 158 | OMP_NUM_THREADS=1 torchrun --standalone --nnodes=1 --nproc_per_node=8 finetune.py data/AdvertiseGen/ THUDM/glm-edge-4b-chat configs/lora.yaml # For Chat Fine-tune 159 | OMP_NUM_THREADS=1 torchrun --standalone --nnodes=1 --nproc_per_node=8 finetune_vision.py data/CogVLM-311K/ THUDM/glm-edge-v-5b configs/lora.yaml # For VQA Fine-tune 160 | ``` 161 | 162 | Execute **single machine single card** run through the following code. 163 | 164 | ```shell 165 | python finetune.py data/AdvertiseGen/ THUDM/glm-edge-4b-chat configs/lora.yaml # For Chat Fine-tune 166 | python finetune_vision.py data/CogVLM-311K/ THUDM/glm-edge-v-5b configs/lora.yaml # For VQA Fine-tune 167 | ``` 168 | 169 | ## Fine-tune from a saved point 170 | 171 | If you train as described above, each fine-tuning will start from the beginning. If you want to fine-tune from a 172 | half-trained model, you can add a fourth parameter, which can be passed in two ways: 173 | 174 | 1. `yes`, automatically start training from the last saved Checkpoint 175 | 176 | 2. `XX`, breakpoint number, for example `600`, start training from Checkpoint 600 177 | 178 | For example, this is an example code to continue fine-tuning from the last saved point 179 | 180 | ```shell 181 | python finetune.py data/AdvertiseGen/ THUDM/glm-edge-4b-chat configs/lora.yaml yes 182 | python finetune_vision.py data/CogVLM-311K/ THUDM/glm-edge-4b-chat configs/lora.yaml yes 183 | ``` 184 | -------------------------------------------------------------------------------- /finetune/configs/ds_zero_3.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_micro_batch_size_per_gpu": "auto", 3 | "zero_allow_untested_optimizer": true, 4 | "bf16": { 5 | "enabled": "auto" 6 | }, 7 | "optimizer": { 8 | "type": "AdamW", 9 | "params": { 10 | "lr": "auto", 11 | "betas": "auto", 12 | "eps": "auto", 13 | "weight_decay": "auto" 14 | } 15 | }, 16 | "zero_optimization": { 17 | "stage": 3, 18 | "allgather_partitions": true, 19 | "allgather_bucket_size": 5e8, 20 | "reduce_scatter": true, 21 | "contiguous_gradients": true, 22 | "overlap_comm": true, 23 | "sub_group_size": 1e9, 24 | "reduce_bucket_size": "auto", 25 | "stage3_prefetch_bucket_size": "auto", 26 | "stage3_param_persistence_threshold": "auto", 27 | "stage3_max_live_parameters": 1e9, 28 | "stage3_max_reuse_distance": 1e9, 29 | "stage3_gather_16bit_weights_on_model_save": true 30 | } 31 | } -------------------------------------------------------------------------------- /finetune/configs/lora.yaml: -------------------------------------------------------------------------------- 1 | data_config: 2 | train_file: train.jsonl 3 | val_file: dev.jsonl 4 | test_file: dev.jsonl 5 | num_proc: 1 6 | 7 | freezeV: False 8 | max_input_length: 2048 # For Image Must larger than 578 9 | max_output_length: 1024 10 | 11 | training_args: 12 | bf16: True 13 | # see `transformers.Seq2SeqTrainingArguments` 14 | output_dir: ./output 15 | max_steps: 3000 16 | # needed to be fit for the dataset 17 | learning_rate: 5e-4 18 | # settings for data loading 19 | per_device_train_batch_size: 16 20 | dataloader_num_workers: 16 21 | remove_unused_columns: false 22 | # settings for saving checkpoints 23 | save_strategy: steps 24 | save_steps: 500 25 | # settings for logging 26 | log_level: info 27 | logging_strategy: steps 28 | logging_steps: 10 29 | # settings for evaluation 30 | per_device_eval_batch_size: 16 31 | eval_strategy: steps 32 | eval_steps: 1000 33 | # adam_epsilon: 1e-6 34 | predict_with_generate: true 35 | generation_config: 36 | max_new_tokens: 512 37 | peft_config: 38 | peft_type: LORA 39 | task_type: CAUSAL_LM 40 | r: 8 41 | lora_alpha: 32 42 | lora_dropout: 0.1 43 | target_modules: ["q_proj", "k_proj", "v_proj"] 44 | -------------------------------------------------------------------------------- /finetune/configs/sft.yaml: -------------------------------------------------------------------------------- 1 | data_config: 2 | train_file: train.jsonl 3 | val_file: dev.jsonl 4 | test_file: dev.jsonl 5 | num_proc: 1 6 | 7 | combine: True 8 | max_input_length: 2048 # For Image Must larger than 578 9 | max_output_length: 1024 10 | 11 | training_args: 12 | bf16: True 13 | # see `transformers.Seq2SeqTrainingArguments` 14 | output_dir: ./output 15 | max_steps: 3000 16 | # needed to be fit for the dataset 17 | learning_rate: 5e-5 18 | # settings for data loading 19 | per_device_train_batch_size: 4 20 | dataloader_num_workers: 16 21 | remove_unused_columns: false 22 | # settings for saving checkpoints 23 | save_strategy: steps 24 | save_steps: 500 25 | # settings for logging 26 | log_level: info 27 | logging_strategy: steps 28 | logging_steps: 10 29 | # settings for evaluation 30 | per_device_eval_batch_size: 16 31 | eval_strategy: steps 32 | eval_steps: 1000 33 | # settings for optimizer 34 | adam_epsilon: 1e-6 35 | predict_with_generate: true 36 | generation_config: 37 | max_new_tokens: 512 38 | # set your absolute deepspeed path here 39 | deepspeed: configs/ds_zero_3.json 40 | -------------------------------------------------------------------------------- /finetune/finetune.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | import jieba 4 | import dataclasses as dc 5 | import functools 6 | from collections.abc import Callable, Mapping, Sequence 7 | from pathlib import Path 8 | from typing import Annotated, Any, Union 9 | import numpy as np 10 | import ruamel.yaml as yaml 11 | import torch 12 | import typer 13 | from datasets import Dataset, Split 14 | from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction 15 | from peft import PeftConfig, get_peft_config, get_peft_model 16 | from rouge_chinese import Rouge 17 | from torch import nn 18 | from transformers import ( 19 | AutoModelForCausalLM, 20 | AutoTokenizer, 21 | EvalPrediction, 22 | GenerationConfig, 23 | PreTrainedTokenizer, 24 | Seq2SeqTrainingArguments, 25 | ) 26 | from transformers import DataCollatorForSeq2Seq as _DataCollatorForSeq2Seq 27 | from transformers import Seq2SeqTrainer as _Seq2SeqTrainer 28 | from datasets import load_dataset, DatasetDict, NamedSplit 29 | from typing import Optional 30 | 31 | app = typer.Typer(pretty_exceptions_show_locals=False) 32 | 33 | 34 | class DataCollatorForSeq2Seq(_DataCollatorForSeq2Seq): 35 | def __call__(self, features, return_tensors=None): 36 | output_ids = [feature["output_ids"] for feature in features] if "output_ids" in features[0].keys() else None 37 | if output_ids is not None: 38 | max_output_length = max(len(out) for out in output_ids) 39 | if self.pad_to_multiple_of is not None: 40 | max_output_length = ( 41 | (max_output_length + self.pad_to_multiple_of - 1) 42 | // self.pad_to_multiple_of 43 | * self.pad_to_multiple_of 44 | ) 45 | for feature in features: 46 | remainder = [self.tokenizer.pad_token_id] * (max_output_length - len(feature["output_ids"])) 47 | if isinstance(feature["output_ids"], list): 48 | feature["output_ids"] = feature["output_ids"] + remainder 49 | else: 50 | feature["output_ids"] = np.concatenate([feature["output_ids"], remainder]).astype(np.int64) 51 | return super().__call__(features, return_tensors) 52 | 53 | 54 | class Seq2SeqTrainer(_Seq2SeqTrainer): 55 | def prediction_step( 56 | self, 57 | model: nn.Module, 58 | inputs: dict[str, Any], 59 | prediction_loss_only: bool, 60 | ignore_keys=None, 61 | **gen_kwargs, 62 | ) -> tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: 63 | with torch.no_grad(): # Ensure no gradient computation 64 | if self.args.predict_with_generate: 65 | output_ids = inputs.pop("output_ids") 66 | input_ids = inputs["input_ids"] 67 | 68 | if "labels" in inputs: 69 | del inputs["labels"] 70 | 71 | loss, generated_tokens, labels = super().prediction_step( 72 | model, inputs, prediction_loss_only, ignore_keys, **gen_kwargs 73 | ) 74 | 75 | generated_tokens = generated_tokens[:, input_ids.size()[1] :] 76 | labels = output_ids 77 | 78 | del inputs, input_ids, output_ids 79 | torch.cuda.empty_cache() 80 | 81 | return loss, generated_tokens, labels 82 | 83 | 84 | @dc.dataclass 85 | class DataConfig(object): 86 | train_file: Optional[str] = None 87 | val_file: Optional[str] = None 88 | test_file: Optional[str] = None 89 | num_proc: Optional[int] = None 90 | 91 | @property 92 | def data_format(self) -> str: 93 | return Path(self.train_file).suffix 94 | 95 | @property 96 | def data_files(self) -> dict[NamedSplit, str]: 97 | return { 98 | split: data_file 99 | for split, data_file in zip( 100 | [Split.TRAIN, Split.VALIDATION, Split.TEST], 101 | [self.train_file, self.val_file, self.test_file], 102 | ) 103 | if data_file is not None 104 | } 105 | 106 | 107 | @dc.dataclass 108 | class FinetuningConfig(object): 109 | data_config: DataConfig 110 | 111 | max_input_length: int 112 | max_output_length: int 113 | freezeV: bool 114 | 115 | training_args: Seq2SeqTrainingArguments = dc.field( 116 | default_factory=lambda: Seq2SeqTrainingArguments(output_dir="./output") 117 | ) 118 | peft_config: Optional[PeftConfig] = None 119 | 120 | def __post_init__(self): 121 | if not self.training_args.do_eval or self.data_config.val_file is None: 122 | self.training_args.do_eval = False 123 | self.training_args.evaluation_strategy = "no" 124 | self.data_config.val_file = None 125 | else: 126 | self.training_args.per_device_eval_batch_size = ( 127 | self.training_args.per_device_eval_batch_size or self.training_args.per_device_train_batch_size 128 | ) 129 | 130 | @classmethod 131 | def from_dict(cls, **kwargs) -> "FinetuningConfig": 132 | training_args = kwargs.get("training_args", None) 133 | if training_args is not None and not isinstance(training_args, Seq2SeqTrainingArguments): 134 | gen_config = training_args.get("generation_config") 135 | if not isinstance(gen_config, GenerationConfig): 136 | training_args["generation_config"] = GenerationConfig(**gen_config) 137 | kwargs["training_args"] = Seq2SeqTrainingArguments(**training_args) 138 | 139 | data_config = kwargs.get("data_config") 140 | if not isinstance(data_config, DataConfig): 141 | kwargs["data_config"] = DataConfig(**data_config) 142 | 143 | peft_config = kwargs.get("peft_config", None) 144 | if peft_config is not None and not isinstance(peft_config, PeftConfig): 145 | kwargs["peft_config"] = get_peft_config(config_dict=peft_config) 146 | return cls(**kwargs) 147 | 148 | @classmethod 149 | def from_file(cls, path: Union[str, Path]) -> "FinetuningConfig": 150 | path = Path(path) 151 | parser = yaml.YAML(typ="safe", pure=True) 152 | parser.indent(mapping=2, offset=2, sequence=4) 153 | parser.default_flow_style = False 154 | kwargs = parser.load(path) 155 | return cls.from_dict(**kwargs) 156 | 157 | 158 | def _load_datasets( 159 | data_dir: str, 160 | data_format: str, 161 | data_files: dict[NamedSplit, str], 162 | num_proc: Optional[int], 163 | ) -> DatasetDict: 164 | if data_format == ".jsonl": 165 | dataset_dct = load_dataset( 166 | data_dir, 167 | data_files=data_files, 168 | split=None, 169 | num_proc=num_proc, 170 | ) 171 | else: 172 | raise NotImplementedError(f"Cannot load dataset in the '{data_format}' format.") 173 | return dataset_dct 174 | 175 | 176 | class DataManager(object): 177 | def __init__(self, data_dir: str, data_config: DataConfig): 178 | self._num_proc = data_config.num_proc 179 | 180 | self._dataset_dct = _load_datasets( 181 | data_dir, 182 | data_config.data_format, 183 | data_config.data_files, 184 | self._num_proc, 185 | ) 186 | 187 | def _get_dataset(self, split: NamedSplit) -> Optional[Dataset]: 188 | return self._dataset_dct.get(split, None) 189 | 190 | def get_dataset( 191 | self, 192 | split: NamedSplit, 193 | process_fn: Callable[[dict[str, Any]], dict[str, Any]], 194 | batched: bool = True, 195 | remove_orig_columns: bool = True, 196 | ) -> Optional[Dataset]: 197 | orig_dataset = self._get_dataset(split) 198 | if orig_dataset is None: 199 | return 200 | 201 | if remove_orig_columns: 202 | remove_columns = orig_dataset.column_names 203 | else: 204 | remove_columns = None 205 | return orig_dataset.map( 206 | process_fn, 207 | batched=batched, 208 | remove_columns=remove_columns, 209 | num_proc=self._num_proc, 210 | ) 211 | 212 | 213 | def process_message(message): 214 | if "tools" in message and message["role"] == "system": 215 | for tool in message["tools"]: 216 | parameters = tool["function"]["parameters"]["properties"] 217 | tool["function"]["parameters"]["properties"] = {k: v for k, v in parameters.items() if v is not None} 218 | elif "tools" in message: 219 | del message["tools"] 220 | return message 221 | 222 | 223 | def process_batch( 224 | batch: Mapping[str, Sequence], 225 | tokenizer: PreTrainedTokenizer, 226 | max_input_length: int, 227 | max_output_length: int, 228 | ) -> dict[str, list]: 229 | batched_conv = batch["messages"] 230 | batched_input_ids = [] 231 | batched_labels = [] 232 | for conv in batched_conv: 233 | new_input_ids = tokenizer.apply_chat_template( 234 | conv, tokenize=True, return_dict=False, add_generation_prompt=False 235 | ) 236 | input_ids = new_input_ids 237 | loss_masks = [False] * len(input_ids) 238 | last_assistant_index = len(input_ids) - input_ids[::-1].index(59254) - 1 # <|assistant|> 239 | for j in range(last_assistant_index + 1, len(input_ids)): 240 | loss_masks[j] = True 241 | 242 | input_ids.append(59253) # EOS for chat <|user|> 243 | loss_masks = [False, *loss_masks] 244 | labels = [] 245 | for input_id, mask in zip(input_ids, loss_masks): 246 | if mask: 247 | labels.append(input_id) 248 | else: 249 | labels.append(-100) 250 | max_length = max_input_length + max_output_length + 1 251 | batched_input_ids.append(input_ids[:max_length]) 252 | batched_labels.append(labels[:max_length]) 253 | 254 | del batched_conv, conv, input_ids, loss_masks, new_input_ids, labels 255 | torch.cuda.empty_cache() 256 | 257 | return {"input_ids": batched_input_ids, "labels": batched_labels} 258 | 259 | 260 | def process_batch_eval( 261 | batch: Mapping[str, Sequence], 262 | tokenizer: PreTrainedTokenizer, 263 | max_input_length: int, 264 | max_output_length: int, 265 | ) -> dict[str, list]: 266 | batched_conv = batch["messages"] 267 | batched_input_ids = [] 268 | batched_output_ids = [] 269 | 270 | for conv in batched_conv: 271 | new_input_ids = tokenizer.apply_chat_template(conv, tokenize=True, return_dict=False) 272 | input_ids = new_input_ids 273 | last_assistant_index = len(input_ids) - input_ids[::-1].index(59254) - 1 274 | output_prompt, output_ids = ( 275 | input_ids[:1], 276 | input_ids[last_assistant_index:], 277 | ) 278 | output_ids.append(59253) 279 | batched_input_ids.append(input_ids[:max_input_length] + output_prompt[:1]) 280 | batched_output_ids.append(output_ids[:max_output_length]) 281 | 282 | del batched_conv, conv, input_ids, new_input_ids, output_prompt, output_ids 283 | torch.cuda.empty_cache() 284 | 285 | return {"input_ids": batched_input_ids, "output_ids": batched_output_ids} 286 | 287 | 288 | def load_tokenizer_and_model( 289 | model_dir: str, 290 | peft_config: Optional[PeftConfig] = None, 291 | ): 292 | tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True, padding_side="left") 293 | 294 | model = AutoModelForCausalLM.from_pretrained( 295 | model_dir, 296 | trust_remote_code=True, 297 | use_cache=False, 298 | torch_dtype=torch.bfloat16, # Must use BFloat 16 299 | ) 300 | 301 | if peft_config is not None: 302 | model = get_peft_model(model, peft_config) 303 | model.print_trainable_parameters() 304 | 305 | return tokenizer, model 306 | 307 | 308 | def compute_metrics(eval_preds: EvalPrediction, tokenizer): 309 | batched_pred_ids, batched_label_ids = eval_preds 310 | batched_pred_ids[batched_pred_ids == -100] = tokenizer.pad_token_id 311 | batched_label_ids[batched_label_ids == -100] = tokenizer.pad_token_id 312 | metrics_dct = {"rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []} 313 | for pred_ids, label_ids in zip(batched_pred_ids, batched_label_ids): 314 | pred_txt = tokenizer.decode(pred_ids, skip_special_tokens=True).strip() 315 | label_txt = tokenizer.decode(label_ids, skip_special_tokens=True).strip() 316 | pred_tokens = list(jieba.cut(pred_txt)) 317 | label_tokens = list(jieba.cut(label_txt)) 318 | rouge = Rouge() 319 | scores = rouge.get_scores(" ".join(pred_tokens), " ".join(label_tokens)) 320 | for k, v in scores[0].items(): 321 | metrics_dct[k].append(round(v["f"] * 100, 4)) 322 | metrics_dct["bleu-4"].append( 323 | sentence_bleu([label_tokens], pred_tokens, smoothing_function=SmoothingFunction().method3) 324 | ) 325 | return {k: np.mean(v) for k, v in metrics_dct.items()} 326 | 327 | 328 | @app.command() 329 | def main( 330 | data_dir: Annotated[str, typer.Argument(help="")], 331 | model_dir: Annotated[ 332 | str, 333 | typer.Argument( 334 | help="A string that specifies the model id of a pretrained model configuration hosted on huggingface.co, or a path to a directory containing a model configuration file." 335 | ), 336 | ], 337 | config_file: Annotated[str, typer.Argument(help="")], 338 | auto_resume_from_checkpoint: str = typer.Argument( 339 | default="", 340 | help="If entered as yes, automatically use the latest save checkpoint. If it is a numerical example 12 15, use the corresponding save checkpoint. If the input is no, restart training", 341 | ), 342 | ): 343 | ft_config = FinetuningConfig.from_file(config_file) 344 | tokenizer, model = load_tokenizer_and_model(model_dir, peft_config=ft_config.peft_config) 345 | data_manager = DataManager(data_dir, ft_config.data_config) 346 | 347 | train_dataset = data_manager.get_dataset( 348 | Split.TRAIN, 349 | functools.partial( 350 | process_batch, 351 | tokenizer=tokenizer, 352 | max_input_length=ft_config.max_input_length, 353 | max_output_length=ft_config.max_output_length, 354 | ), 355 | batched=True, 356 | ) 357 | 358 | val_dataset = data_manager.get_dataset( 359 | Split.VALIDATION, 360 | functools.partial( 361 | process_batch_eval, 362 | tokenizer=tokenizer, 363 | max_input_length=ft_config.max_input_length, 364 | max_output_length=ft_config.max_output_length, 365 | ), 366 | batched=True, 367 | ) 368 | 369 | test_dataset = data_manager.get_dataset( 370 | Split.TEST, 371 | functools.partial( 372 | process_batch_eval, 373 | tokenizer=tokenizer, 374 | max_input_length=ft_config.max_input_length, 375 | max_output_length=ft_config.max_output_length, 376 | ), 377 | batched=True, 378 | ) 379 | 380 | trainer = Seq2SeqTrainer( 381 | model=model, 382 | args=ft_config.training_args, 383 | data_collator=DataCollatorForSeq2Seq( 384 | tokenizer=tokenizer, 385 | padding="longest", 386 | return_tensors="pt", 387 | ), 388 | train_dataset=train_dataset, 389 | eval_dataset=val_dataset, 390 | compute_metrics=functools.partial(compute_metrics, tokenizer=tokenizer), 391 | ) 392 | 393 | if auto_resume_from_checkpoint.upper() == "" or auto_resume_from_checkpoint is None: 394 | trainer.train() 395 | else: 396 | output_dir = ft_config.training_args.output_dir 397 | dirlist = os.listdir(output_dir) 398 | checkpoint_sn = 0 399 | for checkpoint_str in dirlist: 400 | if checkpoint_str.find("eckpoint") > 0 and checkpoint_str.find("tmp") == -1: 401 | checkpoint = int(checkpoint_str.replace("checkpoint-", "")) 402 | if checkpoint > checkpoint_sn: 403 | checkpoint_sn = checkpoint 404 | if auto_resume_from_checkpoint.upper() == "YES": 405 | if checkpoint_sn > 0: 406 | model.gradient_checkpointing_enable() 407 | model.enable_input_require_grads() 408 | checkpoint_directory = os.path.join(output_dir, "checkpoint-" + str(checkpoint_sn)) 409 | print("resume checkpoint from checkpoint-" + str(checkpoint_sn)) 410 | trainer.train(resume_from_checkpoint=checkpoint_directory) 411 | else: 412 | trainer.train() 413 | else: 414 | if auto_resume_from_checkpoint.isdigit(): 415 | if int(auto_resume_from_checkpoint) > 0: 416 | checkpoint_sn = int(auto_resume_from_checkpoint) 417 | model.gradient_checkpointing_enable() 418 | model.enable_input_require_grads() 419 | checkpoint_directory = os.path.join(output_dir, "checkpoint-" + str(checkpoint_sn)) 420 | print("resume checkpoint from checkpoint-" + str(checkpoint_sn)) 421 | trainer.train(resume_from_checkpoint=checkpoint_directory) 422 | else: 423 | print( 424 | auto_resume_from_checkpoint, 425 | "The specified checkpoint sn(" 426 | + auto_resume_from_checkpoint 427 | + ") has not been saved. Please search for the correct checkpoint in the model output directory", 428 | ) 429 | 430 | if test_dataset is not None: 431 | trainer.predict(test_dataset) 432 | 433 | 434 | if __name__ == "__main__": 435 | app() 436 | -------------------------------------------------------------------------------- /finetune/finetune_vision.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | 4 | os.environ["WANDB_DISABLED"] = "true" 5 | import jieba 6 | import dataclasses as dc 7 | import functools 8 | from collections.abc import Callable, Mapping, Sequence 9 | from pathlib import Path 10 | from typing import Annotated, Any, Union 11 | import numpy as np 12 | import ruamel.yaml as yaml 13 | import torch 14 | import typer 15 | from datasets import Dataset, Split 16 | from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction 17 | from peft import PeftConfig, get_peft_config, get_peft_model 18 | from rouge_chinese import Rouge 19 | from torch import nn 20 | from transformers import ( 21 | AutoModelForCausalLM, 22 | AutoImageProcessor, 23 | AutoTokenizer, 24 | EvalPrediction, 25 | GenerationConfig, 26 | PreTrainedTokenizer, 27 | Seq2SeqTrainingArguments, 28 | ) 29 | from transformers import DataCollatorForSeq2Seq as _DataCollatorForSeq2Seq 30 | from transformers import Seq2SeqTrainer as _Seq2SeqTrainer 31 | from datasets import load_dataset, DatasetDict, NamedSplit 32 | from typing import Optional 33 | from PIL import Image 34 | 35 | app = typer.Typer(pretty_exceptions_show_locals=False) 36 | 37 | 38 | class DataCollatorForSeq2Seq(_DataCollatorForSeq2Seq): 39 | def __call__(self, features, return_tensors=None): 40 | output_ids = [feature["output_ids"] for feature in features] if "output_ids" in features[0].keys() else None 41 | if output_ids is not None: 42 | max_output_length = max(len(out) for out in output_ids) 43 | if self.pad_to_multiple_of is not None: 44 | max_output_length = ( 45 | (max_output_length + self.pad_to_multiple_of - 1) 46 | // self.pad_to_multiple_of 47 | * self.pad_to_multiple_of 48 | ) 49 | for feature in features: 50 | remainder = [self.tokenizer.pad_token_id] * (max_output_length - len(feature["output_ids"])) 51 | if isinstance(feature["output_ids"], list): 52 | feature["output_ids"] = feature["output_ids"] + remainder 53 | else: 54 | feature["output_ids"] = np.concatenate([feature["output_ids"], remainder]).astype(np.int64) 55 | return super().__call__(features, return_tensors) 56 | 57 | 58 | class Seq2SeqTrainer(_Seq2SeqTrainer): 59 | def prediction_step( 60 | self, 61 | model: nn.Module, 62 | inputs: dict, 63 | prediction_loss_only: bool, 64 | ignore_keys=None, 65 | **gen_kwargs, 66 | ) -> tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: 67 | with torch.no_grad(): 68 | if self.args.predict_with_generate: 69 | output_ids = inputs.pop("output_ids", None) 70 | 71 | if "labels" in inputs: 72 | del inputs["labels"] 73 | 74 | loss, generated_tokens, labels = super().prediction_step( 75 | model=model, 76 | inputs=inputs, 77 | prediction_loss_only=prediction_loss_only, 78 | ignore_keys=ignore_keys, 79 | **gen_kwargs, 80 | ) 81 | 82 | if generated_tokens is not None: 83 | generated_tokens = generated_tokens[:, inputs["input_ids"].size()[1] :] 84 | 85 | if self.args.predict_with_generate: 86 | labels = output_ids 87 | 88 | del inputs, output_ids 89 | torch.cuda.empty_cache() 90 | 91 | return loss, generated_tokens, labels 92 | 93 | 94 | @dc.dataclass 95 | class DataConfig(object): 96 | train_file: Optional[str] = None 97 | val_file: Optional[str] = None 98 | test_file: Optional[str] = None 99 | num_proc: Optional[int] = None 100 | 101 | @property 102 | def data_format(self) -> str: 103 | return Path(self.train_file).suffix 104 | 105 | @property 106 | def data_files(self) -> dict[NamedSplit, str]: 107 | return { 108 | split: data_file 109 | for split, data_file in zip( 110 | [Split.TRAIN, Split.VALIDATION, Split.TEST], 111 | [self.train_file, self.val_file, self.test_file], 112 | ) 113 | if data_file is not None 114 | } 115 | 116 | 117 | @dc.dataclass 118 | class FinetuningConfig(object): 119 | data_config: DataConfig 120 | 121 | max_input_length: int 122 | max_output_length: int 123 | freezeV: bool 124 | 125 | training_args: Seq2SeqTrainingArguments = dc.field( 126 | default_factory=lambda: Seq2SeqTrainingArguments(output_dir="./output") 127 | ) 128 | peft_config: Optional[PeftConfig] = None 129 | 130 | def __post_init__(self): 131 | if not self.training_args.do_eval or self.data_config.val_file is None: 132 | self.training_args.do_eval = False 133 | self.training_args.evaluation_strategy = "no" 134 | self.data_config.val_file = None 135 | else: 136 | self.training_args.per_device_eval_batch_size = ( 137 | self.training_args.per_device_eval_batch_size or self.training_args.per_device_train_batch_size 138 | ) 139 | 140 | @classmethod 141 | def from_dict(cls, **kwargs) -> "FinetuningConfig": 142 | training_args = kwargs.get("training_args", None) 143 | if training_args is not None and not isinstance(training_args, Seq2SeqTrainingArguments): 144 | gen_config = training_args.get("generation_config") 145 | if not isinstance(gen_config, GenerationConfig): 146 | training_args["generation_config"] = GenerationConfig(**gen_config) 147 | kwargs["training_args"] = Seq2SeqTrainingArguments(**training_args) 148 | 149 | data_config = kwargs.get("data_config") 150 | if not isinstance(data_config, DataConfig): 151 | kwargs["data_config"] = DataConfig(**data_config) 152 | 153 | peft_config = kwargs.get("peft_config", None) 154 | if peft_config is not None and not isinstance(peft_config, PeftConfig): 155 | kwargs["peft_config"] = get_peft_config(config_dict=peft_config) 156 | return cls(**kwargs) 157 | 158 | @classmethod 159 | def from_file(cls, path: Union[str, Path]) -> "FinetuningConfig": 160 | path = Path(path) 161 | parser = yaml.YAML(typ="safe", pure=True) 162 | parser.indent(mapping=2, offset=2, sequence=4) 163 | parser.default_flow_style = False 164 | kwargs = parser.load(path) 165 | return cls.from_dict(**kwargs) 166 | 167 | 168 | def _load_datasets( 169 | data_dir: str, 170 | data_format: str, 171 | data_files: dict[NamedSplit, str], 172 | num_proc: Optional[int], 173 | ) -> DatasetDict: 174 | if data_format == ".jsonl": 175 | dataset_dct = load_dataset( 176 | data_dir, 177 | data_files=data_files, 178 | split=None, 179 | num_proc=num_proc, 180 | ) 181 | else: 182 | raise NotImplementedError(f"Cannot load dataset in the '{data_format}' format.") 183 | return dataset_dct 184 | 185 | 186 | class DataManager(object): 187 | def __init__(self, data_dir: str, data_config: DataConfig): 188 | self._num_proc = data_config.num_proc 189 | 190 | self._dataset_dct = _load_datasets( 191 | data_dir, 192 | data_config.data_format, 193 | data_config.data_files, 194 | self._num_proc, 195 | ) 196 | 197 | def _get_dataset(self, split: NamedSplit) -> Optional[Dataset]: 198 | return self._dataset_dct.get(split, None) 199 | 200 | def get_dataset( 201 | self, 202 | split: NamedSplit, 203 | process_fn: Callable[[dict[str, Any]], dict[str, Any]], 204 | batched: bool = True, 205 | remove_orig_columns: bool = True, 206 | ) -> Optional[Dataset]: 207 | orig_dataset = self._get_dataset(split) 208 | if orig_dataset is None: 209 | return 210 | if remove_orig_columns: 211 | remove_columns = orig_dataset.column_names 212 | else: 213 | remove_columns = None 214 | return orig_dataset.map( 215 | process_fn, 216 | batched=batched, 217 | remove_columns=remove_columns, 218 | num_proc=self._num_proc, 219 | # This is default params of orig_dataset.map, and you can change it smaller 220 | # https://github.com/THUDM/GLM-4/issues/277 221 | writer_batch_size=1000, 222 | batch_size=1000, 223 | ) 224 | 225 | 226 | def process_batch( 227 | batch: Mapping[str, Sequence], 228 | tokenizer: PreTrainedTokenizer, 229 | processor, 230 | max_input_length: int, 231 | max_output_length: int, 232 | ) -> dict[str, list]: 233 | batched_conv = batch["messages"] 234 | batched_input_ids = [] 235 | batched_attention_mask = [] 236 | batched_position_ids = [] 237 | batched_labels = [] 238 | batched_images = [] 239 | 240 | max_length = max_input_length + max_output_length 241 | 242 | for conv in batched_conv: 243 | input_ids = [] 244 | attention_mask = [] 245 | position_ids = [] 246 | loss_masks = [] 247 | pixel_values = [] 248 | 249 | if conv[0]["content"][0].get("image"): 250 | image = Image.open(conv[0]["content"][0]["image"]) 251 | pixel_values.append(torch.tensor(processor(image).pixel_values)) 252 | 253 | for message in conv: 254 | loss_mask_val = False if message["role"] in ("system", "user") else True 255 | new_input_ids_all = tokenizer.apply_chat_template( 256 | [message], 257 | add_generation_prompt=False, 258 | tokenize=True, 259 | return_dict=True, 260 | return_tensors="pt", 261 | ) 262 | new_input_ids = new_input_ids_all["input_ids"][0].tolist() 263 | new_attention_mask = new_input_ids_all["attention_mask"][0].tolist() 264 | new_position_ids = list(range(len(position_ids), len(position_ids) + len(new_input_ids))) 265 | 266 | new_loss_masks = [loss_mask_val] * len(new_input_ids) 267 | input_ids += new_input_ids 268 | attention_mask += new_attention_mask 269 | position_ids += new_position_ids 270 | loss_masks += new_loss_masks 271 | 272 | input_ids.append(59253) # EOS 273 | attention_mask.append(1) 274 | position_ids.append(len(position_ids)) 275 | loss_masks.append(True) 276 | 277 | padding_length = max(0, max_length - len(input_ids)) 278 | 279 | # Left padding with batch 280 | input_ids = [tokenizer.pad_token_id] * padding_length + input_ids[-max_length:] 281 | attention_mask = [0] * padding_length + attention_mask[-max_length:] 282 | position_ids = [0] * padding_length + position_ids[-max_length:] 283 | loss_masks = [False] * padding_length + loss_masks[-max_length:] 284 | 285 | labels = [] 286 | for input_id, mask in zip(input_ids, loss_masks): 287 | if mask: 288 | labels.append(input_id) 289 | else: 290 | labels.append(-100) 291 | 292 | batched_input_ids.append(input_ids[:max_length]) 293 | batched_attention_mask.append(attention_mask[:max_length]) 294 | batched_position_ids.append(position_ids[:max_length]) 295 | batched_labels.append(labels[:max_length]) 296 | if len(pixel_values) > 0: 297 | batched_images.append(pixel_values[0][0]) 298 | else: 299 | batched_images.append(torch.zeros([1, 1, 3, 672, 672])) 300 | 301 | del ( 302 | batched_conv, 303 | conv, 304 | input_ids, 305 | attention_mask, 306 | position_ids, 307 | loss_masks, 308 | message, 309 | new_input_ids, 310 | new_loss_masks, 311 | labels, 312 | input_id, 313 | mask, 314 | ) 315 | torch.cuda.empty_cache() 316 | 317 | return { 318 | "input_ids": batched_input_ids, 319 | "attention_mask": batched_attention_mask, 320 | "position_ids": batched_position_ids, 321 | "labels": batched_labels, 322 | "pixel_values": batched_images, 323 | } 324 | 325 | 326 | def process_batch_eval( 327 | batch: Mapping[str, Sequence], 328 | tokenizer: PreTrainedTokenizer, 329 | processor, 330 | max_input_length: int, 331 | max_output_length: int, 332 | ) -> dict[str, list]: 333 | batched_conv = batch["messages"] 334 | batched_input_ids = [] 335 | batched_attention_mask = [] 336 | batched_position_ids = [] 337 | batched_output_ids = [] 338 | batched_images = [] 339 | 340 | for conv in batched_conv: 341 | if conv[0]["content"][0].get("image"): 342 | image = Image.open(conv[0]["content"][0]["image"]) 343 | 344 | new_input_ids_all = tokenizer.apply_chat_template( 345 | conv, 346 | add_generation_prompt=False, 347 | tokenize=True, 348 | padding=True, 349 | return_dict=True, 350 | return_tensors="pt", 351 | ) 352 | 353 | input_ids = new_input_ids_all["input_ids"][0].tolist() 354 | attention_mask = new_input_ids_all["attention_mask"][0].tolist() 355 | position_ids = list(range(len(input_ids))) 356 | 357 | dialogue_parts = [0] 358 | user_idx = [] 359 | for idx, token_id in enumerate(input_ids): 360 | if token_id == 59254: 361 | dialogue_parts.append(idx + 1) 362 | elif token_id == 59253: 363 | user_idx.append(idx) 364 | 365 | if user_idx[-1] != len(input_ids): 366 | user_idx.append(len(input_ids)) 367 | 368 | # Split the conversation into multiple dialogue segments 369 | for end_idx in range(1, len(dialogue_parts)): 370 | input_segment = input_ids[: dialogue_parts[end_idx]] 371 | attention_segment = attention_mask[: dialogue_parts[end_idx]] 372 | position_segment = position_ids[: dialogue_parts[end_idx]] 373 | output_segment = input_ids[dialogue_parts[end_idx] : user_idx[end_idx]] 374 | output_segment.append(59253) # Add EOS token 375 | 376 | # Left Padding 377 | padding_length = max(0, max_input_length - len(input_segment)) 378 | input_segment = [tokenizer.pad_token_id] * padding_length + input_segment[:max_input_length] 379 | attention_segment = [0] * padding_length + attention_segment[:max_input_length] 380 | position_segment = [0] * padding_length + position_segment[:max_input_length] 381 | output_segment = [tokenizer.pad_token_id] * padding_length + output_segment[:max_output_length] 382 | 383 | batched_input_ids.append(input_segment[:max_input_length]) 384 | batched_attention_mask.append(attention_segment[:max_input_length]) 385 | batched_position_ids.append(position_segment[:max_input_length]) 386 | batched_output_ids.append(output_segment[:max_output_length]) 387 | if conv[0]["content"][0].get("image"): 388 | batched_images.append(torch.tensor(processor(image).pixel_values)[0]) 389 | else: 390 | batched_images.append(torch.zeros([1, 1, 3, 672, 672])) 391 | 392 | del batched_conv, input_ids, attention_mask, position_ids, new_input_ids_all, output_segment 393 | torch.cuda.empty_cache() 394 | 395 | return { 396 | "input_ids": batched_input_ids, 397 | "attention_mask": batched_attention_mask, 398 | "position_ids": batched_position_ids, 399 | "output_ids": batched_output_ids, 400 | "pixel_values": batched_images, 401 | } 402 | 403 | 404 | def load_tokenizer_and_model( 405 | model_dir: str, 406 | peft_config: Optional[PeftConfig] = None, 407 | ): 408 | tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True, padding_side="left") 409 | processor = AutoImageProcessor.from_pretrained(model_dir, trust_remote_code=True, dtype=torch.bfloat16) 410 | if peft_config is not None: 411 | model = AutoModelForCausalLM.from_pretrained(model_dir, trust_remote_code=True, torch_dtype=torch.bfloat16) 412 | model = get_peft_model(model, peft_config) 413 | model.print_trainable_parameters() 414 | else: 415 | model = AutoModelForCausalLM.from_pretrained(model_dir, trust_remote_code=True, torch_dtype=torch.bfloat16) 416 | return tokenizer, model, processor 417 | 418 | 419 | def compute_metrics(eval_preds: EvalPrediction, tokenizer): 420 | batched_pred_ids, batched_label_ids = eval_preds 421 | batched_pred_ids[batched_pred_ids == -100] = tokenizer.pad_token_id 422 | batched_label_ids[batched_label_ids == -100] = tokenizer.pad_token_id 423 | metrics_dct = {"rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []} 424 | for pred_ids, label_ids in zip(batched_pred_ids, batched_label_ids): 425 | pred_txt = tokenizer.decode(pred_ids, skip_special_tokens=True).strip() 426 | label_txt = tokenizer.decode(label_ids, skip_special_tokens=True).strip() 427 | pred_tokens = list(jieba.cut(pred_txt)) 428 | label_tokens = list(jieba.cut(label_txt)) 429 | rouge = Rouge() 430 | scores = rouge.get_scores(" ".join(pred_tokens), " ".join(label_tokens)) 431 | for k, v in scores[0].items(): 432 | metrics_dct[k].append(round(v["f"] * 100, 4)) 433 | metrics_dct["bleu-4"].append( 434 | sentence_bleu([label_tokens], pred_tokens, smoothing_function=SmoothingFunction().method3) 435 | ) 436 | return {k: np.mean(v) for k, v in metrics_dct.items()} 437 | 438 | 439 | @app.command() 440 | def main( 441 | data_dir: Annotated[str, typer.Argument(help="")], 442 | model_dir: Annotated[ 443 | str, 444 | typer.Argument( 445 | help="A string that specifies the model id of a pretrained model configuration hosted on huggingface.co, or a path to a directory containing a model configuration file." 446 | ), 447 | ], 448 | config_file: Annotated[str, typer.Argument(help="")], 449 | auto_resume_from_checkpoint: str = typer.Argument( 450 | default="", 451 | help="If entered as yes, automatically use the latest save checkpoint. If it is a numerical example 12 15, use the corresponding save checkpoint. If the input is no, restart training", 452 | ), 453 | ): 454 | ft_config = FinetuningConfig.from_file(config_file) 455 | tokenizer, model, processor = load_tokenizer_and_model(model_dir, peft_config=ft_config.peft_config) 456 | 457 | if ft_config.freezeV: 458 | for param in model.base_model.model.model.vision.parameters(): 459 | param.requires_grad = False 460 | data_manager = DataManager(data_dir, ft_config.data_config) 461 | 462 | train_dataset = data_manager.get_dataset( 463 | Split.TRAIN, 464 | functools.partial( 465 | process_batch, 466 | tokenizer=tokenizer, 467 | processor=processor, 468 | max_input_length=ft_config.max_input_length, 469 | max_output_length=ft_config.max_output_length, 470 | ), 471 | batched=True, 472 | ) 473 | 474 | val_dataset = data_manager.get_dataset( 475 | Split.VALIDATION, 476 | functools.partial( 477 | process_batch_eval, 478 | tokenizer=tokenizer, 479 | processor=processor, 480 | max_input_length=ft_config.max_input_length, 481 | max_output_length=ft_config.max_output_length, 482 | ), 483 | batched=True, 484 | ) 485 | 486 | test_dataset = data_manager.get_dataset( 487 | Split.TEST, 488 | functools.partial( 489 | process_batch_eval, 490 | tokenizer=tokenizer, 491 | processor=processor, 492 | max_input_length=ft_config.max_input_length, 493 | max_output_length=ft_config.max_output_length, 494 | ), 495 | batched=True, 496 | ) 497 | 498 | model.gradient_checkpointing_enable() 499 | model.enable_input_require_grads() 500 | 501 | trainer = Seq2SeqTrainer( 502 | model=model, 503 | args=ft_config.training_args, 504 | data_collator=DataCollatorForSeq2Seq( 505 | tokenizer=tokenizer, 506 | padding="longest", 507 | return_tensors="pt", 508 | ), 509 | train_dataset=train_dataset, 510 | eval_dataset=val_dataset, 511 | compute_metrics=functools.partial(compute_metrics, tokenizer=tokenizer), 512 | ) 513 | 514 | if auto_resume_from_checkpoint.upper() == "" or auto_resume_from_checkpoint is None: 515 | trainer.train() 516 | else: 517 | output_dir = ft_config.training_args.output_dir 518 | dirlist = os.listdir(output_dir) 519 | checkpoint_sn = 0 520 | for checkpoint_str in dirlist: 521 | if checkpoint_str.find("eckpoint") > 0 and checkpoint_str.find("tmp") == -1: 522 | checkpoint = int(checkpoint_str.replace("checkpoint-", "")) 523 | if checkpoint > checkpoint_sn: 524 | checkpoint_sn = checkpoint 525 | if auto_resume_from_checkpoint.upper() == "YES": 526 | if checkpoint_sn > 0: 527 | model.gradient_checkpointing_enable() 528 | model.enable_input_require_grads() 529 | checkpoint_directory = os.path.join(output_dir, "checkpoint-" + str(checkpoint_sn)) 530 | print("resume checkpoint from checkpoint-" + str(checkpoint_sn)) 531 | trainer.train(resume_from_checkpoint=checkpoint_directory) 532 | else: 533 | trainer.train() 534 | else: 535 | if auto_resume_from_checkpoint.isdigit(): 536 | if int(auto_resume_from_checkpoint) > 0: 537 | checkpoint_sn = int(auto_resume_from_checkpoint) 538 | model.gradient_checkpointing_enable() 539 | model.enable_input_require_grads() 540 | checkpoint_directory = os.path.join(output_dir, "checkpoint-" + str(checkpoint_sn)) 541 | print("resume checkpoint from checkpoint-" + str(checkpoint_sn)) 542 | trainer.train(resume_from_checkpoint=checkpoint_directory) 543 | else: 544 | print( 545 | auto_resume_from_checkpoint, 546 | "The specified checkpoint sn(" 547 | + auto_resume_from_checkpoint 548 | + ") has not been saved. Please search for the correct checkpoint in the model output directory", 549 | ) 550 | 551 | if test_dataset is not None: 552 | trainer.predict(test_dataset) 553 | 554 | 555 | if __name__ == "__main__": 556 | app() 557 | -------------------------------------------------------------------------------- /finetune/vision_dataset.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THUDM/GLM-Edge/7a93ea5047e91cf27b74d1e64ad490f75d329b17/finetune/vision_dataset.zip -------------------------------------------------------------------------------- /inference/cli_demo.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | import asyncio 4 | from threading import Thread 5 | from transformers import ( 6 | AutoTokenizer, 7 | AutoModelForCausalLM, 8 | TextIteratorStreamer, 9 | BitsAndBytesConfig, 10 | ) 11 | from vllm import SamplingParams, AsyncEngineArgs, AsyncLLMEngine 12 | from vllm.lora.request import LoRARequest 13 | from optimum.intel.openvino import OVModelForCausalLM 14 | import torch 15 | 16 | 17 | # Load Model and Tokenizer for VLLM 18 | def load_vllm_model_and_tokenizer(model_dir: str, lora_path: str, precision: str): 19 | enable_lora = bool(lora_path) 20 | tokenizer = AutoTokenizer.from_pretrained(model_dir) 21 | engine_args = AsyncEngineArgs( 22 | model=model_dir, 23 | tokenizer=model_dir, 24 | enable_lora=enable_lora, 25 | tensor_parallel_size=1, 26 | dtype="bfloat16" if precision == "bfloat16" else "float16", 27 | gpu_memory_utilization=0.9, 28 | enforce_eager=True, 29 | worker_use_ray=True, 30 | disable_log_requests=True, 31 | ) 32 | engine = AsyncLLMEngine.from_engine_args(engine_args) 33 | return engine, tokenizer, enable_lora 34 | 35 | 36 | async def vllm_gen(engine, tokenizer, lora_path, enable_lora, messages, top_p, temperature, max_length): 37 | inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) 38 | sampling_params = SamplingParams( 39 | n=1, 40 | best_of=1, 41 | presence_penalty=1.0, 42 | frequency_penalty=0.0, 43 | temperature=temperature, 44 | top_p=top_p, 45 | max_tokens=max_length, 46 | ) 47 | if enable_lora: 48 | async for output in engine.generate( 49 | inputs=inputs, 50 | sampling_params=sampling_params, 51 | request_id=f"{time.time()}", 52 | lora_request=LoRARequest("GLM-Edge-lora", 1, lora_path=lora_path), 53 | ): 54 | yield output.outputs[0].text 55 | else: 56 | async for output in engine.generate( 57 | prompt=inputs, sampling_params=sampling_params, request_id=f"{time.time()}" 58 | ): 59 | yield output.outputs[0].text 60 | 61 | 62 | # CLI Chat Function for Transformers and OpenVINO 63 | def generic_chat(tokenizer, model, temperature, top_p, max_length, backend="transformers"): 64 | history = [] 65 | backend_label = "OpenVINO" if backend == "ov" else "Transformers" 66 | print(f"Welcome to the GLM-Edge CLI chat ({backend_label}). Type your messages below.") 67 | while True: 68 | user_input = input("\nYou: ") 69 | if user_input.lower() in ["exit", "quit"]: 70 | break 71 | history.append([user_input, ""]) 72 | 73 | messages = [{"role": "user", "content": user_input}] 74 | model_inputs = tokenizer.apply_chat_template( 75 | messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt" 76 | ) 77 | model_inputs = {k: v.to("cpu") for k, v in model_inputs.items()} # Ensure CPU for OpenVINO 78 | 79 | streamer = TextIteratorStreamer(tokenizer=tokenizer, timeout=60, skip_prompt=True, skip_special_tokens=True) 80 | generate_kwargs = { 81 | "input_ids": model_inputs["input_ids"], 82 | "attention_mask": model_inputs["attention_mask"], 83 | "streamer": streamer, 84 | "max_new_tokens": max_length, 85 | "do_sample": True, 86 | "top_p": top_p, 87 | "temperature": temperature, 88 | "repetition_penalty": 1.2, 89 | "eos_token_id": tokenizer.encode("<|user|>"), 90 | } 91 | t = Thread(target=model.generate, kwargs=generate_kwargs) 92 | t.start() 93 | 94 | print("GLM-Edge:", end="", flush=True) 95 | for new_token in streamer: 96 | print(new_token, end="", flush=True) 97 | history[-1][1] += new_token 98 | 99 | history[-1][1] = history[-1][1].strip() 100 | 101 | 102 | # Main Async Chat Function for VLLM 103 | async def vllm_chat(engine, tokenizer, lora_path, enable_lora, temperature, top_p, max_length): 104 | history = [] 105 | print("Welcome to the GLM-Edge DEMO chat (VLLM). Type your messages below.") 106 | while True: 107 | user_input = input("\nYou: ") 108 | if user_input.lower() in ["exit", "quit"]: 109 | break 110 | history.append([user_input, ""]) 111 | 112 | messages = [{"role": "user", "content": user_input}] 113 | print("\nGLM-Edge: ", end="") 114 | current_length = 0 115 | output = "" 116 | async for output in vllm_gen( 117 | engine, tokenizer, lora_path, enable_lora, messages, top_p, temperature, max_length 118 | ): 119 | print(output[current_length:], end="", flush=True) 120 | current_length = len(output) 121 | history[-1][1] = output 122 | 123 | 124 | def main(): 125 | parser = argparse.ArgumentParser(description="Run GLM-Edge DEMO Chat with VLLM, Transformers, or OpenVINO backend") 126 | parser.add_argument( 127 | "--backend", 128 | type=str, 129 | choices=["vllm", "transformers", "ov"], 130 | required=True, 131 | help="Choose inference backend: vllm, transformers, or OpenVINO", 132 | ) 133 | parser.add_argument("--model_path", type=str, required=True, help="Path to the model") 134 | parser.add_argument("--lora_path", type=str, default=None, help="Path to LoRA (leave empty to skip)") 135 | parser.add_argument( 136 | "--precision", type=str, default="bfloat16", choices=["float16", "bfloat16", "int4"], help="Model precision" 137 | ) 138 | parser.add_argument("--temperature", type=float, default=0.6, help="Temperature for sampling") 139 | parser.add_argument("--top_p", type=float, default=0.8, help="Top-p (nucleus) sampling probability") 140 | parser.add_argument("--max_length", type=int, default=8192, help="Maximum token length for generation") 141 | args = parser.parse_args() 142 | 143 | if args.backend == "vllm": 144 | engine, tokenizer, enable_lora = load_vllm_model_and_tokenizer(args.model_path, args.lora_path, args.precision) 145 | asyncio.run( 146 | vllm_chat(engine, tokenizer, args.lora_path, enable_lora, args.temperature, args.top_p, args.max_length) 147 | ) 148 | elif args.backend == "ov": 149 | tokenizer = AutoTokenizer.from_pretrained(args.model_path) 150 | model = OVModelForCausalLM.from_pretrained(args.model_path, device="CPU") # CPU,GPU and XPU are supported 151 | generic_chat(tokenizer, model, args.temperature, args.top_p, args.max_length, backend="ov") 152 | else: 153 | tokenizer = AutoTokenizer.from_pretrained(args.model_path) 154 | if args.precision == "int4": 155 | model = AutoModelForCausalLM.from_pretrained( 156 | args.model_path, 157 | trust_remote_code=True, 158 | quantization_config=BitsAndBytesConfig(load_in_4bit=True), 159 | torch_dtype=torch.bfloat16, 160 | low_cpu_mem_usage=True, 161 | ).eval() 162 | else: 163 | model = AutoModelForCausalLM.from_pretrained( 164 | args.model_path, 165 | torch_dtype=torch.bfloat16 if args.precision == "bfloat16" else torch.float16, 166 | trust_remote_code=True, 167 | device_map="auto", 168 | ).eval() 169 | generic_chat(tokenizer, model, args.temperature, args.top_p, args.max_length, backend="transformers") 170 | 171 | 172 | if __name__ == "__main__": 173 | main() 174 | -------------------------------------------------------------------------------- /inference/cli_demo_vision.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from threading import Thread 3 | 4 | from PIL import Image 5 | from transformers import ( 6 | AutoTokenizer, 7 | AutoModelForCausalLM, 8 | AutoImageProcessor, 9 | TextIteratorStreamer, 10 | BitsAndBytesConfig, 11 | ) 12 | import torch 13 | from ov_convert.convert_v import OvGLMv 14 | 15 | 16 | def generic_chat(tokenizer, processor, model, temperature, top_p, max_length, backend="transformers"): 17 | history = [] 18 | image = None 19 | backend_label = "OpenVINO" if backend == "ov" else "Transformers" 20 | print(f"Welcome to the GLM-Edge-v CLI chat ({backend_label}). Type your messages below.") 21 | image_path = input("Image Path:") 22 | try: 23 | image = Image.open(image_path).convert("RGB") 24 | pixel_values = torch.tensor(processor(image).pixel_values).to(model.device) 25 | except: 26 | print("Invalid image path. Continuing with text conversation.") 27 | 28 | while True: 29 | user_input = input("\nYou: ") 30 | if user_input.lower() in ["exit", "quit"]: 31 | break 32 | history.append([user_input, ""]) 33 | 34 | messages = [] 35 | for idx, (user_msg, model_msg) in enumerate(history): 36 | if idx == len(history) - 1 and not model_msg: 37 | messages.append({"role": "user", "content": [{"type": "text", "text": user_msg}, {"type": "image"}]}) 38 | break 39 | if user_msg: 40 | messages.append({"role": "user", "content": [{"type": "text", "text": user_msg}]}) 41 | if model_msg: 42 | messages.append({"role": "assistant", "content": [{"type": "text", "text": model_msg}]}) 43 | 44 | model_inputs = tokenizer.apply_chat_template( 45 | messages, add_generation_prompt=True, tokenize=True, return_tensors="pt", return_dict=True 46 | ) 47 | 48 | model_inputs = model_inputs.to(model.device) 49 | 50 | streamer = TextIteratorStreamer(tokenizer=tokenizer, timeout=60, skip_prompt=True, skip_special_tokens=True) 51 | generate_kwargs = { 52 | **model_inputs, 53 | "streamer": streamer, 54 | "max_new_tokens": max_length, 55 | "do_sample": True, 56 | "top_p": top_p, 57 | "temperature": temperature, 58 | "repetition_penalty": 1.2, 59 | "pixel_values": pixel_values if image else None, 60 | } 61 | 62 | t = Thread(target=model.generate, kwargs=generate_kwargs) 63 | t.start() 64 | print("GLM-Edge-v:", end="", flush=True) 65 | for new_token in streamer: 66 | if new_token: 67 | print(new_token, end="", flush=True) 68 | history[-1][1] += new_token 69 | 70 | history[-1][1] = history[-1][1].strip() 71 | 72 | 73 | def main(): 74 | parser = argparse.ArgumentParser(description="Run GLM-Edge-v DEMO Chat with Transformers or OpenVINO backend") 75 | 76 | parser.add_argument( 77 | "--backend", 78 | type=str, 79 | choices=["transformers", "ov"], 80 | required=True, 81 | help="Choose inference backend: transformers or OpenVINO", 82 | ) 83 | parser.add_argument("--model_path", type=str, required=True, help="Path to the model") 84 | parser.add_argument( 85 | "--precision", type=str, default="bfloat16", choices=["float16", "bfloat16", "int4"], help="Model precision" 86 | ) 87 | parser.add_argument("--temperature", type=float, default=0.6, help="Temperature for sampling") 88 | parser.add_argument("--top_p", type=float, default=0.8, help="Top-p (nucleus) sampling probability") 89 | parser.add_argument("--max_length", type=int, default=8192, help="Maximum token length for generation") 90 | args = parser.parse_args() 91 | 92 | tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True, encode_special_tokens=True) 93 | processor = AutoImageProcessor.from_pretrained(args.model_path, trust_remote_code=True) 94 | 95 | if args.backend == "ov": 96 | model = OvGLMv(args.model_path, device="CPU") # Use OpenVINO 97 | else: 98 | if args.precision == "int4": 99 | model = AutoModelForCausalLM.from_pretrained( 100 | args.model_path, 101 | trust_remote_code=True, 102 | quantization_config=BitsAndBytesConfig(load_in_4bit=True), 103 | torch_dtype=torch.bfloat16, 104 | low_cpu_mem_usage=True, 105 | ).eval() 106 | else: 107 | model = AutoModelForCausalLM.from_pretrained( 108 | args.model_path, 109 | torch_dtype=torch.bfloat16 if args.precision == "bfloat16" else torch.float16, 110 | trust_remote_code=True, 111 | device_map="auto", 112 | ).eval() 113 | 114 | generic_chat(tokenizer, processor, model, args.temperature, args.top_p, args.max_length, backend=args.backend) 115 | 116 | 117 | if __name__ == "__main__": 118 | main() 119 | -------------------------------------------------------------------------------- /inference/ov_convert/convert_chat.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoTokenizer 2 | from optimum.intel import OVWeightQuantizationConfig 3 | from optimum.intel.openvino import OVModelForCausalLM 4 | from optimum.exporters.tasks import TasksManager 5 | import os 6 | import argparse 7 | 8 | TasksManager._SUPPORTED_MODEL_TYPE["glm"] = TasksManager._SUPPORTED_MODEL_TYPE[ 9 | "llama" 10 | ] # Using with Llama Type in converting 11 | 12 | if __name__ == "__main__": 13 | parser = argparse.ArgumentParser(add_help=False) 14 | parser.add_argument("--model_path", default="THUDM/glm-edge-1.5b-chat", type=str, help="orignal model path") 15 | parser.add_argument( 16 | "--precision", default="int4", type=str, choices=["fp16", "int8", "int4"], help="fp16, int8 or int4" 17 | ) 18 | parser.add_argument("--output_path", default="glm-edge-1.5b-chat-ov", type=str, help="path to save the IR model") 19 | args = parser.parse_args() 20 | os.makedirs(args.output_path, exist_ok=True) 21 | 22 | compression_configs = { 23 | "sym": True, 24 | "group_size": 128, 25 | "ratio": 0.8, 26 | } 27 | 28 | TasksManager._SUPPORTED_MODEL_TYPE["glm"] = TasksManager._SUPPORTED_MODEL_TYPE["llama"] 29 | 30 | print("====Exporting IR=====") 31 | if args.precision == "int4": 32 | ov_model = OVModelForCausalLM.from_pretrained( 33 | args.model_path, 34 | export=True, 35 | compile=False, 36 | quantization_config=OVWeightQuantizationConfig(bits=4, **compression_configs), 37 | trust_remote_code=True, 38 | ) 39 | elif args.precision == "int8": 40 | ov_model = OVModelForCausalLM.from_pretrained( 41 | args.model_path, export=True, compile=False, load_in_8bit=True, trust_remote_code=True 42 | ) 43 | else: 44 | ov_model = OVModelForCausalLM.from_pretrained( 45 | args.model_path, export=True, compile=False, load_in_8bit=False, trust_remote_code=True 46 | ) 47 | 48 | print("====Saving IR=====") 49 | ov_model.save_pretrained(args.output_path) 50 | 51 | print("====Exporting tokenizer=====") 52 | tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True) 53 | tokenizer.save_pretrained(args.output_path) 54 | 55 | print("====Exporting IR tokenizer=====") 56 | from optimum.exporters.openvino.convert import export_tokenizer 57 | 58 | export_tokenizer(tokenizer, args.output_path) 59 | print("====Finished=====") 60 | -------------------------------------------------------------------------------- /inference/ov_convert/convert_v.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from pathlib import Path 4 | import types 5 | import gc 6 | 7 | import openvino as ov 8 | from openvino.runtime import opset13 9 | import nncf 10 | import numpy as np 11 | import torch 12 | from transformers.cache_utils import Cache 13 | from transformers import AutoModelForCausalLM, AutoImageProcessor, AutoConfig, AutoTokenizer 14 | from transformers.generation import GenerationMixin 15 | from transformers.modeling_outputs import CausalLMOutputWithPast, BaseModelOutputWithPast 16 | from typing import Optional, Tuple, Union, List, Dict, Any 17 | from transformers import __version__ as transformers_version 18 | from transformers.generation.utils import GenerationConfig, ModelOutput 19 | 20 | 21 | def _chatglm_transformer_forward( 22 | self, 23 | input_ids: torch.LongTensor = None, 24 | attention_mask: Optional[torch.Tensor] = None, 25 | position_ids: Optional[torch.LongTensor] = None, 26 | past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, 27 | inputs_embeds: Optional[torch.FloatTensor] = None, 28 | labels: Optional[torch.LongTensor] = None, 29 | use_cache: Optional[bool] = None, 30 | output_attentions: Optional[bool] = None, 31 | output_hidden_states: Optional[bool] = None, 32 | return_dict: Optional[bool] = None, 33 | cache_position: Optional[torch.LongTensor] = None, 34 | num_logits_to_keep: int = 0, 35 | **loss_kwargs, 36 | ) -> Union[Tuple, BaseModelOutputWithPast]: 37 | """take care of image_encode, position_ids and (attention_mask = None is fine)""" 38 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 39 | output_hidden_states = ( 40 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 41 | ) 42 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 43 | 44 | outputs = self.model( 45 | input_ids=input_ids, 46 | attention_mask=attention_mask, 47 | position_ids=position_ids, 48 | past_key_values=past_key_values, 49 | inputs_embeds=inputs_embeds, 50 | use_cache=use_cache, 51 | output_attentions=output_attentions, 52 | output_hidden_states=output_hidden_states, 53 | return_dict=return_dict, 54 | cache_position=cache_position, 55 | ) 56 | 57 | hidden_states = outputs[0] 58 | logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) 59 | logits = logits.to(torch.float32) 60 | output = (logits,) + outputs[1:] 61 | return output 62 | 63 | 64 | def model_has_state(ov_model: ov.Model): 65 | return len(ov_model.get_sinks()) > 0 66 | 67 | 68 | def model_has_input_output_name(ov_model: ov.Model, name: str): 69 | """ 70 | Helper function for checking that model has specified input or output name 71 | 72 | Parameters: 73 | ov_model (ov.Model): 74 | name (str): 75 | name of input or output 76 | 77 | Returns: 78 | True if input or output with requested name exists else False 79 | """ 80 | return name in sum([list(t.get_names()) for t in ov_model.inputs + ov_model.outputs], []) 81 | 82 | 83 | def fuse_cache_reorder( 84 | ov_model: ov.Model, 85 | not_kv_inputs: List[str], 86 | key_value_input_names: List[str], 87 | gather_dim: int, 88 | ): 89 | """ 90 | Fuses reored_cache during generate cycle into ov.Model. Used with stateful models, because we can not modify model state directly. 91 | 92 | Adds a new beam_idx parameter and Gather op per each kv-cache input in a given model. 93 | Should be run before make_stateful. Implements optimumum's _reorder_cache 94 | inside the model in the beginning of each iteration. 95 | Gather works along given gather_dim dimension that may vary from model to model. 96 | KV-cache inputs are identified based on names in key_value_input_names. 97 | Append the new beam_idx parameter to not_kv_inputs. 98 | 99 | Parameters: 100 | ov_model (`ov.Model`): 101 | openvino model for processing 102 | not_kv_inputs (`List[str]`): 103 | list of input nodes in model that not related to past key values 104 | key_value_input_names (`List[str]`): 105 | list of names for key value input layers 106 | gather_dim (int): 107 | dimension for gathering cache during reorder pass 108 | """ 109 | 110 | if model_has_input_output_name(ov_model, "beam_idx"): 111 | raise ValueError("Model already has fused cache") 112 | input_batch = ov_model.input("inputs_embeds").get_partial_shape()[0] 113 | beam_idx = opset13.parameter(name="beam_idx", dtype=ov.Type.i32, shape=ov.PartialShape([input_batch])) 114 | beam_idx.output(0).get_tensor().add_names({"beam_idx"}) # why list is not accepted? 115 | ov_model.add_parameters([beam_idx]) 116 | not_kv_inputs.append(ov_model.inputs[-1]) 117 | # Go over all cache parameters and fuse _reorder_cache with indices provided by the new parameter beam_idx 118 | for input_name in key_value_input_names: 119 | parameter_output_port = ov_model.input(input_name) 120 | consumers = parameter_output_port.get_target_inputs() 121 | gather = opset13.gather(parameter_output_port, beam_idx, opset13.constant(gather_dim)) 122 | for consumer in consumers: 123 | consumer.replace_source_output(gather.output(0)) 124 | ov_model.validate_nodes_and_infer_types() 125 | 126 | 127 | def build_state_initializer(ov_model: ov.Model, batch_dim: int): 128 | """ 129 | Build initialization ShapeOf Expression for all ReadValue ops 130 | 131 | Parameters: 132 | ov_model (ov.Model): 133 | openvino model 134 | batch_dim (int): 135 | index of dimension corresponding to batch size 136 | """ 137 | input_ids = ov_model.input("inputs_embeds") 138 | batch = opset13.gather( 139 | opset13.shape_of(input_ids, output_type="i64"), 140 | opset13.constant([0]), 141 | opset13.constant(0), 142 | ) 143 | for op in ov_model.get_ops(): 144 | if op.get_type_name() == "ReadValue": 145 | dims = [dim.min_length for dim in list(op.get_output_partial_shape(0))] 146 | dims[batch_dim] = batch 147 | dims = [ 148 | (opset13.constant(np.array([dim], dtype=np.int64)) if isinstance(dim, int) else dim) for dim in dims 149 | ] 150 | shape = opset13.concat(dims, axis=0) 151 | broadcast = opset13.broadcast(opset13.constant(0.0, dtype=op.get_output_element_type(0)), shape) 152 | op.set_arguments([broadcast]) 153 | ov_model.validate_nodes_and_infer_types() 154 | 155 | 156 | def make_stateful( 157 | ov_model: ov.Model, 158 | not_kv_inputs: List[str], 159 | key_value_input_names: List[str], 160 | key_value_output_names: List[str], 161 | batch_dim: int, 162 | num_attention_heads: int, 163 | num_beams_and_batch: int = None, 164 | ): 165 | """ 166 | Hides kv-cache inputs and outputs inside the model as variables. 167 | 168 | Parameters: 169 | ov_model (ov.Model): 170 | openvino model 171 | not_kv_inputs (`List[str]`): 172 | list of input nodes in model that not related to past key values 173 | key_value_input_names (`List[str]`): 174 | list of names for key value input layers 175 | key_value_output_names (`List[str]`): 176 | list of names for key value input layers 177 | batch_dim (int): 178 | index of batch dimension in key value layers 179 | num_attention_heads (int): 180 | number of attention heads for batch dimension initialization 181 | num_beams_an_batch (int): 182 | precalculated number of beams and batch for shapes initialization 183 | """ 184 | from openvino._offline_transformations import apply_make_stateful_transformation 185 | 186 | input_output_map = {} 187 | 188 | if num_beams_and_batch is not None: 189 | # Set batch size for input_ids and attention mask to avoid dynamic dimension got propagated from the end of the model back to ReadValue 190 | for input in not_kv_inputs: 191 | shape = input.get_partial_shape() 192 | if shape.rank.get_length() <= 2: # == 1 for beam_index 193 | shape[0] = num_beams_and_batch 194 | input.get_node().set_partial_shape(shape) 195 | for kv_name_pair in zip(key_value_input_names, key_value_output_names): 196 | input_output_map[kv_name_pair[0]] = kv_name_pair[1] 197 | if num_beams_and_batch is not None: 198 | input = ov_model.input(kv_name_pair[0]) 199 | shape = input.get_partial_shape() 200 | shape[batch_dim] = num_beams_and_batch * num_attention_heads 201 | input.get_node().set_partial_shape(shape) 202 | 203 | if num_beams_and_batch is not None: 204 | # Re-validation model if shapes are altered above 205 | ov_model.validate_nodes_and_infer_types() 206 | 207 | apply_make_stateful_transformation(ov_model, input_output_map) 208 | if num_beams_and_batch is None: 209 | build_state_initializer(ov_model, batch_dim) 210 | 211 | 212 | def patch_stateful(ov_model): 213 | key_value_input_names = [key.get_any_name() for key in ov_model.inputs[2:-1]] 214 | key_value_output_names = [key.get_any_name() for key in ov_model.outputs[1:]] 215 | not_kv_inputs = [ 216 | input for input in ov_model.inputs if not any(name in key_value_input_names for name in input.get_names()) 217 | ] 218 | if not key_value_input_names or not key_value_output_names: 219 | return 220 | batch_dim = 0 221 | num_attention_heads = 1 222 | 223 | fuse_cache_reorder(ov_model, not_kv_inputs, key_value_input_names, batch_dim) 224 | make_stateful( 225 | ov_model, 226 | not_kv_inputs, 227 | key_value_input_names, 228 | key_value_output_names, 229 | batch_dim, 230 | num_attention_heads, 231 | None, 232 | ) 233 | 234 | 235 | core = ov.Core() 236 | 237 | 238 | def cleanup_torchscript_cache(): 239 | """ 240 | Helper for removing cached model representation 241 | """ 242 | torch._C._jit_clear_class_registry() 243 | torch.jit._recursive.concrete_type_store = torch.jit._recursive.ConcreteTypeStore() 244 | torch.jit._state._clear_class_state() 245 | 246 | 247 | def convert_glmv_model(model_id, output_dir, quantization_config): 248 | model_name = Path(model_id).name 249 | output_dir = Path(output_dir) 250 | 251 | lang_model_path = output_dir / "openvino_language_model.xml" 252 | image_embed_path = output_dir / "openvino_vision.xml" 253 | embed_token_path = output_dir / "openvino_embedding.xml" 254 | config = AutoConfig.from_pretrained(model_id, trust_remote_code=True) 255 | image_size = config.vision_config["image_size"] 256 | 257 | if all( 258 | [ 259 | lang_model_path.exists(), 260 | image_embed_path.exists(), 261 | embed_token_path.exists(), 262 | ] 263 | ): 264 | print(f"✅ {model_name} model already converted. You can find results in {output_dir}") 265 | return 266 | print(f"⌛ {model_name} conversion started. Be patient, it may takes some time.") 267 | print("⌛ Load Original model") 268 | model = AutoModelForCausalLM.from_pretrained( 269 | model_id, trust_remote_code=True, torch_dtype=torch.float32, _attn_implementation="eager" 270 | ) 271 | model.config.save_pretrained(output_dir) 272 | tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) 273 | tokenizer.save_pretrained(output_dir) 274 | processor = AutoImageProcessor.from_pretrained(model_id, trust_remote_code=True) 275 | processor.save_pretrained(output_dir) 276 | # shutil.copy2(ori_token_config_path, ov_token_config_path) 277 | 278 | print("✅ Original model successfully loaded") 279 | 280 | if not embed_token_path.exists(): 281 | print("⌛ Convert Input embedding model") 282 | ov_model = ov.convert_model( 283 | model.model.embed_tokens, 284 | example_input=torch.ones([1, 10], dtype=torch.int64), 285 | ) 286 | ov.save_model(ov_model, embed_token_path) 287 | del ov_model 288 | cleanup_torchscript_cache() 289 | gc.collect() 290 | print("✅ Input embedding model successfully converted") 291 | 292 | if not image_embed_path.exists(): 293 | print("⌛ Convert Image embedding model") 294 | # vision_embed_tokens.forward = vision_embed_tokens.vit 295 | ov_model = ov.convert_model(model.model.vision, example_input=torch.ones([1, 3, image_size, image_size])) 296 | ov.save_model(ov_model, image_embed_path) 297 | del ov_model 298 | cleanup_torchscript_cache() 299 | gc.collect() 300 | print("✅ Image embedding model successfully converted") 301 | 302 | if not lang_model_path.exists(): 303 | print("⌛ Convert Language model") 304 | 305 | input_ids = torch.zeros([2, 2], dtype=torch.int64) 306 | inputs_embeds = torch.zeros([2, 2, config.hidden_size], dtype=torch.float32) 307 | 308 | pkv = model.model( 309 | input_ids=input_ids, 310 | attention_mask=torch.ones((2, 2), dtype=torch.int64), 311 | mages=torch.zeros([1, 3, image_size, image_size]) 312 | )[1] 313 | model.forward = types.MethodType(_chatglm_transformer_forward, model) 314 | 315 | model.config.torchscript = True 316 | model_inputs = ["attention_mask", "position_ids"] 317 | model_outputs = ["logits"] 318 | for idx in range(len(pkv)): 319 | model_inputs.extend([f"past_key_values.{idx}.key", f"past_key_values.{idx}.value"]) 320 | model_outputs.extend([f"present.{idx}.key", f"present.{idx}.value"]) 321 | model_inputs.append("inputs_embeds") 322 | position_ids = torch.tensor([[2, 3], [2, 3]]) 323 | ov_model = ov.convert_model( 324 | model, 325 | example_input={ 326 | "position_ids": position_ids, 327 | "inputs_embeds": inputs_embeds, 328 | "attention_mask": torch.ones([2, 4], dtype=torch.int64), 329 | "past_key_values": pkv, 330 | }, 331 | ) 332 | 333 | for input, input_name in zip(ov_model.inputs, model_inputs): 334 | input.get_tensor().set_names({input_name}) 335 | 336 | for output, output_name in zip(ov_model.outputs, model_outputs): 337 | output.get_tensor().set_names({output_name}) 338 | patch_stateful(ov_model) 339 | print("✅ Language model successfully converted") 340 | 341 | if quantization_config is not None: 342 | print(f"⌛ Weights compression with {quantization_config['mode']} mode started") 343 | ov_model = nncf.compress_weights(ov_model, **quantization_config) 344 | print("✅ Weights compression finished") 345 | 346 | ov.save_model(ov_model, lang_model_path) 347 | del ov_model 348 | cleanup_torchscript_cache() 349 | del model 350 | gc.collect() 351 | print(f"✅ {model_name} model conversion finished. You can find results in {output_dir}") 352 | 353 | 354 | def is_empty(images_list: Optional[List[List[torch.Tensor]]]): 355 | if images_list is None or len(images_list) == 0: 356 | return True 357 | for image_list in images_list: 358 | if image_list is not None: 359 | return False 360 | return True 361 | 362 | 363 | class OvGLMv(GenerationMixin): 364 | def __init__(self, model_dir, device): 365 | model_dir = Path(model_dir) 366 | self.model = core.read_model(model_dir / "openvino_language_model.xml") 367 | self.vision = core.compile_model(model_dir / "openvino_vision.xml", "CPU") 368 | self.embedding = core.compile_model(model_dir / "openvino_embedding.xml", "CPU") 369 | self.input_names = {key.get_any_name(): idx for idx, key in enumerate(self.model.inputs)} 370 | self.output_names = {key.get_any_name(): idx for idx, key in enumerate(self.model.outputs)} 371 | # compiled_model = core.compile_model(self.model, device, config={"GPU_ENABLE_SDPA_OPTIMIZATION": "NO", "INFERENCE_PRECISION_HINT": "FP32"}) 372 | compiled_model = core.compile_model(self.model, device) 373 | 374 | self.request = compiled_model.create_infer_request() 375 | self.config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True) 376 | self.generation_config = GenerationConfig.from_model_config(self.config) 377 | self.main_input_name = "input_ids" 378 | self.device = torch.device("cpu") 379 | self.num_pkv = 2 380 | self._supports_cache_class = False 381 | self.next_beam_idx = None 382 | self.hd_transform_order = "glb_sub" 383 | 384 | def can_generate(self): 385 | """Returns True to validate the check that the model using `GenerationMixin.generate()` can indeed generate.""" 386 | return True 387 | 388 | def __call__( 389 | self, 390 | input_ids: torch.LongTensor = None, 391 | pixel_values: torch.Tensor = None, 392 | position_ids: Optional[torch.Tensor] = None, 393 | attention_mask: Optional[torch.BoolTensor] = None, 394 | past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, 395 | inputs_embeds: Optional[torch.Tensor] = None, 396 | **kwargs, 397 | ) -> CausalLMOutputWithPast: 398 | return self.forward( 399 | input_ids=input_ids, 400 | pixel_values=pixel_values, 401 | attention_mask=attention_mask, 402 | past_key_values=past_key_values, 403 | position_ids=position_ids, 404 | inputs_embeds=inputs_embeds, 405 | **kwargs, 406 | ) 407 | 408 | def forward( 409 | self, 410 | input_ids: torch.LongTensor = None, 411 | pixel_values: torch.Tensor = None, 412 | position_ids: Optional[torch.Tensor] = None, 413 | attention_mask: Optional[torch.BoolTensor] = None, 414 | past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, 415 | inputs_embeds: Optional[torch.Tensor] = None, 416 | return_dict: Optional[bool] = None, 417 | ) -> Union[Tuple, BaseModelOutputWithPast]: 418 | batch_size, num_concurrent_media, num_tiles, num_channels, height, width = pixel_values.shape 419 | pixel_values = pixel_values.reshape(batch_size * num_concurrent_media * num_tiles, num_channels, height, width) 420 | if not past_key_values: 421 | self.request.reset_state() 422 | self.next_beam_idx = np.arange(input_ids.shape[0], dtype=int) 423 | # not allow for inputs_embeds, because we want to process image feature 424 | assert input_ids is not None and inputs_embeds is None, f"{input_ids} {inputs_embeds}" 425 | inputs_embeds = torch.from_numpy(self.embedding(input_ids)[0]) 426 | new_input_embeds = [] 427 | multi_flags = [True if self.config.boi_token_id in input_id.tolist() else False for input_id in input_ids] 428 | images_features = None 429 | if not is_empty(pixel_values): 430 | images_features = torch.from_numpy(self.vision(pixel_values)[0]) 431 | image_count = 0 432 | for i in range(len(input_ids)): 433 | input_id = input_ids[i].tolist() 434 | if multi_flags[i]: 435 | boi_token_pos = input_id.index(self.config.boi_token_id) 436 | assert boi_token_pos >= 0, "begin_of_image not found!" 437 | num_image_padding_tokens = input_id.count(self.config.boi_token_id) 438 | assert ( 439 | num_image_padding_tokens == images_features[image_count].shape[0] 440 | ), f"Wrong image padding token number: {num_image_padding_tokens}" 441 | new_input_embeds.append( 442 | torch.cat( 443 | ( 444 | inputs_embeds[i, :boi_token_pos], 445 | images_features[image_count].to(inputs_embeds.device), 446 | inputs_embeds[i, boi_token_pos + num_image_padding_tokens :], 447 | ) 448 | ) 449 | ) 450 | image_count += 1 451 | else: 452 | new_input_embeds.append(inputs_embeds[i]) 453 | inputs_embeds = torch.stack(new_input_embeds, dim=0) 454 | 455 | if inputs_embeds is None: 456 | inputs_embeds = self.embedding(input_ids)[0] 457 | inputs = {} 458 | inputs["inputs_embeds"] = inputs_embeds 459 | inputs["attention_mask"] = attention_mask 460 | inputs["position_ids"] = position_ids 461 | if "beam_idx" in self.input_names: 462 | inputs["beam_idx"] = ( 463 | self.next_beam_idx if self.next_beam_idx is not None else np.arange(inputs_embeds.shape[0], dtype=int) 464 | ) 465 | self.request.start_async(inputs, share_inputs=True) 466 | self.request.wait() 467 | logits = self.request.get_tensor("logits").data 468 | logits = torch.from_numpy(logits).to(self.device) 469 | past_key_values = ((),) 470 | 471 | return CausalLMOutputWithPast(logits=logits, past_key_values=past_key_values) 472 | 473 | def _reorder_cache( 474 | self, past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor 475 | ) -> Tuple[Tuple[torch.Tensor]]: 476 | """ 477 | This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or 478 | [`~PreTrainedModel.beam_sample`] is called. 479 | This is required to match `past_key_values` with the correct beam_idx at every generation step. 480 | """ 481 | self.next_beam_idx = np.array(beam_idx) # save beam_idx to be used as an input in the next iteration 482 | return past_key_values 483 | 484 | def _update_model_kwargs_for_generation( 485 | self, 486 | outputs: ModelOutput, 487 | model_kwargs: Dict[str, Any], 488 | is_encoder_decoder: bool = False, 489 | standardize_cache_format: bool = False, 490 | ) -> Dict[str, Any]: 491 | # update past_key_values 492 | if int(transformers_version.split(".")[1]) >= 44: 493 | assert not standardize_cache_format 494 | _, cache = self._extract_past_from_model_output(outputs) 495 | model_kwargs["past_key_values"] = cache 496 | else: 497 | cache = self._extract_past_from_model_output(outputs, standardize_cache_format) 498 | 499 | # update attention mask 500 | if "attention_mask" in model_kwargs: 501 | attention_mask = model_kwargs["attention_mask"] 502 | model_kwargs["attention_mask"] = torch.cat( 503 | [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 504 | ) 505 | 506 | # update position ids 507 | if "position_ids" in model_kwargs: 508 | position_ids = model_kwargs["position_ids"] 509 | new_position_id = position_ids[..., -1:].clone() 510 | new_position_id += 1 511 | model_kwargs["position_ids"] = torch.cat([position_ids, new_position_id], dim=-1) 512 | 513 | model_kwargs["is_first_forward"] = False 514 | return model_kwargs 515 | 516 | def prepare_inputs_for_generation( 517 | self, 518 | input_ids: torch.LongTensor, 519 | pixel_values: Optional[torch.Tensor] = torch.zeros([1, 1, 1, 3, 672, 672]), 520 | past_key_values: Optional[torch.Tensor] = None, 521 | attention_mask: Optional[torch.Tensor] = None, 522 | position_ids: Optional[torch.Tensor] = None, 523 | use_cache: Optional[bool] = None, 524 | is_first_forward: bool = True, 525 | **kwargs, 526 | ) -> dict: 527 | if position_ids is None: 528 | if attention_mask is None: 529 | # Can only build sequential ids. Raise error right now 530 | raise ValueError("Cannot create position ids when attention mask is None") 531 | else: 532 | position_ids = self._create_position_ids_from_attention_mask(attention_mask) 533 | if not is_first_forward: 534 | if past_key_values is not None: 535 | position_ids = position_ids[..., -1:] 536 | input_ids = input_ids[:, -1:] 537 | return { 538 | "input_ids": input_ids, 539 | "pixel_values": pixel_values, 540 | "past_key_values": past_key_values, 541 | "position_ids": position_ids, 542 | "attention_mask": attention_mask, 543 | } 544 | 545 | def _create_position_ids_from_attention_mask(self, attention_mask): 546 | # Initialize a tensor of the same shape as attention_mask to hold position IDs 547 | position_ids = torch.zeros_like(attention_mask, dtype=torch.long, device=attention_mask.device) 548 | # Iterate over the batch 549 | for i, mask in enumerate(attention_mask): 550 | # Find the positions where the mask is 1 551 | positions = torch.nonzero(mask, as_tuple=False).squeeze(1).to(attention_mask.device) 552 | # Assign position IDs to those positions 553 | position_ids[i, positions] = torch.arange(start=0, end=positions.size(0), dtype=torch.long).to( 554 | attention_mask.device 555 | ) 556 | return position_ids 557 | 558 | 559 | if __name__ == "__main__": 560 | parser = argparse.ArgumentParser() 561 | parser.add_argument("--model_path", default="THUDM/glm-edge-v-2b", type=str, help="orignal model path") 562 | parser.add_argument( 563 | "--precision", default="int4", type=str, choices=["fp16", "int8", "int4"], help="fp16, int8 or int4" 564 | ) 565 | parser.add_argument("--output_path", default="glm-edge-v-2b-ov", help="path to save the ir model") 566 | args = parser.parse_args() 567 | os.makedirs(args.output_path, exist_ok=True) 568 | if args.precision == "int4": 569 | compression_configuration = { 570 | "mode": nncf.CompressWeightsMode.INT4_SYM, 571 | "group_size": 64, 572 | "ratio": 0.6, 573 | } 574 | elif args.precision == "int8": 575 | compression_configuration = { 576 | "mode": nncf.CompressWeightsMode.INT8, 577 | "group_size": 64, 578 | "ratio": 0.6, 579 | } 580 | else: 581 | compression_configuration = None 582 | convert_glmv_model(args.model_path, args.output_path, compression_configuration) 583 | -------------------------------------------------------------------------------- /inference/web_demo.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | from threading import Thread 4 | from typing import Union 5 | import requests 6 | from io import BytesIO 7 | from PIL import Image 8 | import re 9 | import gradio as gr 10 | import torch 11 | from peft import AutoPeftModelForCausalLM 12 | from transformers import ( 13 | AutoTokenizer, 14 | AutoModelForCausalLM, 15 | AutoImageProcessor, 16 | TextIteratorStreamer, 17 | BitsAndBytesConfig, 18 | ) 19 | 20 | # Parse arguments 21 | def parse_args(): 22 | parser = argparse.ArgumentParser(description="GLM-Edge-Chat Gradio Demo with adjustable parameters") 23 | parser.add_argument("--model_path", type=str, required=True, help="Path to the model") 24 | parser.add_argument("--server_name", type=str, default="127.0.0.1", help="Server name") 25 | parser.add_argument("--server_port", type=int, default=7860, help="Server port") 26 | parser.add_argument("--lora_path", type=str, default=None, help="Path to LoRA model if available") 27 | parser.add_argument( 28 | "--precision", 29 | type=str, 30 | choices=["float16", "bfloat16", "int4"], 31 | default="bfloat16", 32 | help="Precision for model", 33 | ) 34 | return parser.parse_args() 35 | 36 | args = parse_args() 37 | 38 | def _resolve_path(path: Union[str, Path]) -> Path: 39 | return Path(path).expanduser().resolve() 40 | 41 | # Load Model and Tokenizer for transformers 42 | def load_model_and_tokenizer(model_dir: Union[str, Path], precision: str, trust_remote_code: bool = True): 43 | model_dir = _resolve_path(model_dir) 44 | if precision == "int4": 45 | model = AutoModelForCausalLM.from_pretrained( 46 | model_dir, 47 | trust_remote_code=trust_remote_code, 48 | quantization_config=BitsAndBytesConfig(load_in_4bit=True), 49 | torch_dtype=torch.bfloat16, 50 | low_cpu_mem_usage=True, 51 | ).eval() 52 | elif (model_dir / "adapter_config.json").exists(): 53 | model = AutoPeftModelForCausalLM.from_pretrained( 54 | model_dir, trust_remote_code=trust_remote_code, device_map="auto" 55 | ) 56 | else: 57 | model = AutoModelForCausalLM.from_pretrained(model_dir, trust_remote_code=trust_remote_code, device_map="auto") 58 | 59 | tokenizer_dir = ( 60 | model.peft_config["default"].base_model_name_or_path if hasattr(model, "peft_config") else model_dir 61 | ) 62 | tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir, trust_remote_code=trust_remote_code, use_fast=False) 63 | return model, tokenizer 64 | 65 | model, tokenizer = load_model_and_tokenizer(args.model_path, args.precision, trust_remote_code=True) 66 | 67 | def is_url(s): 68 | if re.match(r'^(?:http|ftp)s?://', s): 69 | return True 70 | return False 71 | 72 | def get_image(image): 73 | if is_url(image): 74 | response = requests.get(image) 75 | return Image.open(BytesIO(response.content)).convert("RGB") 76 | elif image: 77 | return Image.open(image).convert("RGB") 78 | return None 79 | 80 | def preprocess_messages(history, prompt, image): 81 | messages = [] 82 | pixel_values = None 83 | 84 | if prompt: 85 | messages.append({"role": "system", "content": prompt}) 86 | for idx, (user_msg, model_msg) in enumerate(history): 87 | if prompt and idx == 0: 88 | continue 89 | if idx == len(history) - 1 and not messages: 90 | messages.append({"role": "user", "content": user_msg}) 91 | break 92 | if user_msg: 93 | messages.append({"role": "user", "content": user_msg}) 94 | if model_msg: 95 | messages.append({"role": "assistant", "content": messages}) 96 | 97 | if hasattr(model.config, "vision_config"): 98 | for item in messages: 99 | msg = item['content'] 100 | item['content'] = [{"type": "text", "text": msg}] 101 | if image: 102 | messages[-1]['content'].append({"type": "image"}) 103 | try: 104 | image_input = get_image(image) 105 | 106 | processor = AutoImageProcessor.from_pretrained( 107 | args.model_path, 108 | trust_remote_code=True 109 | ) 110 | pixel_values = torch.tensor( 111 | processor(image_input).pixel_values).to(model.device) 112 | except: 113 | print("Invalid image path. Continuing with text conversation.") 114 | 115 | return messages, pixel_values 116 | 117 | def predict(history, prompt, max_length, top_p, temperature, image=None): 118 | messages, pixel_values = preprocess_messages(history, prompt, image) 119 | model_inputs = tokenizer.apply_chat_template( 120 | messages, add_generation_prompt=True, tokenize=True, return_tensors="pt", return_dict=True 121 | ) 122 | 123 | streamer = TextIteratorStreamer(tokenizer, timeout=60, skip_prompt=True, skip_special_tokens=True) 124 | generate_kwargs = { 125 | "input_ids": model_inputs["input_ids"].to(model.device), 126 | "attention_mask": model_inputs["attention_mask"].to(model.device), 127 | "streamer": streamer, 128 | "max_new_tokens": max_length, 129 | "do_sample": True, 130 | "top_p": top_p, 131 | "temperature": temperature, 132 | "repetition_penalty": 1.2, 133 | } 134 | if hasattr(model.config, "vision_config"): 135 | generate_kwargs['eos_token_id'] = [59246, 59253, 59255] 136 | if image and isinstance(pixel_values, torch.Tensor): 137 | generate_kwargs['pixel_values'] = pixel_values 138 | else: 139 | generate_kwargs['eos_token_id'] = tokenizer.encode("<|user|>") 140 | t = Thread(target=model.generate, kwargs=generate_kwargs) 141 | t.start() 142 | for new_token in streamer: 143 | if new_token: 144 | history[-1][1] += new_token 145 | yield history 146 | 147 | def main(): 148 | with gr.Blocks() as demo: 149 | gr.HTML("""

GLM-Edge-Chat Gradio Chat Demo

""") 150 | 151 | # Top row: Chatbot and Image upload 152 | with gr.Row(): 153 | with gr.Column(scale=3): 154 | chatbot = gr.Chatbot() 155 | with gr.Column(scale=1): 156 | image_input = gr.Image(label="Upload an Image", type="filepath") 157 | 158 | # Bottom row: System prompt, user input, and controls 159 | with gr.Row(): 160 | with gr.Column(scale=2): 161 | prompt_input = gr.Textbox(show_label=True, placeholder="System Prompt", label="System Prompt") 162 | user_input = gr.Textbox(show_label=True, placeholder="Input...", label="User Input") 163 | submitBtn = gr.Button("Submit") 164 | pBtn = gr.Button("Set System prompt") 165 | emptyBtn = gr.Button("Clear History") 166 | with gr.Column(scale=1): 167 | max_length = gr.Slider(0, 8192, value=4096, step=1.0, label="Maximum length", interactive=True) 168 | top_p = gr.Slider(0, 1, value=0.8, step=0.01, label="Top P", interactive=True) 169 | temperature = gr.Slider(0.01, 1, value=0.6, step=0.01, label="Temperature", interactive=True) 170 | 171 | # Define functions for button actions 172 | def user(query, history): 173 | return "", history + [[query, ""]] 174 | 175 | def set_prompt(prompt_text): 176 | return [[prompt_text, "Prompt set successfully"]] 177 | 178 | # Button actions and callbacks 179 | pBtn.click(set_prompt, inputs=[prompt_input], outputs=chatbot) 180 | submitBtn.click(user, [user_input, chatbot], [user_input, chatbot], queue=False).then( 181 | predict, [chatbot, prompt_input, max_length, top_p, temperature, image_input], chatbot 182 | ) 183 | emptyBtn.click(lambda: (None, None), None, [chatbot, prompt_input], queue=False) 184 | 185 | demo.queue() 186 | demo.launch(server_name=args.server_name, server_port=args.server_port) 187 | 188 | 189 | if __name__ == "__main__": 190 | main() 191 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # Install transformers for source: 2 | git+https://github.com/huggingface/transformers.git 3 | torch>=2.5.1 4 | torchvision>=0.20.0 5 | huggingface-hub>=0.25.1 6 | sentencepiece>=0.2.0 7 | jinja2>=3.1.4 8 | pydantic>=2.9.2 9 | timm>=1.0.9 10 | tiktoken>=0.8.0 11 | numpy==1.26.4 12 | accelerate>=1.1.1 13 | sentence_transformers>=3.1.1 14 | gradio>=5.6.0 15 | openai>=1.55.0 16 | einops>=0.8.0 17 | pillow>=10.4.0 18 | sse-starlette>=2.1.3 19 | bitsandbytes>=0.43.2 # INT4 Loading 20 | 21 | # For Intel OpenVINNO convert 22 | # optimum-intel>=1.20.1 23 | # openvino>=1.26.4 24 | # nncf>=2.14.0 25 | 26 | # vllm>=0.6.4.post1 # using with VLLM Framework 27 | # peft>=0.14.0 # Using with finetune model 28 | -------------------------------------------------------------------------------- /resources/img.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THUDM/GLM-Edge/7a93ea5047e91cf27b74d1e64ad490f75d329b17/resources/img.png --------------------------------------------------------------------------------