├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.yaml │ └── feature-request.yaml └── PULL_REQUEST_TEMPLATE.md ├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── README.md ├── README_20240605.md ├── README_zh.md ├── README_zh_240605.md ├── demo ├── composite_demo │ ├── .gitignore │ ├── README.md │ ├── README_en.md │ ├── assets │ │ ├── cogview.png │ │ ├── demo.png │ │ ├── doc_reader.png │ │ ├── tool.png │ │ ├── vlm.png │ │ ├── weather.png │ │ ├── web_plot_1.png │ │ └── web_plot_2.png │ ├── browser │ │ ├── .gitignore │ │ ├── package-lock.json │ │ ├── package.json │ │ ├── pnpm-lock.yaml │ │ ├── src │ │ │ ├── browser.ts │ │ │ ├── config.ts │ │ │ ├── server.ts │ │ │ ├── types.ts │ │ │ └── utils.ts │ │ └── tsconfig.json │ ├── requirements.txt │ └── src │ │ ├── client.py │ │ ├── clients │ │ ├── hf.py │ │ ├── openai.py │ │ └── vllm.py │ │ ├── conversation.py │ │ ├── main.py │ │ ├── tools │ │ ├── browser.py │ │ ├── cogview.py │ │ ├── config.py │ │ ├── interface.py │ │ ├── python.py │ │ └── tool_registry.py │ │ └── utils.py └── intel_device_demo │ ├── itrex │ ├── README.md │ ├── README_en.md │ ├── itrex_cli_demo.py │ └── requirements.txt │ └── openvino │ ├── README.md │ ├── README_en.md │ ├── convert.py │ ├── openvino_cli_demo.py │ └── requirements.txt ├── finetune ├── README.md ├── README_zh.md ├── configs │ ├── ds_zero_2.json │ ├── ds_zero_3.json │ ├── lora.yaml │ └── sft.yaml ├── finetune.py ├── finetune_vision.py └── requirements.txt ├── inference ├── README.md ├── README_zh.md ├── demo.jpg ├── glm4v_api_request.py ├── glm4v_server.py ├── requirements.txt ├── trans_batch_demo.py ├── trans_cli_demo.py ├── trans_cli_vision_demo.py ├── trans_stress_test.py ├── trans_web_demo.py ├── trans_web_vision_demo.py ├── vllm_cli_demo.py └── vllm_cli_vision_demo.py ├── pyproject.toml └── resources ├── Bench-32B.png ├── Bench-Z1-32B.png ├── Bench-Z1-9B.png ├── WECHAT.md ├── eval_needle.jpeg ├── longbench.png └── wechat.jpg /.github/ISSUE_TEMPLATE/bug_report.yaml: -------------------------------------------------------------------------------- 1 | name: "\U0001F41B Bug Report" 2 | description: Submit a bug report to help us improve GLM-4-9B / 提交一个 Bug 问题报告来帮助我们改进 GLM-4-9B 3 | body: 4 | - type: textarea 5 | id: system-info 6 | attributes: 7 | label: System Info / 系統信息 8 | description: Your operating environment / 您的运行环境信息 9 | placeholder: Includes Cuda version, Transformers version, Python version, operating system, hardware information (if you suspect a hardware problem)... / 包括Cuda版本,Transformers版本,Python版本,操作系统,硬件信息(如果您怀疑是硬件方面的问题)... 10 | validations: 11 | required: true 12 | 13 | - type: textarea 14 | id: who-can-help 15 | attributes: 16 | label: Who can help? / 谁可以帮助到您? 17 | description: | 18 | Your issue will be replied to more quickly if you can figure out the right person to tag with @ 19 | All issues are read by one of the maintainers, so if you don't know who to tag, just leave this blank and our maintainer will ping the right person. 20 | 21 | Please tag fewer than 3 people. 22 | 23 | 如果您能找到合适的标签 @,您的问题会更快得到回复。 24 | 所有问题都会由我们的维护者阅读,如果您不知道该标记谁,只需留空,我们的维护人员会找到合适的开发组成员来解决问题。 25 | 26 | 标记的人数应该不超过 3 个人。 27 | 28 | If it's not a bug in these three subsections, you may not specify the helper. Our maintainer will find the right person in the development group to solve the problem. 29 | 30 | 如果不是这三个子版块的bug,您可以不指明帮助者,我们的维护人员会找到合适的开发组成员来解决问题。 31 | 32 | placeholder: "@Username ..." 33 | 34 | - type: checkboxes 35 | id: information-scripts-examples 36 | attributes: 37 | label: Information / 问题信息 38 | description: 'The problem arises when using: / 问题出现在' 39 | options: 40 | - label: "The official example scripts / 官方的示例脚本" 41 | - label: "My own modified scripts / 我自己修改的脚本和任务" 42 | 43 | - type: textarea 44 | id: reproduction 45 | validations: 46 | required: true 47 | attributes: 48 | label: Reproduction / 复现过程 49 | description: | 50 | Please provide a code example that reproduces the problem you encountered, preferably with a minimal reproduction unit. 51 | If you have code snippets, error messages, stack traces, please provide them here as well. 52 | Please format your code correctly using code tags. See https://help.github.com/en/github/writing-on-github/creating-and-highlighting-code-blocks#syntax-highlighting 53 | Do not use screenshots, as they are difficult to read and (more importantly) do not allow others to copy and paste your code. 54 | 55 | 请提供能重现您遇到的问题的代码示例,最好是最小复现单元。 56 | 如果您有代码片段、错误信息、堆栈跟踪,也请在此提供。 57 | 请使用代码标签正确格式化您的代码。请参见 https://help.github.com/en/github/writing-on-github/creating-and-highlighting-code-blocks#syntax-highlighting 58 | 请勿使用截图,因为截图难以阅读,而且(更重要的是)不允许他人复制粘贴您的代码。 59 | placeholder: | 60 | Steps to reproduce the behavior/复现Bug的步骤: 61 | 62 | 1. 63 | 2. 64 | 3. 65 | 66 | - type: textarea 67 | id: expected-behavior 68 | validations: 69 | required: true 70 | attributes: 71 | label: Expected behavior / 期待表现 72 | description: "A clear and concise description of what you would expect to happen. /简单描述您期望发生的事情。" 73 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature-request.yaml: -------------------------------------------------------------------------------- 1 | name: "\U0001F680 Feature request" 2 | description: Submit a request for a new GLM-4-9B feature / 提交一个新的 GLM-4-9B 的功能建议 3 | labels: [ "feature" ] 4 | body: 5 | - type: textarea 6 | id: feature-request 7 | validations: 8 | required: true 9 | attributes: 10 | label: Feature request / 功能建议 11 | description: | 12 | A brief description of the functional proposal. Links to corresponding papers and code are desirable. 13 | 对功能建议的简述。最好提供对应的论文和代码链接 14 | 15 | - type: textarea 16 | id: motivation 17 | validations: 18 | required: true 19 | attributes: 20 | label: Motivation / 动机 21 | description: | 22 | Your motivation for making the suggestion. If that motivation is related to another GitHub issue, link to it here. 23 | 您提出建议的动机。如果该动机与另一个 GitHub 问题有关,请在此处提供对应的链接。 24 | 25 | - type: textarea 26 | id: contribution 27 | validations: 28 | required: true 29 | attributes: 30 | label: Your contribution / 您的贡献 31 | description: | 32 | 33 | Your PR link or any other link you can help with. 34 | 您的PR链接或者其他您能提供帮助的链接。 35 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | # Contribution Guide 2 | 3 | We welcome your contributions to this repository. To ensure elegant code style and better code quality, we have prepared the following contribution guidelines. 4 | 5 | ## What We Accept 6 | 7 | + This PR fixes a typo or improves the documentation (if this is the case, you may skip the other checks). 8 | + This PR fixes a specific issue — please reference the issue number in the PR description. Make sure your code strictly follows the coding standards below. 9 | + This PR introduces a new feature — please clearly explain the necessity and implementation of the feature. Make sure your code strictly follows the coding standards below. 10 | 11 | ## Code Style Guide 12 | 13 | Good code style is an art. We have prepared a `pyproject.toml` and a `pre-commit` hook to enforce consistent code formatting across the project. You can clean up your code following the steps below: 14 | 15 | 1. Install the required dependencies: 16 | ```shell 17 | pip install ruff pre-commit 18 | ``` 19 | 2. Then, run the following command: 20 | ```shell 21 | pre-commit run --all-files 22 | ``` 23 | If your code complies with the standards, you should not see any errors. 24 | 25 | ## Naming Conventions 26 | 27 | - Please use **English** for naming; do not use Pinyin or other languages. All comments should also be in English. 28 | - Follow **PEP8** naming conventions strictly, and use underscores to separate words. Avoid meaningless names such as `a`, `b`, `c`. 29 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *venv 2 | *.DS_Store 3 | *.idea/ 4 | dataset 5 | test* 6 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/astral-sh/ruff-pre-commit 3 | rev: v0.4.5 4 | hooks: 5 | - id: ruff 6 | args: [--fix, --respect-gitignore, --config=pyproject.toml] 7 | - id: ruff-format 8 | args: [--config=pyproject.toml] 9 | 10 | - repo: https://github.com/pre-commit/pre-commit-hooks 11 | rev: v4.5.0 12 | hooks: 13 | - id: trailing-whitespace 14 | - id: end-of-file-fixer 15 | - id: check-yaml 16 | - id: check-toml 17 | - id: check-case-conflict 18 | - id: check-merge-conflict 19 | - id: debug-statements 20 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2025 Zhipu AI 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /demo/composite_demo/.gitignore: -------------------------------------------------------------------------------- 1 | *venv 2 | *.DS_Store 3 | *model 4 | *.idea/ 5 | 6 | # Created by https://www.toptal.com/developers/gitignore/api/python 7 | # Edit at https://www.toptal.com/developers/gitignore?templates=python 8 | 9 | ### Python ### 10 | # Byte-compiled / optimized / DLL files 11 | __pycache__/ 12 | *.py[cod] 13 | *$py.class 14 | 15 | # C extensions 16 | *.so 17 | 18 | # Distribution / packaging 19 | .Python 20 | build/ 21 | develop-eggs/ 22 | dist/ 23 | downloads/ 24 | eggs/ 25 | .eggs/ 26 | lib/ 27 | lib64/ 28 | parts/ 29 | sdist/ 30 | var/ 31 | wheels/ 32 | share/python-wheels/ 33 | *.egg-info/ 34 | .installed.cfg 35 | *.egg 36 | MANIFEST 37 | 38 | # PyInstaller 39 | # Usually these files are written by a python script from a template 40 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 41 | *.manifest 42 | *.spec 43 | 44 | # Installer logs 45 | pip-log.txt 46 | pip-delete-this-directory.txt 47 | 48 | # Unit test / coverage reports 49 | htmlcov/ 50 | .tox/ 51 | .nox/ 52 | .coverage 53 | .coverage.* 54 | .cache 55 | nosetests.xml 56 | coverage.xml 57 | *.cover 58 | *.py,cover 59 | .hypothesis/ 60 | .pytest_cache/ 61 | cover/ 62 | 63 | # Translations 64 | *.mo 65 | *.pot 66 | 67 | # Django stuff: 68 | *.log 69 | local_settings.py 70 | db.sqlite3 71 | db.sqlite3-journal 72 | 73 | # Flask stuff: 74 | instance/ 75 | .webassets-cache 76 | 77 | # Scrapy stuff: 78 | .scrapy 79 | 80 | # Sphinx documentation 81 | docs/_build/ 82 | 83 | # PyBuilder 84 | .pybuilder/ 85 | target/ 86 | 87 | # Jupyter Notebook 88 | .ipynb_checkpoints 89 | 90 | # IPython 91 | profile_default/ 92 | ipython_config.py 93 | 94 | # pyenv 95 | # For a library or package, you might want to ignore these files since the code is 96 | # intended to run in multiple environments; otherwise, check them in: 97 | # .python-version 98 | 99 | # pipenv 100 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 101 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 102 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 103 | # install all needed dependencies. 104 | #Pipfile.lock 105 | 106 | # poetry 107 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 108 | # This is especially recommended for binary packages to ensure reproducibility, and is more 109 | # commonly ignored for libraries. 110 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 111 | #poetry.lock 112 | 113 | # pdm 114 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 115 | #pdm.lock 116 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 117 | # in version control. 118 | # https://pdm.fming.dev/#use-with-ide 119 | .pdm.toml 120 | 121 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 122 | __pypackages__/ 123 | 124 | # Celery stuff 125 | celerybeat-schedule 126 | celerybeat.pid 127 | 128 | # SageMath parsed files 129 | *.sage.py 130 | 131 | # Environments 132 | .env 133 | .venv 134 | env/ 135 | venv/ 136 | ENV/ 137 | env.bak/ 138 | venv.bak/ 139 | 140 | # Spyder project settings 141 | .spyderproject 142 | .spyproject 143 | 144 | # Rope project settings 145 | .ropeproject 146 | 147 | # mkdocs documentation 148 | /site 149 | 150 | # mypy 151 | .mypy_cache/ 152 | .dmypy.json 153 | dmypy.json 154 | 155 | # Pyre type checker 156 | .pyre/ 157 | 158 | # pytype static type analyzer 159 | .pytype/ 160 | 161 | # Cython debug symbols 162 | cython_debug/ 163 | 164 | # PyCharm 165 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 166 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 167 | # and can be added to the global gitignore or merged into this file. For a more nuclear 168 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 169 | #.idea/ 170 | 171 | ### Python Patch ### 172 | # Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration 173 | poetry.toml 174 | 175 | # ruff 176 | .ruff_cache/ 177 | 178 | # LSP config files 179 | pyrightconfig.json 180 | 181 | # End of https://www.toptal.com/developers/gitignore/api/python 182 | -------------------------------------------------------------------------------- /demo/composite_demo/README.md: -------------------------------------------------------------------------------- 1 | # GLM-4-9B Web Demo 2 | 3 | Read this in [English](README_en.md) 4 | 5 | ![Demo webpage](assets/demo.png) 6 | 7 | ## 安装 8 | 9 | 我们建议通过 [Conda](https://docs.conda.io/en/latest/) 进行环境管理。 10 | 执行以下命令新建一个 conda 环境并安装所需依赖: 11 | 12 | ```bash 13 | conda create -n glm-4-demo python=3.12 14 | conda activate glm-4-demo 15 | pip install -r requirements.txt 16 | ``` 17 | 18 | 请注意,本项目需要 Python 3.10 或更高版本。 19 | 此外,使用 Code Interpreter 还需要安装 Jupyter 内核: 20 | 21 | ```bash 22 | ipython kernel install --name glm-4-demo --user 23 | ``` 24 | 25 | 您可以修改 `~/.local/share/jupyter/kernels/glm-4-demo/kernel.json` 来改变 Jupyter 内核的配置,包括内核的启动参数等。例如,若您希望在使用 All Tools 的 Python 代码执行能力时使用 Matplotlib 画图,可以在 `argv` 数组中添加 `"--matplotlib=inline"`。 26 | 27 | 若要使用浏览器和搜索功能,还需要启动浏览器后端。首先,根据 [Node.js](https://nodejs.org/en/download/package-manager) 28 | 官网的指示安装 Node.js,然后安装包管理器 [PNPM](https://pnpm.io) 之后安装浏览器服务的依赖: 29 | 30 | ```bash 31 | cd browser 32 | npm install -g pnpm 33 | pnpm install 34 | ``` 35 | 36 | ## 运行 37 | 38 | 1. 修改 `browser/src/config.ts` 中的 `BING_SEARCH_API_KEY` 配置浏览器服务需要使用的 Bing 搜索 API Key: 39 | 40 | ```diff 41 | export default { 42 | 43 | BROWSER_TIMEOUT: 10000, 44 | BING_SEARCH_API_URL: 'https://api.bing.microsoft.com/v7.0', 45 | BING_SEARCH_API_KEY: '', 46 | 47 | HOST: 'localhost', 48 | PORT: 3000, 49 | }; 50 | ``` 51 | 如果您注册的是Bing Customer Search的API,您可以修改您的配置文件为如下,并且填写您的Custom Configuration ID: 52 | 53 | ```diff 54 | export default { 55 | LOG_LEVEL: 'debug', 56 | BROWSER_TIMEOUT: 10000, 57 | BING_SEARCH_API_URL: 'https://api.bing.microsoft.com/v7.0/custom/', 58 | BING_SEARCH_API_KEY: 'YOUR_BING_SEARCH_API_KEY', 59 | CUSTOM_CONFIG_ID : 'YOUR_CUSTOM_CONFIG_ID', //将您的Custom Configuration ID放在此处 60 | HOST: 'localhost', 61 | PORT: 3000, 62 | }; 63 | ``` 64 | 65 | 2. 文生图功能需要调用 CogView API。修改 `src/tools/config.py` 66 | ,提供文生图功能需要使用的 [智谱 AI 开放平台](https://open.bigmodel.cn) API Key: 67 | 68 | ```diff 69 | BROWSER_SERVER_URL = 'http://localhost:3000' 70 | 71 | IPYKERNEL = 'glm-4-demo' 72 | 73 | ZHIPU_AI_KEY = '' 74 | COGVIEW_MODEL = 'cogview-3' 75 | ``` 76 | 77 | 3. 启动浏览器后端,在单独的 shell 中: 78 | 79 | ```bash 80 | cd browser 81 | pnpm start 82 | ``` 83 | 84 | 4. 运行以下命令在本地加载模型并启动 demo: 85 | 86 | ```bash 87 | streamlit run src/main.py 88 | ``` 89 | 90 | 之后即可从命令行中看到 demo 的地址,点击即可访问。初次访问需要下载并加载模型,可能需要花费一定时间。 91 | 92 | 如果已经在本地下载了模型,可以通过 `export *_MODEL_PATH=/path/to/model` 来指定从本地加载模型。可以指定的模型包括: 93 | - `CHAT_MODEL_PATH`: 用于 All Tools 模式与文档解读模式,默认为 `THUDM/glm-4-9b-chatglm-4-9b-chat`。 94 | - `VLM_MODEL_PATH`: 用于 VLM 模式,默认为 `THUDM/glm-4v-9b`。 95 | 96 | Chat 模型支持使用 [vLLM](https://github.com/vllm-project/vllm) 推理。若要使用,请安装 vLLM 并设置环境变量 `USE_VLLM=1`。 97 | 98 | Chat 模型支持使用 [OpenAI API](https://platform.openai.com/docs/api-reference/introduction) 推理。若要使用,请启动basic_demo目录下的openai_api_server并设置环境变量 `USE_API=1`。该功能可以解耦推理服务器和demo服务器。 99 | 100 | 如果需要自定义 Jupyter 内核,可以通过 `export IPYKERNEL=` 来指定。 101 | 102 | ## 使用 103 | 104 | GLM-4 Demo 拥有三种模式: 105 | 106 | - All Tools: 具有完整工具调用能力的对话模式,原生支持网页浏览、代码执行、图片生成,并支持自定义工具。 107 | - 文档解读: 支持上传文档进行文档解读与对话。 108 | - 多模态: 支持上传图像进行图像理解与对话。 109 | 110 | ### All Tools 111 | 112 | 本模式兼容 ChatGLM3-6B 的工具注册流程。 113 | + 代码能力,绘图能力,联网能力已经自动集成,用户只需按照要求配置对应的Key。 114 | + 本模式下不支持系统提示词,模型会自动构建提示词。 115 | 116 | 对话模式下,用户可以直接在侧边栏修改 top_p, temperature 等参数来调整模型的行为。 117 | 118 | 与模型对话时,模型将会自主决定进行工具调用。 119 | 120 | ![Tool calling](assets/tool.png) 121 | 122 | 由于原始结果可能较长,默认情况下工具调用结果被隐藏,可以通过展开折叠框查看原始的工具调用结果。 123 | 124 | 模型拥有进行网页搜索和 Python 代码执行的能力。同时,模型也可以连续调用多个工具。例如: 125 | 126 | ![Consecutive tool calling, 1](assets/web_plot_1.png) 127 | 128 | 此时模型通过调用浏览器工具进行搜索获取到了需要的数据,之后将会调用 Python 工具执行代码,利用 Matplotlib 绘图: 129 | 130 | ![Consecutive tool calling, 2](assets/web_plot_2.png) 131 | 132 | 如果提供了智谱开放平台 API Key,模型也可以调用 CogView 进行图像生成: 133 | 134 | ![Image generation](assets/cogview.png) 135 | 136 | #### 自定义工具 137 | 138 | 可以通过在 `tool_registry.py` 中注册新的工具来增强模型的能力。只需要使用 `@register_tool` 139 | 装饰函数即可完成注册。对于工具声明,函数名称即为工具的名称,函数 docstring 140 | 即为工具的说明;对于工具的参数,使用 `Annotated[typ: type, description: str, required: bool]` 标注参数的类型、描述和是否必须。 141 | 142 | 例如,`get_weather` 工具的注册如下: 143 | 144 | ```python 145 | @register_tool 146 | def get_weather( 147 | city_name: Annotated[str, 'The name of the city to be queried', True], 148 | ) -> str: 149 | """ 150 | Get the weather for `city_name` in the following week 151 | """ 152 | ... 153 | ``` 154 | 155 | ![The model uses tool to query the weather of Bangkok.](assets/weather.png) 156 | 157 | ### 文档解读 158 | 159 | 用户可以上传文档,使用 GLM-4-9B的长文本能力,对文本进行理解。可以解析 pptx,docx,pdf等文件。 160 | 161 | + 本模式下不支持工具调用和系统提示词。 162 | + 如果文本很长,可能导致模型需要的显存较高,请确认你的硬件配置。 163 | 164 | ![Doc reader demo](assets/doc_reader.png) 165 | 166 | ### 多模态 167 | 168 | 多模态模式下,用户可以利用 GLM-4V 的多模态理解能力,上传图像并与 GLM-4V 进行多轮对话: 169 | 170 | 用户可以上传图片,使用 GLM-4-9B的图像理解能力,对图片进行理解。 171 | 172 | + 本模式必须使用 glm-4v-9b 模型。 173 | + 本模式下不支持工具调用和系统提示词。 174 | + 模型仅能对一张图片进行理解和联系对话,如需更换图片,需要开启一个新的对话。 175 | + 图像支持的分辨率为 1120 x 1120 176 | 177 | ![VLM demo](assets/vlm.png) 178 | -------------------------------------------------------------------------------- /demo/composite_demo/README_en.md: -------------------------------------------------------------------------------- 1 | # GLM-4-9B Web Demo 2 | 3 | ![Demo webpage](assets/demo.png) 4 | 5 | ## Installation 6 | 7 | We recommend using [Conda](https://docs.conda.io/en/latest/) for environment management. 8 | 9 | Execute the following commands to create a conda environment and install the required dependencies: 10 | 11 | ```bash 12 | conda create -n glm-4-demo python=3.12 13 | conda activate glm-4-demo 14 | pip install -r requirements.txt 15 | ``` 16 | 17 | Please note that this project requires Python 3.10 or higher. 18 | In addition, you need to install the Jupyter kernel to use the Code Interpreter: 19 | 20 | ```bash 21 | ipython kernel install --name glm-4-demo --user 22 | ``` 23 | 24 | You can modify `~/.local/share/jupyter/kernels/glm-4-demo/kernel.json` to change the configuration of the Jupyter 25 | kernel, including the kernel startup parameters. For example, if you want to use Matplotlib to draw when using the 26 | Python code execution capability of All Tools, you can add `"--matplotlib=inline"` to the `argv` array. 27 | 28 | To use the browser and search functions, you also need to start the browser backend. First, install Node.js according to 29 | the instructions on the [Node.js](https://nodejs.org/en/download/package-manager) 30 | official website, then install the package manager [PNPM](https://pnpm.io) and then install the browser service 31 | dependencies: 32 | 33 | ```bash 34 | cd browser 35 | npm install -g pnpm 36 | pnpm install 37 | ``` 38 | 39 | ## Run 40 | 41 | 1. Modify `BING_SEARCH_API_KEY` in `browser/src/config.ts` to configure the Bing Search API Key that the browser service 42 | needs to use: 43 | 44 | ```diff 45 | export default { 46 | 47 | BROWSER_TIMEOUT: 10000, 48 | BING_SEARCH_API_URL: 'https://api.bing.microsoft.com/v7.0', 49 | BING_SEARCH_API_KEY: '', 50 | 51 | HOST: 'localhost', 52 | PORT: 3000, 53 | }; 54 | ``` 55 | 56 | 2. The Wenshengtu function needs to call the CogView API. Modify `src/tools/config.py` 57 | , provide the [Zhipu AI Open Platform](https://open.bigmodel.cn) API Key required for the Wenshengtu function: 58 | 59 | ```diff 60 | BROWSER_SERVER_URL = 'http://localhost:3000' 61 | 62 | IPYKERNEL = 'glm4-demo' 63 | 64 | ZHIPU_AI_KEY = '' 65 | COGVIEW_MODEL = 'cogview-3' 66 | ``` 67 | 68 | 3. Start the browser backend in a separate shell: 69 | 70 | ```bash 71 | cd browser 72 | pnpm start 73 | ``` 74 | 75 | 4. Run the following commands to load the model locally and start the demo: 76 | 77 | ```bash 78 | streamlit run src/main.py 79 | ``` 80 | 81 | Then you can see the demo address from the command line and click it to access it. The first access requires downloading 82 | and loading the model, which may take some time. 83 | 84 | If you have downloaded the model locally, you can specify to load the model from the local 85 | by `export *_MODEL_PATH=/path/to/model`. The models that can be specified include: 86 | 87 | - `CHAT_MODEL_PATH`: used for All Tools mode and document interpretation mode, the default is `THUDM/glm-4-9b-chat`. 88 | 89 | - `VLM_MODEL_PATH`: used for VLM mode, the default is `THUDM/glm-4v-9b`. 90 | 91 | The Chat model supports reasoning using [vLLM](https://github.com/vllm-project/vllm). To use it, please install vLLM and 92 | set the environment variable `USE_VLLM=1`. 93 | 94 | The Chat model also supports reasoning using [OpenAI API](https://platform.openai.com/docs/api-reference/introduction). To use it, please run `openai_api_server.py` in `inference` and set the environment variable `USE_API=1`. This function is used to deploy inference server and demo server in different machine. 95 | 96 | If you need to customize the Jupyter kernel, you can specify it by `export IPYKERNEL=`. 97 | 98 | ## Usage 99 | 100 | GLM4 Demo has three modes: 101 | 102 | - All Tools mode 103 | - VLM mode 104 | - Text interpretation mode 105 | 106 | ### All Tools mode 107 | 108 | You can enhance the model's capabilities by registering new tools in `tool_registry.py`. Just use `@register_tool` 109 | decorated function to complete the registration. For tool declarations, the function name is the name of the tool, and 110 | the function docstring 111 | is the description of the tool; for tool parameters, use `Annotated[typ: type, description: str, required: bool]` to 112 | annotate the parameter type, description, and whether it is required. 113 | 114 | For example, the registration of the `get_weather` tool is as follows: 115 | 116 | ```python 117 | @register_tool 118 | def get_weather( 119 | city_name: Annotated[str, 'The name of the city to be queried', True], 120 | ) -> str: 121 | 122 | 123 | """ 124 | Get the weather for `city_name` in the following week 125 | """ 126 | ... 127 | ``` 128 | 129 | This mode is compatible with the tool registration process of ChatGLM3-6B. 130 | 131 | + Code capability, drawing capability, and networking capability have been automatically integrated. Users only need to 132 | configure the corresponding Key as required. 133 | + System prompt words are not supported in this mode. The model will automatically build prompt words. 134 | 135 | ## Text interpretation mode 136 | 137 | Users can upload documents and use the long text capability of GLM-4-9B to understand the text. It can parse pptx, docx, 138 | pdf and other files. 139 | 140 | + Tool calls and system prompt words are not supported in this mode. 141 | + If the text is very long, the model may require a high amount of GPU memory. Please confirm your hardware 142 | configuration. 143 | 144 | ## Image Understanding Mode 145 | 146 | Users can upload images and use the image understanding capabilities of GLM-4-9B to understand the images. 147 | 148 | + This mode must use the glm-4v-9b model. 149 | + Tool calls and system prompts are not supported in this mode. 150 | + The model can only understand and communicate with one image. If you need to change the image, you need to open a new 151 | conversation. 152 | + The supported image resolution is 1120 x 1120 153 | -------------------------------------------------------------------------------- /demo/composite_demo/assets/cogview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zRzRzRzRzRzRzR/GLM-4/2e5efb1365eda83bead08afffb7a884d56724f7a/demo/composite_demo/assets/cogview.png -------------------------------------------------------------------------------- /demo/composite_demo/assets/demo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zRzRzRzRzRzRzR/GLM-4/2e5efb1365eda83bead08afffb7a884d56724f7a/demo/composite_demo/assets/demo.png -------------------------------------------------------------------------------- /demo/composite_demo/assets/doc_reader.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zRzRzRzRzRzRzR/GLM-4/2e5efb1365eda83bead08afffb7a884d56724f7a/demo/composite_demo/assets/doc_reader.png -------------------------------------------------------------------------------- /demo/composite_demo/assets/tool.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zRzRzRzRzRzRzR/GLM-4/2e5efb1365eda83bead08afffb7a884d56724f7a/demo/composite_demo/assets/tool.png -------------------------------------------------------------------------------- /demo/composite_demo/assets/vlm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zRzRzRzRzRzRzR/GLM-4/2e5efb1365eda83bead08afffb7a884d56724f7a/demo/composite_demo/assets/vlm.png -------------------------------------------------------------------------------- /demo/composite_demo/assets/weather.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zRzRzRzRzRzRzR/GLM-4/2e5efb1365eda83bead08afffb7a884d56724f7a/demo/composite_demo/assets/weather.png -------------------------------------------------------------------------------- /demo/composite_demo/assets/web_plot_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zRzRzRzRzRzRzR/GLM-4/2e5efb1365eda83bead08afffb7a884d56724f7a/demo/composite_demo/assets/web_plot_1.png -------------------------------------------------------------------------------- /demo/composite_demo/assets/web_plot_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zRzRzRzRzRzRzR/GLM-4/2e5efb1365eda83bead08afffb7a884d56724f7a/demo/composite_demo/assets/web_plot_2.png -------------------------------------------------------------------------------- /demo/composite_demo/browser/.gitignore: -------------------------------------------------------------------------------- 1 | # Created by https://www.toptal.com/developers/gitignore/api/node 2 | # Edit at https://www.toptal.com/developers/gitignore?templates=node 3 | 4 | ### Node ### 5 | # Logs 6 | logs 7 | *.log 8 | npm-debug.log* 9 | yarn-debug.log* 10 | yarn-error.log* 11 | lerna-debug.log* 12 | .pnpm-debug.log* 13 | 14 | # Diagnostic reports (https://nodejs.org/api/report.html) 15 | report.[0-9]*.[0-9]*.[0-9]*.[0-9]*.json 16 | 17 | # Runtime data 18 | pids 19 | *.pid 20 | *.seed 21 | *.pid.lock 22 | 23 | # Directory for instrumented libs generated by jscoverage/JSCover 24 | lib-cov 25 | 26 | # Coverage directory used by tools like istanbul 27 | coverage 28 | *.lcov 29 | 30 | # nyc test coverage 31 | .nyc_output 32 | 33 | # Grunt intermediate storage (https://gruntjs.com/creating-plugins#storing-task-files) 34 | .grunt 35 | 36 | # Bower dependency directory (https://bower.io/) 37 | bower_components 38 | 39 | # node-waf configuration 40 | .lock-wscript 41 | 42 | # Compiled binary addons (https://nodejs.org/api/addons.html) 43 | build/Release 44 | 45 | # Dependency directories 46 | node_modules/ 47 | jspm_packages/ 48 | 49 | # Snowpack dependency directory (https://snowpack.dev/) 50 | web_modules/ 51 | 52 | # TypeScript cache 53 | *.tsbuildinfo 54 | 55 | # Optional npm cache directory 56 | .npm 57 | 58 | # Optional eslint cache 59 | .eslintcache 60 | 61 | # Optional stylelint cache 62 | .stylelintcache 63 | 64 | # Microbundle cache 65 | .rpt2_cache/ 66 | .rts2_cache_cjs/ 67 | .rts2_cache_es/ 68 | .rts2_cache_umd/ 69 | 70 | # Optional REPL history 71 | .node_repl_history 72 | 73 | # Output of 'npm pack' 74 | *.tgz 75 | 76 | # Yarn Integrity file 77 | .yarn-integrity 78 | 79 | # dotenv environment variable files 80 | .env 81 | .env.development.local 82 | .env.test.local 83 | .env.production.local 84 | .env.local 85 | 86 | # parcel-bundler cache (https://parceljs.org/) 87 | .cache 88 | .parcel-cache 89 | 90 | # Next.js build output 91 | .next 92 | out 93 | 94 | # Nuxt.js build / generate output 95 | .nuxt 96 | dist 97 | 98 | # Gatsby files 99 | .cache/ 100 | # Comment in the public line in if your project uses Gatsby and not Next.js 101 | # https://nextjs.org/blog/next-9-1#public-directory-support 102 | # public 103 | 104 | # vuepress build output 105 | .vuepress/dist 106 | 107 | # vuepress v2.x temp and cache directory 108 | .temp 109 | 110 | # Docusaurus cache and generated files 111 | .docusaurus 112 | 113 | # Serverless directories 114 | .serverless/ 115 | 116 | # FuseBox cache 117 | .fusebox/ 118 | 119 | # DynamoDB Local files 120 | .dynamodb/ 121 | 122 | # TernJS port file 123 | .tern-port 124 | 125 | # Stores VSCode versions used for testing VSCode extensions 126 | .vscode-test 127 | 128 | # yarn v2 129 | .yarn/cache 130 | .yarn/unplugged 131 | .yarn/build-state.yml 132 | .yarn/install-state.gz 133 | .pnp.* 134 | 135 | ### Node Patch ### 136 | # Serverless Webpack directories 137 | .webpack/ 138 | 139 | # Optional stylelint cache 140 | 141 | # SvelteKit build / generate output 142 | .svelte-kit 143 | 144 | # End of https://www.toptal.com/developers/gitignore/api/node 145 | -------------------------------------------------------------------------------- /demo/composite_demo/browser/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "glm4-browser", 3 | "version": "1.0.0", 4 | "description": "Browser system for GLM-4", 5 | "main": "src/server.ts", 6 | "scripts": { 7 | "dev": "npx nodemon src/server", 8 | "start": "npx ts-node src/server.ts" 9 | }, 10 | "license": "MIT", 11 | "dependencies": { 12 | "express": "^4.18.3", 13 | "jsdom": "^24.0.0", 14 | "pnpm": "^9.1.2", 15 | "turndown": "^7.1.2", 16 | "winston": "^3.11.0" 17 | }, 18 | "devDependencies": { 19 | "@types/express": "^4.17.21", 20 | "@types/jsdom": "^21.1.6", 21 | "@types/node": "^20.11.20", 22 | "@types/turndown": "^5.0.4", 23 | "nodemon": "^3.1.0", 24 | "ts-node": "^10.9.2" 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /demo/composite_demo/browser/src/config.ts: -------------------------------------------------------------------------------- 1 | export default { 2 | LOG_LEVEL: 'debug', 3 | BROWSER_TIMEOUT: 10000, 4 | BING_SEARCH_API_URL: 'https://api.bing.microsoft.com/v7.0/custom/', 5 | BING_SEARCH_API_KEY: 'YOUR_BING_SEARCH_API_KEY', 6 | CUSTOM_CONFIG_ID : 'YOUR_CUSTOM_CONFIG_ID', //将您的Custom Configuration ID放在此处 7 | HOST: 'localhost', 8 | PORT: 3000, 9 | }; 10 | -------------------------------------------------------------------------------- /demo/composite_demo/browser/src/server.ts: -------------------------------------------------------------------------------- 1 | import express, { Express, Request, Response } from 'express'; 2 | 3 | import { SimpleBrowser } from './browser'; 4 | import config from './config'; 5 | import { logger } from './utils'; 6 | 7 | const session_history: Record = {}; 8 | 9 | const app: Express = express(); 10 | 11 | app.use(express.json()); 12 | 13 | app.post('/', async (req: Request, res: Response) => { 14 | const { 15 | session_id, 16 | action, 17 | }: { 18 | session_id: string; 19 | action: string; 20 | } = req.body; 21 | logger.info(`session_id: ${session_id}`); 22 | logger.info(`action: ${action}`); 23 | 24 | if (!session_history[session_id]) { 25 | session_history[session_id] = new SimpleBrowser(); 26 | } 27 | 28 | const browser = session_history[session_id]; 29 | 30 | try { 31 | res.json(await browser.action(action)); 32 | } catch (err) { 33 | logger.error(err); 34 | res.status(400).json(err); 35 | } 36 | }) 37 | 38 | process.on('SIGINT', () => { 39 | process.exit(0); 40 | }); 41 | 42 | process.on('uncaughtException', e => { 43 | logger.error(e); 44 | }); 45 | 46 | const { HOST, PORT } = config; 47 | 48 | (async () => { 49 | app.listen(PORT, HOST, () => { 50 | logger.info(`⚡️[server]: Server is running at http://${HOST}:${PORT}`); 51 | try { 52 | (process).send('ready'); 53 | } catch (err) {} 54 | }); 55 | })(); 56 | -------------------------------------------------------------------------------- /demo/composite_demo/browser/src/types.ts: -------------------------------------------------------------------------------- 1 | export interface File { 2 | id: string; 3 | name: string; 4 | size: number; 5 | } 6 | 7 | export interface Metadata { 8 | files?: File[]; 9 | reference?: string; 10 | } 11 | 12 | export interface Message { 13 | role: 'user' | 'assistant' | 'system' | 'observation'; 14 | metadata: string; 15 | content: string; 16 | request_metadata?: Metadata; 17 | } 18 | 19 | export interface ToolObservation { 20 | contentType: string; 21 | result: string; 22 | text?: string; 23 | roleMetadata?: string; // metadata for <|observation|>${metadata} 24 | metadata: any; // metadata for response 25 | } 26 | -------------------------------------------------------------------------------- /demo/composite_demo/browser/src/utils.ts: -------------------------------------------------------------------------------- 1 | import winston from 'winston'; 2 | 3 | import config from './config'; 4 | 5 | export class TimeoutError extends Error {} 6 | 7 | const logLevel = config.LOG_LEVEL; 8 | 9 | export const logger = winston.createLogger({ 10 | level: logLevel, 11 | format: winston.format.combine( 12 | winston.format.colorize(), 13 | winston.format.printf(info => { 14 | return `${info.level}: ${info.message}`; 15 | }), 16 | ), 17 | transports: [new winston.transports.Console()], 18 | }); 19 | 20 | console.log('LOG_LEVEL', logLevel); 21 | 22 | export const parseHrtimeToMillisecond = (hrtime: [number, number]): number => { 23 | return (hrtime[0] + hrtime[1] / 1e9) * 1000; 24 | }; 25 | 26 | export const promiseWithTime = ( 27 | promise: Promise 28 | ): Promise<{ 29 | value: T; 30 | time: number; 31 | }> => { 32 | return new Promise((resolve, reject) => { 33 | const startTime = process.hrtime(); 34 | promise 35 | .then(value => { 36 | resolve({ 37 | value: value, 38 | time: parseHrtimeToMillisecond(process.hrtime(startTime)) 39 | }); 40 | }) 41 | .catch(err => reject(err)); 42 | }); 43 | }; 44 | 45 | export const withTimeout = ( 46 | millis: number, 47 | promise: Promise 48 | ): Promise<{ 49 | value: T; 50 | time: number; 51 | }> => { 52 | const timeout = new Promise<{ value: T; time: number }>((_, reject) => 53 | setTimeout(() => reject(new TimeoutError()), millis) 54 | ); 55 | return Promise.race([promiseWithTime(promise), timeout]); 56 | }; 57 | -------------------------------------------------------------------------------- /demo/composite_demo/browser/tsconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "compilerOptions": { 3 | "target": "es2022", 4 | "lib": ["es2022", "dom"], 5 | "module": "commonjs", 6 | "rootDir": "./", 7 | "outDir": "./dist", 8 | "esModuleInterop": true, 9 | "forceConsistentCasingInFileNames": true, 10 | "strict": true, 11 | }, 12 | "ts-node": { 13 | "transpileOnly": true 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /demo/composite_demo/requirements.txt: -------------------------------------------------------------------------------- 1 | # Please install the requirments.txt in inference first! 2 | 3 | ipykernel>=6.26.0 4 | ipython>=8.18.1 5 | jupyter_client>=8.6.0 6 | langchain>=0.2.12 7 | langchain-community>=0.2.11 8 | matplotlib>=3.9.1 9 | pymupdf>=1.24.9 10 | python-docx>=1.1.2 11 | python-pptx>=0.6.23 12 | pyyaml>=6.0.1 13 | requests>=2.31.0 14 | streamlit>=1.37.1 15 | zhipuai>=2.1.4 16 | -------------------------------------------------------------------------------- /demo/composite_demo/src/client.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | This is a client part of composite_demo. 4 | We provide two clients, HFClient and VLLMClient, which are used to interact with the model. 5 | The HFClient is used to interact with the transformers backend, and the VLLMClient is used to interact with the VLLM model. 6 | 7 | """ 8 | 9 | import json 10 | from collections.abc import Generator 11 | from copy import deepcopy 12 | from enum import Enum, auto 13 | from typing import Protocol 14 | 15 | import streamlit as st 16 | from conversation import Conversation, build_system_prompt 17 | from tools.tool_registry import ALL_TOOLS 18 | 19 | 20 | class ClientType(Enum): 21 | HF = auto() 22 | VLLM = auto() 23 | API = auto() 24 | 25 | 26 | class Client(Protocol): 27 | def __init__(self, model_path: str): ... 28 | 29 | def generate_stream( 30 | self, 31 | tools: list[dict], 32 | history: list[Conversation], 33 | **parameters, 34 | ) -> Generator[tuple[str | dict, list[dict]]]: ... 35 | 36 | 37 | def process_input(history: list[dict], tools: list[dict], role_name_replace: dict = None) -> list[dict]: 38 | chat_history = [] 39 | # if len(tools) > 0: 40 | chat_history.append({"role": "system", "content": build_system_prompt(list(ALL_TOOLS), tools)}) 41 | 42 | for conversation in history: 43 | role = str(conversation.role).removeprefix("<|").removesuffix("|>") 44 | if role_name_replace: 45 | role = role_name_replace.get(role, role) 46 | item = { 47 | "role": role, 48 | "content": conversation.content, 49 | } 50 | if conversation.metadata: 51 | item["metadata"] = conversation.metadata 52 | # Only append image for user 53 | if role == "user" and conversation.image: 54 | item["image"] = conversation.image 55 | chat_history.append(item) 56 | 57 | return chat_history 58 | 59 | 60 | def process_response(output, history): 61 | content = "" 62 | history = deepcopy(history) 63 | for response in output.split("<|assistant|>"): 64 | if "\n" in response: 65 | metadata, content = response.split("\n", maxsplit=1) 66 | else: 67 | metadata, content = "", response 68 | if not metadata.strip(): 69 | content = content.strip() 70 | history.append({"role": "assistant", "metadata": metadata, "content": content}) 71 | content = content.replace("[[训练时间]]", "2023年") 72 | else: 73 | history.append({"role": "assistant", "metadata": metadata, "content": content}) 74 | if history[0]["role"] == "system" and "tools" in history[0]: 75 | parameters = json.loads(content) 76 | content = {"name": metadata.strip(), "parameters": parameters} 77 | else: 78 | content = {"name": metadata.strip(), "content": content} 79 | return content, history 80 | 81 | 82 | # glm-4v-9b is not available in vLLM backend, use HFClient instead. 83 | @st.cache_resource(max_entries=1, show_spinner="Loading model...") 84 | def get_client(model_path, typ: ClientType) -> Client: 85 | match typ: 86 | case ClientType.HF: 87 | from clients.hf import HFClient 88 | 89 | return HFClient(model_path) 90 | case ClientType.VLLM: 91 | try: 92 | from clients.vllm import VLLMClient 93 | except ImportError as e: 94 | e.msg += "; did you forget to install vLLM?" 95 | raise 96 | return VLLMClient(model_path) 97 | case ClientType.API: 98 | from clients.openai import APIClient 99 | 100 | return APIClient(model_path) 101 | 102 | raise NotImplementedError(f"Client type {typ} is not supported.") 103 | -------------------------------------------------------------------------------- /demo/composite_demo/src/clients/hf.py: -------------------------------------------------------------------------------- 1 | """ 2 | HuggingFace client. 3 | """ 4 | 5 | from collections.abc import Generator 6 | from threading import Thread 7 | 8 | import torch 9 | from client import Client, process_input, process_response 10 | from conversation import Conversation 11 | from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer 12 | 13 | 14 | class HFClient(Client): 15 | def __init__(self, model_path: str): 16 | self.tokenizer = AutoTokenizer.from_pretrained( 17 | model_path, 18 | trust_remote_code=True, 19 | ) 20 | self.model = AutoModelForCausalLM.from_pretrained( 21 | model_path, 22 | torch_dtype=torch.bfloat16, 23 | device_map="cuda", 24 | ).eval() 25 | 26 | def generate_stream( 27 | self, 28 | tools: list[dict], 29 | history: list[Conversation], 30 | **parameters, 31 | ) -> Generator[tuple[str | dict, list[dict]]]: 32 | chat_history = process_input(history, tools) 33 | model_inputs = self.tokenizer.apply_chat_template( 34 | chat_history, 35 | add_generation_prompt=True, 36 | tokenize=True, 37 | return_tensors="pt", 38 | return_dict=True, 39 | ).to(self.model.device) 40 | streamer = TextIteratorStreamer( 41 | tokenizer=self.tokenizer, 42 | timeout=5, 43 | skip_prompt=True, 44 | ) 45 | generate_kwargs = { 46 | **model_inputs, 47 | "streamer": streamer, 48 | "eos_token_id": [151329, 151336, 151338], 49 | "do_sample": True, 50 | } 51 | generate_kwargs.update(parameters) 52 | t = Thread(target=self.model.generate, kwargs=generate_kwargs) 53 | t.start() 54 | total_text = "" 55 | for token_text in streamer: 56 | total_text += token_text 57 | yield process_response(total_text, chat_history) 58 | -------------------------------------------------------------------------------- /demo/composite_demo/src/clients/openai.py: -------------------------------------------------------------------------------- 1 | """ 2 | OpenAI API client. 3 | """ 4 | 5 | from collections.abc import Generator 6 | 7 | from client import Client, process_input, process_response 8 | from conversation import Conversation 9 | from openai import OpenAI 10 | 11 | 12 | def format_openai_tool(origin_tools): 13 | openai_tools = [] 14 | for tool in origin_tools: 15 | openai_param = {} 16 | for param in tool["params"]: 17 | openai_param[param["name"]] = {} 18 | openai_tool = { 19 | "type": "function", 20 | "function": { 21 | "name": tool["name"], 22 | "description": tool["description"], 23 | "parameters": { 24 | "type": "object", 25 | "properties": { 26 | param["name"]: {"type": param["type"], "description": param["description"]} 27 | for param in tool["params"] 28 | }, 29 | "required": [param["name"] for param in tool["params"] if param["required"]], 30 | }, 31 | }, 32 | } 33 | openai_tools.append(openai_tool) 34 | return openai_tools 35 | 36 | 37 | class APIClient(Client): 38 | def __init__(self, model_path: str): 39 | base_url = "http://127.0.0.1:8000/v1/" 40 | self.client = OpenAI(api_key="EMPTY", base_url=base_url) 41 | self.use_stream = False 42 | self.role_name_replace = {"observation": "tool"} 43 | 44 | def generate_stream( 45 | self, 46 | tools: list[dict], 47 | history: list[Conversation], 48 | **parameters, 49 | ) -> Generator[tuple[str | dict, list[dict]]]: 50 | chat_history = process_input(history, "", role_name_replace=self.role_name_replace) 51 | # messages = process_input(history, '', role_name_replace=self.role_name_replace) 52 | openai_tools = format_openai_tool(tools) 53 | response = self.client.chat.completions.create( 54 | model="glm-4", 55 | messages=chat_history, 56 | tools=openai_tools, 57 | stream=self.use_stream, 58 | max_tokens=parameters["max_new_tokens"], 59 | temperature=parameters["temperature"], 60 | presence_penalty=1.2, 61 | top_p=parameters["top_p"], 62 | tool_choice="auto", 63 | ) 64 | output = response.choices[0].message 65 | if output.tool_calls: 66 | glm4_output = output.tool_calls[0].function.name + "\n" + output.tool_calls[0].function.arguments 67 | else: 68 | glm4_output = output.content 69 | yield process_response(glm4_output, chat_history) 70 | -------------------------------------------------------------------------------- /demo/composite_demo/src/clients/vllm.py: -------------------------------------------------------------------------------- 1 | """ 2 | vLLM client. 3 | 4 | Please install [vLLM](https://github.com/vllm-project/vllm) according to its 5 | installation guide before running this client. 6 | """ 7 | 8 | import time 9 | from collections.abc import Generator 10 | 11 | from client import Client, process_input, process_response 12 | from conversation import Conversation 13 | from transformers import AutoTokenizer 14 | from vllm import EngineArgs, LLMEngine, SamplingParams 15 | 16 | 17 | class VLLMClient(Client): 18 | def __init__(self, model_path: str): 19 | self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) 20 | self.engine_args = EngineArgs( 21 | model=model_path, 22 | tensor_parallel_size=1, 23 | dtype="bfloat16", # torch.bfloat16 is needed. 24 | gpu_memory_utilization=0.6, 25 | enforce_eager=True, 26 | worker_use_ray=False, 27 | ) 28 | self.engine = LLMEngine.from_engine_args(self.engine_args) 29 | 30 | def generate_stream( 31 | self, tools: list[dict], history: list[Conversation], **parameters 32 | ) -> Generator[tuple[str | dict, list[dict]]]: 33 | chat_history = process_input(history, tools) 34 | model_inputs = self.tokenizer.apply_chat_template(chat_history, add_generation_prompt=True, tokenize=False) 35 | parameters["max_tokens"] = parameters.pop("max_new_tokens") 36 | params_dict = { 37 | "n": 1, 38 | "best_of": 1, 39 | "top_p": 1, 40 | "top_k": -1, 41 | "length_penalty": 1, 42 | "stop_token_ids": [151329, 151336, 151338], 43 | } 44 | params_dict.update(parameters) 45 | sampling_params = SamplingParams(**params_dict) 46 | 47 | self.engine.add_request(request_id=str(time.time()), inputs=model_inputs, params=sampling_params) 48 | while self.engine.has_unfinished_requests(): 49 | request_outputs = self.engine.step() 50 | for request_output in request_outputs: 51 | yield process_response(request_output.outputs[0].text, chat_history) 52 | -------------------------------------------------------------------------------- /demo/composite_demo/src/conversation.py: -------------------------------------------------------------------------------- 1 | import json 2 | import re 3 | from dataclasses import dataclass 4 | from datetime import datetime 5 | from enum import Enum, auto 6 | 7 | import streamlit as st 8 | from PIL.Image import Image 9 | from streamlit.delta_generator import DeltaGenerator 10 | from tools.browser import Quote, quotes 11 | 12 | 13 | QUOTE_REGEX = re.compile(r"【(\d+)†(.+?)】") 14 | 15 | SELFCOG_PROMPT = "你是一个名为 GLM-4 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,你的任务是针对用户的问题和要求提供适当的答复和支持。" 16 | DATE_PROMPT = "当前日期: %Y-%m-%d" 17 | TOOL_SYSTEM_PROMPTS = { 18 | "python": "当你向 `python` 发送包含 Python 代码的消息时,该代码将会在一个有状态的 Jupyter notebook 环境中执行。\n`python` 返回代码执行的输出,或在执行 60 秒后返回超时。\n`/mnt/data` 将会持久化存储你的文件。在此会话中,`python` 无法访问互联网。不要使用 `python` 进行任何网络请求或者在线 API 调用,这些在线内容的访问将不会成功。", 19 | "simple_browser": "你可以使用 `simple_browser` 工具。该工具支持以下函数:\n`search(query: str, recency_days: int)`:使用搜索引擎进行查询并显示结果,可以使用 `recency_days` 参数控制搜索内容的时效性。\n`mclick(ids: list[int])`:获取一系列指定 id 的页面内容。每次调用时,须选择3-10个页面。选择多个角度的页面,同时尽可能选择可信任的信息来源。考虑到部分页面是无法加载的,你也可以多打开一些可能有用的页面而不用担心内容过多。\n`open_url(url: str)`:打开指定的 URL。\n\n使用 `【{引用 id}†{引用文本}】` 来引用内容。\n\n操作步骤:1. 使用 `search` 来获得信息列表; 2. 使用 `mclick` 来获取指定 ID 页面的内容; 3. 根据获得的内容进行回复。在回复中应当引用信息来源。\n 如果用户提供了 URL,也可以用 `open_url` 直接打开页面。\n如果初次搜索结果没有找到合适的信息,也可以再次使用 `search` 进行搜索。", 20 | "cogview": "如果用户的请求中包含了对图像的描述,你可以使用 `cogview` 来生成图像并展示给用户。你需要向 `cogview` 发送图像描述,规则:\n- 发送给 `cogview` 的消息必须使用英语。用户的中文描述必须完全翻译为英语。\n- 应当尽可能详细地描述图像生成的需求,需求描述约 100 英文单词。\n- 保持用户原始描述的意图。不要虚构内容或者没见过的人物。\n- 如无特殊说明,所在地为中国,持有中国立场并遵循中国社会主义价值观。", 21 | } 22 | 23 | FILE_TEMPLATE = "[File Name]\n{file_name}\n[File Content]\n{file_content}" 24 | 25 | 26 | def build_system_prompt( 27 | enabled_tools: list[str], 28 | functions: list[dict], 29 | ): 30 | value = SELFCOG_PROMPT 31 | value += "\n\n" + datetime.now().strftime(DATE_PROMPT) 32 | if enabled_tools or functions: 33 | value += "\n\n# 可用工具" 34 | contents = [] 35 | for tool in enabled_tools: 36 | contents.append(f"\n\n## {tool}\n\n{TOOL_SYSTEM_PROMPTS[tool]}") 37 | for function in functions: 38 | content = f"\n\n## {function['name']}\n\n{json.dumps(function, ensure_ascii=False, indent=4)}" 39 | content += "\n在调用上述函数时,请使用 Json 格式表示调用的参数。" 40 | contents.append(content) 41 | value += "".join(contents) 42 | return value 43 | 44 | 45 | def response_to_str(response: str | dict[str, str]) -> str: 46 | """ 47 | Convert response to string. 48 | """ 49 | if isinstance(response, dict): 50 | return response.get("name", "") + response.get("content", "") 51 | return response 52 | 53 | 54 | class Role(Enum): 55 | SYSTEM = auto() 56 | USER = auto() 57 | ASSISTANT = auto() 58 | TOOL = auto() 59 | OBSERVATION = auto() 60 | 61 | def __str__(self): 62 | match self: 63 | case Role.SYSTEM: 64 | return "<|system|>" 65 | case Role.USER: 66 | return "<|user|>" 67 | case Role.ASSISTANT | Role.TOOL: 68 | return "<|assistant|>" 69 | case Role.OBSERVATION: 70 | return "<|observation|>" 71 | 72 | # Get the message block for the given role 73 | def get_message(self): 74 | # Compare by value here, because the enum object in the session state 75 | # is not the same as the enum cases here, due to streamlit's rerunning 76 | # behavior. 77 | match self.value: 78 | case Role.SYSTEM.value: 79 | return 80 | case Role.USER.value: 81 | return st.chat_message(name="user", avatar="user") 82 | case Role.ASSISTANT.value: 83 | return st.chat_message(name="assistant", avatar="assistant") 84 | case Role.TOOL.value: 85 | return st.chat_message(name="tool", avatar="assistant") 86 | case Role.OBSERVATION.value: 87 | return st.chat_message(name="observation", avatar="assistant") 88 | case _: 89 | st.error(f"Unexpected role: {self}") 90 | 91 | 92 | @dataclass 93 | class Conversation: 94 | role: Role 95 | content: str | dict 96 | # Processed content 97 | saved_content: str | None = None 98 | metadata: str | None = None 99 | image: str | Image | None = None 100 | 101 | def __str__(self) -> str: 102 | metadata_str = self.metadata if self.metadata else "" 103 | return f"{self.role}{metadata_str}\n{self.content}" 104 | 105 | # Human readable format 106 | def get_text(self) -> str: 107 | text = self.saved_content or self.content 108 | match self.role.value: 109 | case Role.TOOL.value: 110 | text = f"Calling tool `{self.metadata}`:\n\n```python\n{text}\n```" 111 | case Role.OBSERVATION.value: 112 | text = f"```python\n{text}\n```" 113 | return text 114 | 115 | # Display as a markdown block 116 | def show(self, placeholder: DeltaGenerator | None = None) -> str: 117 | if placeholder: 118 | message = placeholder 119 | else: 120 | message = self.role.get_message() 121 | 122 | if self.image: 123 | message.image(self.image, width=512) 124 | 125 | if self.role == Role.OBSERVATION: 126 | metadata_str = f"from {self.metadata}" if self.metadata else "" 127 | message = message.expander(f"Observation {metadata_str}") 128 | 129 | text = self.get_text() 130 | if self.role != Role.USER: 131 | show_text = text 132 | else: 133 | splitted = text.split("files uploaded.\n") 134 | if len(splitted) == 1: 135 | show_text = text 136 | else: 137 | # Show expander for document content 138 | doc = splitted[0] 139 | show_text = splitted[-1] 140 | expander = message.expander("File Content") 141 | expander.markdown(doc) 142 | message.markdown(show_text) 143 | 144 | 145 | def postprocess_text(text: str, replace_quote: bool) -> str: 146 | text = text.replace(r"\(", "$") 147 | text = text.replace(r"\)", "$") 148 | text = text.replace(r"\[", "$$") 149 | text = text.replace(r"\]", "$$") 150 | text = text.replace("<|assistant|>", "") 151 | text = text.replace("<|observation|>", "") 152 | text = text.replace("<|system|>", "") 153 | text = text.replace("<|user|>", "") 154 | text = text.replace("<|endoftext|>", "") 155 | 156 | # Replace quotes 157 | if replace_quote: 158 | for match in QUOTE_REGEX.finditer(text): 159 | quote_id = match.group(1) 160 | quote = quotes.get(quote_id, Quote("未找到引用内容", "")) 161 | text = text.replace(match.group(0), f" (来源:[{quote.title}]({quote.url})) ") 162 | 163 | return text.strip() 164 | -------------------------------------------------------------------------------- /demo/composite_demo/src/main.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | This demo show the All tools and Long Context chat Capabilities of GLM-4. 4 | Please follow the Readme.md to run the demo. 5 | 6 | """ 7 | 8 | import os 9 | import traceback 10 | from enum import Enum 11 | from io import BytesIO 12 | from uuid import uuid4 13 | 14 | import streamlit as st 15 | from client import Client, ClientType, get_client 16 | from conversation import ( 17 | FILE_TEMPLATE, 18 | Conversation, 19 | Role, 20 | postprocess_text, 21 | response_to_str, 22 | ) 23 | from PIL import Image 24 | from streamlit.delta_generator import DeltaGenerator 25 | from tools.tool_registry import dispatch_tool, get_tools 26 | from utils import extract_docx, extract_pdf, extract_pptx, extract_text 27 | 28 | 29 | CHAT_MODEL_PATH = os.environ.get("CHAT_MODEL_PATH", "THUDM/glm-4-9b-chat") 30 | VLM_MODEL_PATH = os.environ.get("VLM_MODEL_PATH", "THUDM/glm-4v-9b") 31 | 32 | USE_VLLM = os.environ.get("USE_VLLM", "0") == "1" 33 | USE_API = os.environ.get("USE_API", "0") == "1" 34 | 35 | 36 | class Mode(str, Enum): 37 | ALL_TOOLS = "🛠️ All Tools" 38 | LONG_CTX = "📝 文档解读" 39 | VLM = "🖼️ 多模态" 40 | 41 | 42 | def append_conversation( 43 | conversation: Conversation, 44 | history: list[Conversation], 45 | placeholder: DeltaGenerator | None = None, 46 | ) -> None: 47 | """ 48 | Append a conversation piece into history, meanwhile show it in a new markdown block 49 | """ 50 | history.append(conversation) 51 | conversation.show(placeholder) 52 | 53 | 54 | st.set_page_config( 55 | page_title="GLM-4 Demo", 56 | page_icon=":robot:", 57 | layout="centered", 58 | initial_sidebar_state="expanded", 59 | ) 60 | 61 | st.title("GLM-4 Demo") 62 | st.markdown( 63 | "智谱AI 公开在线技术文档: https://zhipu-ai.feishu.cn/wiki/RuMswanpkiRh3Ok4z5acOABBnjf \n\n 更多 GLM-4 开源模型的使用方法请参考文档。", 64 | unsafe_allow_html=True, 65 | ) 66 | 67 | with st.sidebar: 68 | top_p = st.slider("top_p", 0.0, 1.0, 0.8, step=0.01) 69 | top_k = st.slider("top_k", 1, 20, 10, step=1, key="top_k") 70 | temperature = st.slider("temperature", 0.0, 1.5, 0.95, step=0.01) 71 | repetition_penalty = st.slider("repetition_penalty", 0.0, 2.0, 1.0, step=0.01) 72 | max_new_tokens = st.slider("max_new_tokens", 1, 4096, 2048, step=1) 73 | cols = st.columns(2) 74 | export_btn = cols[0] 75 | clear_history = cols[1].button("Clear", use_container_width=True) 76 | retry = export_btn.button("Retry", use_container_width=True) 77 | 78 | if clear_history: 79 | page = st.session_state.page 80 | client = st.session_state.client 81 | st.session_state.clear() 82 | st.session_state.page = page 83 | st.session_state.client = client 84 | st.session_state.files_uploaded = False 85 | st.session_state.uploaded_texts = "" 86 | st.session_state.uploaded_file_nums = 0 87 | st.session_state.history = [] 88 | 89 | if "files_uploaded" not in st.session_state: 90 | st.session_state.files_uploaded = False 91 | 92 | if "session_id" not in st.session_state: 93 | st.session_state.session_id = uuid4() 94 | 95 | if "history" not in st.session_state: 96 | st.session_state.history = [] 97 | 98 | first_round = len(st.session_state.history) == 0 99 | 100 | 101 | def build_client(mode: Mode) -> Client: 102 | match mode: 103 | case Mode.ALL_TOOLS: 104 | st.session_state.top_k = 10 105 | typ = ClientType.VLLM if USE_VLLM else ClientType.HF 106 | typ = ClientType.API if USE_API else typ 107 | return get_client(CHAT_MODEL_PATH, typ) 108 | case Mode.LONG_CTX: 109 | st.session_state.top_k = 10 110 | typ = ClientType.VLLM if USE_VLLM else ClientType.HF 111 | return get_client(CHAT_MODEL_PATH, typ) 112 | case Mode.VLM: 113 | st.session_state.top_k = 1 114 | # vLLM is not available for VLM mode 115 | return get_client(VLM_MODEL_PATH, ClientType.HF) 116 | 117 | 118 | # Callback function for page change 119 | def page_changed() -> None: 120 | global client 121 | new_page: str = st.session_state.page 122 | st.session_state.history.clear() 123 | st.session_state.client = build_client(Mode(new_page)) 124 | 125 | 126 | page = st.radio( 127 | "选择功能", 128 | [mode.value for mode in Mode], 129 | key="page", 130 | horizontal=True, 131 | index=None, 132 | label_visibility="hidden", 133 | on_change=page_changed, 134 | ) 135 | 136 | HELP = """ 137 | ### 🎉 欢迎使用 GLM-4! 138 | 139 | 请在上方选取一个功能。每次切换功能时,将会重新加载模型并清空对话历史。 140 | 141 | 文档解读模式与 VLM 模式仅支持在第一轮传入文档或图像。 142 | """.strip() 143 | 144 | if page is None: 145 | st.markdown(HELP) 146 | exit() 147 | 148 | if page == Mode.LONG_CTX: 149 | if first_round: 150 | uploaded_files = st.file_uploader( 151 | "上传文件", 152 | type=["pdf", "txt", "py", "docx", "pptx", "json", "cpp", "md"], 153 | accept_multiple_files=True, 154 | ) 155 | if uploaded_files and not st.session_state.files_uploaded: 156 | uploaded_texts = [] 157 | for uploaded_file in uploaded_files: 158 | file_name: str = uploaded_file.name 159 | random_file_name = str(uuid4()) 160 | file_extension = os.path.splitext(file_name)[1] 161 | file_path = os.path.join("/tmp", random_file_name + file_extension) 162 | with open(file_path, "wb") as f: 163 | f.write(uploaded_file.getbuffer()) 164 | if file_name.endswith(".pdf"): 165 | content = extract_pdf(file_path) 166 | elif file_name.endswith(".docx"): 167 | content = extract_docx(file_path) 168 | elif file_name.endswith(".pptx"): 169 | content = extract_pptx(file_path) 170 | else: 171 | content = extract_text(file_path) 172 | uploaded_texts.append(FILE_TEMPLATE.format(file_name=file_name, file_content=content)) 173 | os.remove(file_path) 174 | st.session_state.uploaded_texts = "\n\n".join(uploaded_texts) 175 | st.session_state.uploaded_file_nums = len(uploaded_files) 176 | else: 177 | st.session_state.uploaded_texts = "" 178 | st.session_state.uploaded_file_nums = 0 179 | elif page == Mode.VLM: 180 | if first_round: 181 | uploaded_image = st.file_uploader( 182 | "上传图片", 183 | type=["png", "jpg", "jpeg", "bmp", "tiff", "webp"], 184 | accept_multiple_files=False, 185 | ) 186 | if uploaded_image: 187 | data: bytes = uploaded_image.read() 188 | image = Image.open(BytesIO(data)).convert("RGB") 189 | st.session_state.uploaded_image = image 190 | else: 191 | st.session_state.uploaded_image = None 192 | 193 | prompt_text = st.chat_input("Chat with GLM-4!", key="chat_input") 194 | 195 | if prompt_text == "" and retry == False: 196 | print("\n== Clean ==\n") 197 | st.session_state.history = [] 198 | exit() 199 | 200 | history: list[Conversation] = st.session_state.history 201 | 202 | if retry: 203 | print("\n== Retry ==\n") 204 | last_user_conversation_idx = None 205 | for idx, conversation in enumerate(history): 206 | if conversation.role.value == Role.USER.value: 207 | last_user_conversation_idx = idx 208 | if last_user_conversation_idx is not None: 209 | prompt_text = history[last_user_conversation_idx].content 210 | print(f"New prompt: {prompt_text}, idx = {last_user_conversation_idx}") 211 | del history[last_user_conversation_idx:] 212 | 213 | for conversation in history: 214 | conversation.show() 215 | 216 | tools = get_tools() if page == Mode.ALL_TOOLS else [] 217 | 218 | client: Client = st.session_state.client 219 | 220 | 221 | def main(prompt_text: str): 222 | global client 223 | assert client is not None 224 | 225 | if prompt_text: 226 | prompt_text = prompt_text.strip() 227 | 228 | # Append uploaded files 229 | uploaded_texts = st.session_state.get("uploaded_texts") 230 | if page == Mode.LONG_CTX and uploaded_texts and first_round: 231 | meta_msg = "{} files uploaded.\n".format(st.session_state.uploaded_file_nums) 232 | prompt_text = uploaded_texts + "\n\n\n" + meta_msg + prompt_text 233 | # Clear after first use 234 | st.session_state.files_uploaded = True 235 | st.session_state.uploaded_texts = "" 236 | st.session_state.uploaded_file_nums = 0 237 | 238 | image = st.session_state.get("uploaded_image") 239 | if page == Mode.VLM and image and first_round: 240 | st.session_state.uploaded_image = None 241 | 242 | role = Role.USER 243 | append_conversation(Conversation(role, prompt_text, image=image), history) 244 | 245 | placeholder = st.container() 246 | message_placeholder = placeholder.chat_message(name="assistant", avatar="assistant") 247 | markdown_placeholder = message_placeholder.empty() 248 | 249 | def add_new_block(): 250 | nonlocal message_placeholder, markdown_placeholder 251 | message_placeholder = placeholder.chat_message(name="assistant", avatar="assistant") 252 | markdown_placeholder = message_placeholder.empty() 253 | 254 | def commit_conversation( 255 | role: Role, 256 | text: str, 257 | metadata: str | None = None, 258 | image: str | None = None, 259 | new: bool = False, 260 | ): 261 | processed_text = postprocess_text(text, role.value == Role.ASSISTANT.value) 262 | conversation = Conversation(role, text, processed_text, metadata, image) 263 | 264 | # Use different placeholder for new block 265 | placeholder = message_placeholder if new else markdown_placeholder 266 | 267 | append_conversation( 268 | conversation, 269 | history, 270 | placeholder, 271 | ) 272 | 273 | response = "" 274 | for _ in range(10): 275 | last_response = None 276 | history_len = None 277 | 278 | try: 279 | for response, chat_history in client.generate_stream( 280 | tools=tools, 281 | history=history, 282 | temperature=temperature, 283 | top_p=top_p, 284 | top_k=top_k, 285 | repetition_penalty=repetition_penalty, 286 | max_new_tokens=max_new_tokens, 287 | ): 288 | if history_len is None: 289 | history_len = len(chat_history) 290 | elif history_len != len(chat_history): 291 | commit_conversation(Role.ASSISTANT, last_response) 292 | add_new_block() 293 | history_len = len(chat_history) 294 | last_response = response 295 | replace_quote = chat_history[-1]["role"] == "assistant" 296 | markdown_placeholder.markdown(postprocess_text(str(response) + "●", replace_quote=replace_quote)) 297 | else: 298 | metadata = page == Mode.ALL_TOOLS and isinstance(response, dict) and response.get("name") or None 299 | role = Role.TOOL if metadata else Role.ASSISTANT 300 | text = response.get("content") if metadata else response_to_str(response) 301 | commit_conversation(role, text, metadata) 302 | if metadata: 303 | add_new_block() 304 | try: 305 | with markdown_placeholder: 306 | with st.spinner(f"Calling tool {metadata}..."): 307 | observations = dispatch_tool(metadata, text, str(st.session_state.session_id)) 308 | except Exception as e: 309 | traceback.print_exc() 310 | st.error(f'Uncaught exception in `"{metadata}"`: {e}') 311 | break 312 | 313 | for observation in observations: 314 | observation.text = observation.text 315 | commit_conversation( 316 | Role.OBSERVATION, 317 | observation.text, 318 | observation.role_metadata, 319 | observation.image_url, 320 | new=True, 321 | ) 322 | add_new_block() 323 | continue 324 | else: 325 | break 326 | except Exception: 327 | traceback.print_exc() 328 | st.error(f"Uncaught exception: {traceback.format_exc()}") 329 | else: 330 | st.error("Too many chaining function calls!") 331 | 332 | 333 | main(prompt_text) 334 | -------------------------------------------------------------------------------- /demo/composite_demo/src/tools/browser.py: -------------------------------------------------------------------------------- 1 | """ 2 | Simple browser tool. 3 | 4 | # Usage 5 | 6 | Please start the backend browser server according to the instructions in the README. 7 | """ 8 | 9 | import re 10 | from dataclasses import dataclass 11 | from pprint import pprint 12 | 13 | import requests 14 | import streamlit as st 15 | 16 | from .config import BROWSER_SERVER_URL 17 | from .interface import ToolObservation 18 | 19 | 20 | QUOTE_REGEX = re.compile(r"\[(\d+)†(.+?)\]") 21 | 22 | 23 | @dataclass 24 | class Quote: 25 | title: str 26 | url: str 27 | 28 | 29 | # Quotes for displaying reference 30 | if "quotes" not in st.session_state: 31 | st.session_state.quotes = {} 32 | 33 | quotes: dict[str, Quote] = st.session_state.quotes 34 | 35 | 36 | def map_response(response: dict) -> ToolObservation: 37 | # Save quotes for reference 38 | print("===BROWSER_RESPONSE===") 39 | pprint(response) 40 | role_metadata = response.get("roleMetadata") 41 | metadata = response.get("metadata") 42 | 43 | if role_metadata.split()[0] == "quote_result" and metadata: 44 | quote_id = QUOTE_REGEX.search(role_metadata.split()[1]).group(1) 45 | quote: dict[str, str] = metadata["metadata_list"][0] 46 | quotes[quote_id] = Quote(quote["title"], quote["url"]) 47 | elif role_metadata == "browser_result" and metadata: 48 | for i, quote in enumerate(metadata["metadata_list"]): 49 | quotes[str(i)] = Quote(quote["title"], quote["url"]) 50 | 51 | return ToolObservation( 52 | content_type=response.get("contentType"), 53 | text=response.get("result"), 54 | role_metadata=role_metadata, 55 | metadata=metadata, 56 | ) 57 | 58 | 59 | def tool_call(code: str, session_id: str) -> list[ToolObservation]: 60 | request = { 61 | "session_id": session_id, 62 | "action": code, 63 | } 64 | response = requests.post(BROWSER_SERVER_URL, json=request).json() 65 | return list(map(map_response, response)) 66 | -------------------------------------------------------------------------------- /demo/composite_demo/src/tools/cogview.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | from zhipuai import ZhipuAI 3 | from zhipuai.types.image import GeneratedImage 4 | 5 | from .config import COGVIEW_MODEL, ZHIPU_AI_KEY 6 | from .interface import ToolObservation 7 | 8 | 9 | @st.cache_resource 10 | def get_zhipu_client(): 11 | return ZhipuAI(api_key=ZHIPU_AI_KEY) 12 | 13 | 14 | def map_response(img: GeneratedImage): 15 | return ToolObservation( 16 | content_type="image", 17 | text="CogView 已经生成并向用户展示了生成的图片。", 18 | image_url=img.url, 19 | role_metadata="cogview_result", 20 | ) 21 | 22 | 23 | def tool_call(prompt: str, session_id: str) -> list[ToolObservation]: 24 | client = get_zhipu_client() 25 | response = client.images.generations(model=COGVIEW_MODEL, prompt=prompt).data 26 | return list(map(map_response, response)) 27 | -------------------------------------------------------------------------------- /demo/composite_demo/src/tools/config.py: -------------------------------------------------------------------------------- 1 | BROWSER_SERVER_URL = "http://localhost:3000" 2 | 3 | IPYKERNEL = "glm-4-demo" 4 | 5 | ZHIPU_AI_KEY = "" 6 | COGVIEW_MODEL = "cogview-3" 7 | -------------------------------------------------------------------------------- /demo/composite_demo/src/tools/interface.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Any 3 | 4 | 5 | @dataclass 6 | class ToolObservation: 7 | content_type: str 8 | text: str 9 | image_url: str | None = None 10 | role_metadata: str | None = None 11 | metadata: Any = None 12 | -------------------------------------------------------------------------------- /demo/composite_demo/src/tools/python.py: -------------------------------------------------------------------------------- 1 | import queue 2 | import re 3 | from pprint import pprint 4 | from subprocess import PIPE 5 | from typing import Literal 6 | 7 | import jupyter_client 8 | import streamlit as st 9 | 10 | from .config import IPYKERNEL 11 | from .interface import ToolObservation 12 | 13 | 14 | ANSI_ESCAPE = re.compile(r"(\x9B|\x1B\[|\u001b\[)[0-?]*[ -/]*[@-~]") 15 | CODE = re.compile(r"```([^\n]*)\n(.*?)```") 16 | 17 | 18 | class CodeKernel: 19 | def __init__( 20 | self, 21 | kernel_name="kernel", 22 | kernel_id=None, 23 | kernel_config_path="", 24 | python_path=None, 25 | ipython_path=None, 26 | init_file_path="./startup.py", 27 | verbose=1, 28 | ): 29 | self.kernel_name = kernel_name 30 | self.kernel_id = kernel_id 31 | self.kernel_config_path = kernel_config_path 32 | self.python_path = python_path 33 | self.ipython_path = ipython_path 34 | self.init_file_path = init_file_path 35 | self.verbose = verbose 36 | 37 | if python_path is None and ipython_path is None: 38 | env = None 39 | else: 40 | env = {"PATH": self.python_path + ":$PATH", "PYTHONPATH": self.python_path} 41 | 42 | # Initialize the backend kernel 43 | self.kernel_manager = jupyter_client.KernelManager( 44 | kernel_name=IPYKERNEL, connection_file=self.kernel_config_path, exec_files=[self.init_file_path], env=env 45 | ) 46 | if self.kernel_config_path: 47 | self.kernel_manager.load_connection_file() 48 | self.kernel_manager.start_kernel(stdout=PIPE, stderr=PIPE) 49 | print("Backend kernel started with the configuration: {}".format(self.kernel_config_path)) 50 | else: 51 | self.kernel_manager.start_kernel(stdout=PIPE, stderr=PIPE) 52 | print("Backend kernel started with the configuration: {}".format(self.kernel_manager.connection_file)) 53 | 54 | if verbose: 55 | pprint(self.kernel_manager.get_connection_info()) 56 | 57 | # Initialize the code kernel 58 | self.kernel = self.kernel_manager.blocking_client() 59 | # self.kernel.load_connection_file() 60 | self.kernel.start_channels() 61 | print("Code kernel started.") 62 | 63 | def execute(self, code): 64 | self.kernel.execute(code) 65 | try: 66 | shell_msg = self.kernel.get_shell_msg(timeout=30) 67 | io_msg_content = self.kernel.get_iopub_msg(timeout=30)["content"] 68 | while True: 69 | msg_out = io_msg_content 70 | ### Poll the message 71 | try: 72 | io_msg_content = self.kernel.get_iopub_msg(timeout=30)["content"] 73 | if "execution_state" in io_msg_content and io_msg_content["execution_state"] == "idle": 74 | break 75 | except queue.Empty: 76 | break 77 | 78 | return shell_msg, msg_out 79 | except Exception as e: 80 | print(e) 81 | return None 82 | 83 | def execute_interactive(self, code, verbose=False): 84 | shell_msg = self.kernel.execute_interactive(code) 85 | if shell_msg is queue.Empty: 86 | if verbose: 87 | print("Timeout waiting for shell message.") 88 | self.check_msg(shell_msg, verbose=verbose) 89 | 90 | return shell_msg 91 | 92 | def inspect(self, code, verbose=False): 93 | msg_id = self.kernel.inspect(code) 94 | shell_msg = self.kernel.get_shell_msg(timeout=30) 95 | if shell_msg is queue.Empty: 96 | if verbose: 97 | print("Timeout waiting for shell message.") 98 | self.check_msg(shell_msg, verbose=verbose) 99 | 100 | return shell_msg 101 | 102 | def get_error_msg(self, msg, verbose=False) -> str | None: 103 | if msg["content"]["status"] == "error": 104 | try: 105 | error_msg = msg["content"]["traceback"] 106 | except: 107 | try: 108 | error_msg = msg["content"]["traceback"][-1].strip() 109 | except: 110 | error_msg = "Traceback Error" 111 | if verbose: 112 | print("Error: ", error_msg) 113 | return error_msg 114 | return None 115 | 116 | def check_msg(self, msg, verbose=False): 117 | status = msg["content"]["status"] 118 | if status == "ok": 119 | if verbose: 120 | print("Execution succeeded.") 121 | elif status == "error": 122 | for line in msg["content"]["traceback"]: 123 | if verbose: 124 | print(line) 125 | 126 | def shutdown(self): 127 | # Shutdown the backend kernel 128 | self.kernel_manager.shutdown_kernel() 129 | print("Backend kernel shutdown.") 130 | # Shutdown the code kernel 131 | self.kernel.shutdown() 132 | print("Code kernel shutdown.") 133 | 134 | def restart(self): 135 | # Restart the backend kernel 136 | self.kernel_manager.restart_kernel() 137 | # print("Backend kernel restarted.") 138 | 139 | def interrupt(self): 140 | # Interrupt the backend kernel 141 | self.kernel_manager.interrupt_kernel() 142 | # print("Backend kernel interrupted.") 143 | 144 | def is_alive(self): 145 | return self.kernel.is_alive() 146 | 147 | 148 | def clean_ansi_codes(input_string): 149 | return ANSI_ESCAPE.sub("", input_string) 150 | 151 | 152 | def extract_code(text: str) -> str: 153 | matches = CODE.findall(text, re.DOTALL) 154 | return matches[-1][1] 155 | 156 | 157 | def execute(code: str, kernel: CodeKernel) -> tuple[Literal["text", "image"] | None, str]: 158 | res = "" 159 | res_type = None 160 | code = code.replace("<|observation|>", "") 161 | code = code.replace("<|assistant|>python", "") 162 | code = code.replace("<|assistant|>", "") 163 | code = code.replace("<|user|>", "") 164 | code = code.replace("<|system|>", "") 165 | msg, output = kernel.execute(code) 166 | 167 | if msg["metadata"]["status"] == "timeout": 168 | return res_type, "Timed out" 169 | elif msg["metadata"]["status"] == "error": 170 | return res_type, clean_ansi_codes("\n".join(kernel.get_error_msg(msg, verbose=True))) 171 | 172 | if "text" in output: 173 | res_type = "text" 174 | res = output["text"] 175 | elif "data" in output: 176 | for key in output["data"]: 177 | if "text/plain" in key: 178 | res_type = "text" 179 | res = output["data"][key] 180 | elif "image/png" in key: 181 | res_type = "image" 182 | res = output["data"][key] 183 | break 184 | 185 | return res_type, res 186 | 187 | 188 | @st.cache_resource 189 | def get_kernel() -> CodeKernel: 190 | return CodeKernel() 191 | 192 | 193 | def tool_call(code: str, session_id: str) -> list[ToolObservation]: 194 | kernel = get_kernel() 195 | res_type, res = execute(code, kernel) 196 | 197 | # Convert base64 to data uri 198 | text = "[Image]" if res_type == "image" else res 199 | image = f"data:image/png;base64,{res}" if res_type == "image" else None 200 | 201 | return [ToolObservation(res_type, text, image)] 202 | -------------------------------------------------------------------------------- /demo/composite_demo/src/tools/tool_registry.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is the tool registration part. By registering the tool, the model can call the tool. 3 | This code provides extended functionality to the model, enabling it to call and interact with a variety of utilities 4 | through defined interfaces. 5 | """ 6 | 7 | import copy 8 | import inspect 9 | import json 10 | import subprocess 11 | import traceback 12 | from collections.abc import Callable 13 | from types import GenericAlias 14 | from typing import Annotated, get_origin 15 | 16 | from .browser import tool_call as browser 17 | from .cogview import tool_call as cogview 18 | from .interface import ToolObservation 19 | from .python import tool_call as python 20 | 21 | 22 | ALL_TOOLS = { 23 | "simple_browser": browser, 24 | "python": python, 25 | "cogview": cogview, 26 | } 27 | 28 | _TOOL_HOOKS = {} 29 | _TOOL_DESCRIPTIONS = [] 30 | 31 | 32 | def register_tool(func: Callable): 33 | tool_name = func.__name__ 34 | tool_description = inspect.getdoc(func).strip() 35 | python_params = inspect.signature(func).parameters 36 | tool_params = [] 37 | for name, param in python_params.items(): 38 | annotation = param.annotation 39 | if annotation is inspect.Parameter.empty: 40 | raise TypeError(f"Parameter `{name}` missing type annotation") 41 | if get_origin(annotation) != Annotated: 42 | raise TypeError(f"Annotation type for `{name}` must be typing.Annotated") 43 | 44 | typ, (description, required) = annotation.__origin__, annotation.__metadata__ 45 | typ: str = str(typ) if isinstance(typ, GenericAlias) else typ.__name__ 46 | if not isinstance(description, str): 47 | raise TypeError(f"Description for `{name}` must be a string") 48 | if not isinstance(required, bool): 49 | raise TypeError(f"Required for `{name}` must be a bool") 50 | 51 | tool_params.append( 52 | { 53 | "name": name, 54 | "description": description, 55 | "type": typ, 56 | "required": required, 57 | } 58 | ) 59 | tool_def = { 60 | "name": tool_name, 61 | "description": tool_description, 62 | "params": tool_params, 63 | } 64 | # print("[registered tool] " + pformat(tool_def)) 65 | _TOOL_HOOKS[tool_name] = func 66 | _TOOL_DESCRIPTIONS.append(tool_def) 67 | 68 | return func 69 | 70 | 71 | def dispatch_tool(tool_name: str, code: str, session_id: str) -> list[ToolObservation]: 72 | # Dispatch predefined tools 73 | if tool_name in ALL_TOOLS: 74 | return ALL_TOOLS[tool_name](code, session_id) 75 | 76 | code = code.strip().rstrip("<|observation|>").strip() 77 | 78 | # Dispatch custom tools 79 | try: 80 | tool_params = json.loads(code) 81 | except json.JSONDecodeError as e: 82 | err = f"Error decoding JSON: {e}" 83 | return [ToolObservation("system_error", err)] 84 | 85 | if tool_name not in _TOOL_HOOKS: 86 | err = f"Tool `{tool_name}` not found. Please use a provided tool." 87 | return [ToolObservation("system_error", err)] 88 | 89 | tool_hook = _TOOL_HOOKS[tool_name] 90 | try: 91 | ret: str = tool_hook(**tool_params) 92 | return [ToolObservation(tool_name, str(ret))] 93 | except: 94 | err = traceback.format_exc() 95 | return [ToolObservation("system_error", err)] 96 | 97 | 98 | def get_tools() -> list[dict]: 99 | return copy.deepcopy(_TOOL_DESCRIPTIONS) 100 | 101 | 102 | # Tool Definitions 103 | 104 | 105 | @register_tool 106 | def random_number_generator( 107 | seed: Annotated[int, "The random seed used by the generator", True], 108 | range: Annotated[tuple[int, int], "The range of the generated numbers", True], 109 | ) -> int: 110 | """ 111 | Generates a random number x, s.t. range[0] <= x < range[1] 112 | """ 113 | if not isinstance(seed, int): 114 | raise TypeError("Seed must be an integer") 115 | if not isinstance(range, tuple): 116 | raise TypeError("Range must be a tuple") 117 | if not isinstance(range[0], int) or not isinstance(range[1], int): 118 | raise TypeError("Range must be a tuple of integers") 119 | 120 | import random 121 | 122 | return random.Random(seed).randint(*range) 123 | 124 | 125 | @register_tool 126 | def get_weather( 127 | city_name: Annotated[str, "The name of the city to be queried", True], 128 | ) -> str: 129 | """ 130 | Get the current weather for `city_name` 131 | """ 132 | 133 | if not isinstance(city_name, str): 134 | raise TypeError("City name must be a string") 135 | 136 | key_selection = { 137 | "current_condition": [ 138 | "temp_C", 139 | "FeelsLikeC", 140 | "humidity", 141 | "weatherDesc", 142 | "observation_time", 143 | ], 144 | } 145 | import requests 146 | 147 | try: 148 | resp = requests.get(f"https://wttr.in/{city_name}?format=j1") 149 | resp.raise_for_status() 150 | resp = resp.json() 151 | ret = {k: {_v: resp[k][0][_v] for _v in v} for k, v in key_selection.items()} 152 | except: 153 | import traceback 154 | 155 | ret = "Error encountered while fetching weather data!\n" + traceback.format_exc() 156 | 157 | return str(ret) 158 | 159 | 160 | @register_tool 161 | def get_shell( 162 | query: Annotated[str, "The command should run in Linux shell", True], 163 | ) -> str: 164 | """ 165 | Use shell to run command 166 | """ 167 | if not isinstance(query, str): 168 | raise TypeError("Command must be a string") 169 | try: 170 | result = subprocess.run( 171 | query, 172 | shell=True, 173 | check=True, 174 | stdout=subprocess.PIPE, 175 | stderr=subprocess.PIPE, 176 | text=True, 177 | ) 178 | return result.stdout 179 | except subprocess.CalledProcessError as e: 180 | return e.stderr 181 | 182 | 183 | if __name__ == "__main__": 184 | # print(dispatch_tool("get_shell", {"query": "pwd"})) 185 | print(get_tools()) 186 | -------------------------------------------------------------------------------- /demo/composite_demo/src/utils.py: -------------------------------------------------------------------------------- 1 | import docx 2 | from langchain_community.document_loaders import PyMuPDFLoader 3 | from pptx import Presentation 4 | 5 | 6 | def extract_text(path): 7 | return open(path, "r").read() 8 | 9 | 10 | def extract_pdf(path): 11 | loader = PyMuPDFLoader(path) 12 | data = loader.load() 13 | data = [x.page_content for x in data] 14 | content = "\n\n".join(data) 15 | return content 16 | 17 | 18 | def extract_docx(path): 19 | doc = docx.Document(path) 20 | data = [] 21 | for paragraph in doc.paragraphs: 22 | data.append(paragraph.text) 23 | content = "\n\n".join(data) 24 | return content 25 | 26 | 27 | def extract_pptx(path): 28 | prs = Presentation(path) 29 | text = "" 30 | for slide in prs.slides: 31 | for shape in slide.shapes: 32 | if hasattr(shape, "text"): 33 | text += shape.text + "\n" 34 | return text 35 | -------------------------------------------------------------------------------- /demo/intel_device_demo/itrex/README.md: -------------------------------------------------------------------------------- 1 | # 使用 Intel® Extension for Transformers 推理 GLM-4-9B-Chat 模型 2 | 3 | 本示例介绍如何使用 Intel® Extension for Transformers 推理 GLM-4-9B-Chat 模型。 4 | 5 | ## 设备和依赖检查 6 | 7 | ### 相关推理测试数据 8 | 9 | **本文档的数据均在以下硬件环境测试,实际运行环境需求和运行占用的显存略有不同,请以实际运行环境为准。** 10 | 11 | 测试硬件信息: 12 | 13 | + OS: Ubuntu 22.04 (本教程一定需要在Linux环境下执行) 14 | + Memory: 512GB 15 | + Python: 3.10.12 16 | + CPU: Intel(R) Xeon(R) Platinum 8358 CPU / 12th Gen Intel i5-12400 17 | 18 | ## 安装依赖 19 | 20 | 在开始推理之前,请你先安装`inference`中的依赖,同时您需要安装本目录下的依赖项: 21 | ```shell 22 | pip install -r requirements.txt 23 | ``` 24 | 25 | ## 运行模型推理 26 | 27 | ```shell 28 | python itrex_cli_demo.py 29 | ``` 30 | 31 | 如果您是第一次推理,会有一次模型转换权重的过程,转换后的模型权重存放在`runtime_outputs`文件夹下,这大概会消耗`60G`的硬盘空间。 32 | 转换完成后,文件夹下有两个文件: 33 | + ne_chatglm2_f32.bin 52G(如果您不使用FP32进行推理,可以删掉这个文件) 34 | + ne_chatglm2_q_nf4_bestla_cfp32_sym_sfp32_g32.bin 8.1G 35 | 36 | 如果您不是第一次推理,则会跳过这个步骤,直接开始对话,推理效果如下: 37 | ```shell 38 | Welcome to the CLI chat. Type your messages below. 39 | 40 | User: 你好 41 | AVX:1 AVX2:1 AVX512F:1 AVX512BW:1 AVX_VNNI:0 AVX512_VNNI:1 AMX_INT8:0 AMX_BF16:0 AVX512_BF16:0 AVX512_FP16:0 42 | beam_size: 1, do_sample: 1, top_k: 40, top_p: 0.900, continuous_batching: 0, max_request_num: 1, early_stopping: 0, scratch_size_ratio: 1.000 43 | model_file_loader: loading model from runtime_outs/ne_chatglm2_q_nf4_bestla_cfp32_sym_sfp32_g32.bin 44 | Loading the bin file with NE format... 45 | load_ne_hparams 0.hparams.n_vocab = 151552 46 | load_ne_hparams 1.hparams.n_embd = 4096 47 | load_ne_hparams 2.hparams.n_mult = 0 48 | load_ne_hparams 3.hparams.n_head = 32 49 | load_ne_hparams 4.hparams.n_head_kv = 0 50 | load_ne_hparams 5.hparams.n_layer = 40 51 | load_ne_hparams 6.hparams.n_rot = 0 52 | load_ne_hparams 7.hparams.ftype = 0 53 | load_ne_hparams 8.hparams.max_seq_len = 131072 54 | load_ne_hparams 9.hparams.alibi_bias_max = 0.000 55 | load_ne_hparams 10.hparams.clip_qkv = 0.000 56 | load_ne_hparams 11.hparams.par_res = 0 57 | load_ne_hparams 12.hparams.word_embed_proj_dim = 0 58 | load_ne_hparams 13.hparams.do_layer_norm_before = 0 59 | load_ne_hparams 14.hparams.multi_query_group_num = 2 60 | load_ne_hparams 15.hparams.ffn_hidden_size = 13696 61 | load_ne_hparams 16.hparams.inner_hidden_size = 0 62 | load_ne_hparams 17.hparams.n_experts = 0 63 | load_ne_hparams 18.hparams.n_experts_used = 0 64 | load_ne_hparams 19.hparams.n_embd_head_k = 0 65 | load_ne_hparams 20.hparams.norm_eps = 0.000000 66 | load_ne_hparams 21.hparams.freq_base = 5000000.000 67 | load_ne_hparams 22.hparams.freq_scale = 1.000 68 | load_ne_hparams 23.hparams.rope_scaling_factor = 0.000 69 | load_ne_hparams 24.hparams.original_max_position_embeddings = 0 70 | load_ne_hparams 25.hparams.use_yarn = 0 71 | load_ne_vocab 26.vocab.bos_token_id = 1 72 | load_ne_vocab 27.vocab.eos_token_id = 151329 73 | load_ne_vocab 28.vocab.pad_token_id = 151329 74 | load_ne_vocab 29.vocab.sep_token_id = -1 75 | init: hparams.n_vocab = 151552 76 | init: hparams.n_embd = 4096 77 | init: hparams.n_mult = 0 78 | init: hparams.n_head = 32 79 | init: hparams.n_layer = 40 80 | init: hparams.n_rot = 0 81 | init: hparams.ffn_hidden_size = 13696 82 | init: n_parts = 1 83 | load: ctx size = 16528.38 MB 84 | load: layers[0].ffn_fusion = 1 85 | load: scratch0 = 4096.00 MB 86 | load: scratch1 = 2048.00 MB 87 | load: scratch2 = 4096.00 MB 88 | load: mem required = 26768.38 MB (+ memory per state) 89 | ............................................................................................. 90 | model_init_from_file: support_bestla_kv = 1 91 | kv_cache_init: run_mha_reordered = 1 92 | model_init_from_file: kv self size = 690.00 MB 93 | Assistant: 94 | 你好👋!我是人工智能助手,很高兴为你服务。有什么可以帮助你的吗? 95 | ``` 96 | -------------------------------------------------------------------------------- /demo/intel_device_demo/itrex/README_en.md: -------------------------------------------------------------------------------- 1 | 2 | # Using Intel® Extension for Transformers to Inference the GLM-4-9B-Chat Model 3 | 4 | This example introduces how to use Intel® Extension for Transformers to inference the GLM-4-9B-Chat model. 5 | 6 | ## Device and Dependency Check 7 | 8 | ### Relevant Inference Test Data 9 | 10 | **The data in this document is tested on the following hardware environment. The actual running environment requirements and memory usage may vary slightly. Please refer to the actual running environment.** 11 | 12 | Test hardware information: 13 | 14 | + OS: Ubuntu 22.04 (This tutorial must be executed in a Linux environment) 15 | + Memory: 512GB 16 | + Python: 3.10.12 17 | + CPU: Intel(R) Xeon(R) Platinum 8358 CPU / 12th Gen Intel i5-12400 18 | 19 | ## Installing Dependencies 20 | 21 | Before starting the inference, please install the dependencies in `inference`, and you need to install the dependencies in this directory: 22 | ```shell 23 | pip install -r requirements.txt 24 | ``` 25 | 26 | ## Running Model Inference 27 | 28 | ```shell 29 | python itrex_cli_demo.py 30 | ``` 31 | 32 | If this is your first inference, there will be a process of converting model weights. The converted model weights are stored in the `runtime_outputs` folder, which will consume about `60G` of disk space. 33 | After the conversion is completed, there are two files in the folder: 34 | + ne_chatglm2_f32.bin 52G (If you do not use FP32 for inference, you can delete this file) 35 | + ne_chatglm2_q_nf4_bestla_cfp32_sym_sfp32_g32.bin 8.1G 36 | 37 | If this is not your first inference, this step will be skipped, and you will directly start the conversation. The inference result is as follows: 38 | ```shell 39 | Welcome to the CLI chat. Type your messages below. 40 | 41 | User: Hello 42 | AVX:1 AVX2:1 AVX512F:1 AVX512BW:1 AVX_VNNI:0 AVX512_VNNI:1 AMX_INT8:0 AMX_BF16:0 AVX512_BF16:0 AVX512_FP16:0 43 | beam_size: 1, do_sample: 1, top_k: 40, top_p: 0.900, continuous_batching: 0, max_request_num: 1, early_stopping: 0, scratch_size_ratio: 1.000 44 | model_file_loader: loading model from runtime_outs/ne_chatglm2_q_nf4_bestla_cfp32_sym_sfp32_g32.bin 45 | Loading the bin file with NE format... 46 | load_ne_hparams 0.hparams.n_vocab = 151552 47 | load_ne_hparams 1.hparams.n_embd = 4096 48 | load_ne_hparams 2.hparams.n_mult = 0 49 | load_ne_hparams 3.hparams.n_head = 32 50 | load_ne_hparams 4.hparams.n_head_kv = 0 51 | load_ne_hparams 5.hparams.n_layer = 40 52 | load_ne_hparams 6.hparams.n_rot = 0 53 | load_ne_hparams 7.hparams.ftype = 0 54 | load_ne_hparams 8.hparams.max_seq_len = 131072 55 | load_ne_hparams 9.hparams.alibi_bias_max = 0.000 56 | load_ne_hparams 10.hparams.clip_qkv = 0.000 57 | load_ne_hparams 11.hparams.multi_query_group_num = 2 58 | load_ne_hparams 12.hparams.ffn_hidden_size = 13696 59 | load_ne_hparams 13.hparams.inner_hidden_size = 0 60 | load_ne_hparams 14.hparams.n_experts = 0 61 | load_ne_hparams 15.hparams.n_experts_used = 0 62 | load_ne_hparams 16.hparams.n_embd_head_k = 0 63 | load_ne_hparams 17.hparams.norm_eps = 0.000000 64 | load_ne_hparams 18.hparams.freq_base = 5000000.000 65 | load_ne_hparams 19.hparams.freq_scale = 1.000 66 | load_ne_hparams 20.hparams.rope_scaling_factor = 0.000 67 | load_ne_hparams 21.hparams.original_max_position_embeddings = 0 68 | load_ne_hparams 22.hparams.use_yarn = 0 69 | load_ne_vocab 23.vocab.bos_token_id = 1 70 | load_ne_vocab 24.vocab.eos_token_id = 151329 71 | load_ne_vocab 25.vocab.pad_token_id = 151329 72 | load_ne_vocab 26.vocab.sep_token_id = -1 73 | init: hparams.n_vocab = 151552 74 | init: hparams.n_embd = 4096 75 | init: hparams.n_mult = 0 76 | init: hparams.n_head = 32 77 | init: hparams.n_layer = 40 78 | init: hparams.n_rot = 0 79 | init: hparams.ffn_hidden_size = 13696 80 | init: n_parts = 1 81 | load: ctx size = 16528.38 MB 82 | load: layers[0].ffn_fusion = 1 83 | load: scratch0 = 4096.00 MB 84 | load: scratch1 = 2048.00 MB 85 | load: scratch2 = 4096.00 MB 86 | load: mem required = 26768.38 MB (+ memory per state) 87 | ............................................................................................. 88 | model_init_from_file: support_bestla_kv = 1 89 | kv_cache_init: run_mha_reordered = 1 90 | model_init_from_file: kv self size = 690.00 MB 91 | Assistant: 92 | Hello👋! I am an AI assistant. How can I help you today? 93 | ``` 94 | -------------------------------------------------------------------------------- /demo/intel_device_demo/itrex/itrex_cli_demo.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script creates a CLI demo with transformers backend for the glm-4-9b model with Intel® Extension for Transformers 3 | """ 4 | 5 | import os 6 | 7 | 8 | MODEL_PATH = os.environ.get("MODEL_PATH", "THUDM/GLM-4-9B-Chat-0414") 9 | 10 | from threading import Thread 11 | 12 | import torch 13 | from intel_extension_for_transformers.transformers import AutoModelForCausalLM 14 | from transformers import AutoTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer 15 | 16 | 17 | class StopOnTokens(StoppingCriteria): 18 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: 19 | stop_ids = [151329, 151336, 151338] 20 | for stop_id in stop_ids: 21 | if input_ids[0][-1] == stop_id: 22 | return True 23 | return False 24 | 25 | 26 | def initialize_model_and_tokenizer(): 27 | tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) 28 | model = AutoModelForCausalLM.from_pretrained( 29 | MODEL_PATH, 30 | device_map="cpu", # Use Intel CPU for inference 31 | trust_remote_code=True, 32 | load_in_4bit=True, 33 | ) 34 | return tokenizer, model 35 | 36 | 37 | def get_user_input(): 38 | return input("\nUser: ") 39 | 40 | 41 | def main(): 42 | tokenizer, model = initialize_model_and_tokenizer() 43 | 44 | history = [] 45 | max_length = 100 46 | top_p = 0.9 47 | temperature = 0.8 48 | stop = StopOnTokens() 49 | 50 | print("Welcome to the CLI chat. Type your messages below.") 51 | while True: 52 | user_input = get_user_input() 53 | if user_input.lower() in ["exit", "quit"]: 54 | break 55 | history.append([user_input, ""]) 56 | 57 | messages = [] 58 | for idx, (user_msg, model_msg) in enumerate(history): 59 | if idx == len(history) - 1 and not model_msg: 60 | messages.append({"role": "user", "content": user_msg}) 61 | break 62 | if user_msg: 63 | messages.append({"role": "user", "content": user_msg}) 64 | if model_msg: 65 | messages.append({"role": "assistant", "content": model_msg}) 66 | 67 | model_inputs = tokenizer.apply_chat_template( 68 | messages, add_generation_prompt=True, tokenize=True, return_tensors="pt" 69 | ) 70 | 71 | streamer = TextIteratorStreamer(tokenizer=tokenizer, timeout=60, skip_prompt=True, skip_special_tokens=True) 72 | 73 | generate_kwargs = { 74 | "input_ids": model_inputs, 75 | "streamer": streamer, 76 | "max_new_tokens": max_length, 77 | "do_sample": True, 78 | "top_p": top_p, 79 | "temperature": temperature, 80 | "stopping_criteria": StoppingCriteriaList([stop]), 81 | "repetition_penalty": 1.2, 82 | "eos_token_id": model.config.eos_token_id, 83 | } 84 | 85 | t = Thread(target=model.generate, kwargs=generate_kwargs) 86 | t.start() 87 | print("Assistant:", end="", flush=True) 88 | for new_token in streamer: 89 | if new_token: 90 | print(new_token, end="", flush=True) 91 | history[-1][1] += new_token 92 | 93 | history[-1][1] = history[-1][1].strip() 94 | 95 | 96 | if __name__ == "__main__": 97 | main() 98 | -------------------------------------------------------------------------------- /demo/intel_device_demo/itrex/requirements.txt: -------------------------------------------------------------------------------- 1 | cmake>=3.29.5.1 2 | huggingface-hub>=0.23.4 3 | git+https://github.com/intel/neural-speed.git@main#egg=neural-speed 4 | intel-extension-for-transformers>=1.4.2 5 | -------------------------------------------------------------------------------- /demo/intel_device_demo/openvino/README.md: -------------------------------------------------------------------------------- 1 | # 使用 OpenVINO 部署 GLM-4-9B-Chat 模型 2 | 3 | Read this in [English](README_en.md). 4 | 5 | [OpenVINO](https://www.intel.com/content/www/us/en/developer/tools/openvino-toolkit/overview.html) 6 | 是 Intel 为深度学习推理而设计的开源工具包。它可以帮助开发者优化模型,提高推理性能,减少模型的内存占用。 7 | 本示例将展示如何使用 OpenVINO 部署 GLM-4-9B-Chat 模型。 8 | 9 | ## 1. 环境配置 10 | 11 | 首先,你需要安装依赖 12 | 13 | ```bash 14 | pip install -r requirements.txt 15 | ``` 16 | 17 | ## 2. 转换模型 18 | 19 | 由于需要将Huggingface模型转换为OpenVINO IR模型,因此您需要下载模型并转换。 20 | 21 | ``` 22 | python3 convert.py --model_id THUDM/glm-4-9b-chat --output {your_path}/glm-4-9b-chat-ov 23 | ``` 24 | 25 | ### 可以选择的参数 26 | 27 | * `--model_id` - 模型所在目录的路径(绝对路径)。 28 | * `--output` - 转换后模型保存的地址。 29 | * `--precision` - 转换的精度。 30 | 31 | 32 | 转换过程如下: 33 | ``` 34 | ====Exporting IR===== 35 | Framework not specified. Using pt to export the model. 36 | Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:04<00:00, 2.14it/s] 37 | Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained. 38 | Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained. 39 | Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained. 40 | Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained. 41 | Using framework PyTorch: 2.3.1+cu121 42 | Mixed-Precision assignment ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 160/160 • 0:01:45 • 0:00:00 43 | INFO:nncf:Statistics of the bitwidth distribution: 44 | ┍━━━━━━━━━━━━━━━━┯━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┯━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┑ 45 | │ Num bits (N) │ % all parameters (layers) │ % ratio-defining parameters (layers) │ 46 | ┝━━━━━━━━━━━━━━━━┿━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┿━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┥ 47 | │ 8 │ 31% (76 / 163) │ 20% (73 / 160) │ 48 | ├────────────────┼─────────────────────────────┼────────────────────────────────────────┤ 49 | │ 4 │ 69% (87 / 163) │ 80% (87 / 160) │ 50 | ┕━━━━━━━━━━━━━━━━┷━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┷━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┙ 51 | Applying Weight Compression ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 163/163 • 0:03:46 • 0:00:00 52 | Configuration saved in glm-4-9b-ov/openvino_config.json 53 | ====Exporting tokenizer===== 54 | Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained. 55 | ``` 56 | ## 3. 运行 GLM-4-9B-Chat 模型 57 | 58 | ``` 59 | python3 chat.py --model_path {your_path}/glm-4-9b-chat-ov --max_sequence_length 4096 --device CPU 60 | ``` 61 | 62 | ### 可以选择的参数 63 | 64 | * `--model_path` - OpenVINO IR 模型所在目录的路径。 65 | * `--max_sequence_length` - 输出标记的最大大小。 66 | * `--device` - 运行推理的设备。 67 | 68 | ### 参考代码 69 | 70 | 本代码参考 [OpenVINO 官方示例](https://github.com/OpenVINO-dev-contest/chatglm3.openvino) 进行修改。 71 | -------------------------------------------------------------------------------- /demo/intel_device_demo/openvino/README_en.md: -------------------------------------------------------------------------------- 1 | # Deploy the GLM-4-9B-Chat model using OpenVINO 2 | 3 | [OpenVINO](https://www.intel.com/content/www/us/en/developer/tools/openvino-toolkit/overview.html) 4 | is an open source toolkit designed by Intel for deep learning inference. It can help developers optimize models, improve inference performance, and reduce model memory usage. 5 | This example will show how to deploy the GLM-4-9B-Chat model using OpenVINO. 6 | 7 | ## 1. Environment configuration 8 | 9 | First, you need to install the dependencies 10 | 11 | ```bash 12 | pip install -r requirements.txt 13 | ``` 14 | 15 | ## 2. Convert the model 16 | 17 | Since the Huggingface model needs to be converted to an OpenVINO IR model, you need to download the model and convert it. 18 | 19 | ``` 20 | python3 convert.py --model_id THUDM/glm-4-9b-chat --output {your_path}/glm-4-9b-chat-ov 21 | ``` 22 | The conversion process is as follows: 23 | ``` 24 | ====Exporting IR===== 25 | Framework not specified. Using pt to export the model. 26 | Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:04<00:00, 2.14it/s] 27 | Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained. 28 | Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained. 29 | Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained. 30 | Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained. 31 | Using framework PyTorch: 2.3.1+cu121 32 | Mixed-Precision assignment ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 160/160 • 0:01:45 • 0:00:00 33 | INFO:nncf:Statistics of the bitwidth distribution: 34 | ┍━━━━━━━━━━━━━━━━┯━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┯━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┑ 35 | │ Num bits (N) │ % all parameters (layers) │ % ratio-defining parameters (layers) │ 36 | ┝━━━━━━━━━━━━━━━━┿━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┿━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┥ 37 | │ 8 │ 31% (76 / 163) │ 20% (73 / 160) │ 38 | ├────────────────┼─────────────────────────────┼────────────────────────────────────────┤ 39 | │ 4 │ 69% (87 / 163) │ 80% (87 / 160) │ 40 | ┕━━━━━━━━━━━━━━━━┷━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┷━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┙ 41 | Applying Weight Compression ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 163/163 • 0:03:46 • 0:00:00 42 | Configuration saved in glm-4-9b-ov/openvino_config.json 43 | ====Exporting tokenizer===== 44 | Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained. 45 | ``` 46 | 47 | ### Optional parameters 48 | 49 | * `--model_id` - Path to the directory where the model is located (absolute path). 50 | 51 | * `--output` - Path to where the converted model is saved. 52 | 53 | * `--precision` - Precision of the conversion. 54 | 55 | ## 3. Run the GLM-4-9B-Chat model 56 | 57 | ``` 58 | python3 chat.py --model_path {your_path}glm-4-9b-chat-ov --max_sequence_length 4096 --device CPU 59 | ``` 60 | 61 | ### Optional parameters 62 | 63 | * `--model_path` - Path to the directory where the OpenVINO IR model is located. 64 | 65 | * `--max_sequence_length` - Maximum size of the output token. 66 | * `--device` - the device to run inference on. 67 | 68 | ### Reference code 69 | 70 | This code is modified based on the [OpenVINO official example](https://github.com/OpenVINO-dev-contest/chatglm3.openvino). 71 | -------------------------------------------------------------------------------- /demo/intel_device_demo/openvino/convert.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script is used to convert the original model to OpenVINO IR format. 3 | The Origin Code can check https://github.com/OpenVINO-dev-contest/chatglm3.openvino/blob/main/convert.py 4 | """ 5 | 6 | import argparse 7 | import os 8 | from pathlib import Path 9 | 10 | from optimum.intel import OVWeightQuantizationConfig 11 | from optimum.intel.openvino import OVModelForCausalLM 12 | from transformers import AutoConfig, AutoTokenizer 13 | 14 | 15 | if __name__ == "__main__": 16 | parser = argparse.ArgumentParser(add_help=False) 17 | parser.add_argument("-h", "--help", action="help", help="Show this help message and exit.") 18 | parser.add_argument( 19 | "-m", "--model_id", default="THUDM/GLM-4-9B-Chat-0414", required=False, type=str, help="orignal model path" 20 | ) 21 | parser.add_argument( 22 | "-p", 23 | "--precision", 24 | required=False, 25 | default="int4", 26 | type=str, 27 | choices=["fp16", "int8", "int4"], 28 | help="fp16, int8 or int4", 29 | ) 30 | parser.add_argument( 31 | "-o", "--output", default="./glm-4-9b-ov", required=False, type=str, help="Required. path to save the ir model" 32 | ) 33 | args = parser.parse_args() 34 | 35 | ir_model_path = Path(args.output) 36 | if ir_model_path.exists() == False: 37 | os.mkdir(ir_model_path) 38 | 39 | model_kwargs = { 40 | "trust_remote_code": True, 41 | "config": AutoConfig.from_pretrained(args.model_id, trust_remote_code=True), 42 | } 43 | compression_configs = { 44 | "sym": False, 45 | "group_size": 128, 46 | "ratio": 0.8, 47 | } 48 | 49 | print("====Exporting IR=====") 50 | if args.precision == "int4": 51 | ov_model = OVModelForCausalLM.from_pretrained( 52 | args.model_id, 53 | export=True, 54 | compile=False, 55 | quantization_config=OVWeightQuantizationConfig(bits=4, **compression_configs), 56 | **model_kwargs, 57 | ) 58 | elif args.precision == "int8": 59 | ov_model = OVModelForCausalLM.from_pretrained( 60 | args.model_id, export=True, compile=False, load_in_8bit=True, **model_kwargs 61 | ) 62 | else: 63 | ov_model = OVModelForCausalLM.from_pretrained( 64 | args.model_id, export=True, compile=False, load_in_8bit=False, **model_kwargs 65 | ) 66 | 67 | ov_model.save_pretrained(ir_model_path) 68 | 69 | print("====Exporting tokenizer=====") 70 | tokenizer = AutoTokenizer.from_pretrained(args.model_id, trust_remote_code=True) 71 | tokenizer.save_pretrained(ir_model_path) 72 | -------------------------------------------------------------------------------- /demo/intel_device_demo/openvino/openvino_cli_demo.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from threading import Thread 3 | from typing import List, Tuple 4 | 5 | import torch 6 | from optimum.intel.openvino import OVModelForCausalLM 7 | from transformers import AutoConfig, AutoTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer 8 | 9 | 10 | class StopOnTokens(StoppingCriteria): 11 | def __init__(self, token_ids): 12 | self.token_ids = token_ids 13 | 14 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: 15 | for stop_id in self.token_ids: 16 | if input_ids[0][-1] == stop_id: 17 | return True 18 | return False 19 | 20 | 21 | if __name__ == "__main__": 22 | parser = argparse.ArgumentParser(add_help=False) 23 | parser.add_argument("-h", "--help", action="help", help="Show this help message and exit.") 24 | parser.add_argument("-m", "--model_path", required=True, type=str, help="Required. model path") 25 | parser.add_argument( 26 | "-l", "--max_sequence_length", default=256, required=False, type=int, help="Required. maximun length of output" 27 | ) 28 | parser.add_argument( 29 | "-d", "--device", default="CPU", required=False, type=str, help="Required. device for inference" 30 | ) 31 | args = parser.parse_args() 32 | model_dir = args.model_path 33 | 34 | ov_config = {"PERFORMANCE_HINT": "LATENCY", "NUM_STREAMS": "1", "CACHE_DIR": ""} 35 | 36 | tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True) 37 | 38 | print("====Compiling model====") 39 | ov_model = OVModelForCausalLM.from_pretrained( 40 | model_dir, 41 | device=args.device, 42 | ov_config=ov_config, 43 | config=AutoConfig.from_pretrained(model_dir, trust_remote_code=True), 44 | trust_remote_code=True, 45 | ) 46 | 47 | streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True) 48 | stop_tokens = [StopOnTokens([151329, 151336, 151338])] 49 | 50 | def convert_history_to_token(history: List[Tuple[str, str]]): 51 | messages = [] 52 | for idx, (user_msg, model_msg) in enumerate(history): 53 | if idx == len(history) - 1 and not model_msg: 54 | messages.append({"role": "user", "content": user_msg}) 55 | break 56 | if user_msg: 57 | messages.append({"role": "user", "content": user_msg}) 58 | if model_msg: 59 | messages.append({"role": "assistant", "content": model_msg}) 60 | 61 | model_inputs = tokenizer.apply_chat_template( 62 | messages, add_generation_prompt=True, tokenize=True, return_tensors="pt" 63 | ) 64 | return model_inputs 65 | 66 | history = [] 67 | print("====Starting conversation====") 68 | while True: 69 | input_text = input("用户: ") 70 | if input_text.lower() == "stop": 71 | break 72 | 73 | if input_text.lower() == "clear": 74 | history = [] 75 | print("AI助手: 对话历史已清空") 76 | continue 77 | 78 | print("GLM-4-9B-OpenVINO:", end=" ") 79 | history = history + [[input_text, ""]] 80 | model_inputs = convert_history_to_token(history) 81 | generate_kwargs = dict( 82 | input_ids=model_inputs, 83 | max_new_tokens=args.max_sequence_length, 84 | temperature=0.1, 85 | do_sample=True, 86 | top_p=1.0, 87 | top_k=50, 88 | repetition_penalty=1.1, 89 | streamer=streamer, 90 | stopping_criteria=StoppingCriteriaList(stop_tokens), 91 | ) 92 | 93 | t1 = Thread(target=ov_model.generate, kwargs=generate_kwargs) 94 | t1.start() 95 | 96 | partial_text = "" 97 | for new_text in streamer: 98 | new_text = new_text 99 | print(new_text, end="", flush=True) 100 | partial_text += new_text 101 | print("\n") 102 | history[-1][1] = partial_text 103 | -------------------------------------------------------------------------------- /demo/intel_device_demo/openvino/requirements.txt: -------------------------------------------------------------------------------- 1 | optimum>=1.20.0 2 | optimum-intel @ git+https://github.com/huggingface/optimum-intel.git@c1ee8ac0864e25e22ea56b5a37a35451531da0e6 3 | -------------------------------------------------------------------------------- /finetune/README.md: -------------------------------------------------------------------------------- 1 | # GLM-4-9B Chat Fine-tuning 2 | 3 | [中文阅读](README_zh.md) 4 | 5 | ## Hardware Check 6 | 7 | All fine-tuning tests were performed in the following environment: 8 | 9 | > OS: Ubuntu 22.04 10 | > 11 | > Memory: 512GB 12 | > 13 | > Python: 3.12.3 14 | > 15 | > CUDA Version: 12.4 16 | > 17 | > GPU Driver: 535.104.05 18 | > 19 | > GPU: NVIDIA H100 80GB HBM3 (hereafter referred to as GPU) 20 | 21 | + Fine-tuning based on Llama-Factory 22 | 23 | | Fine-tuning Model | Fine-tuning solution | GPU memory usage | 24 | |---------------------------|----------------------|------------------------------| 25 | | GLM-4-9B-Chat-0414 | lora | 22G (Each GPU, Need 1 GPU) | 26 | | GLM-4-9B-Chat-0414 | SFT (Zero3 method) | 55G (Each GPU, Need 4 GPUs) | 27 | | GLM-4-9B-Chat-0414 | lora | 80G (Each GPU, Need 8 GPUs) | 28 | | GLM-4-32B-Chat-0414 | SFT (Zero3 method) | 80G (Each GPU, Need 16 GPUs) | 29 | 30 | + Fine-tuning based on this repository 31 | 32 | | Fine-tuning Model | Fine-tuning solution | GPU memory usage | 33 | |--------------------------|------------------------------------|-------------------------------| 34 | | GLM-4V-9B | lora (PEFT), Include EVA2CLIPModel | 75G (Each GPU, Need 1 GPU) | 35 | | GLM-4-9B-Chat | lora (PEFT) | 22G (Each GPU, Need 1 GPU) | 36 | | GLM-4-9B-Chat | SFT (Zero3 method) | 80G (Each GPU, Need 8 GPUs) | 37 | 38 | 39 | ## Preparation 40 | 41 | Before starting fine-tuning, please install the dependencies in \`basic_demo\`, ensure you have cloned the latest version of the model repository, and install the dependencies in this directory: 42 | 43 | ```bash 44 | pip install -r requirements.txt 45 | ``` 46 | 47 | ## Multi-round dialogue format 48 | 49 | The multi-round dialogue fine-tuning example uses the GLM-4 dialogue format convention, adding different `loss_mask` to 50 | different roles to calculate `loss` for multiple rounds of replies in one calculation. 51 | 52 | For data files, the sample uses the following format: 53 | 54 | ```json 55 | [ 56 | { 57 | "messages": [ 58 | { 59 | "role": "system", 60 | "content": "", 61 | "tools": [ 62 | { 63 | "name": "", 64 | "args": { 65 | "": "" 66 | } 67 | } 68 | // Add more tools if needed 69 | ] 70 | }, 71 | { 72 | "role": "user", 73 | "content": "" 74 | }, 75 | { 76 | "role": "assistant", 77 | "content": "" 78 | }, 79 | // If Tool Using 80 | { 81 | "role": "user", 82 | "content": "" 83 | }, 84 | { 85 | "role": "assistant", 86 | "content": "" 87 | }, 88 | { 89 | "role": "observation", 90 | "content": "" 91 | }, 92 | { 93 | "role": "assistant", 94 | "content": "" 95 | }, 96 | // Multi_turns 97 | { 98 | "role": "user", 99 | "content": "" 100 | }, 101 | { 102 | "role": "assistant", 103 | "content": "" 104 | } 105 | ] 106 | } 107 | ] 108 | ``` 109 | 110 | This is a sample without tools: 111 | 112 | ```json 113 | { 114 | "messages": [ 115 | { 116 | "role": "user", 117 | "content": "类型#裤*材质#牛仔布*风格#性感" 118 | }, 119 | { 120 | "role": "assistant", 121 | "content": "3x1的这款牛仔裤采用浅白的牛仔面料为裤身材质,其柔然的手感和细腻的质地,在穿着舒适的同时,透露着清纯甜美的个性气质。除此之外,流畅的裤身剪裁将性感的腿部曲线彰显的淋漓尽致,不失为一款随性出街的必备单品。" 122 | } 123 | ] 124 | } 125 | ``` 126 | 127 | This is a sample with tools: 128 | 129 | ```json 130 | { 131 | "messages": [ 132 | { 133 | "role": "system", 134 | "content": "", 135 | "tools": [ 136 | { 137 | "type": "function", 138 | "function": { 139 | "name": "get_recommended_books", 140 | "description": "Get recommended books based on user's interests", 141 | "parameters": { 142 | "type": "object", 143 | "properties": { 144 | "interests": { 145 | "type": "array", 146 | "items": { 147 | "type": "string" 148 | }, 149 | "description": "The interests to recommend books for" 150 | } 151 | }, 152 | "required": [ 153 | "interests" 154 | ] 155 | } 156 | } 157 | } 158 | ] 159 | }, 160 | { 161 | "role": "user", 162 | "content": "Hi, I am looking for some book recommendations. I am interested in history and science fiction." 163 | }, 164 | { 165 | "role": "assistant", 166 | "content": "{\"name\": \"get_recommended_books\", \"arguments\": {\"interests\": [\"history\", \"science fiction\"]}}" 167 | }, 168 | { 169 | "role": "observation", 170 | "content": "{\"books\": [\"Sapiens: A Brief History of Humankind by Yuval Noah Harari\", \"A Brief History of Time by Stephen Hawking\", \"Dune by Frank Herbert\", \"The Martian by Andy Weir\"]}" 171 | }, 172 | { 173 | "role": "assistant", 174 | "content": "Based on your interests in history and science fiction, I would recommend the following books: \"Sapiens: A Brief History of Humankind\" by Yuval Noah Harari, \"A Brief History of Time\" by Stephen Hawking, \"Dune\" by Frank Herbert, and \"The Martian\" by Andy Weir." 175 | } 176 | ] 177 | } 178 | ``` 179 | 180 | This is a sample with VQA Task: 181 | 182 | ```json 183 | { 184 | "messages": [ 185 | { 186 | "role": "user", 187 | "content": "图片中的动物是什么?", 188 | "image": "/root/images/0001.jpg" 189 | }, 190 | { 191 | "role": "assistant", 192 | "content": "图片中有一只猫。" 193 | }, 194 | { 195 | "role": "user", 196 | "content": "图片中的猫在做什么?" 197 | }, 198 | { 199 | "role": "assistant", 200 | "content": "这只猫坐在或站在桌子上,桌上有很多食物。" 201 | } 202 | ] 203 | } 204 | ``` 205 | 206 | - The `system` role is optional, but if it exists, it must appear before the `user` role, and the `system` role can only 207 | appear once in a complete conversation (whether it is a single round or a multi-round conversation). 208 | - The `tools` field is optional, but if it exists, it must appear after the `system` role, and the `tools` field can 209 | only appear once in a complete conversation (whether it is a single round or a multi-round conversation). When 210 | the `tools` field exists, the `system` role must exist and the `content` field is empty. 211 | - `GLM-4V-9B` does not support the `tools` field and the `system` field. And `image` must be placed in the first 212 | message. The `image` field needs to contain the `absolute path` of the image. 213 | 214 | ## Configuration file 215 | 216 | The fine-tuning configuration file is located in the `config` directory, including the following files: 217 | 218 | 1. `ds_zereo_2 / ds_zereo_3.json`: deepspeed configuration file. 219 | 220 | 2. `lora.yaml / ptuning_v2 221 | 3. .yaml / sft.yaml`: Configuration files for different modes of models, including model parameters, optimizer 222 | parameters, training parameters, etc. Some important parameters are explained as follows: + data_config section 223 | 224 | + train_file: File path of training dataset. 225 | + val_file: File path of validation dataset. 226 | + test_file: File path of test dataset. 227 | + num_proc: Number of processes to use when loading data. 228 | + max_input_length: Maximum length of input sequence. 229 | + max_output_length: Maximum length of output sequence. 230 | + training_args section 231 | + output_dir: Directory for saving model and other outputs. 232 | + max_steps: Maximum number of training steps. 233 | + per_device_train_batch_size: Training batch size per device (such as GPU). 234 | + dataloader_num_workers: Number of worker threads to use when loading data. 235 | + remove_unused_columns: Whether to remove unused columns in data. 236 | + save_strategy: Model saving strategy (for example, how many steps to save). 237 | + save_steps: How many steps to save the model. 238 | + log_level: Log level (such as info). 239 | + logging_strategy: logging strategy. 240 | + logging_steps: how many steps to log at. 241 | + per_device_eval_batch_size: per-device evaluation batch size. 242 | + evaluation_strategy: evaluation strategy (e.g. how many steps to evaluate at). 243 | + eval_steps: how many steps to evaluate at. 244 | + predict_with_generate: whether to use generation mode for prediction. 245 | + generation_config section 246 | + max_new_tokens: maximum number of new tokens to generate. 247 | + peft_config section 248 | + peft_type: type of parameter tuning to use (supports LORA and PREFIX_TUNING). 249 | + task_type: task type, here is causal language model (don't change). 250 | + Lora parameters: 251 | + r: rank of LoRA. 252 | + lora_alpha: scaling factor of LoRA. 253 | + lora_dropout: dropout probability to use in LoRA layer. 254 | + P-TuningV2 parameters: + num_virtual_tokens: the number of virtual tokens. 255 | + num_attention_heads: 2: the number of attention heads of P-TuningV2 (do not change). 256 | + token_dim: 256: the token dimension of P-TuningV2 (do not change). 257 | 258 | ## Start fine-tuning 259 | 260 | Execute **single machine multi-card/multi-machine multi-card** run through the following code, which uses `deepspeed` as 261 | the acceleration solution, and you need to install `deepspeed`. 262 | 263 | ```shell 264 | OMP_NUM_THREADS=1 torchrun --standalone --nnodes=1 --nproc_per_node=8 finetune.py data/AdvertiseGen/ THUDM/GLM-4-9b-Chat-0414 configs/lora.yaml # For Chat Fine-tune 265 | OMP_NUM_THREADS=1 torchrun --standalone --nnodes=1 --nproc_per_node=8 finetune_vision.py data/CogVLM-311K/ THUDM/glm-4v-9b configs/lora.yaml # For VQA Fine-tune 266 | ``` 267 | 268 | Execute **single machine single card** run through the following code. 269 | 270 | ```shell 271 | python finetune.py data/AdvertiseGen/ THUDM/GLM-4-9B-Chat-0414 configs/lora.yaml # For Chat Fine-tune 272 | python finetune_vision.py data/CogVLM-311K/ THUDM/glm-4v-9b configs/lora.yaml # For VQA Fine-tune 273 | ``` 274 | 275 | ## Fine-tune from a saved point 276 | 277 | If you train as described above, each fine-tuning will start from the beginning. If you want to fine-tune from a 278 | half-trained model, you can add a fourth parameter, which can be passed in two ways: 279 | 280 | 1. `yes`, automatically start training from the last saved Checkpoint 281 | 282 | 2. `XX`, breakpoint number, for example `600`, start training from Checkpoint 600 283 | 284 | For example, this is an example code to continue fine-tuning from the last saved point 285 | 286 | ```shell 287 | python finetune.py data/AdvertiseGen/ THUDM/GLM-4-9B-Chat-0414 configs/lora.yaml yes 288 | ``` 289 | 290 | ## Use the fine-tuned model 291 | 292 | ### Use the fine-tuned model in other demos in this repository or external repositories 293 | 294 | You can use our `LORA` and fully fine-tuned models in any demo. This requires you to modify the code yourself according 295 | to the following tutorial. 296 | 297 | 1. Replace the way to read the model in the demo with the way to read the model in `finetune_demo/inference.py`. 298 | 299 | > Please note that for LORA and P-TuningV2, we did not merge the trained models, but recorded the fine-tuned path 300 | > in `adapter_config.json` 301 | > If the location of your original model changes, you should modify the path of `base_model_name_or_path` 302 | > in `adapter_config.json`. 303 | 304 | ```python 305 | def load_model_and_tokenizer(model_dir: Union[str, Path]) -> tuple[ModelType, TokenizerType]: 306 | model_dir = _resolve_path(model_dir) 307 | if (model_dir / "adapter_config.json").exists(): 308 | model = AutoPeftModelForCausalLM.from_pretrained(model_dir, device_map="auto") 309 | tokenizer_dir = model.peft_config["default"].base_model_name_or_path 310 | else: 311 | model = AutoModelForCausalLM.from_pretrained(model_dir, device_map="auto") 312 | tokenizer_dir = model_dir 313 | tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir) 314 | return model, tokenizer 315 | ``` 316 | 317 | 2. Read the fine-tuned model. Please note that you should use the location of the fine-tuned model. For example, if your 318 | model location is `/path/to/finetune_adapter_model` 319 | and the original model address is `path/to/base_model`, you should use `/path/to/finetune_adapter_model` 320 | as `model_dir`. 321 | 3. After completing the above operations, you can use the fine-tuned model normally. Other calling methods remain 322 | unchanged. 323 | 4. This fine-tuning script has not been tested on long texts of 128K or 1M tokens. Fine-tuning long texts requires GPU 324 | devices with larger memory and more efficient fine-tuning solutions, which developers need to handle on their own. 325 | 326 | ## Reference 327 | 328 | ``` 329 | @inproceedings{liu2022p, 330 | title={P-tuning: Prompt tuning can be comparable to fine-tuning across scales and tasks}, 331 | author={Liu, Xiao and Ji, Kaixuan and Fu, Yicheng and Tam, Weng and Du, Zhengxiao and Yang, Zhilin and Tang, Jie}, 332 | booktitle={Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics (Volume 2: Short 333 | Papers)}, 334 | pages={61--68}, 335 | year={2022} 336 | } 337 | 338 | @misc{tang2023toolalpaca, 339 | title={ToolAlpaca: Generalized Tool Learning for Language Models with 3000 Simulated Cases}, 340 | author={Qiaoyu Tang and Ziliang Deng and Hongyu Lin and Xianpei Han and Qiao Liang and Le Sun}, 341 | year={2023}, 342 | eprint={2306.05301}, 343 | archivePrefix={arXiv}, 344 | primaryClass={cs.CL} 345 | } 346 | ``` 347 | -------------------------------------------------------------------------------- /finetune/README_zh.md: -------------------------------------------------------------------------------- 1 | # GLM-4-9B Chat 对话模型微调 2 | 3 | Read this in [English](README) 4 | 5 | ## 硬件检查 6 | 7 | 所有微调测试均在以下环境和硬件下测试: 8 | 9 | > OS: Ubuntu 22.04 10 | > 11 | > Memory: 512GB 12 | > 13 | > Python: 3.12.3 14 | > 15 | > CUDA Version: 12.4 16 | > 17 | > GPU Driver: 535.104.05 18 | > 19 | > GPU: NVIDIA H100 80GB HBM3 (以下简称 GPU) 20 | 21 | 22 | + 基于 Llama-Factory 进行微调 23 | 24 | | Fine-tuning Model | Fine-tuning solution | GPU memory usage | 25 | |---------------------------|----------------------|------------------------------| 26 | | GLM-4-9B-Chat-0414 | lora | 22G (Each GPU, Need 1 GPU) | 27 | | GLM-4-9B-Chat-0414 | SFT (Zero3 method) | 55G (Each GPU, Need 4 GPUs) | 28 | | GLM-4-9B-Chat-0414 | lora | 80G (Each GPU, Need 8 GPUs) | 29 | | GLM-4-32B-Chat-0414 | SFT (Zero3 method) | 80G (Each GPU, Need 16 GPUs) | 30 | 31 | + 基于本仓库代码微调 32 | 33 | | Fine-tuning Model | Fine-tuning solution | GPU memory usage | 34 | |--------------------------|------------------------------------|-------------------------------| 35 | | GLM-4V-9B | lora (PEFT), Include EVA2CLIPModel | 75G (Each GPU, Need 1 GPU) | 36 | | GLM-4-9B-Chat | lora (PEFT) | 22G (Each GPU, Need 1 GPU) | 37 | | GLM-4-9B-Chat | SFT (Zero3 method) | 80G (Each GPU, Need 8 GPUs) | 38 | 39 | 40 | ## 准备工作 41 | 42 | 在开始微调之前,请你先安装 `inference` 中的依赖,并保证克隆了最新版本的模型仓库,同时您需要安装本目录下的依赖项: 43 | 44 | ```bash 45 | pip install -r requirements.txt 46 | ``` 47 | 48 | ## 多轮对话格式 49 | 50 | 多轮对话微调示例采用 GLM-4 对话格式约定,对不同角色添加不同 `loss_mask` 从而在一遍计算中为多轮回复计算 `loss`。 51 | 52 | 对于数据文件,样例采用如下格式 53 | 54 | 如果您仅希望微调模型的对话能力,而非工具能力,您应该按照以下格式整理数据。 55 | 56 | ```json 57 | [ 58 | { 59 | "messages": [ 60 | { 61 | "role": "system", 62 | "content": "", 63 | "tools": [ 64 | { 65 | "name": "", 66 | "args": { 67 | "": "" 68 | } 69 | } 70 | // Add more tools if needed 71 | ] 72 | }, 73 | { 74 | "role": "user", 75 | "content": "" 76 | }, 77 | { 78 | "role": "assistant", 79 | "content": "" 80 | }, 81 | // If Tool Using 82 | { 83 | "role": "user", 84 | "content": "" 85 | }, 86 | { 87 | "role": "assistant", 88 | "content": "" 89 | }, 90 | { 91 | "role": "observation", 92 | "content": "" 93 | }, 94 | { 95 | "role": "assistant", 96 | "content": "" 97 | }, 98 | // Multi_turns 99 | { 100 | "role": "user", 101 | "content": "" 102 | }, 103 | { 104 | "role": "assistant", 105 | "content": "" 106 | } 107 | ] 108 | } 109 | ] 110 | ``` 111 | 112 | 这里是一个不带有工具的例子: 113 | 114 | ```json 115 | { 116 | "messages": [ 117 | { 118 | "role": "user", 119 | "content": "类型#裤*材质#牛仔布*风格#性感" 120 | }, 121 | { 122 | "role": "assistant", 123 | "content": "3x1的这款牛仔裤采用浅白的牛仔面料为裤身材质,其柔然的手感和细腻的质地,在穿着舒适的同时,透露着清纯甜美的个性气质。除此之外,流畅的裤身剪裁将性感的腿部曲线彰显的淋漓尽致,不失为一款随性出街的必备单品。" 124 | } 125 | ] 126 | } 127 | ``` 128 | 129 | 这是一个带有工具调用的例子: 130 | 131 | ```json 132 | { 133 | "messages": [ 134 | { 135 | "role": "system", 136 | "content": "", 137 | "tools": [ 138 | { 139 | "type": "function", 140 | "function": { 141 | "name": "get_recommended_books", 142 | "description": "Get recommended books based on user's interests", 143 | "parameters": { 144 | "type": "object", 145 | "properties": { 146 | "interests": { 147 | "type": "array", 148 | "items": { 149 | "type": "string" 150 | }, 151 | "description": "The interests to recommend books for" 152 | } 153 | }, 154 | "required": [ 155 | "interests" 156 | ] 157 | } 158 | } 159 | } 160 | ] 161 | }, 162 | { 163 | "role": "user", 164 | "content": "Hi, I am looking for some book recommendations. I am interested in history and science fiction." 165 | }, 166 | { 167 | "role": "assistant", 168 | "content": "{\"name\": \"get_recommended_books\", \"arguments\": {\"interests\": [\"history\", \"science fiction\"]}}" 169 | }, 170 | { 171 | "role": "observation", 172 | "content": "{\"books\": [\"Sapiens: A Brief History of Humankind by Yuval Noah Harari\", \"A Brief History of Time by Stephen Hawking\", \"Dune by Frank Herbert\", \"The Martian by Andy Weir\"]}" 173 | }, 174 | { 175 | "role": "assistant", 176 | "content": "Based on your interests in history and science fiction, I would recommend the following books: \"Sapiens: A Brief History of Humankind\" by Yuval Noah Harari, \"A Brief History of Time\" by Stephen Hawking, \"Dune\" by Frank Herbert, and \"The Martian\" by Andy Weir." 177 | } 178 | ] 179 | } 180 | ``` 181 | 182 | 这是一个视觉VQA微调的例子: 183 | 184 | ```json 185 | { 186 | "messages": [ 187 | { 188 | "role": "user", 189 | "content": "图片中的动物是什么?", 190 | "image": "/root/images/0001.jpg" 191 | }, 192 | { 193 | "role": "assistant", 194 | "content": "图片中有一只猫。" 195 | }, 196 | { 197 | "role": "user", 198 | "content": "图片中的猫在做什么?" 199 | }, 200 | { 201 | "role": "assistant", 202 | "content": "这只猫坐在或站在桌子上,桌上有很多食物。" 203 | } 204 | ] 205 | } 206 | ``` 207 | 208 | - `system` 角色为可选角色,但若存在 `system` 角色,其必须出现在 `user` 209 | 角色之前,且一个完整的对话数据(无论单轮或者多轮对话)只能出现一次 `system` 角色。 210 | - `tools` 字段为可选字段,若存在 `tools` 字段,其必须出现在 `system` 211 | 角色之后,且一个完整的对话数据(无论单轮或者多轮对话)只能出现一次 `tools` 字段。当 `tools` 字段存在时,`system` 212 | 角色必须存在并且 `content` 字段为空。 213 | - `GLM-4V-9B` 不支持 `tools` 字段和 `system` 字段。并且 `image` 必须放在第一条消息中。 `image` 214 | 字段需要放置置图片的 `绝对路径`。 215 | 216 | ## 配置文件 217 | 218 | 微调配置文件位于 `config` 目录下,包括以下文件: 219 | 220 | 1. `ds_zereo_2 / ds_zereo_3.json`: deepspeed 配置文件。 221 | 2. `lora.yaml 222 | 3. .yaml / sft.yaml`: 模型不同方式的配置文件,包括模型参数、优化器参数、训练参数等。 部分重要参数解释如下: 223 | + data_config 部分 224 | + train_file: 训练数据集的文件路径。 225 | + val_file: 验证数据集的文件路径。 226 | + test_file: 测试数据集的文件路径。 227 | + num_proc: 在加载数据时使用的进程数量。 228 | + max_input_length: 输入序列的最大长度。 229 | + max_output_length: 输出序列的最大长度。 230 | + training_args 部分 231 | + output_dir: 用于保存模型和其他输出的目录。 232 | + max_steps: 训练的最大步数。 233 | + per_device_train_batch_size: 每个设备(如 GPU)的训练批次大小。 234 | + dataloader_num_workers: 加载数据时使用的工作线程数量。 235 | + remove_unused_columns: 是否移除数据中未使用的列。 236 | + save_strategy: 模型保存策略(例如,每隔多少步保存一次)。 237 | + save_steps: 每隔多少步保存一次模型。 238 | + log_level: 日志级别(如 info)。 239 | + logging_strategy: 日志记录策略。 240 | + logging_steps: 每隔多少步记录一次日志。 241 | + per_device_eval_batch_size: 每个设备的评估批次大小。 242 | + evaluation_strategy: 评估策略(例如,每隔多少步进行一次评估)。 243 | + eval_steps: 每隔多少步进行一次评估。 244 | + predict_with_generate: 是否使用生成模式进行预测。 245 | + generation_config 部分 246 | + max_new_tokens: 生成的最大新 token 数量。 247 | + peft_config 部分 248 | + peft_type: 使用的参数有效调整类型 (支持 LORA 和 PREFIX_TUNING)。 249 | + task_type: 任务类型,这里是因果语言模型 (不要改动)。 250 | + Lora 参数: 251 | + r: LoRA 的秩。 252 | + lora_alpha: LoRA 的缩放因子。 253 | + lora_dropout: 在 LoRA 层使用的 dropout 概率。 254 | + P-TuningV2 参数: 255 | + num_virtual_tokens: 虚拟 token 的数量。 256 | + num_attention_heads: 2: P-TuningV2 的注意力头数(不要改动)。 257 | + token_dim: 256: P-TuningV2 的 token 维度(不要改动)。 258 | 259 | ## 开始微调 260 | 261 | 通过以下代码执行 **单机多卡/多机多卡** 运行,这是使用 `deepspeed` 作为加速方案的,您需要安装 `deepspeed`。接着,按照此命令运行: 262 | 263 | ```shell 264 | OMP_NUM_THREADS=1 torchrun --standalone --nnodes=1 --nproc_per_node=8 finetune.py data/AdvertiseGen/ THUDM/GLM-4-9B-Chat-0414 configs/lora.yaml # For Chat Fine-tune 265 | OMP_NUM_THREADS=1 torchrun --standalone --nnodes=1 --nproc_per_node=8 finetune_vision.py data/CogVLM-311K/ THUDM/glm-4v-9b configs/lora.yaml # For VQA Fine-tune 266 | ``` 267 | 268 | 通过以下代码执行 **单机单卡** 运行。 269 | 270 | ```shell 271 | python finetune.py data/AdvertiseGen/ THUDM/GLM-4-9B-Chat-0414 configs/lora.yaml # For Chat Fine-tune 272 | python finetune_vision.py data/CogVLM-311K/ THUDM/glm-4v-9b configs/lora.yaml # For VQA Fine-tune 273 | ``` 274 | 275 | ## 从保存点进行微调 276 | 277 | 如果按照上述方式进行训练,每次微调都会从头开始,如果你想从训练一半的模型开始微调,你可以加入第四个参数,这个参数有两种传入方式: 278 | 279 | 1. `yes`, 自动从最后一个保存的 Checkpoint开始训练 280 | 2. `XX`, 断点号数字 例 `600` 则从序号600 Checkpoint开始训练 281 | 282 | 例如,这就是一个从最后一个保存点继续微调的示例代码 283 | 284 | ```shell 285 | python finetune.py data/AdvertiseGen/ THUDM/GLM-4-9B-Chat-0414 configs/lora.yaml yes 286 | ``` 287 | 288 | ## 使用微调后的模型 289 | 290 | 您可以在任何一个 demo 内使用我们的 `LORA` 和 全参微调的模型。这需要你自己按照以下教程进行修改代码。 291 | 292 | 1. 使用`finetune_demo/inference.py`中读入模型的方式替换 demo 中读入模型的方式。 293 | 294 | > 请注意,对于 LORA 和 P-TuningV2 我们没有合并训练后的模型,而是在`adapter_config.json` 295 | > 中记录了微调型的路径,如果你的原始模型位置发生更改,则你应该修改`adapter_config.json`中`base_model_name_or_path`的路径。 296 | 297 | ```python 298 | def load_model_and_tokenizer(model_dir: Union[str, Path]) -> tuple[ModelType, TokenizerType]: 299 | model_dir = _resolve_path(model_dir) 300 | if (model_dir / "adapter_config.json").exists(): 301 | model = AutoPeftModelForCausalLM.from_pretrained(model_dir, device_map="auto") 302 | tokenizer_dir = model.peft_config["default"].base_model_name_or_path 303 | else: 304 | model = AutoModelForCausalLM.from_pretrained(model_dir, device_map="auto") 305 | tokenizer_dir = model_dir 306 | tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir) 307 | return model, tokenizer 308 | ``` 309 | 310 | 2. 读取微调的模型,请注意,你应该使用微调模型的位置,例如,若你的模型位置为`/path/to/finetune_adapter_model` 311 | ,原始模型地址为`path/to/base_model`,则你应该使用`/path/to/finetune_adapter_model`作为`model_dir`。 312 | 3. 完成上述操作后,就能正常使用微调的模型了,其他的调用方式没有变化。 313 | 4. 本微调脚本没有测试过128K 1M等长文本的微调,长文本的微调需要更大显存的GPU设备,并且需要更高效的微调方案,需要开发者自行解决。 314 | 315 | ## 参考文献 316 | 317 | ``` 318 | @inproceedings{liu2022p, 319 | title={P-tuning: Prompt tuning can be comparable to fine-tuning across scales and tasks}, 320 | author={Liu, Xiao and Ji, Kaixuan and Fu, Yicheng and Tam, Weng and Du, Zhengxiao and Yang, Zhilin and Tang, Jie}, 321 | booktitle={Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics (Volume 2: Short 322 | Papers)}, 323 | pages={61--68}, 324 | year={2022} 325 | } 326 | 327 | @misc{tang2023toolalpaca, 328 | title={ToolAlpaca: Generalized Tool Learning for Language Models with 3000 Simulated Cases}, 329 | author={Qiaoyu Tang and Ziliang Deng and Hongyu Lin and Xianpei Han and Qiao Liang and Le Sun}, 330 | year={2023}, 331 | eprint={2306.05301}, 332 | archivePrefix={arXiv}, 333 | primaryClass={cs.CL} 334 | } 335 | ``` 336 | -------------------------------------------------------------------------------- /finetune/configs/ds_zero_2.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "zero_optimization": { 14 | "stage": 2, 15 | "allgather_partitions": true, 16 | "allgather_bucket_size": 5e8, 17 | "overlap_comm": true, 18 | "reduce_scatter": true, 19 | "reduce_bucket_size": 5e8, 20 | "contiguous_gradients": true 21 | }, 22 | 23 | "gradient_accumulation_steps": "auto", 24 | "gradient_clipping": "auto", 25 | "steps_per_print": 2000, 26 | "train_batch_size": "auto", 27 | "train_micro_batch_size_per_gpu": "auto", 28 | "wall_clock_breakdown": false 29 | } 30 | -------------------------------------------------------------------------------- /finetune/configs/ds_zero_3.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_micro_batch_size_per_gpu": "auto", 3 | "zero_allow_untested_optimizer": true, 4 | "bf16": { 5 | "enabled": "auto" 6 | }, 7 | "optimizer": { 8 | "type": "AdamW", 9 | "params": { 10 | "lr": "auto", 11 | "betas": "auto", 12 | "eps": "auto", 13 | "weight_decay": "auto" 14 | } 15 | }, 16 | "zero_optimization": { 17 | "stage": 3, 18 | "allgather_partitions": true, 19 | "allgather_bucket_size": 5e8, 20 | "reduce_scatter": true, 21 | "contiguous_gradients": true, 22 | "overlap_comm": true, 23 | "sub_group_size": 1e9, 24 | "reduce_bucket_size": "auto", 25 | "stage3_prefetch_bucket_size": "auto", 26 | "stage3_param_persistence_threshold": "auto", 27 | "stage3_max_live_parameters": 1e9, 28 | "stage3_max_reuse_distance": 1e9, 29 | "stage3_gather_16bit_weights_on_model_save": true 30 | } 31 | } 32 | -------------------------------------------------------------------------------- /finetune/configs/lora.yaml: -------------------------------------------------------------------------------- 1 | data_config: 2 | train_file: train.jsonl 3 | val_file: dev.jsonl 4 | test_file: dev.jsonl 5 | num_proc: 1 6 | 7 | combine: True 8 | freezeV: True 9 | max_input_length: 512 10 | max_output_length: 512 11 | 12 | training_args: 13 | # see `transformers.Seq2SeqTrainingArguments` 14 | output_dir: ./output 15 | max_steps: 3000 16 | # needed to be fit for the dataset 17 | learning_rate: 5e-4 18 | # settings for data loading 19 | per_device_train_batch_size: 1 20 | dataloader_num_workers: 16 21 | remove_unused_columns: false 22 | # settings for saving checkpoints 23 | save_strategy: steps 24 | save_steps: 500 25 | # settings for logging 26 | log_level: info 27 | logging_strategy: steps 28 | logging_steps: 10 29 | # settings for evaluation 30 | per_device_eval_batch_size: 4 31 | eval_strategy: steps 32 | eval_steps: 500 33 | # settings for optimizer 34 | # adam_epsilon: 1e-6 35 | # uncomment the following line to detect nan or inf values 36 | # debug: underflow_overflow 37 | predict_with_generate: true 38 | # see `transformers.GenerationConfig` 39 | generation_config: 40 | max_new_tokens: 512 41 | # set your absolute deepspeed path here 42 | # deepspeed: configs/ds_zero_3.json 43 | deepspeed: /data/yuxuan/GLM-4/finetune/configs/ds_zero_2.json 44 | 45 | peft_config: 46 | peft_type: LORA 47 | task_type: CAUSAL_LM 48 | r: 8 49 | lora_alpha: 32 50 | lora_dropout: 0.1 51 | target_modules: ["q_proj", "k_proj", "v_proj"] 52 | -------------------------------------------------------------------------------- /finetune/configs/sft.yaml: -------------------------------------------------------------------------------- 1 | data_config: 2 | train_file: train.jsonl 3 | val_file: dev.jsonl 4 | test_file: dev.jsonl 5 | num_proc: 1 6 | 7 | combine: True 8 | freezeV: True 9 | max_input_length: 512 10 | max_output_length: 512 11 | 12 | training_args: 13 | # see `transformers.Seq2SeqTrainingArguments` 14 | output_dir: ./output 15 | max_steps: 3000 16 | # needed to be fit for the dataset 17 | learning_rate: 5e-5 18 | # settings for data loading 19 | per_device_train_batch_size: 1 20 | dataloader_num_workers: 16 21 | remove_unused_columns: false 22 | # settings for saving checkpoints 23 | save_strategy: steps 24 | save_steps: 500 25 | # settings for logging 26 | log_level: info 27 | logging_strategy: steps 28 | logging_steps: 10 29 | # settings for evaluation 30 | per_device_eval_batch_size: 16 31 | eval_strategy: steps 32 | eval_steps: 500 33 | # settings for optimizer 34 | # adam_epsilon: 1e-6 35 | # uncomment the following line to detect nan or inf values 36 | # debug: underflow_overflow 37 | predict_with_generate: true 38 | generation_config: 39 | max_new_tokens: 512 40 | # set your absolute deepspeed path here 41 | deepspeed: configs/ds_zero_3.json 42 | -------------------------------------------------------------------------------- /finetune/requirements.txt: -------------------------------------------------------------------------------- 1 | jieba>=0.42.1 2 | datasets>=2.20.0 3 | peft>=0.15.1 4 | deepspeed>=0.16.5 5 | nltk==3.8.1 6 | rouge_chinese==1.0.3 7 | ruamel.yaml>=0.18.6 8 | typer>=0.13.0 9 | tqdm>=4.67.0 10 | -------------------------------------------------------------------------------- /inference/README.md: -------------------------------------------------------------------------------- 1 | # Inference 2 | 3 | [中文阅读](README_zh.md) 4 | 5 | Please follow the steps in the document strictly to avoid unnecessary errors. 6 | 7 | ## Device and dependency check 8 | 9 | ### Install dependencies 10 | 11 | ```shell 12 | pip install -r requirements.txt 13 | ``` 14 | 15 | ### Related Inference Benchmark Data 16 | 17 | **All benchmark data in this document was collected under the hardware environment listed below. Actual memory usage and runtime may vary depending on your deployment setup. Please refer to your actual environment.** 18 | 19 | Test Hardware: 20 | 21 | + OS: Ubuntu 22.04 22 | + Memory: 512GB 23 | + Python: 3.12.3 24 | + Cmake 3.23.0 25 | + CUDA Version: 12.4 26 | + GPU Driver: 535.104.05 27 | + GPU: NVIDIA H100 80GB HBM3 * 8 28 | 29 | The following stress test results show memory usage and latency during inference. If multiple GPUs are used, "Memory Usage" refers to the maximum usage on a single GPU. 30 | 31 | #### GLM-4-32B-Chat-0414 32 | 33 | | Precision | #GPUs | Memory Usage | First Token Latency | Token Output Speed | Input Tokens | 34 | |-------------|-------|---------------|---------------------|-------------------|--------------| 35 | | BF16 | 1 | 68 GB | 0.16s | 24.4 tokens/s | 1000 | 36 | | BF16 | 1 | 72 GB | 1.37s | 16.9 tokens/s | 8000 | 37 | | BF16 | 2 | 50 GB | 6.75s | 8.1 tokens/s | 32000 | 38 | | BF16 | 4 | 55 GB | 37.83s | 3.0 tokens/s | 100000 | 39 | 40 | #### GLM-4-9B-Chat-0414 41 | 42 | | Precision | #GPUs | Memory Usage | First Token Latency | Token Output Speed | Input Tokens | 43 | |-----------|-------|---------------|----------------------|---------------------|---------------| 44 | | BF16 | 1 | 19 GB | 0.05s | 44.4 tokens/s | 1000 | 45 | | BF16 | 1 | 25 GB | 0.39s | 39.0 tokens/s | 8000 | 46 | | BF16 | 1 | 31 GB | 2.29s | 18.7 tokens/s | 32000 | 47 | | BF16 | 1 | 55 GB | 6.80s | 14.1 tokens/s | 100000 | 48 | 49 | #### GLM-4-9B-Chat-1M 50 | 51 | | Precision | #GPUs | Memory Usage | First Token Latency | Token Output Speed | Input Tokens | 52 | |-----------|-------|---------------|----------------------|---------------------|---------------| 53 | | BF16 | 1 | 75 GB | 98.4s | 2.3 tokens/s | 200000 | 54 | 55 | #### GLM-4V-9B 56 | 57 | | Precision | #GPUs | Memory Usage | First Token Latency | Token Output Speed | Input Tokens | 58 | |-----------|-------|---------------|----------------------|---------------------|---------------| 59 | | BF16 | 1 | 28 GB | 0.1s | 33.4 tokens/s | 1000 | 60 | | BF16 | 1 | 33 GB | 0.7s | 39.2 tokens/s | 8000 | 61 | 62 | | Precision | #GPUs | Memory Usage | First Token Latency | Token Output Speed | Input Tokens | 63 | |-----------|-------|---------------|----------------------|---------------------|---------------| 64 | | INT4 | 1 | 10 GB | 0.1s | 28.7 tokens/s | 1000 | 65 | | INT4 | 1 | 15 GB | 0.8s | 24.2 tokens/s | 8000 | 66 | 67 | ## Quick Start 68 | 69 | ### Use transformers backend code 70 | 71 | + Use the command line to communicate with the GLM-4-9B model. 72 | 73 | ```shell 74 | python trans_cli_demo.py # LLM Such as GLM-4-9B-Chat-0414 75 | python trans_cli_vision_demo.py # GLM-4V-9B 76 | ``` 77 | 78 | + Use the Gradio web client to communicate with the GLM-4-9B model. 79 | 80 | ```shell 81 | python trans_web_demo.py # LLM Such as GLM-4-9B-Chat-0414 82 | python trans_web_vision_demo.py # GLM-4V-9B 83 | ``` 84 | 85 | + Use Batch inference. 86 | 87 | ```shell 88 | python trans_batch_demo.py # LLM Such as GLM-4-9B-Chat-0414 89 | ``` 90 | 91 | ### Use vLLM backend code 92 | 93 | + Use the command line to communicate with the GLM-4-9B-Chat model. 94 | 95 | ```shell 96 | python vllm_cli_demo.py # LLM Such as GLM-4-9B-Chat-0414 97 | ``` 98 | 99 | + Launch an OpenAI-compatible API service. 100 | 101 | ```shell 102 | vllm serve THUDM/GLM-4-9B-Chat-0414 --tensor_parallel_size 2 103 | ``` 104 | 105 | ### Use glm-4v to build an OpenAI-compatible service 106 | 107 | Start the server: 108 | 109 | ```shell 110 | python glm4v_server.py THUDM/glm-4v-9b 111 | ``` 112 | 113 | Client request: 114 | 115 | ```shell 116 | python glm4v_api_request.py 117 | ``` 118 | 119 | ## Stress test 120 | 121 | Users can use this code to test the generation speed of the model on the transformers backend on their own devices: 122 | 123 | ```shell 124 | python trans_stress_test.py 125 | ``` 126 | 127 | ## Use Ascend card to run code 128 | 129 | Users can run the above code in the Ascend hardware environment. They only need to change the transformers to openmind and the cuda device in device to npu. 130 | 131 | ```shell 132 | #from transformers import AutoModelForCausalLM, AutoTokenizer 133 | from openmind import AutoModelForCausalLM, AutoTokenizer 134 | 135 | #device = 'cuda' 136 | device = 'npu' 137 | ``` 138 | -------------------------------------------------------------------------------- /inference/README_zh.md: -------------------------------------------------------------------------------- 1 | # Inference 2 | 3 | Read this in [English](README.md) 4 | 5 | 请严格按照文档的步骤进行操作,以避免不必要的错误。 6 | 7 | ## 设备和依赖检查 8 | 9 | ### 安装依赖 10 | 11 | ```shell 12 | pip install -r requirements.txt 13 | ``` 14 | 15 | ### 相关推理测试数据 16 | 17 | **本文档的数据均在以下硬件环境测试,实际运行环境需求和运行占用的显存略有不同,请以实际运行环境为准。** 18 | 19 | 测试硬件信息: 20 | 21 | + OS: Ubuntu 22.04 22 | + Memory: 512GB 23 | + Python: 3.12.3 24 | + CUDA Version: 12.4 25 | + Cmake 3.23.0 26 | + GPU Driver: 535.104.05 27 | + GPU: NVIDIA H100 80GB HBM3 * 8 28 | 29 | 推理的压力测试数据如下,如有多张显卡,则显存占用代表显存占用最大一张显卡的显存消耗。 30 | 31 | #### GLM-4-32B-Chat-0414 32 | 33 | | 精度 | 显卡数量 | 显存占用 | 首 Token 延迟 | Token 输出速度 | 输入token数 | 34 | |------|------|-------|------------|---------------|----------| 35 | | BF16 | 1 | 68 GB | 0.16s | 24.4 tokens/s | 1000 | 36 | | BF16 | 1 | 72 GB | 1.37s | 16.9 tokens/s | 8000 | 37 | | BF16 | 2 | 50 GB | 6.75s | 8.1 tokens/s | 32000 | 38 | | BF16 | 4 | 55 GB | 37.83s | 3.0 tokens/s | 100000 | 39 | 40 | #### GLM-4-9B-Chat-0414 41 | 42 | | 精度 | 显卡数量 | 显存占用 | 首 Token 延迟 | Token 输出速度 | 输入token数 | 43 | |------|------|-------|------------|---------------|---------| 44 | | BF16 | 1 | 19 GB | 0.05s | 44.4 tokens/s | 1000 | 45 | | BF16 | 1 | 25 GB | 0.39s | 39.0 tokens/s | 8000 | 46 | | BF16 | 1 | 31 GB | 2.29s | 18.7 tokens/s | 32000 | 47 | | BF16 | 1 | 55 GB | 6.80s | 14.1 tokens/s | 100000 | 48 | 49 | 50 | #### GLM-4-9B-Chat-1M 51 | 52 | | 精度 | 显卡数量 | 显存占用 | 首 Token 延迟 | Token 输出速度 | 输入token数 | 53 | |--------|------|------|------------|--------------|-------------| 54 | | BF16 | 1 | 75 GB | 98.4s | 2.3 tokens/s | 200000 | 55 | 56 | #### GLM-4V-9B 57 | 58 | | 精度 | 显卡数量 | 显存占用 | 首 Token 延迟 | Token 输出速度 | 输入token数 | 59 | |--------|------|------|------------|--------------|-------------| 60 | | BF16 | 1 | 28 GB | 0.1s | 33.4 tokens/s | 1000 | 61 | | BF16 | 1 | 33 GB | 0.7s | 39.2 tokens/s | 8000 | 62 | 63 | | 精度 | 显卡数量 | 显存占用 | 首 Token 延迟 | Token 输出速度 | 输入token数 | 64 | |--------|-------|--------|------------|--------------|-------------| 65 | | INT4 | 1 | 10 GB | 0.1s | 28.7 tokens/s | 1000 | 66 | | INT4 | 1 | 15 GB | 0.8s | 24.2 tokens/s | 8000 | 67 | 68 | ## 快速开始 69 | 70 | ### 使用 transformers 后端代码 71 | 72 | + 使用命令行与 GLM-4-9B 模型进行对话。 73 | 74 | ```shell 75 | python trans_cli_demo.py # LLM Such as GLM-4-9B-Chat-0414 76 | python trans_cli_vision_demo.py # GLM-4V-9B 77 | ``` 78 | 79 | + 使用 Gradio 网页端与 GLM-4-9B 模型进行对话。 80 | 81 | ```shell 82 | python trans_web_demo.py # LLM Such as GLM-4-9B-Chat-0414 83 | python trans_web_vision_demo.py # GLM-4V-9B 84 | ``` 85 | 86 | + 使用 Batch 推理。 87 | 88 | ```shell 89 | python trans_batch_demo.py 90 | ``` 91 | 92 | ### 使用 vLLM 后端代码 93 | 94 | + 使用命令行与 GLM-4-9B-Chat 模型进行对话。 95 | 96 | ```shell 97 | python vllm_cli_demo.py # LLM Such as GLM-4-9B-Chat-0414 98 | ``` 99 | 100 | + 构建 OpenAI 类 API 服务。 101 | ```shell 102 | vllm serve THUDM/GLM-4-9B-Chat-0414 --tensor_parallel_size 2 103 | ``` 104 | 105 | ### 使用 glm-4v 构建 OpenAI 服务 106 | 107 | 启动服务端 108 | 109 | ```shell 110 | python glm4v_server.py THUDM/glm-4v-9b 111 | ``` 112 | 113 | 客户端请求: 114 | 115 | ```shell 116 | python glm4v_api_request.py 117 | ``` 118 | 119 | ## 压力测试 120 | 121 | 用户可以在自己的设备上使用本代码测试模型在 transformers后端的生成速度: 122 | 123 | ```shell 124 | python trans_stress_test.py 125 | ``` 126 | 127 | ## 使用昇腾NPU运行代码 128 | 129 | 用户可以在昇腾硬件环境下运行以上代码,只需将transformers修改为openmind,将device中的cuda设备修改为npu: 130 | 131 | ```shell 132 | #from transformers import AutoModelForCausalLM, AutoTokenizer 133 | from openmind import AutoModelForCausalLM, AutoTokenizer 134 | 135 | #device = 'cuda' 136 | device = 'npu' 137 | ``` 138 | -------------------------------------------------------------------------------- /inference/demo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zRzRzRzRzRzRzR/GLM-4/2e5efb1365eda83bead08afffb7a884d56724f7a/inference/demo.jpg -------------------------------------------------------------------------------- /inference/glm4v_api_request.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script creates a OpenAI Request demo for the glm-4v-9b model, just Use OpenAI API to interact with the model. 3 | For LLM such as GLM-4-9B-0414, using with vLLM OpenAI Server. 4 | 5 | vllm serve THUDM/GLM-4-32B-Chat-0414 --tensor_parallel_size 4 6 | 7 | """ 8 | 9 | import base64 10 | 11 | from openai import OpenAI 12 | 13 | 14 | base_url = "http://127.0.0.1:8000/v1/" 15 | client = OpenAI(api_key="EMPTY", base_url=base_url) 16 | 17 | 18 | def create_chat_completion(messages, use_stream=False): 19 | response = client.chat.completions.create( 20 | model="glm-4v", 21 | messages=messages, 22 | stream=use_stream, 23 | max_tokens=256, 24 | temperature=0.4, 25 | presence_penalty=1.2, 26 | top_p=0.8, 27 | ) 28 | if response: 29 | if use_stream: 30 | for chunk in response: 31 | print(chunk) 32 | else: 33 | print(response) 34 | else: 35 | print("Error:", response.status_code) 36 | 37 | 38 | def encode_image(image_path): 39 | """ 40 | Encodes an image file into a base64 string. 41 | Args: 42 | image_path (str): The path to the image file. 43 | 44 | This function opens the specified image file, reads its content, and encodes it into a base64 string. 45 | The base64 encoding is used to send images over HTTP as text. 46 | """ 47 | 48 | with open(image_path, "rb") as image_file: 49 | return base64.b64encode(image_file.read()).decode("utf-8") 50 | 51 | 52 | def glm4v_simple_image_chat(use_stream=False, img_path=None): 53 | """ 54 | Facilitates a simple chat interaction involving an image. 55 | 56 | Args: 57 | use_stream (bool): Specifies whether to use streaming for chat responses. 58 | img_path (str): Path to the image file to be included in the chat. 59 | 60 | This function encodes the specified image and constructs a predefined conversation involving the image. 61 | It then calls `create_chat_completion` to generate a response from the model. 62 | The conversation includes asking about the content of the image and a follow-up question. 63 | """ 64 | 65 | img_url = f"data:image/jpeg;base64,{encode_image(img_path)}" 66 | messages = [ 67 | { 68 | "role": "user", 69 | "content": [ 70 | { 71 | "type": "text", 72 | "text": "What’s in this image?", 73 | }, 74 | { 75 | "type": "image_url", 76 | "image_url": {"url": img_url}, 77 | }, 78 | ], 79 | }, 80 | { 81 | "role": "assistant", 82 | "content": "The image displays a wooden boardwalk extending through a vibrant green grassy wetland. The sky is partly cloudy with soft, wispy clouds, indicating nice weather. Vegetation is seen on either side of the boardwalk, and trees are present in the background, suggesting that this area might be a natural reserve or park designed for ecological preservation and outdoor recreation. The boardwalk allows visitors to explore the area without disturbing the natural habitat.", 83 | }, 84 | {"role": "user", "content": "Do you think this is a spring or winter photo?"}, 85 | ] 86 | create_chat_completion(messages=messages, use_stream=use_stream) 87 | 88 | 89 | if __name__ == "__main__": 90 | glm4v_simple_image_chat(use_stream=False, img_path="demo.jpg") 91 | -------------------------------------------------------------------------------- /inference/glm4v_server.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import gc 3 | import sys 4 | import threading 5 | import time 6 | from contextlib import asynccontextmanager 7 | from io import BytesIO 8 | from pathlib import Path 9 | from typing import List, Literal, Optional, Tuple, Union 10 | 11 | import requests 12 | import torch 13 | import uvicorn 14 | from fastapi import FastAPI, HTTPException 15 | from fastapi.middleware.cors import CORSMiddleware 16 | from peft import PeftModelForCausalLM 17 | from PIL import Image 18 | from pydantic import BaseModel, Field 19 | from sse_starlette.sse import EventSourceResponse 20 | from transformers import AutoModel, AutoTokenizer, TextIteratorStreamer 21 | 22 | 23 | TORCH_TYPE = ( 24 | torch.bfloat16 if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8 else torch.float16 25 | ) 26 | 27 | 28 | @asynccontextmanager 29 | async def lifespan(app: FastAPI): 30 | """ 31 | An asynchronous context manager for managing the lifecycle of the FastAPI app. 32 | It ensures that GPU memory is cleared after the app's lifecycle ends, which is essential for efficient resource management in GPU environments. 33 | """ 34 | yield 35 | if torch.cuda.is_available(): 36 | torch.cuda.empty_cache() 37 | torch.cuda.ipc_collect() 38 | 39 | 40 | app = FastAPI(lifespan=lifespan) 41 | 42 | app.add_middleware( 43 | CORSMiddleware, 44 | allow_origins=["*"], 45 | allow_credentials=True, 46 | allow_methods=["*"], 47 | allow_headers=["*"], 48 | ) 49 | 50 | 51 | class ModelCard(BaseModel): 52 | """ 53 | A Pydantic model representing a model card, which provides metadata about a machine learning model. 54 | It includes fields like model ID, owner, and creation time. 55 | """ 56 | 57 | id: str 58 | object: str = "model" 59 | created: int = Field(default_factory=lambda: int(time.time())) 60 | owned_by: str = "owner" 61 | root: Optional[str] = None 62 | parent: Optional[str] = None 63 | permission: Optional[list] = None 64 | 65 | 66 | class ModelList(BaseModel): 67 | object: str = "list" 68 | data: List[ModelCard] = [] 69 | 70 | 71 | class ImageUrl(BaseModel): 72 | url: str 73 | 74 | 75 | class TextContent(BaseModel): 76 | type: Literal["text"] 77 | text: str 78 | 79 | 80 | class ImageUrlContent(BaseModel): 81 | type: Literal["image_url"] 82 | image_url: ImageUrl 83 | 84 | 85 | ContentItem = Union[TextContent, ImageUrlContent] 86 | 87 | 88 | class ChatMessageInput(BaseModel): 89 | role: Literal["user", "assistant", "system"] 90 | content: Union[str, List[ContentItem]] 91 | name: Optional[str] = None 92 | 93 | 94 | class ChatMessageResponse(BaseModel): 95 | role: Literal["assistant"] 96 | content: str = None 97 | name: Optional[str] = None 98 | 99 | 100 | class DeltaMessage(BaseModel): 101 | role: Optional[Literal["user", "assistant", "system"]] = None 102 | content: Optional[str] = None 103 | 104 | 105 | class ChatCompletionRequest(BaseModel): 106 | model: str 107 | messages: List[ChatMessageInput] 108 | temperature: Optional[float] = 0.8 109 | top_p: Optional[float] = 0.8 110 | max_tokens: Optional[int] = None 111 | stream: Optional[bool] = False 112 | # Additional parameters 113 | repetition_penalty: Optional[float] = 1.0 114 | 115 | 116 | class ChatCompletionResponseChoice(BaseModel): 117 | index: int 118 | message: ChatMessageResponse 119 | 120 | 121 | class ChatCompletionResponseStreamChoice(BaseModel): 122 | index: int 123 | delta: DeltaMessage 124 | 125 | 126 | class UsageInfo(BaseModel): 127 | prompt_tokens: int = 0 128 | total_tokens: int = 0 129 | completion_tokens: Optional[int] = 0 130 | 131 | 132 | class ChatCompletionResponse(BaseModel): 133 | model: str 134 | object: Literal["chat.completion", "chat.completion.chunk"] 135 | choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]] 136 | created: Optional[int] = Field(default_factory=lambda: int(time.time())) 137 | usage: Optional[UsageInfo] = None 138 | 139 | 140 | @app.get("/v1/models", response_model=ModelList) 141 | async def list_models(): 142 | """ 143 | An endpoint to list available models. It returns a list of model cards. 144 | This is useful for clients to query and understand what models are available for use. 145 | """ 146 | model_card = ModelCard(id="GLM-4v-9b") 147 | return ModelList(data=[model_card]) 148 | 149 | 150 | @app.post("/v1/chat/completions", response_model=ChatCompletionResponse) 151 | async def create_chat_completion(request: ChatCompletionRequest): 152 | global model, tokenizer 153 | 154 | if len(request.messages) < 1 or request.messages[-1].role == "assistant": 155 | raise HTTPException(status_code=400, detail="Invalid request") 156 | 157 | gen_params = dict( 158 | messages=request.messages, 159 | temperature=request.temperature, 160 | top_p=request.top_p, 161 | max_tokens=request.max_tokens or 1024, 162 | echo=False, 163 | stream=request.stream, 164 | repetition_penalty=request.repetition_penalty, 165 | ) 166 | 167 | if request.stream: 168 | generate = predict(request.model, gen_params) 169 | return EventSourceResponse(generate, media_type="text/event-stream") 170 | response = generate_glm4v(model, tokenizer, gen_params) 171 | 172 | usage = UsageInfo() 173 | 174 | message = ChatMessageResponse( 175 | role="assistant", 176 | content=response["text"], 177 | ) 178 | choice_data = ChatCompletionResponseChoice( 179 | index=0, 180 | message=message, 181 | ) 182 | task_usage = UsageInfo.model_validate(response["usage"]) 183 | for usage_key, usage_value in task_usage.model_dump().items(): 184 | setattr(usage, usage_key, getattr(usage, usage_key) + usage_value) 185 | return ChatCompletionResponse(model=request.model, choices=[choice_data], object="chat.completion", usage=usage) 186 | 187 | 188 | def predict(model_id: str, params: dict): 189 | global model, tokenizer 190 | 191 | choice_data = ChatCompletionResponseStreamChoice(index=0, delta=DeltaMessage(role="assistant"), finish_reason=None) 192 | chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk") 193 | yield "{}".format(chunk.model_dump_json(exclude_unset=True)) 194 | 195 | previous_text = "" 196 | for new_response in generate_stream_glm4v(model, tokenizer, params): 197 | decoded_unicode = new_response["text"] 198 | delta_text = decoded_unicode[len(previous_text) :] 199 | previous_text = decoded_unicode 200 | delta = DeltaMessage(content=delta_text, role="assistant") 201 | choice_data = ChatCompletionResponseStreamChoice(index=0, delta=delta) 202 | chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk") 203 | yield "{}".format(chunk.model_dump_json(exclude_unset=True)) 204 | 205 | choice_data = ChatCompletionResponseStreamChoice(index=0, delta=DeltaMessage()) 206 | chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk") 207 | yield "{}".format(chunk.model_dump_json(exclude_unset=True)) 208 | 209 | 210 | def generate_glm4v(model: AutoModel, tokenizer: AutoTokenizer, params: dict): 211 | """ 212 | Generates a response using the GLM-4v-9b model. It processes the chat history and image data, if any, 213 | and then invokes the model to generate a response. 214 | """ 215 | 216 | response = None 217 | 218 | for response in generate_stream_glm4v(model, tokenizer, params): 219 | pass 220 | return response 221 | 222 | 223 | def process_history_and_images( 224 | messages: List[ChatMessageInput], 225 | ) -> Tuple[Optional[str], Optional[List[Tuple[str, str]]], Optional[List[Image.Image]]]: 226 | """ 227 | Process history messages to extract text, identify the last user query, 228 | and convert base64 encoded image URLs to PIL images. 229 | 230 | Args: 231 | messages(List[ChatMessageInput]): List of ChatMessageInput objects. 232 | return: A tuple of three elements: 233 | - The last user query as a string. 234 | - Text history formatted as a list of tuples for the model. 235 | - List of PIL Image objects extracted from the messages. 236 | """ 237 | 238 | formatted_history = [] 239 | image_list = [] 240 | last_user_query = "" 241 | 242 | for i, message in enumerate(messages): 243 | role = message.role 244 | content = message.content 245 | 246 | if isinstance(content, list): # text 247 | text_content = " ".join(item.text for item in content if isinstance(item, TextContent)) 248 | else: 249 | text_content = content 250 | 251 | if isinstance(content, list): # image 252 | for item in content: 253 | if isinstance(item, ImageUrlContent): 254 | image_url = item.image_url.url 255 | if image_url.startswith("data:image/jpeg;base64,"): 256 | base64_encoded_image = image_url.split("data:image/jpeg;base64,")[1] 257 | image_data = base64.b64decode(base64_encoded_image) 258 | image = Image.open(BytesIO(image_data)).convert("RGB") 259 | else: 260 | response = requests.get(image_url, verify=False) 261 | image = Image.open(BytesIO(response.content)).convert("RGB") 262 | image_list.append(image) 263 | 264 | if role == "user": 265 | if i == len(messages) - 1: # 最后一条用户消息 266 | last_user_query = text_content 267 | else: 268 | formatted_history.append((text_content, "")) 269 | elif role == "assistant": 270 | if formatted_history: 271 | if formatted_history[-1][1] != "": 272 | assert False, f"the last query is answered. answer again. {formatted_history[-1][0]}, {formatted_history[-1][1]}, {text_content}" 273 | formatted_history[-1] = (formatted_history[-1][0], text_content) 274 | else: 275 | assert False, "assistant reply before user" 276 | else: 277 | assert False, f"unrecognized role: {role}" 278 | 279 | return last_user_query, formatted_history, image_list 280 | 281 | 282 | @torch.inference_mode() 283 | def generate_stream_glm4v(model: AutoModel, tokenizer: AutoTokenizer, params: dict): 284 | uploaded = False 285 | messages = params["messages"] 286 | temperature = float(params.get("temperature", 1.0)) 287 | repetition_penalty = float(params.get("repetition_penalty", 1.0)) 288 | top_p = float(params.get("top_p", 1.0)) 289 | max_new_tokens = int(params.get("max_tokens", 256)) 290 | query, history, image_list = process_history_and_images(messages) 291 | 292 | inputs = [] 293 | for idx, (user_msg, model_msg) in enumerate(history): 294 | if idx == len(history) - 1 and not model_msg: 295 | inputs.append({"role": "user", "content": user_msg}) 296 | if image_list and not uploaded: 297 | inputs[-1].update({"image": image_list[0]}) 298 | uploaded = True 299 | break 300 | if user_msg: 301 | inputs.append({"role": "user", "content": user_msg}) 302 | if model_msg: 303 | inputs.append({"role": "assistant", "content": model_msg}) 304 | if len(image_list) >= 1: 305 | inputs.append({"role": "user", "content": query, "image": image_list[0]}) 306 | else: 307 | inputs.append({"role": "user", "content": query}) 308 | 309 | model_inputs = tokenizer.apply_chat_template( 310 | inputs, add_generation_prompt=True, tokenize=True, return_tensors="pt", return_dict=True 311 | ).to(next(model.parameters()).device) 312 | 313 | input_echo_len = len(model_inputs["input_ids"][0]) 314 | streamer = TextIteratorStreamer(tokenizer=tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True) 315 | gen_kwargs = { 316 | "repetition_penalty": repetition_penalty, 317 | "max_new_tokens": max_new_tokens, 318 | "do_sample": True if temperature > 1e-5 else False, 319 | "top_p": top_p if temperature > 1e-5 else 0, 320 | "top_k": 1, 321 | "streamer": streamer, 322 | "eos_token_id": [151329, 151336, 151338], 323 | } 324 | if temperature > 1e-5: 325 | gen_kwargs["temperature"] = temperature 326 | 327 | generated_text = "" 328 | 329 | def generate_text(): 330 | with torch.no_grad(): 331 | model.generate(**model_inputs, **gen_kwargs) 332 | 333 | generation_thread = threading.Thread(target=generate_text) 334 | generation_thread.start() 335 | 336 | total_len = input_echo_len 337 | for next_text in streamer: 338 | generated_text += next_text 339 | total_len = len(tokenizer.encode(generated_text)) 340 | yield { 341 | "text": generated_text, 342 | "usage": { 343 | "prompt_tokens": input_echo_len, 344 | "completion_tokens": total_len - input_echo_len, 345 | "total_tokens": total_len, 346 | }, 347 | } 348 | generation_thread.join() 349 | print("\033[91m--generated_text\033[0m", generated_text) 350 | yield { 351 | "text": generated_text, 352 | "usage": { 353 | "prompt_tokens": input_echo_len, 354 | "completion_tokens": total_len - input_echo_len, 355 | "total_tokens": total_len, 356 | }, 357 | } 358 | 359 | 360 | gc.collect() 361 | torch.cuda.empty_cache() 362 | 363 | if __name__ == "__main__": 364 | MODEL_PATH = sys.argv[1] 365 | model_dir = Path(MODEL_PATH).expanduser().resolve() 366 | if (model_dir / "adapter_config.json").exists(): 367 | import json 368 | 369 | with open(model_dir / "adapter_config.json", "r", encoding="utf-8") as file: 370 | config = json.load(file) 371 | model = AutoModel.from_pretrained( 372 | config.get("base_model_name_or_path"), device_map="auto", torch_dtype=TORCH_TYPE 373 | ) 374 | model = PeftModelForCausalLM.from_pretrained( 375 | model=model, 376 | model_id=model_dir, 377 | ) 378 | tokenizer = AutoTokenizer.from_pretrained(config.get("base_model_name_or_path"), encode_special_tokens=True) 379 | model.eval() 380 | else: 381 | tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, encode_special_tokens=True) 382 | model = AutoModel.from_pretrained( 383 | MODEL_PATH, 384 | torch_dtype=TORCH_TYPE, 385 | device_map="auto", 386 | ).eval() 387 | 388 | uvicorn.run(app, host="0.0.0.0", port=8000, workers=1) 389 | -------------------------------------------------------------------------------- /inference/requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=2.6.0 2 | torchvision>=0.21.0 3 | transformers>=4.51.3 4 | sentencepiece>=0.2.0 5 | jinja2>=3.1.4 6 | pydantic>=2.11.1 7 | timm>=1.0.15 8 | tiktoken>=0.9.0 9 | numpy<2 10 | accelerate>=1.6.0 11 | sentence_transformers>=3.1.1 12 | gradio>=5.23.3 13 | openai>=1.70.0 14 | einops>=0.8.0 15 | pillow>=10.4.0 16 | sse-starlette>=2.1.3 17 | bitsandbytes>=0.44.1 # INT4 Loading, Not support for NPU 18 | peft>=0.15.0 # Using with finetune model 19 | 20 | # git+https://github.com/vllm-project/vllm.git For vLLM 21 | -------------------------------------------------------------------------------- /inference/trans_batch_demo.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | Here is an example of using batch request GLM-4-0414 Models and glm-4-9b-chat-hf models with the transformers library., 4 | here you need to build the conversation format yourself and then call the batch function to make batch requests. 5 | Please note that in this demo, the memory consumption is significantly higher. 6 | 7 | """ 8 | 9 | from typing import Union 10 | 11 | from transformers import AutoModelForCausalLM, AutoTokenizer, LogitsProcessorList 12 | 13 | 14 | MODEL_PATH = "THUDM/GLM-4-9B-Chat-0414" 15 | 16 | tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) 17 | model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, device_map="auto").eval() 18 | 19 | 20 | def process_model_outputs(inputs, outputs, tokenizer): 21 | responses = [] 22 | for input_ids, output_ids in zip(inputs.input_ids, outputs): 23 | response = tokenizer.decode(output_ids[len(input_ids) :], skip_special_tokens=True).strip() 24 | responses.append(response) 25 | return responses 26 | 27 | 28 | def batch( 29 | model, 30 | tokenizer, 31 | messages: Union[str, list[str]], 32 | max_input_tokens: int = 8192, 33 | max_new_tokens: int = 8192, 34 | num_beams: int = 1, 35 | do_sample: bool = True, 36 | top_p: float = 0.8, 37 | temperature: float = 0.8, 38 | logits_processor=None, 39 | ): 40 | if logits_processor is None: 41 | logits_processor = LogitsProcessorList() 42 | messages = [messages] if isinstance(messages, str) else messages 43 | batched_inputs = tokenizer( 44 | messages, return_tensors="pt", padding="max_length", truncation=True, max_length=max_input_tokens 45 | ).to(model.device) 46 | 47 | gen_kwargs = { 48 | "max_new_tokens": max_new_tokens, 49 | "num_beams": num_beams, 50 | "do_sample": do_sample, 51 | "top_p": top_p, 52 | "temperature": temperature, 53 | "logits_processor": logits_processor, 54 | "eos_token_id": model.config.eos_token_id, 55 | } 56 | batched_outputs = model.generate(**batched_inputs, **gen_kwargs) 57 | batched_response = process_model_outputs(batched_inputs, batched_outputs, tokenizer) 58 | return batched_response 59 | 60 | 61 | if __name__ == "__main__": 62 | batch_message = [ 63 | [ 64 | {"role": "user", "content": "我的爸爸和妈妈结婚为什么不能带我去"}, 65 | {"role": "assistant", "content": "因为他们结婚时你还没有出生"}, 66 | {"role": "user", "content": "我刚才的提问是"}, 67 | ], 68 | [{"role": "user", "content": "你好,你是谁"}], 69 | ] 70 | 71 | batch_inputs = [] 72 | max_input_tokens = 128 73 | for i, messages in enumerate(batch_message): 74 | new_batch_input = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)[12:] 75 | max_input_tokens = max(max_input_tokens, len(new_batch_input)) 76 | batch_inputs.append(new_batch_input) 77 | gen_kwargs = { 78 | "max_input_tokens": max_input_tokens, 79 | "max_new_tokens": 256, 80 | "do_sample": True, 81 | "top_p": 0.8, 82 | "temperature": 0.8, 83 | "num_beams": 1, 84 | } 85 | 86 | batch_responses = batch(model, tokenizer, batch_inputs, **gen_kwargs) 87 | for response in batch_responses: 88 | print("=" * 10) 89 | print(response) 90 | -------------------------------------------------------------------------------- /inference/trans_cli_demo.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script creates a CLI demo with transformers backend for the glm-4-9b-chat model, 3 | allowing users to interact with the model through a command-line interface. 4 | 5 | Usage: 6 | - Run the script to start the CLI demo. 7 | - Interact with the model by typing questions and receiving responses. 8 | 9 | Note: The script includes a modification to handle markdown to plain text conversion, 10 | ensuring that the CLI interface displays formatted text correctly. 11 | 12 | If you use flash attention, you should install the flash-attn and add attn_implementation="flash_attention_2" in model loading. 13 | 14 | """ 15 | 16 | from threading import Thread 17 | 18 | import torch 19 | from transformers import ( 20 | AutoModelForCausalLM, 21 | AutoTokenizer, 22 | StoppingCriteria, 23 | StoppingCriteriaList, 24 | TextIteratorStreamer, 25 | ) 26 | 27 | 28 | MODEL_PATH = "THUDM/GLM-4-9B-Chat-0414" 29 | 30 | tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True) 31 | 32 | model = AutoModelForCausalLM.from_pretrained( 33 | MODEL_PATH, 34 | torch_dtype=torch.bfloat16, 35 | device_map="auto", 36 | ).eval() 37 | 38 | 39 | class StopOnTokens(StoppingCriteria): 40 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: 41 | stop_ids = model.config.eos_token_id 42 | for stop_id in stop_ids: 43 | if input_ids[0][-1] == stop_id: 44 | return True 45 | return False 46 | 47 | 48 | if __name__ == "__main__": 49 | history = [] 50 | max_length = 8192 51 | top_p = 0.8 52 | temperature = 0.6 53 | stop = StopOnTokens() 54 | 55 | print("Welcome to the GLM-4-9B CLI chat. Type your messages below.") 56 | while True: 57 | user_input = input("\nYou: ") 58 | if user_input.lower() in ["exit", "quit"]: 59 | break 60 | history.append([user_input, ""]) 61 | 62 | messages = [] 63 | for idx, (user_msg, model_msg) in enumerate(history): 64 | if idx == len(history) - 1 and not model_msg: 65 | messages.append({"role": "user", "content": user_msg}) 66 | break 67 | if user_msg: 68 | messages.append({"role": "user", "content": user_msg}) 69 | if model_msg: 70 | messages.append({"role": "assistant", "content": model_msg}) 71 | model_inputs = tokenizer.apply_chat_template( 72 | messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt" 73 | ).to(model.device) 74 | streamer = TextIteratorStreamer(tokenizer=tokenizer, timeout=60, skip_prompt=True, skip_special_tokens=True) 75 | generate_kwargs = { 76 | "input_ids": model_inputs["input_ids"], 77 | "attention_mask": model_inputs["attention_mask"], 78 | "streamer": streamer, 79 | "max_new_tokens": max_length, 80 | "do_sample": True, 81 | "top_p": top_p, 82 | "temperature": temperature, 83 | "stopping_criteria": StoppingCriteriaList([stop]), 84 | "repetition_penalty": 1.2, 85 | "eos_token_id": model.config.eos_token_id, 86 | } 87 | t = Thread(target=model.generate, kwargs=generate_kwargs) 88 | t.start() 89 | print("GLM-4:", end="", flush=True) 90 | for new_token in streamer: 91 | if new_token: 92 | print(new_token, end="", flush=True) 93 | history[-1][1] += new_token 94 | 95 | history[-1][1] = history[-1][1].strip() 96 | -------------------------------------------------------------------------------- /inference/trans_cli_vision_demo.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script creates a CLI demo with transformers backend for the glm-4v-9b model, 3 | allowing users to interact with the model through a command-line interface. 4 | 5 | Usage: 6 | - Run the script to start the CLI demo. 7 | - Interact with the model by typing questions and receiving responses. 8 | 9 | Note: The script includes a modification to handle markdown to plain text conversion, 10 | ensuring that the CLI interface displays formatted text correctly. 11 | """ 12 | 13 | from threading import Thread 14 | 15 | import torch 16 | from PIL import Image 17 | from transformers import ( 18 | AutoModel, 19 | AutoTokenizer, 20 | StoppingCriteria, 21 | StoppingCriteriaList, 22 | TextIteratorStreamer, 23 | ) 24 | 25 | 26 | MODEL_PATH = "THUDM/glm-4v-9b" 27 | 28 | tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True, encode_special_tokens=True) 29 | 30 | ## For BF16 inference 31 | model = AutoModel.from_pretrained( 32 | MODEL_PATH, 33 | trust_remote_code=True, 34 | torch_dtype=torch.bfloat16, 35 | device_map="auto", 36 | ).eval() 37 | 38 | ## For INT4 inference 39 | # model = AutoModel.from_pretrained( 40 | # MODEL_PATH, 41 | # trust_remote_code=True, 42 | # quantization_config=BitsAndBytesConfig(load_in_4bit=True), 43 | # torch_dtype=torch.bfloat16, 44 | # low_cpu_mem_usage=True 45 | # ).eval() 46 | 47 | 48 | class StopOnTokens(StoppingCriteria): 49 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: 50 | stop_ids = model.config.eos_token_id 51 | for stop_id in stop_ids: 52 | if input_ids[0][-1] == stop_id: 53 | return True 54 | return False 55 | 56 | 57 | if __name__ == "__main__": 58 | history = [] 59 | max_length = 1024 60 | top_p = 0.8 61 | temperature = 0.6 62 | stop = StopOnTokens() 63 | uploaded = False 64 | image = None 65 | print("Welcome to the GLM-4-9B CLI chat. Type your messages below.") 66 | image_path = input("Image Path:") 67 | try: 68 | image = Image.open(image_path).convert("RGB") 69 | except: 70 | print("Invalid image path. Continuing with text conversation.") 71 | while True: 72 | user_input = input("\nYou: ") 73 | if user_input.lower() in ["exit", "quit"]: 74 | break 75 | history.append([user_input, ""]) 76 | 77 | messages = [] 78 | for idx, (user_msg, model_msg) in enumerate(history): 79 | if idx == len(history) - 1 and not model_msg: 80 | messages.append({"role": "user", "content": user_msg}) 81 | if image and not uploaded: 82 | messages[-1].update({"image": image}) 83 | uploaded = True 84 | break 85 | if user_msg: 86 | messages.append({"role": "user", "content": user_msg}) 87 | if model_msg: 88 | messages.append({"role": "assistant", "content": model_msg}) 89 | model_inputs = tokenizer.apply_chat_template( 90 | messages, add_generation_prompt=True, tokenize=True, return_tensors="pt", return_dict=True 91 | ).to(next(model.parameters()).device) 92 | streamer = TextIteratorStreamer(tokenizer=tokenizer, timeout=60, skip_prompt=True, skip_special_tokens=True) 93 | generate_kwargs = { 94 | **model_inputs, 95 | "streamer": streamer, 96 | "max_new_tokens": max_length, 97 | "do_sample": True, 98 | "top_p": top_p, 99 | "temperature": temperature, 100 | "stopping_criteria": StoppingCriteriaList([stop]), 101 | "repetition_penalty": 1.2, 102 | "eos_token_id": [151329, 151336, 151338], 103 | } 104 | t = Thread(target=model.generate, kwargs=generate_kwargs) 105 | t.start() 106 | print("GLM-4V:", end="", flush=True) 107 | for new_token in streamer: 108 | if new_token: 109 | print(new_token, end="", flush=True) 110 | history[-1][1] += new_token 111 | 112 | history[-1][1] = history[-1][1].strip() 113 | -------------------------------------------------------------------------------- /inference/trans_stress_test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | from threading import Thread 4 | 5 | import torch 6 | from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer 7 | 8 | 9 | MODEL_PATH = "THUDM/GLM-4-9B-Chat-0414" 10 | 11 | 12 | def stress_test(input_token_len, n, output_token_len): 13 | tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, paddsing_side="left") 14 | model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, torch_dtype=torch.bfloat16, device_map="auto").eval() 15 | device = model.device 16 | 17 | # Use INT4 weight infer 18 | # model = AutoModelForCausalLM.from_pretrained( 19 | # MODEL_PATH, 20 | # trust_remote_code=True, 21 | # quantization_config=BitsAndBytesConfig(load_in_4bit=True), 22 | # low_cpu_mem_usage=True, 23 | # ).eval() 24 | 25 | times = [] 26 | decode_times = [] 27 | 28 | print("Warming up...") 29 | vocab_size = tokenizer.vocab_size 30 | warmup_token_len = 20 31 | random_token_ids = torch.randint(3, vocab_size - 200, (warmup_token_len - 5,), dtype=torch.long) 32 | start_tokens = [151331, 151333, 151336, 198] 33 | end_tokens = [151337] 34 | input_ids = ( 35 | torch.tensor(start_tokens + random_token_ids.tolist() + end_tokens, dtype=torch.long).unsqueeze(0).to(device) 36 | ) 37 | attention_mask = torch.ones_like(input_ids, dtype=torch.bfloat16).to(device) 38 | position_ids = torch.arange(len(input_ids[0]), dtype=torch.bfloat16).unsqueeze(0).to(device) 39 | warmup_inputs = {"input_ids": input_ids, "attention_mask": attention_mask, "position_ids": position_ids} 40 | with torch.no_grad(): 41 | _ = model.generate( 42 | input_ids=warmup_inputs["input_ids"], 43 | attention_mask=warmup_inputs["attention_mask"], 44 | max_new_tokens=512, 45 | do_sample=False, 46 | repetition_penalty=0.1, 47 | eos_token_id=[151329, 151336, 151338], 48 | ) 49 | print("Warming up complete. Starting stress test...") 50 | 51 | for i in range(n): 52 | random_token_ids = torch.randint(3, vocab_size - 200, (input_token_len - 5,), dtype=torch.long) 53 | input_ids = ( 54 | torch.tensor(start_tokens + random_token_ids.tolist() + end_tokens, dtype=torch.long) 55 | .unsqueeze(0) 56 | .to(device) 57 | ) 58 | attention_mask = torch.ones_like(input_ids, dtype=torch.bfloat16).to(device) 59 | position_ids = torch.arange(len(input_ids[0]), dtype=torch.bfloat16).unsqueeze(0).to(device) 60 | test_inputs = {"input_ids": input_ids, "attention_mask": attention_mask, "position_ids": position_ids} 61 | 62 | streamer = TextIteratorStreamer(tokenizer=tokenizer, timeout=36000, skip_prompt=True, skip_special_tokens=True) 63 | 64 | generate_kwargs = { 65 | "input_ids": test_inputs["input_ids"], 66 | "attention_mask": test_inputs["attention_mask"], 67 | "max_new_tokens": output_token_len, 68 | "do_sample": False, 69 | "repetition_penalty": 0.1, # For generate more tokens for test. 70 | "eos_token_id": [151329, 151336, 151338], 71 | "streamer": streamer, 72 | } 73 | 74 | start_time = time.time() 75 | t = Thread(target=model.generate, kwargs=generate_kwargs) 76 | t.start() 77 | 78 | first_token_time = None 79 | all_token_times = [] 80 | 81 | for token in streamer: 82 | current_time = time.time() 83 | if first_token_time is None: 84 | first_token_time = current_time 85 | times.append(first_token_time - start_time) 86 | all_token_times.append(current_time) 87 | 88 | t.join() 89 | end_time = time.time() 90 | 91 | avg_decode_time_per_token = len(all_token_times) / (end_time - first_token_time) if all_token_times else 0 92 | decode_times.append(avg_decode_time_per_token) 93 | print( 94 | f"Iteration {i + 1}/{n} - Prefilling Time: {times[-1]:.4f} seconds - Average Decode Time: {avg_decode_time_per_token:.4f} tokens/second" 95 | ) 96 | 97 | torch.cuda.empty_cache() 98 | 99 | avg_first_token_time = sum(times) / n 100 | avg_decode_time = sum(decode_times) / n 101 | print(f"\nAverage First Token Time over {n} iterations: {avg_first_token_time:.4f} seconds") 102 | print(f"Average Decode Time per Token over {n} iterations: {avg_decode_time:.4f} tokens/second") 103 | return times, avg_first_token_time, decode_times, avg_decode_time 104 | 105 | 106 | if __name__ == "__main__": 107 | parser = argparse.ArgumentParser(description="Stress test for model inference") 108 | parser.add_argument("--input_token_len", type=int, default=100000, help="Number of tokens for each test") 109 | parser.add_argument("--output_token_len", type=int, default=128, help="Number of output tokens for each test") 110 | parser.add_argument("--n", type=int, default=3, help="Number of iterations for the stress test") 111 | args = parser.parse_args() 112 | stress_test(args.input_token_len, args.n, args.output_token_len) 113 | -------------------------------------------------------------------------------- /inference/trans_web_demo.py: -------------------------------------------------------------------------------- 1 | from threading import Thread 2 | 3 | import gradio as gr 4 | from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer 5 | 6 | 7 | MODEL_PATH = "THUDM/GLM-4-9B-Chat-0414" 8 | tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) 9 | model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, device_map="auto") 10 | 11 | 12 | def preprocess_messages(history, system_prompt): 13 | messages = [] 14 | 15 | if system_prompt: 16 | messages.append({"role": "system", "content": system_prompt}) 17 | 18 | for idx, (user_msg, model_msg) in enumerate(history): 19 | if idx == len(history) - 1 and not model_msg: 20 | messages.append({"role": "user", "content": user_msg}) 21 | break 22 | if user_msg: 23 | messages.append({"role": "user", "content": user_msg}) 24 | if model_msg: 25 | messages.append({"role": "assistant", "content": model_msg}) 26 | 27 | return messages 28 | 29 | 30 | def predict(history, system_prompt, max_length, top_p, top_k, temperature): 31 | messages = preprocess_messages(history, system_prompt) 32 | model_inputs = tokenizer.apply_chat_template( 33 | messages, add_generation_prompt=True, tokenize=True, return_tensors="pt", return_dict=True 34 | ).to(model.device) 35 | streamer = TextIteratorStreamer(tokenizer, timeout=60, skip_prompt=True, skip_special_tokens=True) 36 | generate_kwargs = { 37 | "input_ids": model_inputs["input_ids"], 38 | "attention_mask": model_inputs["attention_mask"], 39 | "streamer": streamer, 40 | "max_new_tokens": max_length, 41 | "do_sample": True, 42 | "top_p": top_p, 43 | "top_k": top_k, 44 | "temperature": temperature, 45 | "repetition_penalty": 1.2, 46 | } 47 | 48 | generate_kwargs["eos_token_id"] = tokenizer.encode("<|user|>") 49 | 50 | t = Thread(target=model.generate, kwargs=generate_kwargs) 51 | t.start() 52 | for new_token in streamer: 53 | if new_token: 54 | history[-1][1] += new_token 55 | yield history 56 | 57 | 58 | def main(): 59 | with gr.Blocks() as demo: 60 | gr.HTML("""

GLM-4-0414 Gradio Demo

""") 61 | 62 | with gr.Row(): 63 | with gr.Column(scale=3): 64 | system_prompt = gr.Textbox( 65 | show_label=True, placeholder="Enter system prompt here...", label="System Prompt", lines=2 66 | ) 67 | 68 | with gr.Row(): 69 | with gr.Column(scale=3): 70 | chatbot = gr.Chatbot() 71 | 72 | with gr.Row(): 73 | with gr.Column(scale=2): 74 | user_input = gr.Textbox(show_label=True, placeholder="Input...", label="User Input") 75 | submitBtn = gr.Button("Submit") 76 | emptyBtn = gr.Button("Clear History") 77 | with gr.Column(scale=1): 78 | max_length = gr.Slider(0, 8192, value=4096, step=1.0, label="Maximum length", interactive=True) 79 | top_p = gr.Slider(0, 1, value=0.8, step=0.01, label="Top P", interactive=True) 80 | top_k = gr.Slider(0, 100, value=50, step=1, label="Top K", interactive=True) 81 | temperature = gr.Slider(0.01, 1, value=0.6, step=0.01, label="Temperature", interactive=True) 82 | 83 | def user(query, history): 84 | return "", history + [[query, ""]] 85 | 86 | def clear_history(): 87 | return None 88 | 89 | submitBtn.click(user, [user_input, chatbot], [user_input, chatbot], queue=False).then( 90 | predict, [chatbot, system_prompt, max_length, top_p, top_k, temperature], chatbot 91 | ) 92 | emptyBtn.click(clear_history, None, [chatbot], queue=False) 93 | 94 | demo.queue() 95 | demo.launch() 96 | 97 | 98 | if __name__ == "__main__": 99 | main() 100 | -------------------------------------------------------------------------------- /inference/trans_web_vision_demo.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script creates a Gradio demo with a Transformers backend for the glm-4v-9b model, allowing users to interact with the model through a Gradio web UI. 3 | 4 | Usage: 5 | - Run the script to start the Gradio server. 6 | - Interact with the model via the web UI. 7 | 8 | Requirements: 9 | - Gradio package 10 | - Type `pip install gradio==4.44.1` to install Gradio. 11 | """ 12 | 13 | import os 14 | from io import BytesIO 15 | from threading import Thread 16 | 17 | import gradio as gr 18 | import requests 19 | import torch 20 | from PIL import Image 21 | from transformers import AutoModel, AutoTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer 22 | 23 | 24 | MODEL_PATH = os.environ.get("MODEL_PATH", "THUDM/glm-4v-9b") 25 | 26 | tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True, encode_special_tokens=True) 27 | model = AutoModel.from_pretrained( 28 | MODEL_PATH, trust_remote_code=True, device_map="auto", torch_dtype=torch.bfloat16 29 | ).eval() 30 | 31 | 32 | class StopOnTokens(StoppingCriteria): 33 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: 34 | stop_ids = model.config.eos_token_id 35 | for stop_id in stop_ids: 36 | if input_ids[0][-1] == stop_id: 37 | return True 38 | return False 39 | 40 | 41 | def get_image(image_path=None, image_url=None): 42 | if image_path: 43 | return Image.open(image_path).convert("RGB") 44 | elif image_url: 45 | response = requests.get(image_url) 46 | return Image.open(BytesIO(response.content)).convert("RGB") 47 | return None 48 | 49 | 50 | def chatbot(image_path=None, image_url=None, assistant_prompt=""): 51 | image = get_image(image_path, image_url) 52 | 53 | messages = [{"role": "assistant", "content": assistant_prompt}, {"role": "user", "content": "", "image": image}] 54 | 55 | model_inputs = tokenizer.apply_chat_template( 56 | messages, add_generation_prompt=True, tokenize=True, return_tensors="pt", return_dict=True 57 | ).to(next(model.parameters()).device) 58 | 59 | streamer = TextIteratorStreamer(tokenizer=tokenizer, timeout=60, skip_prompt=True, skip_special_tokens=True) 60 | 61 | generate_kwargs = { 62 | **model_inputs, 63 | "streamer": streamer, 64 | "max_new_tokens": 1024, 65 | "do_sample": True, 66 | "top_p": 0.8, 67 | "temperature": 0.6, 68 | "stopping_criteria": StoppingCriteriaList([StopOnTokens()]), 69 | "repetition_penalty": 1.2, 70 | "eos_token_id": [151329, 151336, 151338], 71 | } 72 | 73 | t = Thread(target=model.generate, kwargs=generate_kwargs) 74 | t.start() 75 | 76 | response = "" 77 | for new_token in streamer: 78 | if new_token: 79 | response += new_token 80 | 81 | return image, response.strip() 82 | 83 | 84 | with gr.Blocks() as demo: 85 | demo.title = "GLM-4V-9B Image Recognition Demo" 86 | demo.description = """ 87 | This demo uses the GLM-4V-9B model to got image infomation. 88 | """ 89 | with gr.Row(): 90 | with gr.Column(): 91 | image_path_input = gr.File(label="Upload Image (High-Priority)", type="filepath") 92 | image_url_input = gr.Textbox(label="Image URL (Low-Priority)") 93 | assistant_prompt_input = gr.Textbox(label="Assistant Prompt (You Can Change It)", value="这是什么?") 94 | submit_button = gr.Button("Submit") 95 | with gr.Column(): 96 | chatbot_output = gr.Textbox(label="GLM-4V-9B Model Response") 97 | image_output = gr.Image(label="Image Preview") 98 | 99 | submit_button.click( 100 | chatbot, 101 | inputs=[image_path_input, image_url_input, assistant_prompt_input], 102 | outputs=[image_output, chatbot_output], 103 | ) 104 | 105 | demo.launch(server_name="127.0.0.1", server_port=8911, inbrowser=True, share=False) 106 | -------------------------------------------------------------------------------- /inference/vllm_cli_demo.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script creates a CLI demo with vllm backand for the glm-4-9b model, 3 | allowing users to interact with the model through a command-line interface. 4 | 5 | Usage: 6 | - Run the script to start the CLI demo. 7 | - Interact with the model by typing questions and receiving responses. 8 | 9 | Note: The script includes a modification to handle markdown to plain text conversion, 10 | ensuring that the CLI interface displays formatted text correctly. 11 | """ 12 | 13 | import asyncio 14 | import time 15 | from typing import Dict, List 16 | 17 | from transformers import AutoTokenizer 18 | from vllm import AsyncEngineArgs, AsyncLLMEngine, SamplingParams 19 | from vllm.lora.request import LoRARequest 20 | 21 | 22 | MODEL_PATH = "THUDM/GLM-4-9B-Chat-0414" 23 | LORA_PATH = "" 24 | 25 | 26 | def load_model_and_tokenizer(model_dir: str, enable_lora: bool): 27 | tokenizer = AutoTokenizer.from_pretrained(model_dir) 28 | 29 | engine_args = AsyncEngineArgs( 30 | model=model_dir, 31 | tokenizer=model_dir, 32 | enable_lora=enable_lora, 33 | tensor_parallel_size=1, 34 | dtype="bfloat16", 35 | gpu_memory_utilization=0.9, 36 | enforce_eager=True, 37 | disable_log_requests=True, 38 | ) 39 | 40 | engine = AsyncLLMEngine.from_engine_args(engine_args) 41 | return engine, tokenizer 42 | 43 | 44 | enable_lora = False 45 | if LORA_PATH: 46 | enable_lora = True 47 | 48 | engine, tokenizer = load_model_and_tokenizer(MODEL_PATH, enable_lora) 49 | 50 | 51 | async def vllm_gen( 52 | lora_path: str, 53 | enable_lora: bool, 54 | messages: List[Dict[str, str]], 55 | top_p: float, 56 | temperature: float, 57 | max_dec_len: int, 58 | ): 59 | inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) 60 | params_dict = { 61 | "n": 1, 62 | "best_of": 1, 63 | "presence_penalty": 1.0, 64 | "frequency_penalty": 0.0, 65 | "temperature": temperature, 66 | "top_p": top_p, 67 | "max_tokens": max_dec_len, 68 | "skip_special_tokens": True, 69 | } 70 | sampling_params = SamplingParams(**params_dict) 71 | if enable_lora: 72 | async for output in engine.generate( 73 | prompt=inputs, 74 | sampling_params=sampling_params, 75 | request_id=f"{time.time()}", 76 | lora_request=LoRARequest("glm-4-lora", 1, lora_path=lora_path), 77 | ): 78 | yield output.outputs[0].text 79 | else: 80 | async for output in engine.generate( 81 | prompt=inputs, sampling_params=sampling_params, request_id=f"{time.time()}" 82 | ): 83 | yield output.outputs[0].text 84 | 85 | 86 | async def chat(): 87 | history = [] 88 | max_length = 8192 89 | top_p = 0.8 90 | temperature = 0.6 91 | 92 | print("Welcome to the GLM-4-9B CLI chat. Type your messages below.") 93 | while True: 94 | user_input = input("\nYou: ") 95 | if user_input.lower() in ["exit", "quit"]: 96 | break 97 | history.append([user_input, ""]) 98 | 99 | messages = [] 100 | for idx, (user_msg, model_msg) in enumerate(history): 101 | if idx == len(history) - 1 and not model_msg: 102 | messages.append({"role": "user", "content": user_msg}) 103 | break 104 | if user_msg: 105 | messages.append({"role": "user", "content": user_msg}) 106 | if model_msg: 107 | messages.append({"role": "assistant", "content": model_msg}) 108 | 109 | print("\nGLM-4: ", end="") 110 | current_length = 0 111 | output = "" 112 | async for output in vllm_gen(LORA_PATH, enable_lora, messages, top_p, temperature, max_length): 113 | print(output[current_length:], end="", flush=True) 114 | current_length = len(output) 115 | history[-1][1] = output 116 | 117 | 118 | if __name__ == "__main__": 119 | asyncio.run(chat()) 120 | -------------------------------------------------------------------------------- /inference/vllm_cli_vision_demo.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script creates a CLI demo with vllm backand for the glm-4v-9b model, 3 | allowing users to interact with the model through a command-line interface. 4 | 5 | Usage: 6 | - Run the script to start the CLI demo. 7 | - Interact with the model by typing questions and receiving responses. 8 | 9 | Note: The script includes a modification to handle markdown to plain text conversion, 10 | ensuring that the CLI interface displays formatted text correctly. 11 | """ 12 | 13 | import asyncio 14 | import time 15 | from typing import Dict, List 16 | 17 | from PIL import Image 18 | from vllm import AsyncEngineArgs, AsyncLLMEngine, SamplingParams 19 | 20 | 21 | MODEL_PATH = "THUDM/glm-4v-9b" 22 | 23 | 24 | def load_model_and_tokenizer(model_dir: str): 25 | engine_args = AsyncEngineArgs( 26 | model=model_dir, 27 | tokenizer=model_dir, 28 | tensor_parallel_size=1, 29 | dtype="bfloat16", 30 | gpu_memory_utilization=0.9, 31 | enforce_eager=True, 32 | disable_log_requests=True, 33 | ) 34 | engine = AsyncLLMEngine.from_engine_args(engine_args) 35 | return engine 36 | 37 | 38 | engine = load_model_and_tokenizer(MODEL_PATH) 39 | 40 | 41 | async def vllm_gen(messages: List[Dict[str, str]], top_p: float, temperature: float, max_dec_len: int): 42 | inputs = messages[-1] 43 | params_dict = { 44 | "n": 1, 45 | "best_of": 1, 46 | "presence_penalty": 1.0, 47 | "frequency_penalty": 0.0, 48 | "temperature": temperature, 49 | "top_p": top_p, 50 | "max_tokens": max_dec_len, 51 | "skip_special_tokens": True, 52 | } 53 | sampling_params = SamplingParams(**params_dict) 54 | 55 | async for output in engine.generate(prompt=inputs, sampling_params=sampling_params, request_id=f"{time.time()}"): 56 | yield output.outputs[0].text 57 | 58 | 59 | async def chat(): 60 | history = [] 61 | max_length = 8192 62 | top_p = 0.8 63 | temperature = 0.6 64 | image = None 65 | 66 | print("Welcome to the GLM-4v-9B CLI chat. Type your messages below.") 67 | image_path = input("Image Path:") 68 | try: 69 | image = Image.open(image_path).convert("RGB") 70 | except: 71 | print("Invalid image path. Continuing with text conversation.") 72 | while True: 73 | user_input = input("\nYou: ") 74 | if user_input.lower() in ["exit", "quit"]: 75 | break 76 | history.append([user_input, ""]) 77 | 78 | messages = [] 79 | for idx, (user_msg, model_msg) in enumerate(history): 80 | if idx == len(history) - 1 and not model_msg: 81 | messages.append( 82 | { 83 | "prompt": user_msg, 84 | "multi_modal_data": {"image": image}, 85 | } 86 | ) 87 | break 88 | if user_msg: 89 | messages.append({"role": "user", "prompt": user_msg}) 90 | if model_msg: 91 | messages.append({"role": "assistant", "prompt": model_msg}) 92 | 93 | print("\nGLM-4v: ", end="") 94 | current_length = 0 95 | output = "" 96 | async for output in vllm_gen(messages, top_p, temperature, max_length): 97 | print(output[current_length:], end="", flush=True) 98 | current_length = len(output) 99 | history[-1][1] = output 100 | 101 | 102 | if __name__ == "__main__": 103 | asyncio.run(chat()) 104 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.ruff] 2 | line-length = 119 3 | 4 | [tool.ruff.lint] 5 | # Never enforce `E501` (line length violations). 6 | ignore = ["C901", "E501", "E741", "F402", "F823"] 7 | select = ["C", "E", "F", "I", "W"] 8 | 9 | # Ignore import violations in all `__init__.py` files. 10 | [tool.ruff.lint.per-file-ignores] 11 | "__init__.py" = ["E402", "F401", "F403", "F811"] 12 | 13 | [tool.ruff.lint.isort] 14 | lines-after-imports = 2 15 | 16 | [tool.ruff.format] 17 | # Like Black, use double quotes for strings. 18 | quote-style = "double" 19 | 20 | # Like Black, indent with spaces, rather than tabs. 21 | indent-style = "space" 22 | 23 | # Like Black, respect magic trailing commas. 24 | skip-magic-trailing-comma = false 25 | 26 | # Like Black, automatically detect the appropriate line ending. 27 | line-ending = "auto" 28 | -------------------------------------------------------------------------------- /resources/Bench-32B.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zRzRzRzRzRzRzR/GLM-4/2e5efb1365eda83bead08afffb7a884d56724f7a/resources/Bench-32B.png -------------------------------------------------------------------------------- /resources/Bench-Z1-32B.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zRzRzRzRzRzRzR/GLM-4/2e5efb1365eda83bead08afffb7a884d56724f7a/resources/Bench-Z1-32B.png -------------------------------------------------------------------------------- /resources/Bench-Z1-9B.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zRzRzRzRzRzRzR/GLM-4/2e5efb1365eda83bead08afffb7a884d56724f7a/resources/Bench-Z1-9B.png -------------------------------------------------------------------------------- /resources/WECHAT.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | 4 |

扫码加入「GLM-4交流群」

5 |

Scan the QR code to follow to join the "ChatGLM Discussion Group"

6 |
7 | -------------------------------------------------------------------------------- /resources/eval_needle.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zRzRzRzRzRzRzR/GLM-4/2e5efb1365eda83bead08afffb7a884d56724f7a/resources/eval_needle.jpeg -------------------------------------------------------------------------------- /resources/longbench.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zRzRzRzRzRzRzR/GLM-4/2e5efb1365eda83bead08afffb7a884d56724f7a/resources/longbench.png -------------------------------------------------------------------------------- /resources/wechat.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zRzRzRzRzRzRzR/GLM-4/2e5efb1365eda83bead08afffb7a884d56724f7a/resources/wechat.jpg --------------------------------------------------------------------------------