├── .gitignore ├── .vscode └── settings.json ├── LICENSE ├── MODEL_LICENSE ├── README.md ├── cli-demo.py ├── config.json ├── configuration_chatglm.py ├── gitattributes ├── ice_text.model ├── imgs ├── cli-demo.jpg ├── lowcode-page3.png ├── lowcode.jpg └── web-demo.jpg ├── modeling_chatglm.py ├── prompts ├── prompt-compress.md └── prompt.md ├── quantization.py ├── quantization_kernels.c ├── quantization_kernels_parallel.c ├── requirements.txt ├── tokenization_chatglm.py ├── tokenizer.model ├── tokenizer_config.json └── web-demo.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | pytorch_model.bin 7 | node_modules 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | cover/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | .pybuilder/ 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | # For a library or package, you might want to ignore these files since the code is 90 | # intended to run in multiple environments; otherwise, check them in: 91 | # .python-version 92 | 93 | # pipenv 94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 97 | # install all needed dependencies. 98 | #Pipfile.lock 99 | 100 | # poetry 101 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 102 | # This is especially recommended for binary packages to ensure reproducibility, and is more 103 | # commonly ignored for libraries. 104 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 105 | #poetry.lock 106 | 107 | # pdm 108 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 109 | #pdm.lock 110 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 111 | # in version control. 112 | # https://pdm.fming.dev/#use-with-ide 113 | .pdm.toml 114 | 115 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 116 | __pypackages__/ 117 | 118 | # Celery stuff 119 | celerybeat-schedule 120 | celerybeat.pid 121 | 122 | # SageMath parsed files 123 | *.sage.py 124 | 125 | # Environments 126 | .env 127 | .venv 128 | env/ 129 | venv/ 130 | ENV/ 131 | env.bak/ 132 | venv.bak/ 133 | 134 | # Spyder project settings 135 | .spyderproject 136 | .spyproject 137 | 138 | # Rope project settings 139 | .ropeproject 140 | 141 | # mkdocs documentation 142 | /site 143 | 144 | # mypy 145 | .mypy_cache/ 146 | .dmypy.json 147 | dmypy.json 148 | 149 | # Pyre type checker 150 | .pyre/ 151 | 152 | # pytype static type analyzer 153 | .pytype/ 154 | 155 | # Cython debug symbols 156 | cython_debug/ 157 | 158 | # PyCharm 159 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 160 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 161 | # and can be added to the global gitignore or merged into this file. For a more nuclear 162 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 163 | #.idea/ 164 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "python.analysis.autoImportCompletions": false, 3 | "python.analysis.typeCheckingMode": "off" 4 | } -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright Zhengxiao Du 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /MODEL_LICENSE: -------------------------------------------------------------------------------- 1 | The ChatGLM2-6B License 2 | 3 | 1. Definitions 4 | 5 | “Licensor” means the ChatGLM2-6B Model Team that distributes its Software. 6 | 7 | “Software” means the ChatGLM2-6B model parameters made available under this license. 8 | 9 | 2. License Grant 10 | 11 | Subject to the terms and conditions of this License, the Licensor hereby grants to you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty-free copyright license to use the Software solely for your non-commercial research purposes. 12 | 13 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 14 | 15 | 3. Restriction 16 | 17 | You will not use, copy, modify, merge, publish, distribute, reproduce, or create derivative works of the Software, in whole or in part, for any commercial, military, or illegal purposes. 18 | 19 | You will not use the Software for any act that may undermine China's national security and national unity, harm the public interest of society, or infringe upon the rights and interests of human beings. 20 | 21 | 4. Disclaimer 22 | 23 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 24 | 25 | 5. Limitation of Liability 26 | 27 | EXCEPT TO THE EXTENT PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER BASED IN TORT, NEGLIGENCE, CONTRACT, LIABILITY, OR OTHERWISE WILL ANY LICENSOR BE LIABLE TO YOU FOR ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES, OR ANY OTHER COMMERCIAL LOSSES, EVEN IF THE LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. 28 | 29 | 6. Dispute Resolution 30 | 31 | This license shall be governed and construed in accordance with the laws of People’s Republic of China. Any dispute arising from or in connection with this License shall be submitted to Haidian District People's Court in Beijing. 32 | 33 | Note that the license is subject to update to a more comprehensive version. For any questions related to the license and copyright, please contact us at glm-130b@googlegroups.com. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # lowcode-llm-demo 2 | 3 | 低代码与大语言模型的结合实践 demo,本分支是使用 chatglm-6b 开源模型的示例代码,如果想查看 chatgpt 示例代码,请查看 [openai](https://github.com/woai3c/lowcode-llm-demo/tree/openai) 分支。 4 | 5 | ## 低代码相关文档 6 | * [可视化拖拽组件库一些技术要点原理分析](https://github.com/woai3c/Front-end-articles/issues/19) 7 | * [可视化拖拽组件库一些技术要点原理分析(二)](https://github.com/woai3c/Front-end-articles/issues/20) 8 | * [可视化拖拽组件库一些技术要点原理分析(三)](https://github.com/woai3c/Front-end-articles/issues/21) 9 | * [可视化拖拽组件库一些技术要点原理分析(四)](https://github.com/woai3c/Front-end-articles/issues/33) 10 | * [低代码与大语言模型的探索实践](https://github.com/woai3c/Front-end-articles/issues/45) 11 | 12 | ## 安装 13 | 14 | 请先查看这篇文章安装相关依赖 [手把手教你本地部署清华大学KEG的ChatGLM-6B模型——Windows+6GB显卡版本和CPU版本的本地部署](https://zhuanlan.zhihu.com/p/620455056)。本仓库的代码使用的是 Windows + CPU 版本。 15 | 16 | 模型文件太大,无法上传到 GitHub,需要自行下载。打开这个地址 找到模型文件 `pytorch_model.bin`,文件右边有一个向下箭头,点击下载。然后放到项目根目录下。 17 | 18 | ### TDM-GCC 编译错误 19 | 按照文章中的要求安装了 TDM-GCC 后发现编译 `quantization_kernels_parallel.c` 文件错误,卸载 TDM-GCC 后换了 [MinGW-w64](https://www.mingw-w64.org/downloads/) 就好了。 20 | 21 | ## 使用 22 | **使用前请先把仓库代码中的 `D:\\res\\lowcode-llm-demo` 替换为你项目当前的路径。** 23 | 24 | 通过命令行使用: 25 | 26 | ```sh 27 | python ./cli-demo.py 28 | ``` 29 | 30 | 通过浏览器使用: 31 | 32 | ```sh 33 | python ./web-demo.py 34 | ``` 35 | 36 | 浏览器网页不能正常显示响应内容,有 BUG,但是通过接口是可以看到返回内容的,没时间修复,先这样了。 37 | 38 | **注意**:如果运行后没有任何反应,也没有报错,程序直接结束。说明系统资源不足,请关掉无关的程序后,再重新运行。 39 | 40 | ### 低代码示例 41 | 42 | `prompts` 目录里有两个 markdown 文件,是关于生成低代码页面的 prompt。可以直接复制里面的文本跟模型交互。它会返回一个 JSON 字符串,然后把这个 JSON 放到[低代码平台](https://woai3c.github.io/visual-drag-demo/)里导入使用,可以直接生成页面。 43 | 44 | ![cli 截图](imgs/lowcode-page3.png) 45 | 46 | 不过 chatglm-6b 模型的生成结果不是很理想,生成的 JSON 不一定能直接使用,如果有 chatgpt 的话,最好使用 chatgpt 来生成。 47 | 48 | ## DEMO 截图 49 | 50 | ![cli 截图](imgs/lowcode.jpg) 51 | 52 | ![cli 截图](imgs/cli-demo.jpg) 53 | 54 | ![web 截图](imgs/web-demo.jpg) 55 | -------------------------------------------------------------------------------- /cli-demo.py: -------------------------------------------------------------------------------- 1 | import os 2 | import platform 3 | import signal 4 | from transformers import AutoTokenizer, AutoModel 5 | 6 | # 这里要替换为你的 ChatGLM-6B 模型路径 7 | tokenizer = AutoTokenizer.from_pretrained("D:\\res\\lowcode-llm-demo", trust_remote_code=True, revision="") 8 | model = AutoModel.from_pretrained("D:\\res\\lowcode-llm-demo",trust_remote_code=True, revision="").float() 9 | model = model.eval() 10 | 11 | os_name = platform.system() 12 | clear_command = 'cls' if os_name == 'Windows' else 'clear' 13 | stop_stream = False 14 | 15 | 16 | def build_prompt(history): 17 | prompt = "欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序" 18 | for query, response in history: 19 | prompt += f"\n\n用户:{query}" 20 | prompt += f"\n\nChatGLM-6B:{response}" 21 | return prompt 22 | 23 | 24 | def signal_handler(signal, frame): 25 | global stop_stream 26 | stop_stream = True 27 | 28 | 29 | def main(): 30 | history = [] 31 | global stop_stream 32 | print("欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序") 33 | while True: 34 | query = input("\n用户:") 35 | if query.strip() == "stop": 36 | break 37 | if query.strip() == "clear": 38 | history = [] 39 | os.system(clear_command) 40 | print("欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序") 41 | continue 42 | count = 0 43 | for response, history in model.stream_chat(tokenizer, query, history=history): 44 | if stop_stream: 45 | stop_stream = False 46 | break 47 | else: 48 | count += 1 49 | if count % 8 == 0: 50 | os.system(clear_command) 51 | print(build_prompt(history), flush=True) 52 | signal.signal(signal.SIGINT, signal_handler) 53 | os.system(clear_command) 54 | print(build_prompt(history), flush=True) 55 | 56 | 57 | if __name__ == "__main__": 58 | print("Loading ChatGLM-6B model...") 59 | main() -------------------------------------------------------------------------------- /config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_name_or_path": "D:\\res\\lowcode-llm-demo", 3 | "architectures": [ 4 | "ChatGLMModel" 5 | ], 6 | "auto_map": { 7 | "AutoConfig": "configuration_chatglm.ChatGLMConfig", 8 | "AutoModel": "modeling_chatglm.ChatGLMForConditionalGeneration", 9 | "AutoModelForSeq2SeqLM": "modeling_chatglm.ChatGLMForConditionalGeneration" 10 | }, 11 | "bos_token_id": 130004, 12 | "eos_token_id": 130005, 13 | "mask_token_id": 130000, 14 | "gmask_token_id": 130001, 15 | "pad_token_id": 3, 16 | "hidden_size": 4096, 17 | "inner_hidden_size": 16384, 18 | "layernorm_epsilon": 1e-05, 19 | "max_sequence_length": 2048, 20 | "model_type": "chatglm", 21 | "num_attention_heads": 32, 22 | "num_layers": 28, 23 | "position_encoding_2d": true, 24 | "quantization_bit": 4, 25 | "quantization_embeddings": false, 26 | "torch_dtype": "float16", 27 | "transformers_version": "4.27.1", 28 | "use_cache": true, 29 | "vocab_size": 130528 30 | } -------------------------------------------------------------------------------- /configuration_chatglm.py: -------------------------------------------------------------------------------- 1 | """ ChatGLM model configuration """ 2 | 3 | from transformers.configuration_utils import PretrainedConfig 4 | from transformers.utils import logging 5 | 6 | logger = logging.get_logger(__name__) 7 | 8 | 9 | class ChatGLMConfig(PretrainedConfig): 10 | r""" 11 | This is the configuration class to store the configuration of a [`~ChatGLMModel`]. 12 | It is used to instantiate an ChatGLM model according to the specified arguments, defining the model 13 | architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of 14 | the ChatGLM-6B [D:\\res\\lowcode-llm-demo](https://huggingface.co/D:\\res\\lowcode-llm-demo) architecture. 15 | 16 | Configuration objects inherit from [`PretrainedConfig`] and can be used 17 | to control the model outputs. Read the documentation from [`PretrainedConfig`] 18 | for more information. 19 | 20 | 21 | Args: 22 | vocab_size (`int`, *optional*, defaults to 150528): 23 | Vocabulary size of the ChatGLM-6B model. Defines the number of different tokens that can be represented by the 24 | `inputs_ids` passed when calling [`~ChatGLMModel`] or 25 | [`~TFChatGLMModel`]. 26 | hidden_size (`int`, *optional*, defaults to 4096): 27 | Dimension of the encoder layers and the pooler layer. 28 | num_hidden_layers (`int`, *optional*, defaults to 28): 29 | Number of hidden layers in the Transformer encoder. 30 | num_attention_heads (`int`, *optional*, defaults to 32): 31 | Number of attention heads for each attention layer in the Transformer encoder. 32 | inner_hidden_size (`int`, *optional*, defaults to 16384): 33 | Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. 34 | max_sequence_length (`int`, *optional*, defaults to 512): 35 | The maximum sequence length that this model might ever be used with. 36 | Typically set this to something large just in case (e.g., 512 or 1024 or 2048). 37 | layernorm_epsilon (`float`, *optional*, defaults to 1e-5): 38 | The epsilon used by the layer normalization layers. 39 | use_cache (`bool`, *optional*, defaults to `True`): 40 | Whether the model should return the last key/values attentions (not used by all models). 41 | Example: 42 | 43 | ```python 44 | >>> from configuration_chatglm import ChatGLMConfig 45 | >>> from modeling_chatglm import ChatGLMModel 46 | 47 | >>> # Initializing a ChatGLM-6B D:\\res\\lowcode-llm-demo style configuration 48 | >>> configuration = ChatGLMConfig() 49 | 50 | >>> # Initializing a model from the D:\\res\\lowcode-llm-demo style configuration 51 | >>> model = ChatGLMModel(configuration) 52 | 53 | >>> # Accessing the model configuration 54 | >>> configuration = model.config 55 | ``` 56 | """ 57 | model_type = "chatglm" 58 | 59 | def __init__( 60 | self, 61 | vocab_size=150528, 62 | hidden_size=4096, 63 | num_layers=28, 64 | num_attention_heads=32, 65 | layernorm_epsilon=1e-5, 66 | use_cache=False, 67 | bos_token_id=150004, 68 | eos_token_id=150005, 69 | mask_token_id=150000, 70 | gmask_token_id=150001, 71 | pad_token_id=0, 72 | max_sequence_length=2048, 73 | inner_hidden_size=16384, 74 | position_encoding_2d=True, 75 | quantization_bit=0, 76 | quantization_embeddings=False, 77 | pre_seq_len=None, 78 | prefix_projection=False, 79 | **kwargs 80 | ): 81 | self.num_layers = num_layers 82 | self.vocab_size = vocab_size 83 | self.hidden_size = hidden_size 84 | self.num_attention_heads = num_attention_heads 85 | self.max_sequence_length = max_sequence_length 86 | self.layernorm_epsilon = layernorm_epsilon 87 | self.inner_hidden_size = inner_hidden_size 88 | self.use_cache = use_cache 89 | self.bos_token_id = bos_token_id 90 | self.eos_token_id = eos_token_id 91 | self.pad_token_id = pad_token_id 92 | self.mask_token_id = mask_token_id 93 | self.gmask_token_id = gmask_token_id 94 | self.position_encoding_2d = position_encoding_2d 95 | self.quantization_bit = quantization_bit 96 | self.quantization_embeddings = quantization_embeddings 97 | self.pre_seq_len = pre_seq_len 98 | self.prefix_projection = prefix_projection 99 | 100 | super().__init__( 101 | pad_token_id=pad_token_id, 102 | bos_token_id=bos_token_id, 103 | eos_token_id=eos_token_id, 104 | **kwargs 105 | ) 106 | -------------------------------------------------------------------------------- /gitattributes: -------------------------------------------------------------------------------- 1 | *.7z filter=lfs diff=lfs merge=lfs -text 2 | *.arrow filter=lfs diff=lfs merge=lfs -text 3 | *.bin filter=lfs diff=lfs merge=lfs -text 4 | *.bz2 filter=lfs diff=lfs merge=lfs -text 5 | *.ckpt filter=lfs diff=lfs merge=lfs -text 6 | *.ftz filter=lfs diff=lfs merge=lfs -text 7 | *.gz filter=lfs diff=lfs merge=lfs -text 8 | *.h5 filter=lfs diff=lfs merge=lfs -text 9 | *.joblib filter=lfs diff=lfs merge=lfs -text 10 | *.lfs.* filter=lfs diff=lfs merge=lfs -text 11 | *.mlmodel filter=lfs diff=lfs merge=lfs -text 12 | *.model filter=lfs diff=lfs merge=lfs -text 13 | *.msgpack filter=lfs diff=lfs merge=lfs -text 14 | *.npy filter=lfs diff=lfs merge=lfs -text 15 | *.npz filter=lfs diff=lfs merge=lfs -text 16 | *.onnx filter=lfs diff=lfs merge=lfs -text 17 | *.ot filter=lfs diff=lfs merge=lfs -text 18 | *.parquet filter=lfs diff=lfs merge=lfs -text 19 | *.pb filter=lfs diff=lfs merge=lfs -text 20 | *.pickle filter=lfs diff=lfs merge=lfs -text 21 | *.pkl filter=lfs diff=lfs merge=lfs -text 22 | *.pt filter=lfs diff=lfs merge=lfs -text 23 | *.pth filter=lfs diff=lfs merge=lfs -text 24 | *.rar filter=lfs diff=lfs merge=lfs -text 25 | *.safetensors filter=lfs diff=lfs merge=lfs -text 26 | saved_model/**/* filter=lfs diff=lfs merge=lfs -text 27 | *.tar.* filter=lfs diff=lfs merge=lfs -text 28 | *.tflite filter=lfs diff=lfs merge=lfs -text 29 | *.tgz filter=lfs diff=lfs merge=lfs -text 30 | *.wasm filter=lfs diff=lfs merge=lfs -text 31 | *.xz filter=lfs diff=lfs merge=lfs -text 32 | *.zip filter=lfs diff=lfs merge=lfs -text 33 | *.zst filter=lfs diff=lfs merge=lfs -text 34 | *tfevents* filter=lfs diff=lfs merge=lfs -text 35 | -------------------------------------------------------------------------------- /ice_text.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/woai3c/lowcode-llm-demo/5a1ffa6f9ed9b2d2eef78650c021ffa0b74dcd62/ice_text.model -------------------------------------------------------------------------------- /imgs/cli-demo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/woai3c/lowcode-llm-demo/5a1ffa6f9ed9b2d2eef78650c021ffa0b74dcd62/imgs/cli-demo.jpg -------------------------------------------------------------------------------- /imgs/lowcode-page3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/woai3c/lowcode-llm-demo/5a1ffa6f9ed9b2d2eef78650c021ffa0b74dcd62/imgs/lowcode-page3.png -------------------------------------------------------------------------------- /imgs/lowcode.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/woai3c/lowcode-llm-demo/5a1ffa6f9ed9b2d2eef78650c021ffa0b74dcd62/imgs/lowcode.jpg -------------------------------------------------------------------------------- /imgs/web-demo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/woai3c/lowcode-llm-demo/5a1ffa6f9ed9b2d2eef78650c021ffa0b74dcd62/imgs/web-demo.jpg -------------------------------------------------------------------------------- /modeling_chatglm.py: -------------------------------------------------------------------------------- 1 | """ PyTorch ChatGLM model. """ 2 | 3 | import math 4 | import copy 5 | import os 6 | import warnings 7 | import re 8 | import sys 9 | 10 | import torch 11 | import torch.utils.checkpoint 12 | import torch.nn.functional as F 13 | from torch import nn 14 | from torch.nn import CrossEntropyLoss, LayerNorm 15 | from torch.nn.utils import skip_init 16 | from typing import Optional, Tuple, Union, List, Callable, Dict, Any 17 | 18 | from transformers.utils import ( 19 | add_code_sample_docstrings, 20 | add_start_docstrings, 21 | add_start_docstrings_to_model_forward, 22 | ) 23 | from transformers.modeling_outputs import ( 24 | BaseModelOutputWithPast, 25 | CausalLMOutputWithPast, 26 | BaseModelOutputWithPastAndCrossAttentions, 27 | ) 28 | from transformers.modeling_utils import PreTrainedModel 29 | from transformers.utils import logging 30 | from transformers.generation.logits_process import LogitsProcessor 31 | from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig, ModelOutput 32 | 33 | from .configuration_chatglm import ChatGLMConfig 34 | 35 | 36 | # flags required to enable jit fusion kernels 37 | 38 | if sys.platform != 'darwin': 39 | torch._C._jit_set_profiling_mode(False) 40 | torch._C._jit_set_profiling_executor(False) 41 | torch._C._jit_override_can_fuse_on_cpu(True) 42 | torch._C._jit_override_can_fuse_on_gpu(True) 43 | 44 | logger = logging.get_logger(__name__) 45 | 46 | _CHECKPOINT_FOR_DOC = "D:\\res\\lowcode-llm-demo" 47 | _CONFIG_FOR_DOC = "ChatGLM6BConfig" 48 | 49 | CHATGLM_6B_PRETRAINED_MODEL_ARCHIVE_LIST = [ 50 | "D:\\res\\lowcode-llm-demo", 51 | # See all ChatGLM-6B models at https://huggingface.co/models?filter=chatglm 52 | ] 53 | 54 | 55 | class InvalidScoreLogitsProcessor(LogitsProcessor): 56 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: 57 | if torch.isnan(scores).any() or torch.isinf(scores).any(): 58 | scores.zero_() 59 | scores[..., 5] = 5e4 60 | return scores 61 | 62 | 63 | def load_tf_weights_in_chatglm_6b(model, config, tf_checkpoint_path): 64 | """Load tf checkpoints in a pytorch model.""" 65 | try: 66 | import re 67 | 68 | import numpy as np 69 | import tensorflow as tf 70 | except ImportError: 71 | logger.error( 72 | "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " 73 | "https://www.tensorflow.org/install/ for installation instructions." 74 | ) 75 | raise 76 | tf_path = os.path.abspath(tf_checkpoint_path) 77 | logger.info(f"Converting TensorFlow checkpoint from {tf_path}") 78 | # Load weights from TF model 79 | init_vars = tf.train.list_variables(tf_path) 80 | names = [] 81 | arrays = [] 82 | for name, shape in init_vars: 83 | logger.info(f"Loading TF weight {name} with shape {shape}") 84 | array = tf.train.load_variable(tf_path, name) 85 | names.append(name) 86 | arrays.append(array) 87 | 88 | for name, array in zip(names, arrays): 89 | name = name.split("/") 90 | # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v 91 | # which are not required for using pretrained model 92 | if any( 93 | n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] 94 | for n in name 95 | ): 96 | logger.info(f"Skipping {'/'.join(name)}") 97 | continue 98 | pointer = model 99 | for m_name in name: 100 | if re.fullmatch(r"[A-Za-z]+_\d+", m_name): 101 | scope_names = re.split(r"_(\d+)", m_name) 102 | else: 103 | scope_names = [m_name] 104 | if scope_names[0] == "kernel" or scope_names[0] == "gamma": 105 | pointer = getattr(pointer, "weight") 106 | elif scope_names[0] == "output_bias" or scope_names[0] == "beta": 107 | pointer = getattr(pointer, "bias") 108 | elif scope_names[0] == "output_weights": 109 | pointer = getattr(pointer, "weight") 110 | elif scope_names[0] == "squad": 111 | pointer = getattr(pointer, "classifier") 112 | else: 113 | try: 114 | pointer = getattr(pointer, scope_names[0]) 115 | except AttributeError: 116 | logger.info(f"Skipping {'/'.join(name)}") 117 | continue 118 | if len(scope_names) >= 2: 119 | num = int(scope_names[1]) 120 | pointer = pointer[num] 121 | if m_name[-11:] == "_embeddings": 122 | pointer = getattr(pointer, "weight") 123 | elif m_name == "kernel": 124 | array = np.transpose(array) 125 | try: 126 | assert ( 127 | pointer.shape == array.shape 128 | ), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched" 129 | except AssertionError as e: 130 | e.args += (pointer.shape, array.shape) 131 | raise 132 | logger.info(f"Initialize PyTorch weight {name}") 133 | pointer.data = torch.from_numpy(array) 134 | return model 135 | 136 | 137 | class PrefixEncoder(torch.nn.Module): 138 | """ 139 | The torch.nn model to encode the prefix 140 | Input shape: (batch-size, prefix-length) 141 | Output shape: (batch-size, prefix-length, 2*layers*hidden) 142 | """ 143 | 144 | def __init__(self, config): 145 | super().__init__() 146 | self.prefix_projection = config.prefix_projection 147 | if self.prefix_projection: 148 | # Use a two-layer MLP to encode the prefix 149 | self.embedding = torch.nn.Embedding(config.pre_seq_len, config.hidden_size) 150 | self.trans = torch.nn.Sequential( 151 | torch.nn.Linear(config.hidden_size, config.hidden_size), 152 | torch.nn.Tanh(), 153 | torch.nn.Linear(config.hidden_size, config.num_layers * config.hidden_size * 2) 154 | ) 155 | else: 156 | self.embedding = torch.nn.Embedding(config.pre_seq_len, config.num_layers * config.hidden_size * 2) 157 | 158 | def forward(self, prefix: torch.Tensor): 159 | if self.prefix_projection: 160 | prefix_tokens = self.embedding(prefix) 161 | past_key_values = self.trans(prefix_tokens) 162 | else: 163 | past_key_values = self.embedding(prefix) 164 | return past_key_values 165 | 166 | 167 | @torch.jit.script 168 | def gelu_impl(x): 169 | """OpenAI's gelu implementation.""" 170 | return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x * 171 | (1.0 + 0.044715 * x * x))) 172 | 173 | 174 | def gelu(x): 175 | return gelu_impl(x) 176 | 177 | 178 | class RotaryEmbedding(torch.nn.Module): 179 | def __init__(self, dim, base=10000, precision=torch.half, learnable=False): 180 | super().__init__() 181 | inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim)) 182 | inv_freq = inv_freq.half() 183 | self.learnable = learnable 184 | if learnable: 185 | self.inv_freq = torch.nn.Parameter(inv_freq) 186 | self.max_seq_len_cached = None 187 | else: 188 | self.register_buffer('inv_freq', inv_freq) 189 | self.max_seq_len_cached = None 190 | self.cos_cached = None 191 | self.sin_cached = None 192 | self.precision = precision 193 | 194 | def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, 195 | error_msgs): 196 | pass 197 | 198 | def forward(self, x, seq_dim=1, seq_len=None): 199 | if seq_len is None: 200 | seq_len = x.shape[seq_dim] 201 | if self.max_seq_len_cached is None or (seq_len > self.max_seq_len_cached): 202 | self.max_seq_len_cached = None if self.learnable else seq_len 203 | t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype) 204 | freqs = torch.einsum('i,j->ij', t, self.inv_freq) 205 | # Different from paper, but it uses a different permutation in order to obtain the same calculation 206 | emb = torch.cat((freqs, freqs), dim=-1).to(x.device) 207 | if self.precision == torch.bfloat16: 208 | emb = emb.float() 209 | 210 | # [sx, 1 (b * np), hn] 211 | cos_cached = emb.cos()[:, None, :] 212 | sin_cached = emb.sin()[:, None, :] 213 | if self.precision == torch.bfloat16: 214 | cos_cached = cos_cached.bfloat16() 215 | sin_cached = sin_cached.bfloat16() 216 | if self.learnable: 217 | return cos_cached, sin_cached 218 | self.cos_cached, self.sin_cached = cos_cached, sin_cached 219 | return self.cos_cached[:seq_len, ...], self.sin_cached[:seq_len, ...] 220 | 221 | def _apply(self, fn): 222 | if self.cos_cached is not None: 223 | self.cos_cached = fn(self.cos_cached) 224 | if self.sin_cached is not None: 225 | self.sin_cached = fn(self.sin_cached) 226 | return super()._apply(fn) 227 | 228 | def rotate_half(x): 229 | x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:] 230 | return torch.cat((-x2, x1), dim=x1.ndim - 1) # dim=-1 triggers a bug in earlier torch versions 231 | 232 | 233 | @torch.jit.script 234 | def apply_rotary_pos_emb_index(q, k, cos, sin, position_id): 235 | # position_id: [sq, b], q, k: [sq, b, np, hn], cos: [sq, 1, hn] -> [sq, b, 1, hn] 236 | cos, sin = F.embedding(position_id, cos.squeeze(1)).unsqueeze(2), \ 237 | F.embedding(position_id, sin.squeeze(1)).unsqueeze(2) 238 | q, k = (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin) 239 | return q, k 240 | 241 | 242 | def attention_fn( 243 | self, 244 | query_layer, 245 | key_layer, 246 | value_layer, 247 | attention_mask, 248 | hidden_size_per_partition, 249 | layer_id, 250 | layer_past=None, 251 | scaling_attention_score=True, 252 | use_cache=False, 253 | ): 254 | if layer_past is not None: 255 | past_key, past_value = layer_past[0], layer_past[1] 256 | key_layer = torch.cat((past_key, key_layer), dim=0) 257 | value_layer = torch.cat((past_value, value_layer), dim=0) 258 | 259 | # seqlen, batch, num_attention_heads, hidden_size_per_attention_head 260 | seq_len, b, nh, hidden_size = key_layer.shape 261 | 262 | if use_cache: 263 | present = (key_layer, value_layer) 264 | else: 265 | present = None 266 | 267 | query_key_layer_scaling_coeff = float(layer_id + 1) 268 | if scaling_attention_score: 269 | query_layer = query_layer / (math.sqrt(hidden_size) * query_key_layer_scaling_coeff) 270 | 271 | # =================================== 272 | # Raw attention scores. [b, np, s, s] 273 | # =================================== 274 | 275 | # [b, np, sq, sk] 276 | output_size = (query_layer.size(1), query_layer.size(2), query_layer.size(0), key_layer.size(0)) 277 | 278 | # [sq, b, np, hn] -> [sq, b * np, hn] 279 | query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1) 280 | # [sk, b, np, hn] -> [sk, b * np, hn] 281 | key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1) 282 | 283 | matmul_result = torch.zeros( 284 | 1, 1, 1, 285 | dtype=query_layer.dtype, 286 | device=query_layer.device, 287 | ) 288 | 289 | matmul_result = torch.baddbmm( 290 | matmul_result, 291 | query_layer.transpose(0, 1), # [b * np, sq, hn] 292 | key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] 293 | beta=0.0, 294 | alpha=1.0, 295 | ) 296 | 297 | # change view to [b, np, sq, sk] 298 | attention_scores = matmul_result.view(*output_size) 299 | 300 | if self.scale_mask_softmax: 301 | self.scale_mask_softmax.scale = query_key_layer_scaling_coeff 302 | attention_probs = self.scale_mask_softmax(attention_scores, attention_mask.contiguous()) 303 | else: 304 | if not (attention_mask == 0).all(): 305 | # if auto-regressive, skip 306 | attention_scores.masked_fill_(attention_mask, -10000.0) 307 | dtype = attention_scores.dtype 308 | attention_scores = attention_scores.float() 309 | attention_scores = attention_scores * query_key_layer_scaling_coeff 310 | 311 | attention_probs = F.softmax(attention_scores, dim=-1) 312 | 313 | attention_probs = attention_probs.type(dtype) 314 | 315 | # ========================= 316 | # Context layer. [sq, b, hp] 317 | # ========================= 318 | 319 | # value_layer -> context layer. 320 | # [sk, b, np, hn] --> [b, np, sq, hn] 321 | 322 | # context layer shape: [b, np, sq, hn] 323 | output_size = (value_layer.size(1), value_layer.size(2), query_layer.size(0), value_layer.size(3)) 324 | 325 | # change view [sk, b * np, hn] 326 | value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1) 327 | 328 | # change view [b * np, sq, sk] 329 | attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1) 330 | 331 | # matmul: [b * np, sq, hn] 332 | context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1)) 333 | 334 | # change view [b, np, sq, hn] 335 | context_layer = context_layer.view(*output_size) 336 | 337 | # [b, np, sq, hn] --> [sq, b, np, hn] 338 | context_layer = context_layer.permute(2, 0, 1, 3).contiguous() 339 | 340 | # [sq, b, np, hn] --> [sq, b, hp] 341 | new_context_layer_shape = context_layer.size()[:-2] + (hidden_size_per_partition,) 342 | context_layer = context_layer.view(*new_context_layer_shape) 343 | 344 | outputs = (context_layer, present, attention_probs) 345 | 346 | return outputs 347 | 348 | 349 | def default_init(cls, *args, **kwargs): 350 | return cls(*args, **kwargs) 351 | 352 | 353 | class SelfAttention(torch.nn.Module): 354 | def __init__(self, hidden_size, num_attention_heads, 355 | layer_id, hidden_size_per_attention_head=None, bias=True, 356 | params_dtype=torch.float, position_encoding_2d=True, empty_init=True): 357 | if empty_init: 358 | init_method = skip_init 359 | else: 360 | init_method = default_init 361 | super(SelfAttention, self).__init__() 362 | 363 | self.layer_id = layer_id 364 | self.hidden_size = hidden_size 365 | self.hidden_size_per_partition = hidden_size 366 | self.num_attention_heads = num_attention_heads 367 | self.num_attention_heads_per_partition = num_attention_heads 368 | self.position_encoding_2d = position_encoding_2d 369 | self.rotary_emb = RotaryEmbedding( 370 | self.hidden_size // (self.num_attention_heads * 2) 371 | if position_encoding_2d 372 | else self.hidden_size // self.num_attention_heads, 373 | base=10000, 374 | precision=torch.half, 375 | learnable=False, 376 | ) 377 | 378 | self.scale_mask_softmax = None 379 | 380 | if hidden_size_per_attention_head is None: 381 | self.hidden_size_per_attention_head = hidden_size // num_attention_heads 382 | else: 383 | self.hidden_size_per_attention_head = hidden_size_per_attention_head 384 | 385 | self.inner_hidden_size = num_attention_heads * self.hidden_size_per_attention_head 386 | 387 | # Strided linear layer. 388 | self.query_key_value = init_method( 389 | torch.nn.Linear, 390 | hidden_size, 391 | 3 * self.inner_hidden_size, 392 | bias=bias, 393 | dtype=params_dtype, 394 | ) 395 | 396 | self.dense = init_method( 397 | torch.nn.Linear, 398 | self.inner_hidden_size, 399 | hidden_size, 400 | bias=bias, 401 | dtype=params_dtype, 402 | ) 403 | 404 | @staticmethod 405 | def attention_mask_func(attention_scores, attention_mask): 406 | attention_scores.masked_fill_(attention_mask, -10000.0) 407 | return attention_scores 408 | 409 | def split_tensor_along_last_dim(self, tensor, num_partitions, 410 | contiguous_split_chunks=False): 411 | """Split a tensor along its last dimension. 412 | Arguments: 413 | tensor: input tensor. 414 | num_partitions: number of partitions to split the tensor 415 | contiguous_split_chunks: If True, make each chunk contiguous 416 | in memory. 417 | """ 418 | # Get the size and dimension. 419 | last_dim = tensor.dim() - 1 420 | last_dim_size = tensor.size()[last_dim] // num_partitions 421 | # Split. 422 | tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) 423 | # Note: torch.split does not create contiguous tensors by default. 424 | if contiguous_split_chunks: 425 | return tuple(chunk.contiguous() for chunk in tensor_list) 426 | 427 | return tensor_list 428 | 429 | def forward( 430 | self, 431 | hidden_states: torch.Tensor, 432 | position_ids, 433 | attention_mask: torch.Tensor, 434 | layer_id, 435 | layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, 436 | use_cache: bool = False, 437 | output_attentions: bool = False, 438 | ): 439 | """ 440 | hidden_states: [seq_len, batch, hidden_size] 441 | attention_mask: [(1, 1), seq_len, seq_len] 442 | """ 443 | 444 | # [seq_len, batch, 3 * hidden_size] 445 | mixed_raw_layer = self.query_key_value(hidden_states) 446 | 447 | # [seq_len, batch, 3 * hidden_size] --> [seq_len, batch, num_attention_heads, 3 * hidden_size_per_attention_head] 448 | new_tensor_shape = mixed_raw_layer.size()[:-1] + ( 449 | self.num_attention_heads_per_partition, 450 | 3 * self.hidden_size_per_attention_head, 451 | ) 452 | mixed_raw_layer = mixed_raw_layer.view(*new_tensor_shape) 453 | 454 | # [seq_len, batch, num_attention_heads, hidden_size_per_attention_head] 455 | (query_layer, key_layer, value_layer) = self.split_tensor_along_last_dim(mixed_raw_layer, 3) 456 | 457 | if self.position_encoding_2d: 458 | q1, q2 = query_layer.chunk(2, dim=(query_layer.ndim - 1)) 459 | k1, k2 = key_layer.chunk(2, dim=(key_layer.ndim - 1)) 460 | cos, sin = self.rotary_emb(q1, seq_len=position_ids.max() + 1) 461 | position_ids, block_position_ids = position_ids[:, 0, :].transpose(0, 1).contiguous(), \ 462 | position_ids[:, 1, :].transpose(0, 1).contiguous() 463 | q1, k1 = apply_rotary_pos_emb_index(q1, k1, cos, sin, position_ids) 464 | q2, k2 = apply_rotary_pos_emb_index(q2, k2, cos, sin, block_position_ids) 465 | query_layer = torch.concat([q1, q2], dim=(q1.ndim - 1)) 466 | key_layer = torch.concat([k1, k2], dim=(k1.ndim - 1)) 467 | else: 468 | position_ids = position_ids.transpose(0, 1) 469 | cos, sin = self.rotary_emb(value_layer, seq_len=position_ids.max() + 1) 470 | # [seq_len, batch, num_attention_heads, hidden_size_per_attention_head] 471 | query_layer, key_layer = apply_rotary_pos_emb_index(query_layer, key_layer, cos, sin, position_ids) 472 | 473 | # [seq_len, batch, hidden_size] 474 | context_layer, present, attention_probs = attention_fn( 475 | self=self, 476 | query_layer=query_layer, 477 | key_layer=key_layer, 478 | value_layer=value_layer, 479 | attention_mask=attention_mask, 480 | hidden_size_per_partition=self.hidden_size_per_partition, 481 | layer_id=layer_id, 482 | layer_past=layer_past, 483 | use_cache=use_cache 484 | ) 485 | 486 | output = self.dense(context_layer) 487 | 488 | outputs = (output, present) 489 | 490 | if output_attentions: 491 | outputs += (attention_probs,) 492 | 493 | return outputs # output, present, attention_probs 494 | 495 | 496 | class GEGLU(torch.nn.Module): 497 | def __init__(self): 498 | super().__init__() 499 | self.activation_fn = F.gelu 500 | 501 | def forward(self, x): 502 | # dim=-1 breaks in jit for pt<1.10 503 | x1, x2 = x.chunk(2, dim=(x.ndim - 1)) 504 | return x1 * self.activation_fn(x2) 505 | 506 | 507 | class GLU(torch.nn.Module): 508 | def __init__(self, hidden_size, inner_hidden_size=None, 509 | layer_id=None, bias=True, activation_func=gelu, params_dtype=torch.float, empty_init=True): 510 | super(GLU, self).__init__() 511 | if empty_init: 512 | init_method = skip_init 513 | else: 514 | init_method = default_init 515 | self.layer_id = layer_id 516 | self.activation_func = activation_func 517 | 518 | # Project to 4h. 519 | self.hidden_size = hidden_size 520 | if inner_hidden_size is None: 521 | inner_hidden_size = 4 * hidden_size 522 | self.inner_hidden_size = inner_hidden_size 523 | self.dense_h_to_4h = init_method( 524 | torch.nn.Linear, 525 | self.hidden_size, 526 | self.inner_hidden_size, 527 | bias=bias, 528 | dtype=params_dtype, 529 | ) 530 | # Project back to h. 531 | self.dense_4h_to_h = init_method( 532 | torch.nn.Linear, 533 | self.inner_hidden_size, 534 | self.hidden_size, 535 | bias=bias, 536 | dtype=params_dtype, 537 | ) 538 | 539 | def forward(self, hidden_states): 540 | """ 541 | hidden_states: [seq_len, batch, hidden_size] 542 | """ 543 | 544 | # [seq_len, batch, inner_hidden_size] 545 | intermediate_parallel = self.dense_h_to_4h(hidden_states) 546 | 547 | intermediate_parallel = self.activation_func(intermediate_parallel) 548 | 549 | output = self.dense_4h_to_h(intermediate_parallel) 550 | 551 | return output 552 | 553 | 554 | class GLMBlock(torch.nn.Module): 555 | def __init__( 556 | self, 557 | hidden_size, 558 | num_attention_heads, 559 | layernorm_epsilon, 560 | layer_id, 561 | inner_hidden_size=None, 562 | hidden_size_per_attention_head=None, 563 | layernorm=LayerNorm, 564 | use_bias=True, 565 | params_dtype=torch.float, 566 | num_layers=28, 567 | position_encoding_2d=True, 568 | empty_init=True 569 | ): 570 | super(GLMBlock, self).__init__() 571 | # Set output layer initialization if not provided. 572 | 573 | self.layer_id = layer_id 574 | 575 | # Layernorm on the input data. 576 | self.input_layernorm = layernorm(hidden_size, eps=layernorm_epsilon) 577 | 578 | self.position_encoding_2d = position_encoding_2d 579 | 580 | # Self attention. 581 | self.attention = SelfAttention( 582 | hidden_size, 583 | num_attention_heads, 584 | layer_id, 585 | hidden_size_per_attention_head=hidden_size_per_attention_head, 586 | bias=use_bias, 587 | params_dtype=params_dtype, 588 | position_encoding_2d=self.position_encoding_2d, 589 | empty_init=empty_init 590 | ) 591 | 592 | # Layernorm on the input data. 593 | self.post_attention_layernorm = layernorm(hidden_size, eps=layernorm_epsilon) 594 | 595 | self.num_layers = num_layers 596 | 597 | # GLU 598 | self.mlp = GLU( 599 | hidden_size, 600 | inner_hidden_size=inner_hidden_size, 601 | bias=use_bias, 602 | layer_id=layer_id, 603 | params_dtype=params_dtype, 604 | empty_init=empty_init 605 | ) 606 | 607 | def forward( 608 | self, 609 | hidden_states: torch.Tensor, 610 | position_ids, 611 | attention_mask: torch.Tensor, 612 | layer_id, 613 | layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, 614 | use_cache: bool = False, 615 | output_attentions: bool = False, 616 | ): 617 | """ 618 | hidden_states: [seq_len, batch, hidden_size] 619 | attention_mask: [(1, 1), seq_len, seq_len] 620 | """ 621 | 622 | # Layer norm at the begining of the transformer layer. 623 | # [seq_len, batch, hidden_size] 624 | attention_input = self.input_layernorm(hidden_states) 625 | 626 | # Self attention. 627 | attention_outputs = self.attention( 628 | attention_input, 629 | position_ids, 630 | attention_mask=attention_mask, 631 | layer_id=layer_id, 632 | layer_past=layer_past, 633 | use_cache=use_cache, 634 | output_attentions=output_attentions 635 | ) 636 | 637 | attention_output = attention_outputs[0] 638 | 639 | outputs = attention_outputs[1:] 640 | 641 | # Residual connection. 642 | alpha = (2 * self.num_layers) ** 0.5 643 | hidden_states = attention_input * alpha + attention_output 644 | 645 | mlp_input = self.post_attention_layernorm(hidden_states) 646 | 647 | # MLP. 648 | mlp_output = self.mlp(mlp_input) 649 | 650 | # Second residual connection. 651 | output = mlp_input * alpha + mlp_output 652 | 653 | if use_cache: 654 | outputs = (output,) + outputs 655 | else: 656 | outputs = (output,) + outputs[1:] 657 | 658 | return outputs # hidden_states, present, attentions 659 | 660 | 661 | class ChatGLMPreTrainedModel(PreTrainedModel): 662 | """ 663 | An abstract class to handle weights initialization and 664 | a simple interface for downloading and loading pretrained models. 665 | """ 666 | 667 | is_parallelizable = False 668 | supports_gradient_checkpointing = True 669 | config_class = ChatGLMConfig 670 | base_model_prefix = "transformer" 671 | _no_split_modules = ["GLMBlock"] 672 | 673 | def __init__(self, *inputs, **kwargs): 674 | super().__init__(*inputs, **kwargs) 675 | 676 | def _init_weights(self, module: nn.Module): 677 | """Initialize the weights.""" 678 | return 679 | 680 | def get_masks(self, input_ids, device): 681 | batch_size, seq_length = input_ids.shape 682 | context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids] 683 | attention_mask = torch.ones((batch_size, seq_length, seq_length), device=device) 684 | attention_mask.tril_() 685 | for i, context_length in enumerate(context_lengths): 686 | attention_mask[i, :, :context_length] = 1 687 | attention_mask.unsqueeze_(1) 688 | attention_mask = (attention_mask < 0.5).bool() 689 | 690 | return attention_mask 691 | 692 | def get_position_ids(self, input_ids, mask_positions, device, use_gmasks=None): 693 | batch_size, seq_length = input_ids.shape 694 | if use_gmasks is None: 695 | use_gmasks = [False] * batch_size 696 | context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids] 697 | if self.position_encoding_2d: 698 | position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1) 699 | for i, context_length in enumerate(context_lengths): 700 | position_ids[i, context_length:] = mask_positions[i] 701 | block_position_ids = [torch.cat(( 702 | torch.zeros(context_length, dtype=torch.long, device=device), 703 | torch.arange(seq_length - context_length, dtype=torch.long, device=device) + 1 704 | )) for context_length in context_lengths] 705 | block_position_ids = torch.stack(block_position_ids, dim=0) 706 | position_ids = torch.stack((position_ids, block_position_ids), dim=1) 707 | else: 708 | position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1) 709 | for i, context_length in enumerate(context_lengths): 710 | if not use_gmasks[i]: 711 | position_ids[context_length:] = mask_positions[i] 712 | 713 | return position_ids 714 | 715 | def _set_gradient_checkpointing(self, module, value=False): 716 | if isinstance(module, ChatGLMModel): 717 | module.gradient_checkpointing = value 718 | 719 | 720 | CHATGLM_6B_START_DOCSTRING = r""" 721 | This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. 722 | Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general 723 | usage and behavior. 724 | 725 | Parameters: 726 | config ([`~ChatGLM6BConfig`]): Model configuration class with all the parameters of the model. 727 | Initializing with a config file does not load the weights associated with the model, only the configuration. 728 | Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. 729 | """ 730 | 731 | CHATGLM_6B_INPUTS_DOCSTRING = r""" 732 | Args: 733 | input_ids (`torch.LongTensor` of shape `({0})`): 734 | Indices of input sequence tokens in the vocabulary. 735 | 736 | Indices can be obtained using [`ChatGLM6BTokenizer`]. 737 | See [`PreTrainedTokenizer.encode`] and 738 | [`PreTrainedTokenizer.__call__`] for details. 739 | 740 | [What are input IDs?](../glossary#input-ids) 741 | attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*): 742 | Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: 743 | 744 | - 1 for tokens that are **not masked**, 745 | - 0 for tokens that are **masked**. 746 | 747 | [What are attention masks?](../glossary#attention-mask) 748 | token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*): 749 | Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, 1]`: 750 | 751 | - 0 corresponds to a *sentence A* token, 752 | - 1 corresponds to a *sentence B* token. 753 | 754 | [What are token type IDs?](../glossary#token-type-ids) 755 | position_ids (`torch.LongTensor` of shape `({0})`, *optional*): 756 | Indices of positions of each input sequence tokens in the position embeddings. 757 | Selected in the range `[0, config.max_position_embeddings - 1]`. 758 | 759 | [What are position IDs?](../glossary#position-ids) 760 | head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): 761 | Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: 762 | 763 | - 1 indicates the head is **not masked**, 764 | - 0 indicates the head is **masked**. 765 | 766 | inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*): 767 | Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. 768 | This is useful if you want more control over how to convert *input_ids* indices into associated vectors 769 | than the model's internal embedding lookup matrix. 770 | output_attentions (`bool`, *optional*): 771 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned 772 | tensors for more detail. 773 | output_hidden_states (`bool`, *optional*): 774 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for 775 | more detail. 776 | return_dict (`bool`, *optional*): 777 | Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. 778 | """ 779 | 780 | 781 | @add_start_docstrings( 782 | "The bare ChatGLM-6B Model transformer outputting raw hidden-states without any specific head on top.", 783 | CHATGLM_6B_START_DOCSTRING, 784 | ) 785 | class ChatGLMModel(ChatGLMPreTrainedModel): 786 | """ 787 | 788 | The model can behave as an encoder (with only self-attention) as well 789 | as a decoder, in which case a layer of cross-attention is added between 790 | the self-attention layers, following the architecture described in [Attention is 791 | all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, 792 | Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. 793 | 794 | To behave as an decoder the model needs to be initialized with the 795 | `is_decoder` argument of the configuration set to `True`. 796 | To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` 797 | argument and `add_cross_attention` set to `True`; an 798 | `encoder_hidden_states` is then expected as an input to the forward pass. 799 | """ 800 | 801 | def __init__(self, config: ChatGLMConfig, empty_init=True): 802 | super().__init__(config) 803 | if empty_init: 804 | init_method = skip_init 805 | else: 806 | init_method = default_init 807 | # recording parameters 808 | self.max_sequence_length = config.max_sequence_length 809 | self.hidden_size = config.hidden_size 810 | self.params_dtype = torch.half 811 | self.num_attention_heads = config.num_attention_heads 812 | self.vocab_size = config.vocab_size 813 | self.num_layers = config.num_layers 814 | self.layernorm_epsilon = config.layernorm_epsilon 815 | self.inner_hidden_size = config.inner_hidden_size 816 | self.hidden_size_per_attention_head = self.hidden_size // self.num_attention_heads 817 | self.position_encoding_2d = config.position_encoding_2d 818 | self.pre_seq_len = config.pre_seq_len 819 | self.prefix_projection = config.prefix_projection 820 | 821 | self.word_embeddings = init_method( 822 | torch.nn.Embedding, 823 | num_embeddings=self.vocab_size, embedding_dim=self.hidden_size, 824 | dtype=self.params_dtype 825 | ) 826 | self.gradient_checkpointing = False 827 | 828 | def get_layer(layer_id): 829 | return GLMBlock( 830 | self.hidden_size, 831 | self.num_attention_heads, 832 | self.layernorm_epsilon, 833 | layer_id, 834 | inner_hidden_size=self.inner_hidden_size, 835 | hidden_size_per_attention_head=self.hidden_size_per_attention_head, 836 | layernorm=LayerNorm, 837 | use_bias=True, 838 | params_dtype=self.params_dtype, 839 | position_encoding_2d=self.position_encoding_2d, 840 | empty_init=empty_init 841 | ) 842 | 843 | self.layers = torch.nn.ModuleList( 844 | [get_layer(layer_id) for layer_id in range(self.num_layers)] 845 | ) 846 | 847 | # Final layer norm before output. 848 | self.final_layernorm = LayerNorm(self.hidden_size, eps=self.layernorm_epsilon) 849 | 850 | if self.pre_seq_len is not None: 851 | for param in self.parameters(): 852 | param.requires_grad = False 853 | self.prefix_tokens = torch.arange(self.pre_seq_len).long() 854 | self.prefix_encoder = PrefixEncoder(config) 855 | self.dropout = torch.nn.Dropout(0.1) 856 | 857 | # total_params = sum(p.numel() for p in self.parameters()) 858 | # trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad) 859 | # print("Using p-tuning v2: # trainable_params = {} / {}".format(trainable_params, total_params)) 860 | 861 | def get_input_embeddings(self): 862 | return self.word_embeddings 863 | 864 | def set_input_embeddings(self, new_embeddings: torch.Tensor): 865 | self.word_embeddings = new_embeddings 866 | 867 | def get_prompt(self, batch_size, device, dtype=torch.half): 868 | prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device) 869 | past_key_values = self.prefix_encoder(prefix_tokens).type(dtype) 870 | past_key_values = past_key_values.view( 871 | batch_size, 872 | self.pre_seq_len, 873 | self.num_layers * 2, 874 | self.num_attention_heads, 875 | self.hidden_size // self.num_attention_heads 876 | ) 877 | # seq_len, b, nh, hidden_size 878 | past_key_values = self.dropout(past_key_values) 879 | past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2) 880 | # past_key_values = [(v[0], v[1]) for v in past_key_values] 881 | return past_key_values 882 | 883 | @add_start_docstrings_to_model_forward(CHATGLM_6B_INPUTS_DOCSTRING.format("batch_size, sequence_length")) 884 | @add_code_sample_docstrings( 885 | checkpoint=_CHECKPOINT_FOR_DOC, 886 | output_type=BaseModelOutputWithPastAndCrossAttentions, 887 | config_class=_CONFIG_FOR_DOC, 888 | ) 889 | def forward( 890 | self, 891 | input_ids: Optional[torch.LongTensor] = None, 892 | position_ids: Optional[torch.LongTensor] = None, 893 | attention_mask: Optional[torch.Tensor] = None, 894 | past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, 895 | inputs_embeds: Optional[torch.LongTensor] = None, 896 | use_cache: Optional[bool] = None, 897 | output_attentions: Optional[bool] = None, 898 | output_hidden_states: Optional[bool] = None, 899 | return_dict: Optional[bool] = None, 900 | ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPast]: 901 | 902 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 903 | output_hidden_states = ( 904 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 905 | ) 906 | use_cache = use_cache if use_cache is not None else self.config.use_cache 907 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 908 | 909 | if self.gradient_checkpointing and self.training: 910 | if use_cache: 911 | logger.warning_once( 912 | "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." 913 | ) 914 | use_cache = False 915 | 916 | if input_ids is not None and inputs_embeds is not None: 917 | raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") 918 | elif input_ids is not None: 919 | batch_size, seq_length = input_ids.shape[:2] 920 | elif inputs_embeds is not None: 921 | batch_size, seq_length = inputs_embeds.shape[:2] 922 | else: 923 | raise ValueError("You have to specify either input_ids or inputs_embeds") 924 | 925 | if inputs_embeds is None: 926 | inputs_embeds = self.word_embeddings(input_ids) 927 | 928 | if past_key_values is None: 929 | if self.pre_seq_len is not None: 930 | past_key_values = self.get_prompt(batch_size=input_ids.shape[0], device=input_ids.device, 931 | dtype=inputs_embeds.dtype) 932 | else: 933 | past_key_values = tuple([None] * len(self.layers)) 934 | 935 | if attention_mask is None: 936 | attention_mask = self.get_masks( 937 | input_ids, 938 | device=input_ids.device 939 | ) 940 | 941 | 942 | if position_ids is None: 943 | MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id 944 | seqs = input_ids.tolist() 945 | 946 | mask_positions, use_gmasks = [], [] 947 | for seq in seqs: 948 | mask_token = gMASK if gMASK in seq else MASK 949 | use_gmask = mask_token == gMASK 950 | mask_positions.append(seq.index(mask_token)) 951 | use_gmasks.append(use_gmask) 952 | 953 | position_ids = self.get_position_ids( 954 | input_ids, 955 | mask_positions=mask_positions, 956 | device=input_ids.device, 957 | use_gmasks=use_gmasks 958 | ) 959 | 960 | if self.pre_seq_len is not None and attention_mask is not None: 961 | prefix_attention_mask = torch.ones(batch_size, 1, input_ids.size(-1), self.pre_seq_len).to( 962 | attention_mask.device) 963 | prefix_attention_mask = (prefix_attention_mask < 0.5).bool() 964 | attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=3) 965 | 966 | # [seq_len, batch, hidden_size] 967 | hidden_states = inputs_embeds.transpose(0, 1) 968 | 969 | presents = () if use_cache else None 970 | all_self_attentions = () if output_attentions else None 971 | all_hidden_states = () if output_hidden_states else None 972 | 973 | if attention_mask is None: 974 | attention_mask = torch.zeros(1, 1, device=input_ids.device).bool() 975 | else: 976 | attention_mask = attention_mask.to(hidden_states.device) 977 | 978 | for i, layer in enumerate(self.layers): 979 | 980 | if output_hidden_states: 981 | all_hidden_states = all_hidden_states + (hidden_states,) 982 | layer_past = past_key_values[i] 983 | 984 | if self.gradient_checkpointing and self.training: 985 | layer_ret = torch.utils.checkpoint.checkpoint( 986 | layer, 987 | hidden_states, 988 | position_ids, 989 | attention_mask, 990 | torch.tensor(i), 991 | layer_past, 992 | use_cache, 993 | output_attentions 994 | ) 995 | else: 996 | layer_ret = layer( 997 | hidden_states, 998 | position_ids=position_ids, 999 | attention_mask=attention_mask, 1000 | layer_id=torch.tensor(i), 1001 | layer_past=layer_past, 1002 | use_cache=use_cache, 1003 | output_attentions=output_attentions 1004 | ) 1005 | 1006 | hidden_states = layer_ret[0] 1007 | 1008 | if use_cache: 1009 | presents = presents + (layer_ret[1],) 1010 | 1011 | if output_attentions: 1012 | all_self_attentions = all_self_attentions + (layer_ret[2 if use_cache else 1],) 1013 | 1014 | # Final layer norm. 1015 | hidden_states = self.final_layernorm(hidden_states) 1016 | 1017 | if output_hidden_states: 1018 | all_hidden_states = all_hidden_states + (hidden_states,) 1019 | 1020 | if not return_dict: 1021 | return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) 1022 | 1023 | return BaseModelOutputWithPast( 1024 | last_hidden_state=hidden_states, 1025 | past_key_values=presents, 1026 | hidden_states=all_hidden_states, 1027 | attentions=all_self_attentions, 1028 | ) 1029 | 1030 | 1031 | class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): 1032 | def __init__(self, config: ChatGLMConfig, empty_init=True): 1033 | super().__init__(config) 1034 | if empty_init: 1035 | init_method = skip_init 1036 | else: 1037 | init_method = default_init 1038 | 1039 | # self.hidden_size = config.hidden_size 1040 | # self.params_dtype = torch.half 1041 | # self.vocab_size = config.vocab_size 1042 | self.max_sequence_length = config.max_sequence_length 1043 | 1044 | self.position_encoding_2d = config.position_encoding_2d 1045 | 1046 | self.transformer = ChatGLMModel(config, empty_init=empty_init) 1047 | 1048 | self.lm_head = init_method( 1049 | nn.Linear, 1050 | config.hidden_size, 1051 | config.vocab_size, 1052 | bias=False, 1053 | dtype=torch.half 1054 | ) 1055 | 1056 | self.config = config 1057 | 1058 | self.quantized = False 1059 | 1060 | if self.config.quantization_bit: 1061 | self.quantize(self.config.quantization_bit, self.config.quantization_embeddings, use_quantization_cache=True, empty_init=True) 1062 | 1063 | def get_output_embeddings(self): 1064 | return self.lm_head 1065 | 1066 | def set_output_embeddings(self, new_embeddings): 1067 | self.lm_head = new_embeddings 1068 | 1069 | def _update_model_kwargs_for_generation( 1070 | self, 1071 | outputs: ModelOutput, 1072 | model_kwargs: Dict[str, Any], 1073 | is_encoder_decoder: bool = False, 1074 | standardize_cache_format: bool = False, 1075 | ) -> Dict[str, Any]: 1076 | # update past_key_values 1077 | model_kwargs["past_key_values"] = self._extract_past_from_model_output( 1078 | outputs, standardize_cache_format=standardize_cache_format 1079 | ) 1080 | 1081 | # update attention mask 1082 | if "attention_mask" in model_kwargs: 1083 | attention_mask = model_kwargs["attention_mask"] 1084 | if attention_mask is not None and attention_mask.dtype == torch.bool: 1085 | attention_mask = torch.cat( 1086 | [attention_mask, attention_mask.new_ones((*attention_mask.shape[:3], 1))], dim=3) 1087 | new_attention_mask = attention_mask[:, :, -1:].clone() 1088 | new_attention_mask[..., -1] = False 1089 | model_kwargs["attention_mask"] = torch.cat( 1090 | [attention_mask, new_attention_mask], dim=2 1091 | ) 1092 | 1093 | # update position ids 1094 | if "position_ids" in model_kwargs: 1095 | position_ids = model_kwargs["position_ids"] 1096 | new_position_id = position_ids[..., -1:].clone() 1097 | new_position_id[:, 1, :] += 1 1098 | model_kwargs["position_ids"] = torch.cat( 1099 | [position_ids, new_position_id], dim=-1 1100 | ) 1101 | 1102 | return model_kwargs 1103 | 1104 | def prepare_inputs_for_generation( 1105 | self, 1106 | input_ids: torch.LongTensor, 1107 | past: Optional[torch.Tensor] = None, 1108 | past_key_values: Optional[torch.Tensor] = None, 1109 | attention_mask: Optional[torch.Tensor] = None, 1110 | position_ids: Optional[torch.Tensor] = None, 1111 | **kwargs 1112 | ) -> dict: 1113 | batch_size, seq_length = input_ids.shape 1114 | MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id 1115 | seqs = input_ids.tolist() 1116 | mask_positions, use_gmasks = [], [] 1117 | for seq in seqs: 1118 | mask_token = gMASK if gMASK in seq else MASK 1119 | use_gmask = mask_token == gMASK 1120 | mask_positions.append(seq.index(mask_token)) 1121 | use_gmasks.append(use_gmask) 1122 | 1123 | # only last token for input_ids if past is not None 1124 | if past is not None or past_key_values is not None: 1125 | last_token = input_ids[:, -1].unsqueeze(-1) 1126 | if attention_mask is not None and attention_mask.dtype == torch.bool: 1127 | attention_mask = attention_mask[:, :, -1:] 1128 | else: 1129 | attention_mask = None 1130 | if position_ids is not None: 1131 | position_ids = position_ids[..., -1:] 1132 | else: 1133 | context_lengths = [seq.index(self.config.bos_token_id) for seq in seqs] 1134 | if self.position_encoding_2d: 1135 | position_ids = torch.tensor( 1136 | [[mask_position, seq_length - context_length] for mask_position, context_length in 1137 | zip(mask_positions, context_lengths)], dtype=torch.long, device=input_ids.device).unsqueeze(-1) 1138 | else: 1139 | position_ids = torch.tensor([mask_position for mask_position in mask_positions], dtype=torch.long, 1140 | device=input_ids.device).unsqueeze(-1) 1141 | 1142 | if past is None: 1143 | past = past_key_values 1144 | return { 1145 | "input_ids": last_token, 1146 | "past_key_values": past, 1147 | "position_ids": position_ids, 1148 | "attention_mask": attention_mask 1149 | } 1150 | else: 1151 | if attention_mask is not None and attention_mask.dtype != torch.bool: 1152 | logger.warning_once(f"The dtype of attention mask ({attention_mask.dtype}) is not bool") 1153 | attention_mask = None 1154 | if attention_mask is None: 1155 | attention_mask = self.get_masks( 1156 | input_ids, 1157 | device=input_ids.device 1158 | ) 1159 | if position_ids is None: 1160 | position_ids = self.get_position_ids( 1161 | input_ids, 1162 | device=input_ids.device, 1163 | mask_positions=mask_positions, 1164 | use_gmasks=use_gmasks 1165 | ) 1166 | 1167 | return { 1168 | "input_ids": input_ids, 1169 | "past_key_values": past, 1170 | "position_ids": position_ids, 1171 | "attention_mask": attention_mask 1172 | } 1173 | 1174 | def forward( 1175 | self, 1176 | input_ids: Optional[torch.Tensor] = None, 1177 | position_ids: Optional[torch.Tensor] = None, 1178 | attention_mask: Optional[torch.Tensor] = None, 1179 | past_key_values: Optional[Tuple[torch.FloatTensor]] = None, 1180 | inputs_embeds: Optional[torch.Tensor] = None, 1181 | labels: Optional[torch.Tensor] = None, 1182 | use_cache: Optional[bool] = None, 1183 | output_attentions: Optional[bool] = None, 1184 | output_hidden_states: Optional[bool] = None, 1185 | return_dict: Optional[bool] = None, 1186 | ): 1187 | use_cache = use_cache if use_cache is not None else self.config.use_cache 1188 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1189 | 1190 | transformer_outputs = self.transformer( 1191 | input_ids=input_ids, 1192 | position_ids=position_ids, 1193 | attention_mask=attention_mask, 1194 | past_key_values=past_key_values, 1195 | inputs_embeds=inputs_embeds, 1196 | use_cache=use_cache, 1197 | output_attentions=output_attentions, 1198 | output_hidden_states=output_hidden_states, 1199 | return_dict=return_dict, 1200 | ) 1201 | 1202 | hidden_states = transformer_outputs[0] 1203 | 1204 | lm_logits = self.lm_head(hidden_states).permute(1, 0, 2).contiguous() 1205 | 1206 | loss = None 1207 | if labels is not None: 1208 | lm_logits = lm_logits.to(torch.float32) 1209 | 1210 | # Shift so that tokens < n predict n 1211 | shift_logits = lm_logits[..., :-1, :].contiguous() 1212 | shift_labels = labels[..., 1:].contiguous() 1213 | # Flatten the tokens 1214 | loss_fct = CrossEntropyLoss(ignore_index=-100) 1215 | loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) 1216 | 1217 | lm_logits = lm_logits.to(hidden_states.dtype) 1218 | loss = loss.to(hidden_states.dtype) 1219 | 1220 | if not return_dict: 1221 | output = (lm_logits,) + transformer_outputs[1:] 1222 | return ((loss,) + output) if loss is not None else output 1223 | 1224 | return CausalLMOutputWithPast( 1225 | loss=loss, 1226 | logits=lm_logits, 1227 | past_key_values=transformer_outputs.past_key_values, 1228 | hidden_states=transformer_outputs.hidden_states, 1229 | attentions=transformer_outputs.attentions, 1230 | ) 1231 | 1232 | @staticmethod 1233 | def _reorder_cache( 1234 | past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor 1235 | ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]: 1236 | """ 1237 | This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or 1238 | [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct 1239 | beam_idx at every generation step. 1240 | 1241 | Output shares the same memory storage as `past`. 1242 | """ 1243 | return tuple( 1244 | ( 1245 | layer_past[0].index_select(1, beam_idx.to(layer_past[0].device)), 1246 | layer_past[1].index_select(1, beam_idx.to(layer_past[1].device)), 1247 | ) 1248 | for layer_past in past 1249 | ) 1250 | 1251 | def process_response(self, response): 1252 | response = response.strip() 1253 | response = response.replace("[[训练时间]]", "2023年") 1254 | punkts = [ 1255 | [",", ","], 1256 | ["!", "!"], 1257 | [":", ":"], 1258 | [";", ";"], 1259 | ["\?", "?"], 1260 | ] 1261 | for item in punkts: 1262 | response = re.sub(r"([\u4e00-\u9fff])%s" % item[0], r"\1%s" % item[1], response) 1263 | response = re.sub(r"%s([\u4e00-\u9fff])" % item[0], r"%s\1" % item[1], response) 1264 | return response 1265 | 1266 | @torch.no_grad() 1267 | def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048, num_beams=1, 1268 | do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs): 1269 | if history is None: 1270 | history = [] 1271 | if logits_processor is None: 1272 | logits_processor = LogitsProcessorList() 1273 | logits_processor.append(InvalidScoreLogitsProcessor()) 1274 | gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p, 1275 | "temperature": temperature, "logits_processor": logits_processor, **kwargs} 1276 | if not history: 1277 | prompt = query 1278 | else: 1279 | prompt = "" 1280 | for i, (old_query, response) in enumerate(history): 1281 | prompt += "[Round {}]\n问:{}\n答:{}\n".format(i, old_query, response) 1282 | prompt += "[Round {}]\n问:{}\n答:".format(len(history), query) 1283 | inputs = tokenizer([prompt], return_tensors="pt") 1284 | inputs = inputs.to(self.device) 1285 | outputs = self.generate(**inputs, **gen_kwargs) 1286 | outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):] 1287 | response = tokenizer.decode(outputs) 1288 | response = self.process_response(response) 1289 | history = history + [(query, response)] 1290 | return response, history 1291 | 1292 | @torch.no_grad() 1293 | def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048, 1294 | do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs): 1295 | if history is None: 1296 | history = [] 1297 | if logits_processor is None: 1298 | logits_processor = LogitsProcessorList() 1299 | logits_processor.append(InvalidScoreLogitsProcessor()) 1300 | gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p, 1301 | "temperature": temperature, "logits_processor": logits_processor, **kwargs} 1302 | if not history: 1303 | prompt = query 1304 | else: 1305 | prompt = "" 1306 | for i, (old_query, response) in enumerate(history): 1307 | prompt += "[Round {}]\n问:{}\n答:{}\n".format(i, old_query, response) 1308 | prompt += "[Round {}]\n问:{}\n答:".format(len(history), query) 1309 | inputs = tokenizer([prompt], return_tensors="pt") 1310 | inputs = inputs.to(self.device) 1311 | for outputs in self.stream_generate(**inputs, **gen_kwargs): 1312 | outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):] 1313 | response = tokenizer.decode(outputs) 1314 | response = self.process_response(response) 1315 | new_history = history + [(query, response)] 1316 | yield response, new_history 1317 | 1318 | @torch.no_grad() 1319 | def stream_generate( 1320 | self, 1321 | input_ids, 1322 | generation_config: Optional[GenerationConfig] = None, 1323 | logits_processor: Optional[LogitsProcessorList] = None, 1324 | stopping_criteria: Optional[StoppingCriteriaList] = None, 1325 | prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, 1326 | **kwargs, 1327 | ): 1328 | batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1] 1329 | 1330 | if generation_config is None: 1331 | generation_config = self.generation_config 1332 | generation_config = copy.deepcopy(generation_config) 1333 | model_kwargs = generation_config.update(**kwargs) 1334 | bos_token_id, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id 1335 | 1336 | if isinstance(eos_token_id, int): 1337 | eos_token_id = [eos_token_id] 1338 | 1339 | has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None 1340 | if has_default_max_length and generation_config.max_new_tokens is None: 1341 | warnings.warn( 1342 | f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. " 1343 | "This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we" 1344 | " recommend using `max_new_tokens` to control the maximum length of the generation.", 1345 | UserWarning, 1346 | ) 1347 | elif generation_config.max_new_tokens is not None: 1348 | generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length 1349 | if not has_default_max_length: 1350 | logger.warn( 1351 | f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(=" 1352 | f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. " 1353 | "Please refer to the documentation for more information. " 1354 | "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)", 1355 | UserWarning, 1356 | ) 1357 | 1358 | if input_ids_seq_length >= generation_config.max_length: 1359 | input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" 1360 | logger.warning( 1361 | f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to" 1362 | f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" 1363 | " increasing `max_new_tokens`." 1364 | ) 1365 | 1366 | # 2. Set generation parameters if not already defined 1367 | logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() 1368 | stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() 1369 | 1370 | logits_processor = self._get_logits_processor( 1371 | generation_config=generation_config, 1372 | input_ids_seq_length=input_ids_seq_length, 1373 | encoder_input_ids=input_ids, 1374 | prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, 1375 | logits_processor=logits_processor, 1376 | ) 1377 | 1378 | stopping_criteria = self._get_stopping_criteria( 1379 | generation_config=generation_config, stopping_criteria=stopping_criteria 1380 | ) 1381 | logits_warper = self._get_logits_warper(generation_config) 1382 | 1383 | unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) 1384 | scores = None 1385 | while True: 1386 | model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) 1387 | # forward pass to get next token 1388 | outputs = self( 1389 | **model_inputs, 1390 | return_dict=True, 1391 | output_attentions=False, 1392 | output_hidden_states=False, 1393 | ) 1394 | 1395 | next_token_logits = outputs.logits[:, -1, :] 1396 | 1397 | # pre-process distribution 1398 | next_token_scores = logits_processor(input_ids, next_token_logits) 1399 | next_token_scores = logits_warper(input_ids, next_token_scores) 1400 | 1401 | # sample 1402 | probs = nn.functional.softmax(next_token_scores, dim=-1) 1403 | if generation_config.do_sample: 1404 | next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) 1405 | else: 1406 | next_tokens = torch.argmax(probs, dim=-1) 1407 | 1408 | # update generated ids, model inputs, and length for next step 1409 | input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) 1410 | model_kwargs = self._update_model_kwargs_for_generation( 1411 | outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder 1412 | ) 1413 | unfinished_sequences = unfinished_sequences.mul((sum(next_tokens != i for i in eos_token_id)).long()) 1414 | 1415 | # stop when each sentence is finished, or if we exceed the maximum length 1416 | if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): 1417 | break 1418 | yield input_ids 1419 | 1420 | def quantize(self, bits: int, quantize_embeddings=False, use_quantization_cache=False, empty_init=False, **kwargs): 1421 | if bits == 0: 1422 | return 1423 | 1424 | from .quantization import quantize, QuantizedEmbedding, QuantizedLinear, load_cpu_kernel 1425 | 1426 | if self.quantized: 1427 | if self.device == torch.device("cpu"): 1428 | logger.info("Already quantized, reloading cpu kernel.") 1429 | load_cpu_kernel(**kwargs) 1430 | else: 1431 | logger.info("Already quantized.") 1432 | return self 1433 | 1434 | self.quantized = True 1435 | 1436 | self.config.quantization_bit = bits 1437 | self.config.quantization_embeddings = quantize_embeddings 1438 | 1439 | self.transformer = quantize(self.transformer, bits, use_quantization_cache=use_quantization_cache, empty_init=empty_init, **kwargs) 1440 | 1441 | if self.device == torch.device("cpu"): 1442 | dtype = torch.float32 1443 | else: 1444 | dtype = torch.half 1445 | 1446 | if quantize_embeddings: 1447 | logger.info("Applying quantization to embeddings") 1448 | self.transformer.word_embeddings = QuantizedEmbedding( 1449 | weight_bit_width=bits, 1450 | weight_tensor=self.transformer.word_embeddings.weight.to(self.device), 1451 | num_embeddings=self.transformer.word_embeddings.num_embeddings, 1452 | embedding_dim=self.transformer.word_embeddings.embedding_dim, 1453 | dtype=dtype, 1454 | empty_init=empty_init, 1455 | device=self.transformer.word_embeddings.weight.device, 1456 | ) 1457 | self.lm_head = QuantizedLinear( 1458 | weight_bit_width=bits, 1459 | weight_tensor=self.lm_head.weight.to(self.device), 1460 | bias_tensor=None, 1461 | in_features=self.lm_head.in_features, 1462 | out_features=self.lm_head.out_features, 1463 | bias=False, 1464 | quantized_weight=self.transformer.word_embeddings.weight, 1465 | quantized_weight_scale=self.transformer.word_embeddings.weight_scale, 1466 | dtype=dtype, 1467 | empty_init=empty_init, 1468 | device=self.lm_head.weight.device, 1469 | ) 1470 | 1471 | return self 1472 | -------------------------------------------------------------------------------- /prompts/prompt-compress.md: -------------------------------------------------------------------------------- 1 | 我有一个低代码平台项目,它可以根据符合规范的 JSON 数据生成页面,这个 JSON 数据是一个数组,里面的每一项都是一个 JSON 对象,每个 JSON 对象都对应着一个组件。 2 | 下面用 ``` 包括起来的代码就是所有的组件列表。 3 | 4 | ```json 5 | [{"animations":[],"events":{},"groupStyle":{},"isLock":false,"collapseName":"style","linkage":{"duration":0,"data":[{"id":"","label":"","event":"","style":[{"key":"","value":""}]}]},"component":"VText","label":"文字","propValue":"双击编辑文字","icon":"wenben","request":{"method":"GET","data":[],"url":"","series":false,"time":1000,"paramType":"","requestCount":0},"style":{"rotate":0,"opacity":1,"width":200,"height":28,"fontSize":"","fontWeight":400,"lineHeight":"","letterSpacing":0,"textAlign":"","color":""}},{"animations":[],"events":{},"groupStyle":{},"isLock":false,"collapseName":"style","linkage":{"duration":0,"data":[{"id":"","label":"","event":"","style":[{"key":"","value":""}]}]},"component":"VButton","label":"按钮","propValue":"按钮","icon":"button","style":{"rotate":0,"opacity":1,"width":100,"height":34,"borderWidth":1,"borderColor":"","borderRadius":"","fontSize":"","fontWeight":400,"lineHeight":"","letterSpacing":0,"textAlign":"","color":"","backgroundColor":""}},{"animations":[],"events":{},"groupStyle":{},"isLock":false,"collapseName":"style","linkage":{"duration":0,"data":[{"id":"","label":"","event":"","style":[{"key":"","value":""}]}]},"component":"Picture","label":"图片","icon":"tupian","propValue":{"url":"img/title.07a15c19.jpg","flip":{"horizontal":false,"vertical":false}},"style":{"rotate":0,"opacity":1,"width":300,"height":200,"borderRadius":""}},{"animations":[],"events":{},"groupStyle":{},"isLock":false,"collapseName":"style","linkage":{"duration":0,"data":[{"id":"","label":"","event":"","style":[{"key":"","value":""}]}]},"component":"RectShape","label":"矩形","propValue":" ","icon":"juxing","style":{"rotate":0,"opacity":1,"width":200,"height":200,"fontSize":"","fontWeight":400,"lineHeight":"","letterSpacing":0,"textAlign":"center","color":"","borderColor":"#000","borderWidth":1,"backgroundColor":"","borderStyle":"solid","borderRadius":"","verticalAlign":"middle"}},{"animations":[],"events":{},"groupStyle":{},"isLock":false,"collapseName":"style","linkage":{"duration":0,"data":[{"id":"","label":"","event":"","style":[{"key":"","value":""}]}]},"component":"LineShape","label":"直线","propValue":"","icon":"zhixian","style":{"rotate":0,"opacity":1,"width":200,"height":2,"backgroundColor":"#000"}},{"animations":[],"events":{},"groupStyle":{},"isLock":false,"collapseName":"style","linkage":{"duration":0,"data":[{"id":"","label":"","event":"","style":[{"key":"","value":""}]}]},"component":"CircleShape","label":"圆形","propValue":" ","icon":"24gl-circle","style":{"rotate":0,"opacity":1,"width":200,"height":200,"fontSize":"","fontWeight":400,"lineHeight":"","letterSpacing":0,"textAlign":"center","color":"","borderColor":"#000","borderWidth":1,"backgroundColor":"","borderStyle":"solid","borderRadius":"","verticalAlign":"middle"}},{"animations":[],"events":{},"groupStyle":{},"isLock":false,"collapseName":"style","linkage":{"duration":0,"data":[{"id":"","label":"","event":"","style":[{"key":"","value":""}]}]},"component":"SVGStar","label":"星形","icon":"kongwujiaoxing","propValue":"","style":{"rotate":0,"opacity":1,"width":80,"height":80,"fontSize":"","fontWeight":400,"lineHeight":"","letterSpacing":0,"textAlign":"center","color":"","borderColor":"#000","backgroundColor":"rgba(255, 255, 255, 1)"}},{"animations":[],"events":{},"groupStyle":{},"isLock":false,"collapseName":"style","linkage":{"duration":0,"data":[{"id":"","label":"","event":"","style":[{"key":"","value":""}]}]},"component":"SVGTriangle","label":"三角形","icon":"xingzhuang-sanjiaoxing","propValue":"","style":{"rotate":0,"opacity":1,"width":80,"height":80,"fontSize":"","fontWeight":400,"lineHeight":"","letterSpacing":0,"textAlign":"center","color":"","borderColor":"#000","backgroundColor":"rgba(255, 255, 255, 1)"}},{"animations":[],"events":{},"groupStyle":{},"isLock":false,"collapseName":"style","linkage":{"duration":0,"data":[{"id":"","label":"","event":"","style":[{"key":"","value":""}]}]},"component":"VTable","label":"表格","icon":"biaoge","propValue":{"data":[["表头1","表头2","表头3"],["内容1","内容2","内容3"]],"stripe":true,"thBold":true},"request":{"method":"GET","data":[],"url":"","series":false,"time":1000,"paramType":"","requestCount":0},"style":{"rotate":0,"opacity":1,"width":600,"height":200,"fontSize":"","fontWeight":400,"textAlign":"center","color":"","backgroundColor":"rgba(255, 255, 255, 1)"}},{"animations":[],"events":{},"groupStyle":{},"isLock":false,"collapseName":"style","linkage":{"duration":0,"data":[{"id":"","label":"","event":"","style":[{"key":"","value":""}]}]},"component":"VChart","label":"图表","icon":"el-icon-data-analysis","propValue":{"chart":"VChart","option":{"title":{"text":"柱状图","show":true},"legend":{"show":true},"tooltip":{"show":true,"trigger":"item"},"xAxis":{"show":true,"data":["A","B","C","D","E"]},"yAxis":{},"series":{"type":"bar","name":"销量","data":[23,61,35,77,35],"itemStyle":{"barBorderRadius":5,"borderWidth":1,"borderType":"solid","borderColor":"#73c0de","shadowColor":"#5470c6","shadowBlur":3}}}},"style":{"rotate":0,"opacity":1,"width":800,"height":500,"borderRadius":""}}] 6 | ``` 7 | 8 | 如果一个页面包含了一个文本和按钮组件,那么这个页面的 JSON 代码如下: 9 | 10 | ```json 11 | [{"animations":[],"events":{},"groupStyle":{},"isLock":false,"collapseName":"style","linkage":{"duration":0,"data":[{"id":"","label":"","event":"","style":[{"key":"","value":""}]}]},"component":"VText","label":"文字","propValue":"双击编辑文字","icon":"wenben","request":{"method":"GET","data":[],"url":"","series":false,"time":1000,"paramType":"","requestCount":0},"style":{"rotate":0,"opacity":1,"width":200,"height":28,"fontSize":"","fontWeight":400,"lineHeight":"","letterSpacing":0,"textAlign":"","color":"","top":96,"left":253},"id":"aDznLRoPXX3LMHgvfsc1c"},{"animations":[],"events":{},"groupStyle":{},"isLock":false,"collapseName":"style","linkage":{"duration":0,"data":[{"id":"","label":"","event":"","style":[{"key":"","value":""}]}]},"component":"VButton","label":"按钮","propValue":"按钮","icon":"button","style":{"rotate":0,"opacity":1,"width":100,"height":34,"borderWidth":1,"borderColor":"","borderRadius":"","fontSize":"","fontWeight":400,"lineHeight":"","letterSpacing":0,"textAlign":"","color":"","backgroundColor":"","top":162,"left":308},"id":"DaaGKgxWdTXV3tOiBI_h4"}] 12 | ``` 13 | 14 | 你作为一个技术专家,现在需要按照上面的规则来为我生成页面,并且生成的页面中每一个组件的属性都不能忽略,也不需要解释,只需要返回 JSON 数据即可。要注意的是,有些数值的单位是没有 px 的。 15 | 16 | 现在我需要生成一个海报页面,主要用于宣传编程有什么用。 17 | -------------------------------------------------------------------------------- /prompts/prompt.md: -------------------------------------------------------------------------------- 1 | 我有一个低代码平台项目,它可以根据符合规范的 JSON 数据生成页面,这个 JSON 数据是一个数组,里面的每一项都是一个 JSON 对象,每个 JSON 对象都对应着一个组件。 2 | 下面用 ``` 包括起来的代码就是所有的组件列表。 3 | 4 | ```json 5 | [ 6 | { 7 | "animations": [], // 动画 8 | "events": {}, // 事件 9 | "groupStyle": {}, // 组合样式 10 | "isLock": false, // 是否锁定 11 | "collapseName": "style", 12 | "linkage": { // 联动属性 13 | "duration": 0, 14 | "data": [ 15 | { 16 | "id": "", 17 | "label": "", 18 | "event": "", 19 | "style": [ 20 | { 21 | "key": "", 22 | "value": "" 23 | } 24 | ] 25 | } 26 | ] 27 | }, 28 | "component": "VText", // 组件类型 29 | "label": "文字", // 组件名称 30 | "propValue": "双击编辑文字", // 组件值 31 | "icon": "wenben", // 组件图标 32 | "request": { // 组件请求 33 | "method": "GET", 34 | "data": [], 35 | "url": "", 36 | "series": false, 37 | "time": 1000, 38 | "paramType": "", 39 | "requestCount": 0 40 | }, 41 | "style": { // 组件样式 42 | "rotate": 0, 43 | "opacity": 1, 44 | "width": 200, 45 | "height": 28, 46 | "fontSize": "", 47 | "fontWeight": 400, 48 | "lineHeight": "", 49 | "letterSpacing": 0, 50 | "textAlign": "", 51 | "color": "" 52 | } 53 | }, 54 | { 55 | "animations": [], 56 | "events": {}, 57 | "groupStyle": {}, 58 | "isLock": false, 59 | "collapseName": "style", 60 | "linkage": { 61 | "duration": 0, 62 | "data": [ 63 | { 64 | "id": "", 65 | "label": "", 66 | "event": "", 67 | "style": [ 68 | { 69 | "key": "", 70 | "value": "" 71 | } 72 | ] 73 | } 74 | ] 75 | }, 76 | "component": "VButton", 77 | "label": "按钮", 78 | "propValue": "按钮", 79 | "icon": "button", 80 | "style": { 81 | "rotate": 0, 82 | "opacity": 1, 83 | "width": 100, 84 | "height": 34, 85 | "borderWidth": 1, 86 | "borderColor": "", 87 | "borderRadius": "", 88 | "fontSize": "", 89 | "fontWeight": 400, 90 | "lineHeight": "", 91 | "letterSpacing": 0, 92 | "textAlign": "", 93 | "color": "", 94 | "backgroundColor": "" 95 | } 96 | }, 97 | { 98 | "animations": [], 99 | "events": {}, 100 | "groupStyle": {}, 101 | "isLock": false, 102 | "collapseName": "style", 103 | "linkage": { 104 | "duration": 0, 105 | "data": [ 106 | { 107 | "id": "", 108 | "label": "", 109 | "event": "", 110 | "style": [ 111 | { 112 | "key": "", 113 | "value": "" 114 | } 115 | ] 116 | } 117 | ] 118 | }, 119 | "component": "Picture", 120 | "label": "图片", 121 | "icon": "tupian", 122 | "propValue": { 123 | "url": "img/title.07a15c19.jpg", 124 | "flip": { 125 | "horizontal": false, 126 | "vertical": false 127 | } 128 | }, 129 | "style": { 130 | "rotate": 0, 131 | "opacity": 1, 132 | "width": 300, 133 | "height": 200, 134 | "borderRadius": "" 135 | } 136 | }, 137 | { 138 | "animations": [], 139 | "events": {}, 140 | "groupStyle": {}, 141 | "isLock": false, 142 | "collapseName": "style", 143 | "linkage": { 144 | "duration": 0, 145 | "data": [ 146 | { 147 | "id": "", 148 | "label": "", 149 | "event": "", 150 | "style": [ 151 | { 152 | "key": "", 153 | "value": "" 154 | } 155 | ] 156 | } 157 | ] 158 | }, 159 | "component": "RectShape", 160 | "label": "矩形", 161 | "propValue": " ", 162 | "icon": "juxing", 163 | "style": { 164 | "rotate": 0, 165 | "opacity": 1, 166 | "width": 200, 167 | "height": 200, 168 | "fontSize": "", 169 | "fontWeight": 400, 170 | "lineHeight": "", 171 | "letterSpacing": 0, 172 | "textAlign": "center", 173 | "color": "", 174 | "borderColor": "#000", 175 | "borderWidth": 1, 176 | "backgroundColor": "", 177 | "borderStyle": "solid", 178 | "borderRadius": "", 179 | "verticalAlign": "middle" 180 | } 181 | }, 182 | { 183 | "animations": [], 184 | "events": {}, 185 | "groupStyle": {}, 186 | "isLock": false, 187 | "collapseName": "style", 188 | "linkage": { 189 | "duration": 0, 190 | "data": [ 191 | { 192 | "id": "", 193 | "label": "", 194 | "event": "", 195 | "style": [ 196 | { 197 | "key": "", 198 | "value": "" 199 | } 200 | ] 201 | } 202 | ] 203 | }, 204 | "component": "LineShape", 205 | "label": "直线", 206 | "propValue": "", 207 | "icon": "zhixian", 208 | "style": { 209 | "rotate": 0, 210 | "opacity": 1, 211 | "width": 200, 212 | "height": 2, 213 | "backgroundColor": "#000" 214 | } 215 | }, 216 | { 217 | "animations": [], 218 | "events": {}, 219 | "groupStyle": {}, 220 | "isLock": false, 221 | "collapseName": "style", 222 | "linkage": { 223 | "duration": 0, 224 | "data": [ 225 | { 226 | "id": "", 227 | "label": "", 228 | "event": "", 229 | "style": [ 230 | { 231 | "key": "", 232 | "value": "" 233 | } 234 | ] 235 | } 236 | ] 237 | }, 238 | "component": "CircleShape", 239 | "label": "圆形", 240 | "propValue": " ", 241 | "icon": "24gl-circle", 242 | "style": { 243 | "rotate": 0, 244 | "opacity": 1, 245 | "width": 200, 246 | "height": 200, 247 | "fontSize": "", 248 | "fontWeight": 400, 249 | "lineHeight": "", 250 | "letterSpacing": 0, 251 | "textAlign": "center", 252 | "color": "", 253 | "borderColor": "#000", 254 | "borderWidth": 1, 255 | "backgroundColor": "", 256 | "borderStyle": "solid", 257 | "borderRadius": "", 258 | "verticalAlign": "middle" 259 | } 260 | }, 261 | { 262 | "animations": [], 263 | "events": {}, 264 | "groupStyle": {}, 265 | "isLock": false, 266 | "collapseName": "style", 267 | "linkage": { 268 | "duration": 0, 269 | "data": [ 270 | { 271 | "id": "", 272 | "label": "", 273 | "event": "", 274 | "style": [ 275 | { 276 | "key": "", 277 | "value": "" 278 | } 279 | ] 280 | } 281 | ] 282 | }, 283 | "component": "SVGStar", 284 | "label": "星形", 285 | "icon": "kongwujiaoxing", 286 | "propValue": "", 287 | "style": { 288 | "rotate": 0, 289 | "opacity": 1, 290 | "width": 80, 291 | "height": 80, 292 | "fontSize": "", 293 | "fontWeight": 400, 294 | "lineHeight": "", 295 | "letterSpacing": 0, 296 | "textAlign": "center", 297 | "color": "", 298 | "borderColor": "#000", 299 | "backgroundColor": "rgba(255, 255, 255, 1)" 300 | } 301 | }, 302 | { 303 | "animations": [], 304 | "events": {}, 305 | "groupStyle": {}, 306 | "isLock": false, 307 | "collapseName": "style", 308 | "linkage": { 309 | "duration": 0, 310 | "data": [ 311 | { 312 | "id": "", 313 | "label": "", 314 | "event": "", 315 | "style": [ 316 | { 317 | "key": "", 318 | "value": "" 319 | } 320 | ] 321 | } 322 | ] 323 | }, 324 | "component": "SVGTriangle", 325 | "label": "三角形", 326 | "icon": "xingzhuang-sanjiaoxing", 327 | "propValue": "", 328 | "style": { 329 | "rotate": 0, 330 | "opacity": 1, 331 | "width": 80, 332 | "height": 80, 333 | "fontSize": "", 334 | "fontWeight": 400, 335 | "lineHeight": "", 336 | "letterSpacing": 0, 337 | "textAlign": "center", 338 | "color": "", 339 | "borderColor": "#000", 340 | "backgroundColor": "rgba(255, 255, 255, 1)" 341 | } 342 | }, 343 | { 344 | "animations": [], 345 | "events": {}, 346 | "groupStyle": {}, 347 | "isLock": false, 348 | "collapseName": "style", 349 | "linkage": { 350 | "duration": 0, 351 | "data": [ 352 | { 353 | "id": "", 354 | "label": "", 355 | "event": "", 356 | "style": [ 357 | { 358 | "key": "", 359 | "value": "" 360 | } 361 | ] 362 | } 363 | ] 364 | }, 365 | "component": "VTable", 366 | "label": "表格", 367 | "icon": "biaoge", 368 | "propValue": { 369 | "data": [ 370 | [ 371 | "表头1", 372 | "表头2", 373 | "表头3" 374 | ], 375 | [ 376 | "内容1", 377 | "内容2", 378 | "内容3" 379 | ] 380 | ], 381 | "stripe": true, 382 | "thBold": true 383 | }, 384 | "request": { 385 | "method": "GET", 386 | "data": [], 387 | "url": "", 388 | "series": false, 389 | "time": 1000, 390 | "paramType": "", 391 | "requestCount": 0 392 | }, 393 | "style": { 394 | "rotate": 0, 395 | "opacity": 1, 396 | "width": 600, 397 | "height": 200, 398 | "fontSize": "", 399 | "fontWeight": 400, 400 | "textAlign": "center", 401 | "color": "", 402 | "backgroundColor": "rgba(255, 255, 255, 1)" 403 | } 404 | }, 405 | { 406 | "animations": [], 407 | "events": {}, 408 | "groupStyle": {}, 409 | "isLock": false, 410 | "collapseName": "style", 411 | "linkage": { 412 | "duration": 0, 413 | "data": [ 414 | { 415 | "id": "", 416 | "label": "", 417 | "event": "", 418 | "style": [ 419 | { 420 | "key": "", 421 | "value": "" 422 | } 423 | ] 424 | } 425 | ] 426 | }, 427 | "component": "VChart", 428 | "label": "图表", 429 | "icon": "el-icon-data-analysis", 430 | "propValue": { 431 | "chart": "VChart", 432 | "option": { 433 | "title": { 434 | "text": "柱状图", 435 | "show": true 436 | }, 437 | "legend": { 438 | "show": true 439 | }, 440 | "tooltip": { 441 | "show": true, 442 | "trigger": "item" 443 | }, 444 | "xAxis": { 445 | "show": true, 446 | "data": [ 447 | "A", 448 | "B", 449 | "C", 450 | "D", 451 | "E" 452 | ] 453 | }, 454 | "yAxis": {}, 455 | "series": { 456 | "type": "bar", 457 | "name": "销量", 458 | "data": [ 459 | 23, 460 | 61, 461 | 35, 462 | 77, 463 | 35 464 | ], 465 | "itemStyle": { 466 | "barBorderRadius": 5, 467 | "borderWidth": 1, 468 | "borderType": "solid", 469 | "borderColor": "#73c0de", 470 | "shadowColor": "#5470c6", 471 | "shadowBlur": 3 472 | } 473 | } 474 | } 475 | }, 476 | "style": { 477 | "rotate": 0, 478 | "opacity": 1, 479 | "width": 800, 480 | "height": 500, 481 | "borderRadius": "" 482 | } 483 | } 484 | ] 485 | ``` 486 | 487 | 如果一个页面包含了一个文本和按钮组件,那么这个页面的 JSON 代码如下: 488 | 489 | ```json 490 | [ 491 | { 492 | "animations": [], 493 | "events": {}, 494 | "groupStyle": {}, 495 | "isLock": false, 496 | "collapseName": "style", 497 | "linkage": { 498 | "duration": 0, 499 | "data": [ 500 | { 501 | "id": "", 502 | "label": "", 503 | "event": "", 504 | "style": [ 505 | { 506 | "key": "", 507 | "value": "" 508 | } 509 | ] 510 | } 511 | ] 512 | }, 513 | "component": "VText", 514 | "label": "文字", 515 | "propValue": "双击编辑文字", 516 | "icon": "wenben", 517 | "request": { 518 | "method": "GET", 519 | "data": [], 520 | "url": "", 521 | "series": false, 522 | "time": 1000, 523 | "paramType": "", 524 | "requestCount": 0 525 | }, 526 | "style": { 527 | "rotate": 0, 528 | "opacity": 1, 529 | "width": 200, 530 | "height": 28, 531 | "fontSize": "", 532 | "fontWeight": 400, 533 | "lineHeight": "", 534 | "letterSpacing": 0, 535 | "textAlign": "", 536 | "color": "", 537 | "top": 96, 538 | "left": 253 539 | }, 540 | "id": "aDznLRoPXX3LMHgvfsc1c" 541 | }, 542 | { 543 | "animations": [], 544 | "events": {}, 545 | "groupStyle": {}, 546 | "isLock": false, 547 | "collapseName": "style", 548 | "linkage": { 549 | "duration": 0, 550 | "data": [ 551 | { 552 | "id": "", 553 | "label": "", 554 | "event": "", 555 | "style": [ 556 | { 557 | "key": "", 558 | "value": "" 559 | } 560 | ] 561 | } 562 | ] 563 | }, 564 | "component": "VButton", 565 | "label": "按钮", 566 | "propValue": "按钮", 567 | "icon": "button", 568 | "style": { 569 | "rotate": 0, 570 | "opacity": 1, 571 | "width": 100, 572 | "height": 34, 573 | "borderWidth": 1, 574 | "borderColor": "", 575 | "borderRadius": "", 576 | "fontSize": "", 577 | "fontWeight": 400, 578 | "lineHeight": "", 579 | "letterSpacing": 0, 580 | "textAlign": "", 581 | "color": "", 582 | "backgroundColor": "", 583 | "top": 162, 584 | "left": 308 585 | }, 586 | "id": "DaaGKgxWdTXV3tOiBI_h4" 587 | } 588 | ] 589 | ``` 590 | 591 | 你作为一个技术专家,现在需要按照上面的规则来为我生成页面,并且生成的页面中每一个组件的属性都不能忽略。 592 | 593 | 现在我需要生成一个海报页面,主要用于宣传编程有什么用。 594 | -------------------------------------------------------------------------------- /quantization.py: -------------------------------------------------------------------------------- 1 | from torch.nn import Linear, Embedding 2 | from torch.nn.parameter import Parameter 3 | import torch.nn.functional as F 4 | 5 | import os 6 | import bz2 7 | import torch 8 | import base64 9 | import ctypes 10 | import sys 11 | from transformers.utils import logging 12 | 13 | from typing import List 14 | from functools import partial 15 | 16 | logger = logging.get_logger(__name__) 17 | 18 | try: 19 | from cpm_kernels.kernels.base import LazyKernelCModule, KernelFunction, round_up 20 | 21 | 22 | class Kernel: 23 | def __init__(self, code: bytes, function_names: List[str]): 24 | self.code = code 25 | self._function_names = function_names 26 | self._cmodule = LazyKernelCModule(self.code) 27 | 28 | for name in self._function_names: 29 | setattr(self, name, KernelFunction(self._cmodule, name)) 30 | 31 | 32 | quantization_code = "$QlpoOTFBWSZTWU9yuJUAQHN//////////f/n/8/n///n//bt4dTidcVx8X3V9FV/92/v4B7/AD5FBQFAAAChSgKpFCFAFVSigUAAAEKhSgUUqgFBKigqVREQAABQBQIANDTTIGI00BkZBkNGE0A0BkBkGQGRkaNAaAGQNBoGgDIAAYIGTI0DQAQAaGmmQMRpoDIyDIaMJoBoDIDIMgMjI0aA0AMgaDQNAGQAAwQMmRoGgAgA0NNMgYjTQGRkGQ0YTQDQGQGQZAZGRo0BoAZA0GgaAMgABggZMjQNABABoaaZAxGmgMjIMhowmgGgMgMgyAyMjRoDQAyBoNA0AZAADBAyZGgaAAmqU1NEgJqnptU/Sn4jRR6J6epk2pqb1Q/SgAPUGgyNNGjQ2SBpoAZAAGg0NB6mgDIAAAAA2oaApSREBNAARhGiYEaEwU8pvImlP0k2aam1GaGqbFNM1MHpTwmkepmyU9R6nqPKekHqNNPUxNGhp6n6p6QaZ6o9TG1GMqcoV9ly6nRanHlq6zPNbnGZNi6HSug+2nPiZ13XcnFYZW+45W11CumhzYhchOJ2GLLV1OBjBjGf4TptOddTSOcVxhqYZMYwZXZZY00zI1paX5X9J+b+f4e+x43RXSxXPOdquiGpduatGyXneN696M9t4HU2eR5XX/kPhP261NTx3JO1Ow7LyuDmeo9a7d351T1ZxnvnrvYnrXv/hXxPCeuYx2XsNmO003eg9J3Z6U7b23meJ4ri01OdzTk9BNO96brz+qT5nuvvH3ds/G+m/JcG/F2XYuhXlvO+jP7U3XgrzPN/lr8Sf1n6j4j7jZs+s/T0tNaNNYzTs12rxjwztHlnire3Nzc3N1wuBwOBwXBvZfoHpD7rFmR99V5vj3aXza3xdBbXMalubTg/jIv5dfAi54Pdc75j4z412n3Npj3Ld/ENm7a3b/Cod6h/ret1/5vn/C+l+gdslMvgPSLJ8d8q+U66fevYn/tW1chleEtNTGlcHCbLRlq0tHzF5tsbbZZfHjjLgZu42XCuC3NrdjTasZGNzgxPIrGqp7r3p7L2p5XjnpPSmTd5XtzqnB6U87zzg1Ol0zd0zsLszxR6lkxp35u6/teL0L0W922cR7Lu1lpL9CsHirzuM2T+BgsyViT6LHcm0/Vr6U/7LGGyJeqTEjt0PHWhF5mCT7R9mtlDwriYv0Tyr/OxYt6qp5r0mPVT0608TqnqMZaarU2nFwrTzzlrs1ed7z1ux60wyr4ydCaTi3enW8x68x0zU7tXSlcmPSW1mGpWJMg4zmPC2lK96tp0OE80y4MfEvnZj8zGluR6b22ki1Ou9V2nCd9xovcPvcYMZYy0lvN60ScZ45vN6yeCeeXFb1lVjnnCar5fwXwE2bzJ4HI1XVPXfXZMm44GUsMpYsmLB65TuVdm0cl0b+i/wGNN66XjeV7zuPpHcnK/juhhjdfId5jMdE5nN0dGmmm2zZs2cexD5n9p/dY352XsvXHaZNWWsmmS1atjR452nYudzvqv2HMRyvNNnlMcDl3R2+yx2uVrBubTW9icHDVtbNXlZm7jma1rM4VurZZd2y6nUau7ZXZ7bVU+mnoOVxZGMrVmvX60605JwmzGZhhhjTWtaaaMaaGTGmNMZasY0iX8VMUl8eepaIrzGSpemWOQyZORk2bNpjUybMmxqYmknCGCFynutfksaZpjTNMaaatM0xsxcGR0sociNqxNSmhhR1ZJPbsn8qyF0t2qH6iYBclclalbtTTcHTDsPaX6rlnElph2Jyumumtynv2Kk8GI7rsvXbIcJgHJOSaSXnnGaI3m87RtVXJOZ/YtgdTE6Wpha6ZlE8ayXkef1fh602r2WwvfMXtMdLlkfnLFdYYwYso+bWqm7yJqHXZGw2nrS5ZanSYnWlxBxMF1V940K2wdrI7R6OYf7DGGamMmTSbRhlS45xmVOumF1EyPCmHrrN8wwZOOrdNtLeMtzFzDlWnfTBxMk2NaXIZHBYxYLD4w8yju0ao65Vz1OIXoS9dLanwCe1PWrYuWMqf1if1z2k2yYfKJ741PDgno1ZQ8DRqvUny3mNoWTzGO6m1DkrJI8JiR5cSd+vZdGOO8nrMoc5+NDUFsMSXaZJeNlMmGLtJsovOsUp7I9S5VojKxF6bTVEelXqlfJobQr3LozSh2Jk7VcrVMfhXqszGWMzNqGhqZY0OadxkyyMssKugZR0KNFXBHlqwmJgTE/BNVMk6ItJXZMR0H47GpXv/DMOvNkmVuaV1PRfEdxuqc7Hcd+ZV/zTLaRxWk0nl9CdCeM6mn5rstHIBcpiuwmUZXeq81DacHI2rmrZ5SuE5mOZd6LQrZg9mx32TprA8BMo5jKN6yLTCi3WzQaZSuhzTtM1fUTGVpG8Tw+KXI0tjEpiWxtLYynOlktSbVlaI5kxP8TDH8kx50xoxi5KcA4pcja8KWLRlO/Ks6q06ergnvm1ca3Tq8Uw7LTUsmWyctXPWmpitl/uvGcWTGXGuAXDfhqazGmjkxcJW5hMMMMpYsXl2TZYtVOddG3XCarUt6Ptq9CZXSNzyuRzqRZOjsxdBbFVz6OA5HI43r1jityVlVpVkxmOsyaYWE1NTGq1sOVh36mHMcxtSvcy70edG0ZGR3I1Go1GRlV7mWWo1G0ZGRqlvH40l7o4m5xMWLLLYyNjnqc8556mdPqLJ31n/1nWOncxzG1tizrHs/Z+d2vP/B/l8wdJ6rHUn2nbbDq4p6htFtYzMMMTaZis1K5GKzGNmxhmUx2DDlZ/qNnIx41xnaMfCZWYaZWtNLTNW8ND4Fw1MyZOCdM428suKG1ehW8TesOydg7J+YYcD4cYR+8dFK6M4E3HM9ZfRNNL+Sn6rsl4DsrDl2HpPCnfxjGXtbZtYys1ttlyJ4T+BvexjGWRjMszK4Jpc77D3GyuVD7q0+G8m9G+2+rGm7cOR2y7FdtY2XUYx/oNlfRYxhMYyYZkyyg55enna9Kt/FFi6GMMwYwdwxWgxGMLKYmUyGExTKMZkMFhkymKuh0NOBNnBu+23LdwDoZYYzGGMxtORaTU1pjTGWTTGGtMrNWUsyyTTLLG1qy2ZjbK2DBllWqxMtBMaYZQmcE7zvvRcTkclUwdkxTaSdyySt/7fpL+T1v516Ji97fwr5JbLu305zMn5+GMTTZ9F+y7ExwmGVfG44yxn3dLv6l5i+Wth1jCrDq21nW9LqvvDzz3Vf3LLH/O/32TJ/erx3bXftO4eF+G956D952K/An4NfvOpjFjExjevP/UmE0fIoZXx6/w6lX/no3D0bLt+ixjieBM6ksRd0yB4Lt2SwYNE+gd1detlZWUnpiZfGfFaK+4PyCa/v18V8X75pe9fLXzp7l3VjF76vWZmHwGz1IZNWT7b8yddJ4q5kyrVdfru6atWc7bVYztL9Jf4GXvT+Y8m9/YsXP6H018a8D4XVOqvfzqeR+6yZOD8dPv0+U7/q5Pl+2dNb0MjzGVH5p6MNQ7cOWvw62U9aHE8DprDek+McLyvDz+te+9Zhq5+YTruufMcWMabqysTmZVWjKPfnK0wyVcrsuhjZRdLkHNvD72b9abriOSGIxiLixMOoalNPXzy+wT/tf+U6HHONfsz+xe8ufHBdQWWGWLA9if0rsnmrxK5LvRZQeWsTCsrmOYy8VteVfuRfcVTtDLItLIsMYxZLdU/DbtSemxF6Z6Zo5WBXE4tFdCyVMMXMTEMZXVlS6Xec2T4e0tHsRcEuWshcJ2YsNF5rUx1E8ifCq6Z+ZP7qdCeu/aTwFd53l16/o0NOw6O3dLavP4Hbi4RdmuDk6DoYaninC0+o4uZjbJ7Rxeu0/FbuFg+q7DVS6fQe0rZ6NDGUNNU6DEqOaLTicKnYZMnBWruljQxoaS3dZhocDge0bSTyOvdAbG5hxe2xji7E/L55xX13wWNDi6HCekcFxfCPGxY0MXC+s7afWaMdDyjyr+o8Rudm/NabOZvdl274zH4f5XK9z6On1Pe/K5TdPAslg77BjuO6Y3eO7GqvOPG/stknp1leyvLL0Z7bl9I4noMvLkzytLhWYzrOZzLXCORe028rORzOg4N/L0HlMOQ3Pgmnbb6KczlabORpu980q37TBqRu0/p3PO6234Bl03Ynuz+9W7gnsEcmvYaYY3aMYY0wx3pYd+ujsXauWdaY5Xkbtl23fPzFHiDB/QMo0yFjBllYxTQYYyxkrwn7JufwJ/PfgJ+C83X69ni6zvXcnyXabv0ncbLwsceS+RNlyN2mnneJtX0ngYO0+e+0+UnA+Wch3ji8hj5an4h+i6XBySU4n+R0roVcbw5yvHrmr4Yw8Y7x6c+9POPYHI5HI5HI5HI5HGXGww4nE4nrVyOR8XeqPEO7PLOiukYa3Novk5hV4cdtYZLI93e+uxff2jRo0aNGjRo0aNG1bVtW1dy3m83m8+tQ5ZzHw3nObwOu8La9Rc1dtkdS8A3eTk823tnktXWlxN6Oixe06zrN70Isd9jiOgZFq9yfkPqP/SLhN2Myl8jDM43bl1nbcb4cO57jlh8Jow6pzXZdL4dyODTuuhu77FyO27DdwdRxmvO+O+3N2+BdqyTwLHVczDVY4UPE4O66/ZO2cx1LFzVdSXtF7G4HMbrauOHRw6c8FdZ5m9fHZHYZXfTlZquyynSyTTKke6vcffSD9pzPA/G7n7jxPmuhc1DHMynPMrGL6AdewYmwu5ko+UUyTwrMv27rPH1v1nGqd87+p6N6LU8k3NEng53xXyHS97+44OSg/sy/hn+Se6yfYNjW0/uTgP+PvWYzLMmjhcLB/gGpri6H83/84eUXWT6T9Hsv7785z/7z4icpW+zfXypuR7rx/gMdZb1/wC678pcs8/2a3mDitGHxl9mfPlll5MafWWqxk/eYuTDgcNMzDGWLWvsuglNxs53GtN6uWpktlW1tZZYcuinMMWmnNnJydze3b2Y1McBxrBkXw799izLMZZYyy0TkbsGM4p03S2uVu5s/XXUdSdec6smVxZYYGpVmT8A+8ajuEyV5FatkvVru2x6uxGXXbH4A+jvgP4GMYy3iPLXzq/6z65+E005ey+cwMZD3fZcqc6xpjTFjQ0P3U+e++cPYmTIwj0nrK5NPTfl3WvpfLtXDcb2HQMudYOxFXQBor4L4T6vrOauFctYXJQ++NUWmJe5bmx1jDiZS1dTqWxo4GR8jm3fttpmPHppk9PEyv4/y8/sO07XacOmcqc0x2Vi9BvNJvN5oW8x4mOsydpidRxMYJPx06m1bqPzq9KtK8sxXNXFodD/+MYYaJTLwOhc9brCsV18oOR1i4tXChyTkq4lf4y1Ke+9axjDHqs1mfBbMXuP4Hzi+X7t8vzv7bHerrUPgPCxhjre4fXdfLNtNM+Jd+Zdh8xd8wP87uNPoPgv4W7/5P2BuxfsMabNnMnza+54Pdi5U671GPZY8CehX8Voeoo7FHpkeEc6715FwHZrIrUrHaviPUbPZHND+IhczrP6FcYvhOZ0Di/ETt0OI+YwNWR9r7tpf6WDeZKZDB1+z2IthOl1mPyb5FluvEx9h9d0NnM0Y1XPFkWIsk1WotJ0PBMmkvjvQTd0e71tfeV+8r8lQ/tpzpsmxJ+InrI/dj2UajUajVTUajatRqNRtGo1Go1Go4wjeMpZFMVV9CHbofPraLsJ3JpWV2XOoanCuFky4y3PPNxucK2uKC1Lbdb1eo+m5XomN6HfeZsabHLHRX/K+offtNGGmHWctcVcG44MdSqsOLY9VzX+Zxfxn2HPdWTpzWvkrtJ8M5zorrKcquRytJ5N5DZmcaW02l76nWO+BqPXm1A2Ry/0q71dH/mqrqeFjkYxjEXtsX8qubTk67rGycyqsdm4tZx5D6D5hhi0waaWmiaMP81Yjii5qxPlPuU/GfTL1Y5E6Jyfiq63qTa39A4J0sOGDgO9WF9bOXl0XfPRbsY2bPNKPy1YrFYrFYmRhhlTIyMjJWJYZHXuCXI8OoXsvfljGLFicNifpp2XunoPiG1wtx3p1Tah+/DD66OnVtVXP9rKbVxOnL0tR/rHtqB5UDErUVcl11D4qqvjpOcxX7armUNJB3LpW6bxVvD08e8h3odKKvyCFZBdSh2FVcST9xV3n3T8t1j7Kr9qgrqXg+13Pt5U7JCvFXVIV1YG5lRhkVYZJYYDDD4KOIMoHCp26WS8GB7uBh2zIdgq/PKyInjV2STShuoapUdCpX1yTwqq/z1VvET7Kh5nVPkO8YyxjLt2MaaMmWTLQvx3qnzltnXW0p2jxgbEtSny/Osv8Y9pLMXYoHVPAhkVdWVeODhR6q9/Sxe2liwwZWMVvFXfRkeIDxAePUPIrdJ4ey6yquzH+PD/bUOWAu05qVHtFd8rrKHSoeNIOUqrYr3FXyToqfYJgwmJdKpXXOwYYegNNGMzfZPp/t3t/DVs4zjNTN61rRqaWaa4NYbRjTa0tWwy2Y2tGN8ZO8ofNKq4j9SL7I+cSm4/6ovLV5HNXLI0jJidwrtk6ynCaP6Z++GjRlWS3tLeW129Mi9evxU9mtz6s5J3Z7M2ngTgnKvmpomxpaLCzPfmx0JWE+m3NLDDGOX47RctdYYNK5jakdqLkRlI39n590T5zctGSwwZZDJj6kW8XSi6ot2MmWWJ0DUT3nuvebBudScjZ79g8cWJ8av0k+/bE5WKd5MdbFpbDVMxu1DVMmtNZGJvq1mtRbn6M+g/kP0FwDwr7quZs7xosNGpbscyxhhd9TyJyFwbLcxlTasg75vW7TsV5K7ji44XPMMrdoj+Y3rT0Hie62nlYV/pwczzOmdLqLhYkzGMzCZWGMQzGMSsZYY6Di1t4nlJ+Em63mJxrVLxPbYxNEdgc1dU2iOKyoYYWjNrEeHTYybVk0atSa7ehuwsWMWTqn1TrnS6hYsi71d1+s+k+ic70e20fzE/VaTdxT9ZtU4GIXdeNx3X77guYYfpHeTQjaMX6brOu4OY4K7Y2d9mbHarI5ox3p4GpJ2Vd/Tst60f7j999pppjR+Q/Qf8J/VaORs3cji7FfFuN61+ui9s8hix1OCh5KGVV23BPXvZfz3CLyHpix+exi8z/KnCnosY2eunor+cxyPO/xJ0vKey9OvE9VjqaYu0x3Z3jd6o2b1T12D+F8l232lwaaacD5LE8LBxu7WTlbWraWpew8Xexjel3E+wWD4APITdNqR8F3R3T0lunCQ4GaE9R37DxeCYfcHi4xci5ovKfxVs55y2hf+65E/Xdp6jR5nrebTmi5incpkyOjs50JvrZwstbbW6kfuuQw+2mykf/EXNFzxfKTrxew929TR6bWnGL//F3JFOFCQT3K4lQ" 33 | 34 | kernels = Kernel( 35 | bz2.decompress(base64.b64decode(quantization_code)), 36 | [ 37 | "int4WeightCompression", 38 | "int4WeightExtractionFloat", 39 | "int4WeightExtractionHalf", 40 | "int8WeightExtractionFloat", 41 | "int8WeightExtractionHalf", 42 | ], 43 | ) 44 | except Exception as exception: 45 | kernels = None 46 | logger.warning("Failed to load cpm_kernels:", exception) 47 | 48 | 49 | class W8A16Linear(torch.autograd.Function): 50 | @staticmethod 51 | def forward(ctx, inp: torch.Tensor, quant_w: torch.Tensor, scale_w: torch.Tensor, weight_bit_width): 52 | ctx.inp_shape = inp.size() 53 | ctx.weight_bit_width = weight_bit_width 54 | out_features = quant_w.size(0) 55 | inp = inp.contiguous().view(-1, inp.size(-1)) 56 | weight = extract_weight_to_half(quant_w, scale_w, weight_bit_width) 57 | ctx.weight_shape = weight.size() 58 | output = inp.mm(weight.t()) 59 | ctx.save_for_backward(inp, quant_w, scale_w) 60 | return output.view(*(ctx.inp_shape[:-1] + (out_features,))) 61 | 62 | @staticmethod 63 | def backward(ctx, grad_output: torch.Tensor): 64 | inp, quant_w, scale_w = ctx.saved_tensors 65 | weight = extract_weight_to_half(quant_w, scale_w, ctx.weight_bit_width) 66 | grad_output = grad_output.contiguous().view(-1, weight.size(0)) 67 | grad_input = grad_output.mm(weight) 68 | grad_weight = grad_output.t().mm(inp) 69 | return grad_input.view(ctx.inp_shape), grad_weight.view(ctx.weight_shape), None, None 70 | 71 | 72 | class W8A16LinearCPU(torch.autograd.Function): 73 | @staticmethod 74 | def forward(ctx, inp: torch.Tensor, quant_w: torch.Tensor, scale_w: torch.Tensor, weight_bit_width, 75 | quantization_cache=None): 76 | ctx.inp_shape = inp.size() 77 | ctx.weight_bit_width = weight_bit_width 78 | out_features = quant_w.size(0) 79 | inp = inp.contiguous().view(-1, inp.size(-1)) 80 | weight = extract_weight_to_float(quant_w, scale_w, weight_bit_width, quantization_cache=quantization_cache) 81 | ctx.weight_shape = weight.size() 82 | output = inp.mm(weight.t()) 83 | ctx.save_for_backward(inp, quant_w, scale_w) 84 | return output.view(*(ctx.inp_shape[:-1] + (out_features,))) 85 | 86 | @staticmethod 87 | def backward(ctx, grad_output: torch.Tensor): 88 | inp, quant_w, scale_w = ctx.saved_tensors 89 | weight = extract_weight_to_float(quant_w, scale_w, ctx.weight_bit_width) 90 | grad_output = grad_output.contiguous().view(-1, weight.size(0)) 91 | grad_input = grad_output.mm(weight) 92 | grad_weight = grad_output.t().mm(inp) 93 | return grad_input.view(ctx.inp_shape), grad_weight.view(ctx.weight_shape), None, None 94 | 95 | 96 | default_cpu_kernel_code_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "quantization_kernels.c") 97 | default_cpu_kernel_code = "QlpoOTFBWSZTWXLbSoQAAgzbgERwQXxmTwAAr/ff3kABt0Q2oRVT0hpo9RtEAAAAyBEiSQ9EGjQGQAAAwANGhowjJoNGmgMEUplMTNSMJ5TQaDJpsoMyRMj8P4mZzFSVVwqSXG8GG7MlVwiToYEQwVD7noBxMhNfkeZYtYFtbgOBUSIGtIQjhNHCEnPJsadhb3yBmRIOD3TeAtNLSaU5GgvKUBWSNuuOIHmVt0YhW6rsmDMDUjeUJGJ64R1Jm5lrh0Aa0tKjhFwPdWcGogxLDSXPWQUWTM8Sd3Qz1HMYNxx3HMeiNqNo4jeRDEfZ3gUSHIcU/heomq0vEzL1Msz5KKGxH8FrNOYw3KaxdqaEmNHYMxJFgQbR0DyRknL2L4kwUSxKRdhjRpEtUqilVfggFL1klaMS3PPRDfNqbBOPWO7m4JTVGhS9QTBDDJaEbLbrUQNB+IpJSKQbG5SZZ5gkwJEhJ3aYKJipZ/i7kinChIOW2lQg" 98 | default_cpu_parallel_kernel_code_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 99 | "quantization_kernels_parallel.c") 100 | default_cpu_parallel_kernel_code = "QlpoOTFBWSZTWUzax5EAALXbgERwSX1mTwAAr/ff3kACNyXUbZYwBpoaNGIyAaADQwRSaVP9QoMg0A2oAPU0AEUkU9GaaKMaQB6gA09T1ARRKnpk0niaJkaaNDJ6g0DTIKVKfZ/g6v1Kem5LJLa0WmkukkuCIHUqWbtJGJMsCSQFiPEIYHgBIZDzR8R6REbYxIqD2Cu7lMkFoPu6LmHeOAy0GF83Tc40jgmTs4HnCe60QfJa2bDBZ0Y1lhgbiZjW8SNsAKCk42UOEdjWN3KoiCIYeQUCCKWIyHewhtSoInLKSG22l4jKM2ZDCVKtBm3OTYBl3jsVqMImtj7PQw7xKxLXQzwgJaPPgW1fRhrvPJICl4YFDYfNbkbBh5JDgrazFml50xEQQwQUjxNwE0IDSofLzSg7UNVKn+Rr1KErzBHUxBqdHRlXzqYsIa5K9Y0UuE2ugw3g5KYofm7AaGNTzJSMhcchhxdaU4JZ0F1UNgQ8XcGDguypqYza8yFaEoGgNRcLej+g2t0feGKFE5OY2PFluQ3q4HgycxlfvzHqo0KcM0JI8OKXtzayJFgsqC1NdUQVu8rChnA6FO3MFyGOoC9KO8ITPpYM5pRqTlczFkLES/4u5IpwoSCZtY8i" 101 | 102 | cpu_kernels = None 103 | 104 | 105 | class CPUKernel: 106 | def __init__(self, kernel_file="", source_code=default_cpu_kernel_code_path, compile_parallel_kernel=None, 107 | parallel_num=None): 108 | self.load = False 109 | self.int8WeightExtractionFloat = None 110 | self.int4WeightExtractionFloat = None 111 | self.int4WeightCompression = None 112 | self.SetNumThreads = lambda x: x 113 | 114 | try: 115 | if not os.path.exists(default_cpu_kernel_code_path): 116 | with open(default_cpu_kernel_code_path, "w", encoding="utf-8") as file: 117 | code = default_cpu_kernel_code 118 | cpu_quantization_code = bz2.decompress(base64.b64decode(code)).decode() 119 | file.write(cpu_quantization_code) 120 | 121 | if not os.path.exists(default_cpu_parallel_kernel_code_path): 122 | with open(default_cpu_parallel_kernel_code_path, "w", encoding="utf-8") as file: 123 | code = default_cpu_parallel_kernel_code 124 | cpu_quantization_code = bz2.decompress(base64.b64decode(code)).decode() 125 | file.write(cpu_quantization_code) 126 | 127 | except Exception as ex: 128 | print("Error when generating default cpu kernel code(can be ignored when using custom kernels).") 129 | 130 | if compile_parallel_kernel is None: 131 | compile_parallel_kernel = bool(int(os.cpu_count()) >= 4) 132 | 133 | if compile_parallel_kernel and source_code == default_cpu_kernel_code_path: 134 | source_code = default_cpu_parallel_kernel_code_path 135 | 136 | kernels = None 137 | 138 | if (not kernel_file) or (not os.path.exists(kernel_file)): 139 | print("No compiled kernel found.") 140 | try: 141 | if os.path.exists(source_code): 142 | print("Compiling kernels :", source_code) 143 | kernel_file = source_code[:-2] + ".so" 144 | 145 | if compile_parallel_kernel: 146 | if sys.platform != 'darwin': 147 | compile_command = "gcc -O3 -fPIC -pthread -fopenmp -std=c99 {} -shared -o {}".format( 148 | source_code, kernel_file) 149 | else: 150 | compile_command = "clang -O3 -fPIC -pthread -Xclang -fopenmp -lomp -std=c99 {} -shared -o {}".format( 151 | source_code, kernel_file) 152 | print("Compiling", compile_command) 153 | exit_state = os.system(compile_command) 154 | if not exit_state: 155 | try: 156 | kernels = ctypes.CDLL(kernel_file,winmode=0) 157 | print("Load kernel :", kernel_file) 158 | except: 159 | kernels = None 160 | print("Load parallel cpu kernel failed, using default cpu kernel code:") 161 | import traceback 162 | exception = traceback.format_exc() 163 | print(exception) 164 | else: 165 | print("Compile default cpu kernel failed, using default cpu kernel code.") 166 | 167 | if kernels is None: # adjust config, use default cpu kernel 168 | compile_parallel_kernel = False 169 | source_code = default_cpu_kernel_code_path 170 | kernel_file = source_code[:-2] + ".so" 171 | 172 | if kernels is None: 173 | compile_command = "gcc -O3 -fPIC -std=c99 {} -shared -o {}".format(source_code, kernel_file) 174 | print("Compiling", compile_command) 175 | exit_state = os.system(compile_command) 176 | if not exit_state: 177 | try: 178 | kernels = ctypes.CDLL(kernel_file,winmode=0) 179 | print("Load kernel :", kernel_file) 180 | except: 181 | kernels = None 182 | print("Load default cpu kernel failed:") 183 | import traceback 184 | exception = traceback.format_exc() 185 | print(exception) 186 | else: 187 | print("Compile default cpu kernel failed.") 188 | else: 189 | print("Kernel source code not found.") 190 | return 191 | except: 192 | print("Failed to build cpu kernel:") 193 | import traceback 194 | exception = traceback.format_exc() 195 | print(exception) 196 | return 197 | else: 198 | try: 199 | kernels = ctypes.CDLL(kernel_file,winmode=0) 200 | print("Load kernel :", kernel_file) 201 | except: 202 | kernels = None 203 | print("Load custom cpu kernel failed:") 204 | import traceback 205 | exception = traceback.format_exc() 206 | print(exception) 207 | 208 | if kernels is not None: 209 | self.int8WeightExtractionFloat = kernels.extract_int8_weight_to_float 210 | self.int4WeightExtractionFloat = kernels.extract_int4_weight_to_float 211 | self.int4WeightCompression = kernels.compress_int4_weight 212 | if compile_parallel_kernel: 213 | try: 214 | self.SetNumThreads = kernels.set_num_threads 215 | except: 216 | print("No set_num_threads() found in kernel.") 217 | self.load = True 218 | else: 219 | print("Failed to load kernel.") 220 | return 221 | 222 | if compile_parallel_kernel: 223 | if parallel_num is None: 224 | parallel_num = max(os.cpu_count() // 2, 1) 225 | print("Setting CPU quantization kernel threads to", parallel_num) 226 | if parallel_num < 4: 227 | print("Parallel kernel is not recommended when parallel num < 4.") 228 | self.SetNumThreads(parallel_num) 229 | 230 | self.parallel_num = parallel_num 231 | 232 | 233 | def compress_int4_weight(weight: torch.Tensor): # (n, m) 234 | """compress weight on cpu or cuda to int4""" 235 | if weight.device == torch.device("cpu"): 236 | assert isinstance(cpu_kernels, CPUKernel) 237 | n, m = weight.size(0), weight.size(1) 238 | assert m % 2 == 0 239 | m = m // 2 240 | out = torch.empty(n, m, dtype=torch.int8, device="cpu") 241 | cpu_kernels.int4WeightCompression( 242 | ctypes.c_void_p(weight.data_ptr()), 243 | ctypes.c_void_p(out.data_ptr()), 244 | ctypes.c_int32(n), 245 | ctypes.c_int32(m) 246 | ) 247 | return out 248 | else: 249 | with torch.cuda.device(weight.device): 250 | n, m = weight.size(0), weight.size(1) 251 | assert m % 2 == 0 252 | m = m // 2 253 | out = torch.empty(n, m, dtype=torch.int8, device="cuda") 254 | stream = torch.cuda.current_stream() 255 | 256 | gridDim = (n, 1, 1) 257 | blockDim = (min(round_up(m, 32), 1024), 1, 1) 258 | 259 | kernels.int4WeightCompression( 260 | gridDim, 261 | blockDim, 262 | 0, 263 | stream, 264 | [ctypes.c_void_p(weight.data_ptr()), ctypes.c_void_p(out.data_ptr()), ctypes.c_int32(n), 265 | ctypes.c_int32(m)], 266 | ) 267 | return out 268 | 269 | 270 | def extract_weight_to_half(weight: torch.Tensor, scale_list: torch.Tensor, source_bit_width: int): 271 | if source_bit_width == 8: 272 | func = kernels.int8WeightExtractionHalf 273 | elif source_bit_width == 4: 274 | func = kernels.int4WeightExtractionHalf 275 | else: 276 | assert False, "Unsupported bit-width" 277 | 278 | with torch.cuda.device(weight.device): 279 | n, m = weight.size(0), weight.size(1) 280 | out = torch.empty(n, m * (8 // source_bit_width), dtype=torch.half, device="cuda") 281 | stream = torch.cuda.current_stream() 282 | 283 | gridDim = (n, 1, 1) 284 | blockDim = (min(round_up(m, 32), 1024), 1, 1) 285 | 286 | func( 287 | gridDim, 288 | blockDim, 289 | 0, 290 | stream, 291 | [ 292 | ctypes.c_void_p(weight.data_ptr()), 293 | ctypes.c_void_p(scale_list.data_ptr()), 294 | ctypes.c_void_p(out.data_ptr()), 295 | ctypes.c_int32(n), 296 | ctypes.c_int32(m), 297 | ], 298 | ) 299 | return out 300 | 301 | 302 | def extract_weight_to_float(weight: torch.Tensor, scale_list: torch.Tensor, source_bit_width: int, 303 | quantization_cache=None): 304 | """extract weight on cpu to float32""" 305 | if source_bit_width == 8: 306 | func = cpu_kernels.int8WeightExtractionFloat 307 | elif source_bit_width == 4: 308 | func = cpu_kernels.int4WeightExtractionFloat 309 | else: 310 | assert False, "Unsupported bit-width" 311 | 312 | n, m = weight.size(0), weight.size(1) 313 | 314 | if quantization_cache is not None: 315 | out = quantization_cache 316 | func( 317 | ctypes.c_void_p(weight.data_ptr()), 318 | ctypes.c_void_p(scale_list.data_ptr()), 319 | ctypes.c_void_p(out.data_ptr()), 320 | ctypes.c_int32(n), 321 | ctypes.c_int32(m) 322 | ) 323 | return out.tensor 324 | else: 325 | out = torch.empty(n, m * (8 // source_bit_width), dtype=torch.float, device="cpu") 326 | func( 327 | ctypes.c_void_p(weight.data_ptr()), 328 | ctypes.c_void_p(scale_list.data_ptr()), 329 | ctypes.c_void_p(out.data_ptr()), 330 | ctypes.c_int32(n), 331 | ctypes.c_int32(m) 332 | ) 333 | return out 334 | 335 | 336 | class CacheTensor(): 337 | def __init__(self, *args, **kwargs): 338 | self.tensor = torch.empty(*args, **kwargs) 339 | 340 | def to(self, *args, **kwargs): 341 | self.tensor = self.tensor.to(*args, **kwargs) 342 | 343 | def data_ptr(self): 344 | return self.tensor.data_ptr() 345 | 346 | 347 | class QuantizedLinear(Linear): 348 | def __init__(self, weight_bit_width: int, weight_tensor=None, bias_tensor=None, quantized_weight=None, 349 | quantized_weight_scale=None, quantization_cache=None, empty_init=False, *args, **kwargs): 350 | super(QuantizedLinear, self).__init__(*args, **kwargs) 351 | self.weight_bit_width = weight_bit_width 352 | self.quantization_cache = quantization_cache 353 | 354 | if (quantized_weight is not None) and (quantized_weight_scale is not None): 355 | del self.weight 356 | self.weight = Parameter(quantized_weight.to(kwargs["device"]), requires_grad=False) 357 | self.weight_scale = Parameter(quantized_weight_scale.to(kwargs["device"]), requires_grad=False) 358 | else: 359 | shape = self.weight.shape 360 | del self.weight 361 | 362 | if weight_tensor is None or empty_init: 363 | self.weight = torch.empty( 364 | shape[0], shape[1] * weight_bit_width // 8, dtype=torch.int8, device=kwargs["device"] 365 | ) 366 | self.weight_scale = torch.empty(shape[0], dtype=kwargs["dtype"], device=kwargs["device"]) 367 | else: 368 | self.weight_scale = (weight_tensor.abs().max(dim=-1).values / ((2 ** (weight_bit_width - 1)) - 1)).to( 369 | kwargs["dtype"]) 370 | self.weight = torch.round(weight_tensor / self.weight_scale[:, None]).to(torch.int8) 371 | if weight_bit_width == 4: 372 | self.weight = compress_int4_weight(self.weight) 373 | 374 | self.weight = Parameter(self.weight.to(kwargs["device"]), requires_grad=False) 375 | self.weight_scale = Parameter(self.weight_scale.to(kwargs["device"]), requires_grad=False) 376 | 377 | if bias_tensor is not None: 378 | self.bias = Parameter(bias_tensor.to(kwargs["device"]), requires_grad=False) 379 | else: 380 | self.bias = None 381 | 382 | def reset_parameters(self): 383 | """To accelerate initialization""" 384 | pass 385 | 386 | def forward(self, input): 387 | if self.weight.device == torch.device("cpu"): 388 | output = W8A16LinearCPU.apply(input, self.weight, self.weight_scale, self.weight_bit_width, 389 | self.quantization_cache) 390 | else: 391 | output = W8A16Linear.apply(input, self.weight, self.weight_scale, self.weight_bit_width) 392 | if self.bias is not None: 393 | output = output + self.bias 394 | return output 395 | 396 | def _apply(self, fn): 397 | self_obj = super()._apply(fn) 398 | if self.quantization_cache is not None: 399 | self.quantization_cache.to(self_obj.weight.device) 400 | self.quantization_cache.to(self_obj.weight_scale.dtype) 401 | return self_obj 402 | 403 | 404 | class QuantizedEmbedding(Embedding): # TODO: backward, check empty_init 405 | def __init__(self, weight_bit_width: int, weight_tensor=None, quantized_weight=None, quantized_weight_scale=None, 406 | empty_init=False, *args, **kwargs): 407 | super(QuantizedEmbedding, self).__init__(*args, **kwargs) 408 | self.weight_bit_width = weight_bit_width 409 | 410 | if (quantized_weight is not None) and (quantized_weight_scale is not None): 411 | del self.weight 412 | self.weight = Parameter(quantized_weight.to(kwargs["device"]), requires_grad=False) 413 | self.weight_scale = Parameter(quantized_weight_scale.to(kwargs["device"]), requires_grad=False) 414 | else: 415 | shape = self.weight.shape 416 | del self.weight 417 | 418 | if weight_tensor is None or empty_init: 419 | self.weight = torch.empty( 420 | shape[0], shape[1] * weight_bit_width // 8, dtype=torch.int8, device=kwargs["device"] 421 | ) 422 | self.weight_scale = torch.empty(shape[0], dtype=kwargs["dtype"], device=kwargs["device"]) 423 | else: 424 | self.weight_scale = (weight_tensor.abs().max(dim=-1).values / ((2 ** (weight_bit_width - 1)) - 1)).to( 425 | kwargs["dtype"]) 426 | self.weight = torch.round(weight_tensor / self.weight_scale[:, None]).to(torch.int8) 427 | if weight_bit_width == 4: 428 | self.weight = compress_int4_weight(self.weight) 429 | 430 | self.weight = Parameter(self.weight.to(kwargs["device"]), requires_grad=False) 431 | self.weight_scale = Parameter(self.weight_scale.to(kwargs["device"]), requires_grad=False) 432 | 433 | def forward(self, input): 434 | if self.weight.device == torch.device("cpu"): 435 | original_weight = extract_weight_to_float(weight=self.weight, scale_list=self.weight_scale, 436 | source_bit_width=self.weight_bit_width) 437 | else: 438 | original_weight = extract_weight_to_half(weight=self.weight, scale_list=self.weight_scale, 439 | source_bit_width=self.weight_bit_width) 440 | output = F.embedding( 441 | input, original_weight, self.padding_idx, self.max_norm, 442 | self.norm_type, self.scale_grad_by_freq, self.sparse 443 | ) 444 | return output 445 | 446 | 447 | def load_cpu_kernel(**kwargs): 448 | global cpu_kernels 449 | cpu_kernels = CPUKernel(**kwargs) 450 | 451 | 452 | def quantize(model, weight_bit_width, use_quantization_cache=False, empty_init=False, **kwargs): 453 | """Replace fp16 linear with quantized linear""" 454 | 455 | query_key_value_quantization_cache = None 456 | dense_quantization_cache = None 457 | dense_h_to_4h_quantization_cache = None 458 | dense_4h_to_h_quantization_cache = None 459 | 460 | load_cpu_kernel(**kwargs) 461 | if not cpu_kernels.load: 462 | if kernels is None: # CUDA kernels failed 463 | print("Cannot load cpu or cuda kernel, quantization failed:") 464 | assert kernels is not None 465 | print("Cannot load cpu kernel, don't use quantized model on cpu.") 466 | 467 | current_device = model.device 468 | 469 | if model.device == torch.device("cpu"): 470 | dtype = torch.float32 471 | else: 472 | dtype = torch.half 473 | 474 | QuantizedLinearWithPara = partial( 475 | QuantizedLinear, 476 | weight_bit_width=weight_bit_width, 477 | bias=True, 478 | dtype=dtype, 479 | empty_init=empty_init 480 | ) 481 | 482 | if use_quantization_cache: 483 | print("Using quantization cache") 484 | layer = model.layers[0] 485 | weight = layer.attention.query_key_value.weight 486 | n, m = weight.size(0), weight.size(1) 487 | query_key_value_quantization_cache = CacheTensor(n, m, dtype=dtype, device=current_device, requires_grad=False) 488 | weight = layer.attention.dense.weight 489 | n, m = weight.size(0), weight.size(1) 490 | dense_quantization_cache = CacheTensor(n, m, dtype=dtype, device=current_device, requires_grad=False) 491 | weight = layer.mlp.dense_h_to_4h.weight 492 | n, m = weight.size(0), weight.size(1) 493 | dense_h_to_4h_quantization_cache = CacheTensor(n, m, dtype=dtype, device=current_device, requires_grad=False) 494 | weight = layer.mlp.dense_4h_to_h.weight 495 | n, m = weight.size(0), weight.size(1) 496 | dense_4h_to_h_quantization_cache = CacheTensor(n, m, dtype=dtype, device=current_device, requires_grad=False) 497 | 498 | print("Applying quantization to glm layers") 499 | 500 | for layer in model.layers: 501 | layer.attention.query_key_value = QuantizedLinearWithPara( 502 | weight_tensor=layer.attention.query_key_value.weight.to(current_device), 503 | bias_tensor=layer.attention.query_key_value.bias, 504 | in_features=layer.attention.query_key_value.in_features, 505 | out_features=layer.attention.query_key_value.out_features, 506 | device=layer.attention.query_key_value.weight.device, 507 | quantization_cache=query_key_value_quantization_cache 508 | ) 509 | layer.attention.dense = QuantizedLinearWithPara( 510 | weight_tensor=layer.attention.dense.weight.to(current_device), 511 | bias_tensor=layer.attention.dense.bias, 512 | in_features=layer.attention.dense.in_features, 513 | out_features=layer.attention.dense.out_features, 514 | device=layer.attention.dense.weight.device, 515 | quantization_cache=dense_quantization_cache 516 | ) 517 | layer.mlp.dense_h_to_4h = QuantizedLinearWithPara( 518 | weight_tensor=layer.mlp.dense_h_to_4h.weight.to(current_device), 519 | bias_tensor=layer.mlp.dense_h_to_4h.bias, 520 | in_features=layer.mlp.dense_h_to_4h.in_features, 521 | out_features=layer.mlp.dense_h_to_4h.out_features, 522 | device=layer.mlp.dense_h_to_4h.weight.device, 523 | quantization_cache=dense_h_to_4h_quantization_cache 524 | ) 525 | layer.mlp.dense_4h_to_h = QuantizedLinearWithPara( 526 | weight_tensor=layer.mlp.dense_4h_to_h.weight.to(current_device), 527 | bias_tensor=layer.mlp.dense_4h_to_h.bias, 528 | in_features=layer.mlp.dense_4h_to_h.in_features, 529 | out_features=layer.mlp.dense_4h_to_h.out_features, 530 | device=layer.mlp.dense_4h_to_h.weight.device, 531 | quantization_cache=dense_4h_to_h_quantization_cache 532 | ) 533 | return model 534 | -------------------------------------------------------------------------------- /quantization_kernels.c: -------------------------------------------------------------------------------- 1 | void compress_int4_weight(void *weight, void *out, int n, int m) 2 | { 3 | for(int i=0;i> 4); 27 | out += sizeof(float); 28 | (*(float*)(out)) = (*(float*)(scale_list)) * (((char)((*(unsigned char*)(weight)) << 4))>> 4); 29 | out += sizeof(float); 30 | weight += sizeof(char); 31 | } 32 | scale_list += sizeof(float); 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /quantization_kernels_parallel.c: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | void set_num_threads(int n_threads) 4 | { 5 | omp_set_num_threads(n_threads); 6 | } 7 | 8 | int get_num_threads() 9 | { 10 | return omp_get_num_threads(); 11 | } 12 | 13 | void compress_int4_weight(void *weight, void *out, int n, int m) 14 | { 15 | #pragma omp parallel for 16 | for(int i=0;i> 4); 44 | (*(float*)(out + sizeof(float) * (i * (m << 1) + ((j << 1) | 1)))) = (*(float*)(scale_list + sizeof(float) * i)) * (((char)((*(unsigned char*)(weight + sizeof(char) * (i * m + j))) << 4))>> 4); 45 | } 46 | } 47 | } 48 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | protobuf 2 | transformers==4.27.1 3 | cpm_kernels 4 | torch>=1.10 5 | gradio 6 | mdtex2html 7 | sentencepiece 8 | accelerate -------------------------------------------------------------------------------- /tokenization_chatglm.py: -------------------------------------------------------------------------------- 1 | """Tokenization classes for ChatGLM.""" 2 | from typing import List, Optional, Union 3 | import os 4 | 5 | from transformers.tokenization_utils import PreTrainedTokenizer 6 | from transformers.utils import logging, PaddingStrategy 7 | from transformers.tokenization_utils_base import EncodedInput, BatchEncoding 8 | from typing import Dict 9 | import sentencepiece as spm 10 | import numpy as np 11 | 12 | logger = logging.get_logger(__name__) 13 | 14 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 15 | "D:\\res\\lowcode-llm-demo": 2048, 16 | } 17 | 18 | 19 | class TextTokenizer: 20 | def __init__(self, model_path): 21 | self.sp = spm.SentencePieceProcessor() 22 | self.sp.Load(model_path) 23 | self.num_tokens = self.sp.vocab_size() 24 | 25 | def encode(self, text): 26 | return self.sp.EncodeAsIds(text) 27 | 28 | def decode(self, ids: List[int]): 29 | return self.sp.DecodeIds(ids) 30 | 31 | def tokenize(self, text): 32 | return self.sp.EncodeAsPieces(text) 33 | 34 | def convert_tokens_to_string(self, tokens): 35 | return self.sp.DecodePieces(tokens) 36 | 37 | def convert_tokens_to_ids(self, tokens): 38 | return [self.sp.PieceToId(token) for token in tokens] 39 | 40 | def convert_token_to_id(self, token): 41 | return self.sp.PieceToId(token) 42 | 43 | def convert_id_to_token(self, idx): 44 | return self.sp.IdToPiece(idx) 45 | 46 | def __len__(self): 47 | return self.num_tokens 48 | 49 | 50 | class SPTokenizer: 51 | def __init__( 52 | self, 53 | vocab_file, 54 | num_image_tokens=20000, 55 | max_blank_length=80, 56 | byte_fallback=True, 57 | ): 58 | assert vocab_file is not None 59 | self.vocab_file = vocab_file 60 | self.num_image_tokens = num_image_tokens 61 | self.special_tokens = ["[MASK]", "[gMASK]", "[sMASK]", "", "", "", "", ""] 62 | self.max_blank_length = max_blank_length 63 | self.byte_fallback = byte_fallback 64 | self.text_tokenizer = TextTokenizer(vocab_file) 65 | 66 | def _get_text_tokenizer(self): 67 | return self.text_tokenizer 68 | 69 | @staticmethod 70 | def get_blank_token(length: int): 71 | assert length >= 2 72 | return f"<|blank_{length}|>" 73 | 74 | @staticmethod 75 | def get_tab_token(): 76 | return f"<|tab|>" 77 | 78 | @property 79 | def num_text_tokens(self): 80 | return self.text_tokenizer.num_tokens 81 | 82 | @property 83 | def num_tokens(self): 84 | return self.num_image_tokens + self.num_text_tokens 85 | 86 | @staticmethod 87 | def _encode_whitespaces(text: str, max_len: int = 80): 88 | text = text.replace("\t", SPTokenizer.get_tab_token()) 89 | for i in range(max_len, 1, -1): 90 | text = text.replace(" " * i, SPTokenizer.get_blank_token(i)) 91 | return text 92 | 93 | def _preprocess(self, text: str, linebreak=True, whitespaces=True): 94 | if linebreak: 95 | text = text.replace("\n", "") 96 | if whitespaces: 97 | text = self._encode_whitespaces(text, max_len=self.max_blank_length) 98 | return text 99 | 100 | def encode( 101 | self, text: str, linebreak=True, whitespaces=True, add_dummy_prefix=True 102 | ) -> List[int]: 103 | """ 104 | @param text: Text to encode. 105 | @param linebreak: Whether to encode newline (\n) in text. 106 | @param whitespaces: Whether to encode multiple whitespaces or tab in text, useful for source code encoding. 107 | @param special_tokens: Whether to encode special token ([MASK], [gMASK], etc.) in text. 108 | @param add_dummy_prefix: Whether to add dummy blank space in the beginning. 109 | """ 110 | text = self._preprocess(text, linebreak, whitespaces) 111 | if not add_dummy_prefix: 112 | text = "" + text 113 | tmp = self._get_text_tokenizer().encode(text) 114 | tokens = [x + self.num_image_tokens for x in tmp] 115 | return tokens if add_dummy_prefix else tokens[2:] 116 | 117 | def postprocess(self, text): 118 | text = text.replace("", "\n") 119 | text = text.replace(SPTokenizer.get_tab_token(), "\t") 120 | for i in range(2, self.max_blank_length + 1): 121 | text = text.replace(self.get_blank_token(i), " " * i) 122 | return text 123 | 124 | def decode(self, text_ids: List[int]) -> str: 125 | ids = [int(_id) - self.num_image_tokens for _id in text_ids] 126 | ids = [_id for _id in ids if _id >= 0] 127 | text = self._get_text_tokenizer().decode(ids) 128 | text = self.postprocess(text) 129 | return text 130 | 131 | def decode_tokens(self, tokens: List[str]) -> str: 132 | text = self._get_text_tokenizer().convert_tokens_to_string(tokens) 133 | text = self.postprocess(text) 134 | return text 135 | 136 | def tokenize( 137 | self, text: str, linebreak=True, whitespaces=True, add_dummy_prefix=True 138 | ) -> List[str]: 139 | """ 140 | @param text: Text to encode. 141 | @param linebreak: Whether to encode newline (\n) in text. 142 | @param whitespaces: Whether to encode multiple whitespaces or tab in text, useful for source code encoding. 143 | @param special_tokens: Whether to encode special token ([MASK], [gMASK], etc.) in text. 144 | @param add_dummy_prefix: Whether to add dummy blank space in the beginning. 145 | """ 146 | text = self._preprocess(text, linebreak, whitespaces) 147 | if not add_dummy_prefix: 148 | text = "" + text 149 | tokens = self._get_text_tokenizer().tokenize(text) 150 | return tokens if add_dummy_prefix else tokens[2:] 151 | 152 | def __getitem__(self, x: Union[int, str]): 153 | if isinstance(x, int): 154 | if x < self.num_image_tokens: 155 | return "".format(x) 156 | else: 157 | return self.text_tokenizer.convert_id_to_token(x - self.num_image_tokens) 158 | elif isinstance(x, str): 159 | if x.startswith("") and x[7:-1].isdigit(): 160 | return int(x[7:-1]) 161 | else: 162 | return self.text_tokenizer.convert_token_to_id(x) + self.num_image_tokens 163 | else: 164 | raise ValueError("The key should be str or int.") 165 | 166 | 167 | class ChatGLMTokenizer(PreTrainedTokenizer): 168 | """ 169 | Construct a ChatGLM tokenizer. Based on byte-level Byte-Pair-Encoding. 170 | 171 | Args: 172 | vocab_file (`str`): 173 | Path to the vocabulary file. 174 | """ 175 | 176 | vocab_files_names = {"vocab_file": "ice_text.model"} 177 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES 178 | model_input_names = ["input_ids", "attention_mask", "position_ids"] 179 | 180 | def __init__( 181 | self, 182 | vocab_file, 183 | do_lower_case=False, 184 | remove_space=False, 185 | bos_token='', 186 | eos_token='', 187 | end_token='', 188 | mask_token='[MASK]', 189 | gmask_token='[gMASK]', 190 | padding_side="left", 191 | pad_token="", 192 | unk_token="", 193 | num_image_tokens=20000, 194 | **kwargs 195 | ) -> None: 196 | super().__init__( 197 | do_lower_case=do_lower_case, 198 | remove_space=remove_space, 199 | padding_side=padding_side, 200 | bos_token=bos_token, 201 | eos_token=eos_token, 202 | end_token=end_token, 203 | mask_token=mask_token, 204 | gmask_token=gmask_token, 205 | pad_token=pad_token, 206 | unk_token=unk_token, 207 | num_image_tokens=num_image_tokens, 208 | **kwargs 209 | ) 210 | 211 | self.do_lower_case = do_lower_case 212 | self.remove_space = remove_space 213 | self.vocab_file = vocab_file 214 | 215 | self.bos_token = bos_token 216 | self.eos_token = eos_token 217 | self.end_token = end_token 218 | self.mask_token = mask_token 219 | self.gmask_token = gmask_token 220 | 221 | self.sp_tokenizer = SPTokenizer(vocab_file, num_image_tokens=num_image_tokens) 222 | 223 | """ Initialisation """ 224 | 225 | @property 226 | def gmask_token_id(self) -> Optional[int]: 227 | if self.gmask_token is None: 228 | return None 229 | return self.convert_tokens_to_ids(self.gmask_token) 230 | 231 | @property 232 | def end_token_id(self) -> Optional[int]: 233 | """ 234 | `Optional[int]`: Id of the end of context token in the vocabulary. Returns `None` if the token has not been 235 | set. 236 | """ 237 | if self.end_token is None: 238 | return None 239 | return self.convert_tokens_to_ids(self.end_token) 240 | 241 | @property 242 | def vocab_size(self): 243 | """ Returns vocab size """ 244 | return self.sp_tokenizer.num_tokens 245 | 246 | def get_vocab(self): 247 | """ Returns vocab as a dict """ 248 | vocab = {self._convert_id_to_token(i): i for i in range(self.vocab_size)} 249 | vocab.update(self.added_tokens_encoder) 250 | return vocab 251 | 252 | def preprocess_text(self, inputs): 253 | if self.remove_space: 254 | outputs = " ".join(inputs.strip().split()) 255 | else: 256 | outputs = inputs 257 | 258 | if self.do_lower_case: 259 | outputs = outputs.lower() 260 | 261 | return outputs 262 | 263 | def _tokenize(self, text, **kwargs): 264 | """ Returns a tokenized string. """ 265 | text = self.preprocess_text(text) 266 | 267 | seq = self.sp_tokenizer.tokenize(text) 268 | 269 | return seq 270 | 271 | def convert_tokens_to_string(self, tokens: List[str]) -> str: 272 | return self.sp_tokenizer.decode_tokens(tokens) 273 | 274 | def _decode( 275 | self, 276 | token_ids: Union[int, List[int]], 277 | **kwargs 278 | ) -> str: 279 | if isinstance(token_ids, int): 280 | token_ids = [token_ids] 281 | if len(token_ids) == 0: 282 | return "" 283 | if self.pad_token_id in token_ids: # remove pad 284 | token_ids = list(filter((self.pad_token_id).__ne__, token_ids)) 285 | return super()._decode(token_ids, **kwargs) 286 | 287 | def _convert_token_to_id(self, token): 288 | """ Converts a token (str) in an id using the vocab. """ 289 | return self.sp_tokenizer[token] 290 | 291 | def _convert_id_to_token(self, index): 292 | """Converts an index (integer) in a token (str) using the vocab.""" 293 | return self.sp_tokenizer[index] 294 | 295 | def save_vocabulary(self, save_directory, filename_prefix=None): 296 | """ 297 | Save the vocabulary and special tokens file to a directory. 298 | 299 | Args: 300 | save_directory (`str`): 301 | The directory in which to save the vocabulary. 302 | filename_prefix (`str`, *optional*): 303 | An optional prefix to add to the named of the saved files. 304 | 305 | Returns: 306 | `Tuple(str)`: Paths to the files saved. 307 | """ 308 | if os.path.isdir(save_directory): 309 | vocab_file = os.path.join( 310 | save_directory, self.vocab_files_names["vocab_file"] 311 | ) 312 | else: 313 | vocab_file = save_directory 314 | 315 | with open(self.vocab_file, 'rb') as fin: 316 | proto_str = fin.read() 317 | 318 | with open(vocab_file, "wb") as writer: 319 | writer.write(proto_str) 320 | 321 | return (vocab_file,) 322 | 323 | def build_inputs_with_special_tokens( 324 | self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None 325 | ) -> List[int]: 326 | """ 327 | Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and 328 | adding special tokens. A BERT sequence has the following format: 329 | 330 | - single sequence: `[CLS] X [SEP]` 331 | - pair of sequences: `[CLS] A [SEP] B [SEP]` 332 | 333 | Args: 334 | token_ids_0 (`List[int]`): 335 | List of IDs to which the special tokens will be added. 336 | token_ids_1 (`List[int]`, *optional*): 337 | Optional second list of IDs for sequence pairs. 338 | 339 | Returns: 340 | `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. 341 | """ 342 | gmask_id = self.sp_tokenizer[self.gmask_token] 343 | eos_id = self.sp_tokenizer[self.eos_token] 344 | token_ids_0 = token_ids_0 + [gmask_id, self.sp_tokenizer[self.bos_token]] 345 | if token_ids_1 is not None: 346 | token_ids_0 = token_ids_0 + token_ids_1 + [eos_id] 347 | return token_ids_0 348 | 349 | def _pad( 350 | self, 351 | encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding], 352 | max_length: Optional[int] = None, 353 | padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, 354 | pad_to_multiple_of: Optional[int] = None, 355 | return_attention_mask: Optional[bool] = None, 356 | ) -> dict: 357 | """ 358 | Pad encoded inputs (on left/right and up to predefined length or max length in the batch) 359 | 360 | Args: 361 | encoded_inputs: 362 | Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`). 363 | max_length: maximum length of the returned list and optionally padding length (see below). 364 | Will truncate by taking into account the special tokens. 365 | padding_strategy: PaddingStrategy to use for padding. 366 | 367 | - PaddingStrategy.LONGEST Pad to the longest sequence in the batch 368 | - PaddingStrategy.MAX_LENGTH: Pad to the max length (default) 369 | - PaddingStrategy.DO_NOT_PAD: Do not pad 370 | The tokenizer padding sides are defined in self.padding_side: 371 | 372 | - 'left': pads on the left of the sequences 373 | - 'right': pads on the right of the sequences 374 | pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value. 375 | This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability 376 | `>= 7.5` (Volta). 377 | return_attention_mask: 378 | (optional) Set to False to avoid returning attention mask (default: set to model specifics) 379 | """ 380 | # Load from model defaults 381 | bos_token_id = self.sp_tokenizer[self.bos_token] 382 | mask_token_id = self.sp_tokenizer[self.mask_token] 383 | gmask_token_id = self.sp_tokenizer[self.gmask_token] 384 | assert self.padding_side == "left" 385 | 386 | required_input = encoded_inputs[self.model_input_names[0]] 387 | seq_length = len(required_input) 388 | 389 | if padding_strategy == PaddingStrategy.LONGEST: 390 | max_length = len(required_input) 391 | 392 | if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0): 393 | max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of 394 | 395 | needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length 396 | 397 | # Initialize attention mask if not present. 398 | if max_length is not None: 399 | if "attention_mask" not in encoded_inputs: 400 | if bos_token_id in required_input: 401 | context_length = required_input.index(bos_token_id) 402 | else: 403 | context_length = seq_length 404 | attention_mask = np.ones((1, seq_length, seq_length)) 405 | attention_mask = np.tril(attention_mask) 406 | attention_mask[:, :, :context_length] = 1 407 | attention_mask = np.bool_(attention_mask < 0.5) 408 | encoded_inputs["attention_mask"] = attention_mask 409 | 410 | if "position_ids" not in encoded_inputs: 411 | if bos_token_id in required_input: 412 | context_length = required_input.index(bos_token_id) 413 | else: 414 | context_length = seq_length 415 | position_ids = np.arange(seq_length, dtype=np.int64) 416 | mask_token = mask_token_id if mask_token_id in required_input else gmask_token_id 417 | if mask_token in required_input: 418 | mask_position = required_input.index(mask_token) 419 | position_ids[context_length:] = mask_position 420 | block_position_ids = np.concatenate( 421 | [np.zeros(context_length, dtype=np.int64), 422 | np.arange(1, seq_length - context_length + 1, dtype=np.int64)]) 423 | encoded_inputs["position_ids"] = np.stack([position_ids, block_position_ids], axis=0) 424 | 425 | if needs_to_be_padded: 426 | difference = max_length - len(required_input) 427 | 428 | if "attention_mask" in encoded_inputs: 429 | encoded_inputs["attention_mask"] = np.pad(encoded_inputs["attention_mask"], 430 | pad_width=[(0, 0), (difference, 0), (difference, 0)], 431 | mode='constant', constant_values=True) 432 | if "token_type_ids" in encoded_inputs: 433 | encoded_inputs["token_type_ids"] = [self.pad_token_type_id] * difference + encoded_inputs[ 434 | "token_type_ids" 435 | ] 436 | if "special_tokens_mask" in encoded_inputs: 437 | encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"] 438 | if "position_ids" in encoded_inputs: 439 | encoded_inputs["position_ids"] = np.pad(encoded_inputs["position_ids"], 440 | pad_width=[(0, 0), (difference, 0)]) 441 | encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input 442 | 443 | return encoded_inputs 444 | -------------------------------------------------------------------------------- /tokenizer.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/woai3c/lowcode-llm-demo/5a1ffa6f9ed9b2d2eef78650c021ffa0b74dcd62/tokenizer.model -------------------------------------------------------------------------------- /tokenizer_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "name_or_path": "D:\\res\\lowcode-llm-demo", 3 | "bos_token": "", 4 | "eos_token": "", 5 | "end_token": "", 6 | "gmask_token": "[gMASK]", 7 | "mask_token": "[MASK]", 8 | "pad_token": "", 9 | "unk_token": "", 10 | "remove_space": false, 11 | "do_lower_case": false, 12 | "tokenizer_class": "ChatGLMTokenizer", 13 | "num_image_tokens": 0, 14 | "auto_map": { 15 | "AutoTokenizer": [ 16 | "tokenization_chatglm.ChatGLMTokenizer", 17 | null 18 | ] 19 | } 20 | } 21 | -------------------------------------------------------------------------------- /web-demo.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoModel, AutoTokenizer 2 | import gradio as gr 3 | import mdtex2html 4 | 5 | # 这里要替换为你的 ChatGLM-6B 模型路径 6 | tokenizer = AutoTokenizer.from_pretrained("D:\\res\\lowcode-llm-demo", trust_remote_code=True, revision="") 7 | model = AutoModel.from_pretrained("D:\\res\\lowcode-llm-demo",trust_remote_code=True, revision="").float() 8 | model = model.eval() 9 | 10 | """Override Chatbot.postprocess""" 11 | 12 | 13 | def postprocess(self, y): 14 | if y is None: 15 | return [] 16 | for i, (message, response) in enumerate(y): 17 | y[i] = ( 18 | None if message is None else mdtex2html.convert((message)), 19 | None if response is None else mdtex2html.convert(response), 20 | ) 21 | return y 22 | 23 | 24 | gr.Chatbot.postprocess = postprocess 25 | 26 | 27 | def parse_text(text): 28 | print(text) 29 | lines = text.split("\n") 30 | lines = [line for line in lines if line != ""] 31 | count = 0 32 | for i, line in enumerate(lines): 33 | if "```" in line: 34 | count += 1 35 | items = line.split('`') 36 | if count % 2 == 1: 37 | lines[i] = f'
'
 38 |             else:
 39 |                 lines[i] = f'
' 40 | else: 41 | if i > 0: 42 | if count % 2 == 1: 43 | line = line.replace("`", "\`") 44 | line = line.replace("<", "<") 45 | line = line.replace(">", ">") 46 | line = line.replace(" ", " ") 47 | line = line.replace("*", "*") 48 | line = line.replace("_", "_") 49 | line = line.replace("-", "-") 50 | line = line.replace(".", ".") 51 | line = line.replace("!", "!") 52 | line = line.replace("(", "(") 53 | line = line.replace(")", ")") 54 | line = line.replace("$", "$") 55 | lines[i] = "
"+line 56 | text = "".join(lines) 57 | return text 58 | 59 | 60 | def predict(input, chatbot, max_length, top_p, temperature, history): 61 | chatbot.append((parse_text(input), "")) 62 | for response, history in model.stream_chat(tokenizer, input, history, max_length=max_length, top_p=top_p, 63 | temperature=temperature): 64 | chatbot[-1] = (parse_text(input), parse_text(response)) 65 | 66 | yield chatbot, history 67 | 68 | 69 | def reset_user_input(): 70 | return gr.update(value='') 71 | 72 | 73 | def reset_state(): 74 | return [], [] 75 | 76 | 77 | with gr.Blocks() as demo: 78 | gr.HTML("""

ChatGLM

""") 79 | 80 | chatbot = gr.Chatbot() 81 | with gr.Row(): 82 | with gr.Column(scale=4): 83 | with gr.Column(scale=12): 84 | user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10) 85 | with gr.Column(min_width=32, scale=1): 86 | submitBtn = gr.Button("Submit", variant="primary") 87 | with gr.Column(scale=1): 88 | emptyBtn = gr.Button("Clear History") 89 | max_length = gr.Slider(0, 4096, value=2048, step=1.0, label="Maximum length", interactive=True) 90 | top_p = gr.Slider(0, 1, value=0.7, step=0.01, label="Top P", interactive=True) 91 | temperature = gr.Slider(0, 1, value=0.95, step=0.01, label="Temperature", interactive=True) 92 | 93 | history = gr.State([]) 94 | 95 | submitBtn.click(predict, [user_input, chatbot, max_length, top_p, temperature, history], [chatbot, history], 96 | show_progress=True) 97 | submitBtn.click(reset_user_input, [], [user_input]) 98 | 99 | emptyBtn.click(reset_state, outputs=[chatbot, history], show_progress=True) 100 | 101 | demo.queue().launch(share=False, inbrowser=True) 102 | --------------------------------------------------------------------------------