├── .gitignore ├── LICENSE ├── README.md ├── code ├── benchmark │ └── load_test.py ├── client │ ├── __init__.py │ ├── client.py │ └── openai_jumper.py ├── client_app.py ├── client_requirements.txt ├── continuous_batching_server_app.py ├── gunicorn_config.py ├── protocol │ ├── __init__.py │ ├── completion_task.py │ ├── error.py │ └── routes.py ├── server │ ├── __init__.py │ ├── continuous_batching_server │ │ ├── __init__.py │ │ ├── batcher.py │ │ ├── beam.py │ │ ├── cache │ │ │ ├── README.md │ │ │ ├── __init__.py │ │ │ ├── cache.py │ │ │ └── cache_manager.py │ │ ├── config.py │ │ ├── generation_utils │ │ │ ├── README.md │ │ │ ├── __init__.py │ │ │ ├── logits_process.py │ │ │ └── tokens.py │ │ ├── modeling │ │ │ ├── __init__.py │ │ │ ├── llama.py │ │ │ └── utils │ │ │ │ ├── __init__.py │ │ │ │ ├── attention.py │ │ │ │ ├── linear.py │ │ │ │ └── weights.py │ │ ├── server.py │ │ └── worker.py │ └── static_batching_server │ │ ├── __init__.py │ │ ├── batcher.py │ │ ├── config.py │ │ ├── server.py │ │ └── worker.py ├── server_requirements.txt ├── start_app.py ├── static_batching_server_app.py └── utils │ ├── __init__.py │ └── log_util.py └── docs └── tutorial ├── 0_概述.md └── 1_大语言模型推理引擎设计与代码讲解.md /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

大语言模型推理部署教程

2 |

面向大语言模型开发工程师的,囊括推理引擎设计、服务部署、性能评测等方面的进阶教程,并开源了一套经过优化的单卡推理引擎代码供研究与学习

3 | 4 |

教程文档正在持续快速地更新中,watch 本项目以第一时间获取最新内容!

5 | 6 | ## 目录 7 | 8 | - [本项目文件结构](#本项目文件结构) 9 | - [推理引擎代码](#推理引擎代码) 10 | * [介绍](#介绍) 11 | * [环境配置](#环境配置) 12 | * [服务部署](#服务部署) 13 | 14 | 15 | ## 本项目文件结构 16 | 17 | 本项目主要包含两个目录:`code` 和 `docs`。 18 | - `code` 目录:存放着我们以研究和学习为主要目的的推理引擎代码,开发者们可以在该目录下通过简单几步来快速配置环境和部署服务,以对比和体验不同策略在不同流量特征下的差异。 19 | - `docs` 目录:存放着所有章节的教程文档,通过阅读相关内容,开发者们可以了解到在商业化的生产环境中部署大语言模型进行推理的多方面知识。 20 | 21 | ## 推理引擎代码 22 | 23 | ### 介绍 24 | 25 | 我们在本项目中开源了一套经过优化的大语言模型单卡推理引擎,采用 C/S 设计模式,向开发者们展示大语言模型推理引擎的内部运作机理,方便大家亲自上手部署以体验不同策略在不同流量特征下的性能差异。 26 | 27 | #### 优点 28 | 29 | - 加载 `int4` GPTQ 模型 30 | - (客户端)负载均衡 31 | - 多策略的服务端:**静态批处理** (Static Batching, SB) 策略和**持续批处理** (Continuous Batching, CB) 策略 32 | - (支持持续批处理的服务端)基于 xformers 和 PagedAttention 加速推理 33 | - (支持持续批处理的服务端)异构解码策略 34 | - GPU-CPU 内存交换 35 | 36 | #### 缺点 37 | 38 | - 多卡推理 39 | - 流式传输 40 | - (支持持续批处理的服务端)只支持 llama v1 和 llama v2(除 70B)模型 41 | - (支持持续批处理的服务端)未对注意力层之外的其他网络模块和计算操作进行推理性能的优化 42 | - (支持持续批处理的服务端)仅支持 `safetensors` 格式的权重文件 43 | - (支持持续批处理的服务端)不支持拓展最大可处理的上下文长度 44 | - 不支持加载 adapter(s) 45 | - 不支持投机采样(又称辅助生成) 46 | 47 | #### 其他开源大模型推理框架 48 | 49 | 本项目提供的推理引擎代码旨在向开发者们展示大语言模型推理引擎的内部运作机理,因此未对与这一目的关系较弱或无关的方面做进一步优化。 50 | 51 | 通过使用我们提供的这套推理引擎代码,开发者们可以在单卡上轻松部署 20B 及以下的模型,但在面对更大参数量的模型时则捉襟见肘。为此,我们在这里列出开源社区上现已存在的功能较为完善的大语言模型推理框架并作简单介绍: 52 | 53 | - [TGI](https://github.com/huggingface/text-generation-inference): Hugging Face 在内部生产环境中使用的大语言模型推理框架,注意在 0.9.4 版本之后应用于商业目的需获得官方许可; 54 | - [vLLM](https://github.com/vllm-project/vllm): 高性能、易用的大语言模型推理框架,提出并实现了 PagedAttention; 55 | - [lmdeploy](https://github.com/InternLM/lmdeploy): 书生浦语团队研发的大语言模型推理部署框架; 56 | - [TransformerEngine](https://github.com/NVIDIA/TransformerEngine): 英伟达新一代 Transformers 架构模型推理框架,支持 `fp8` 格式; 57 | - [DeepSpeed-MII](https://github.com/microsoft/DeepSpeed-MII): 基于 DeepSpeed 的低延迟和高吞吐推理引擎,不仅仅只支持大语言模型。 58 | 59 | 此外,还有 [FasterTransformer](https://github.com/NVIDIA/FasterTransformer), [FlexGen](https://github.com/FMInference/FlexGen), [EnergonAI](https://github.com/hpcaitech/EnergonAI) 等针对大语言模型推理性能优化而设计的框架,但更新频率较低,故不在此作进一步介绍。 60 | 61 | 我们十分建议开发者们在学习完本项目后,亲自去阅读以上开源项目的代码,以进一步理解大模型推理引擎设计思路和优化技术。有关更多大语言模型推理引擎和部署框架对比的详细介绍,可参阅这篇[文章](https://mp.weixin.qq.com/s/xIbNSAI9cKGIA19yZhIEgg) 。 62 | 63 | > 本项目提供的推理引擎代码在权重加载、模型网络代码设计和文本生成策略实现上借鉴了 TGI 项目;在内存管理上借鉴了 vLLM 项目。 64 | 65 | ### 环境配置 66 | 67 | #### 客户端 68 | 69 | 在 `code` 目录下,执行 `pip install -r client_requirements.txt` 来安装客户端所需的第三方代码库。 70 | 71 | #### 服务端 72 | 73 | 首先,确保虚拟环境中已经安装了支持 CUDA 的 2.0.0 及以上版本的 PyTorch,若没有,可以在 [这里](https://pytorch.org/get-started/locally/) 选择并安装与你的软硬件信息相符的预编译安装包,也可以根据 [这里](https://github.com/pytorch/pytorch#from-source) 的指导从源码编译和安装。 74 | 75 | 其次,安装 vLLM,此举的目的是为了方便我们在代码中使用 paged-attention 算子和与内存管理相关的算子。 76 | - 快速安装:`pip install vllm`; 77 | - 从源码编译安装:遵循 [这里](https://vllm.readthedocs.io/en/latest/getting_started/installation.html#build-from-source) 的指导。 78 | 79 | 接着,根据 [这里](https://github.com/PanQiWei/AutoGPTQ#installation) 的指导安装 `auto-gptq` 代码库,我们提供的推理引擎支持加载和使用 GPTQ 模型。 80 | 81 | 最后,在 `code` 目录下,执行 `pip install -r server_requirements.txt` 安装服务端所需的其他第三方代码库。 82 | 83 | ### 服务部署 84 | 正如你所见,`code` 目录下包含了四个模块: 85 | - `protocol`:定义了客户端和服务端之间的通信规范; 86 | - `client`:客户端相关代码; 87 | - `server`:服务端相关代码; 88 | - `utils`:构建 HTTP 服务代码时的其他工具代码。 89 | 90 | 你可以组合这些模块并将它们封装成单独的 HTTP 服务,如: 91 | - `client_app.py`:使用 `protocol`, `client` 和 `utils` 模块构建的客户端 HTTP 服务; 92 | - `continuous_batching_server_app.py` 和 `static_batching_server_app.py`:使用 `protocol`, `server` 和 `utils` 模块构建的服务端 HTTP 服务。 93 | 94 | 也可以将这些模块直接拷贝到你当前工作的项目中进行集成应用。 95 | 96 | 我们在 `code` 目录中提供了三个以 "_app.py" 为后缀的 HTTP 服务脚本,你可以直接执行这些脚本来启动服务,如: 97 | ```shell 98 | CUDA_VISIBLE_DEVICES=0 python client_app.py --port 8000 # 查询更多命令行参数可使用 python client_app.py --help 99 | ``` 100 | 101 | 此时服务进程将始终运行在前台,你可以直接使用 `ctrl`+`c` 来关闭进程。 102 | 103 | 此外,我们还提供了一个统一的服务部署脚本 `start_app.py`,通过运行这个脚本,你可以使用 `gunicorn` 来管理和在后台运行服务进程。 104 | 105 | > 注意:由于我们提供的推理引擎目前只支持单卡推理,因此在运行服务脚本时,强烈建议同时设置 CUDA_VISIBLE_DEVICES 环境变量以使用指定的单张显卡,避免预期外的行为发生。 106 | 107 | 部署每个 HTTP 服务都需要提供一个相应的配置文件,以下为各服务配置文件模板(使用时注意删去注释文本): 108 | 109 |
110 | client_config.json 111 | 112 | ```json 113 | { 114 | "continuous_batching_server_urls": ["http://127.0.0.1:8001"], # 支持持续批处理的服务端 HTTP 服务地址,请求会在此间负载均衡地分发 115 | "static_batching_server_urls": ["http://127.0.0.1:8002"], # 支持静态批处理的服务端 HTTP 服务地址,请求会在此间负载均衡地分发 116 | "openai_jumper_configs": [ 117 | { 118 | "api_key": "YOUR_OPENAI_KEY", 119 | "org_id": null 120 | } 121 | ], # openai 账号列表,请求会在此间负载均衡地分发 122 | "heart_beat_interval_seconds": 600 # 对服务端的 HTTP 服务心跳检测间隔(以秒为单位) 123 | } 124 | ``` 125 |
126 | 127 |
128 | cb_server_config.json 129 | 130 | ```json 131 | { 132 | "model_loading_config": { 133 | "model_type": "llama", # 模型架构类型,目前只支持 llama 134 | "model_name_or_path": "PATH_TO_MODEL_DIR", # 存放模型权重文件的目录路径,只支持 safetensors 格式 135 | "torch_dtype": "float16", # (非 GPTQ 模型时)模型权重和运算时使用的数值类型,可选项为 float16 和 bfloat16 136 | "tokenizer_name_or_path": null, # 存放分词器模型文件的目录路径,如果为空则使用存放模型权重文件的目录路径 137 | "use_fast_tokenizer": false, # 若为 true 则加载分词器时设置 use_fast=True 138 | "trust_remote_code": false, # 是否使用非 Hugging Face 官方提供的模型或分词器代码 139 | "quantize_method": null, # 量化方法,可选值为 gptq 140 | "model_max_length": 2048, # 模型能处理的最大上下文长度 141 | "gptq_model_base_name": null, # GPTQ 模型权重文件名称(不包含文件拓展名),若为空则使用默认的命名格式查找文件 142 | "gptq_config_base_name": null # GPTQ 配置文件名称(不包含文件拓展名),若为空则使用默认的命名格式查找文件 143 | }, 144 | "batcher_config": { 145 | "batch_max_tokens": 56000, # 一个批次中同时处理的最大 tokens 数量,这里的值为 llama-7b fp16 模型在 A100-40G 上的一个合理值 146 | "batch_max_beams": 32 # 一个批次中同时处理的最大 beam(文本生成阶段的预测分支) 数量,这里的值为 llama-7b fp16 模型在 A100-40G 上的一个合理值 147 | }, 148 | "cache_config": { 149 | "num_blocks": 2500, # GPU 内存块数量,这里的值为 llama-7b fp16 模型在 A100-40G 上的一个合理值 150 | "num_blocks_cpu": 1024, # CPU 内存块数量 151 | "block_size": 16, # 一个内存块的大小 152 | "watermark": 0.01, # 预留的 GPU 内存块比例,这是为了防止过分分配 GPU 内存块给 prompt 的 和从 CPU 内存换入的 KV Cache 而导致文本生成时 GPU 内存资源紧张 153 | } 154 | } 155 | ``` 156 |
157 | 158 |
159 | sb_server_config.json 160 | 161 | ```json 162 | { 163 | "batcher_config": { 164 | "package_max_workload": "16", # 一个任务包的最大工作负载,单位为 beam 165 | "packaging_interval_seconds": 2 # 打包的间隔时间,TPS/QPS 越小间隔时间可以越长 166 | }, 167 | "worker_config": { 168 | "model_name_or_path": "PATH_TO_MODEL_DIR", # 存放模型权重文件的目录路径 169 | "tokenizer_name_or_path": null, # 存放模型权重文件的目录路径 170 | "revision": "main", # 使用的模型仓库分支,仅在目录路径为 Hugging Face Hub 模型名或 github 仓库目录时生效 171 | "low_cpu_mem_usage": true, # 是否直接将模型权重加载到 GPU 172 | "torch_dtype": "float16", # (非 GPTQ 模型时)模型权重和运算时使用的数值类型,可选项为 float16 和 bfloat16 173 | "use_fast_tokenizer": false, # 若为 true 则加载分词器时设置 use_fast=True 174 | "trust_remote_code": false, # 是否使用非 Hugging Face 官方提供的模型或分词器代码 175 | "use_safetensors": false, # 是否加载 `safetensors` 格式的权重文件 176 | "batch_size": -1, # TextGenerationPipeline 执行时所使用的 batch_size,-1 表示同时处理所有的输入 177 | "is_gptq_quantized": false # 是否使用的是 GPTQ 模型 178 | } 179 | } 180 | ``` 181 |
182 | -------------------------------------------------------------------------------- /code/benchmark/load_test.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import json 3 | import random 4 | import sys 5 | import time 6 | from argparse import ArgumentParser 7 | from os.path import abspath, dirname 8 | from typing import List 9 | 10 | import aiohttp 11 | import numpy as np 12 | import pandas as pd 13 | from transformers import AutoTokenizer, PreTrainedTokenizerBase 14 | 15 | sys.path.insert(0, abspath(dirname(dirname(__file__)))) 16 | 17 | from protocol.completion_task import ( 18 | HuggingFaceGenerationConfig, 19 | HuggingFaceCompletionInputs, 20 | HuggingFaceCompletionOutputs, 21 | ) 22 | from protocol.routes import ROUTE_POST_CONTINUOUS_BATCHING_COMPLETION, ROUTE_POST_STATIC_BATCHING_COMPLETION 23 | 24 | 25 | def load_tokenizer(tokenizer_name_or_path: str, use_fast: bool, max_length: int) -> PreTrainedTokenizerBase: 26 | tokenizer = AutoTokenizer.from_pretrained( 27 | tokenizer_name_or_path, 28 | use_fast=use_fast, 29 | model_max_length=max_length, 30 | padding_side="left", 31 | truncation_side="right", 32 | ) 33 | if not tokenizer.pad_token: 34 | tokenizer.pad_token = tokenizer.eos_token 35 | tokenizer.pad_token_id = tokenizer.eos_token_id 36 | 37 | return tokenizer 38 | 39 | 40 | def gen_random_lens(distribution: str, len_mean: int, len_range: int, num_requests: int) -> List[int]: 41 | if distribution == 'uniform': 42 | if len_range == 0: 43 | return [len_mean for _ in range(num_requests)] 44 | 45 | low = len_mean - (len_range // 2) 46 | high = len_mean + (len_range // 2) 47 | num_to_generate = list( 48 | map(lambda _: random.randint(low, high), range(num_requests))) 49 | return num_to_generate 50 | elif distribution == 'exponential': 51 | np.random.seed(random.randint(0, 1e6)) 52 | return [min(round(s), len_range) for s in np.random.exponential(scale=len_mean, size=num_requests)] 53 | elif distribution == 'capped_exponential': 54 | np.random.seed(random.randint(0, 1e6)) 55 | response_lens = [] 56 | while len(response_lens) < num_requests: 57 | sample = round(np.random.exponential(scale=len_mean)) 58 | if sample <= len_range: 59 | response_lens.append(sample) 60 | return response_lens 61 | else: 62 | raise ValueError(f'unknown distribution {distribution=}') 63 | 64 | 65 | def prepare_payloads( 66 | model_id: str, 67 | data_path: str, 68 | tokenizer: PreTrainedTokenizerBase, 69 | num_beams: int, 70 | prompt_len: int, 71 | response_lens: List[int] 72 | ) -> List[dict]: 73 | with open(data_path, "r", encoding="utf-8") as f: 74 | prompts = json.load(f) 75 | 76 | assert len(prompts) >= len(response_lens) 77 | 78 | prompts = tokenizer.batch_decode( 79 | tokenizer( 80 | prompts, 81 | max_length=prompt_len, 82 | truncation=True, 83 | padding="max_length" 84 | )["input_ids"] 85 | ) 86 | prompts = random.sample(prompts, len(response_lens)) 87 | random.shuffle(prompts) 88 | 89 | payloads = [] 90 | for idx, (prompt, response_len) in enumerate(zip(prompts, response_lens)): 91 | payload = HuggingFaceCompletionInputs( 92 | model=model_id, 93 | prompt=prompt, 94 | generation_config=HuggingFaceGenerationConfig( 95 | max_new_tokens=response_len, 96 | min_new_tokens=response_len, 97 | num_beams=num_beams, 98 | num_return_sequences=1 99 | ) 100 | ).dict(by_alias=True) 101 | payloads.append(payload) 102 | 103 | return payloads 104 | 105 | 106 | async def load_test( 107 | payloads: List[dict], 108 | throughput_oriented: bool, 109 | batch_size: int, 110 | url: str 111 | ) -> dict: 112 | async def request_one(request_id, payload, sleep_seconds): 113 | await asyncio.sleep(sleep_seconds) 114 | request_start = time.time() 115 | async with aiohttp.request( 116 | method="post", 117 | url=url, 118 | json=payload, 119 | timeout=aiohttp.ClientTimeout(total=60*60) 120 | ) as resp: 121 | if resp.status: 122 | outputs = HuggingFaceCompletionOutputs(**(await resp.json())) 123 | else: 124 | outputs = HuggingFaceCompletionOutputs() 125 | num_gen_tokens = outputs.usage.completion_tokens 126 | status = resp.status 127 | request_end = time.time() 128 | wall_time = request_end - request_start 129 | print(f" - request_id={request_id} :: {status=}, {wall_time=: .4f}s, {num_gen_tokens=}") 130 | 131 | return num_gen_tokens, wall_time, status 132 | 133 | if throughput_oriented: 134 | total_tokens = 0 135 | duration = 0 136 | num_success = 0 137 | num_failed = 0 138 | for start_idx in range(0, len(payloads), batch_size): 139 | end_idx = start_idx + batch_size 140 | start = time.time() 141 | batch_results = await asyncio.gather( 142 | *[ 143 | request_one(start_idx + idx, payload, 0) 144 | for idx, payload in enumerate(payloads[start_idx: end_idx]) 145 | ] 146 | ) 147 | end = time.time() 148 | duration += (end - start) 149 | total_tokens += sum([each[0] for each in batch_results if each[2] == 200]) 150 | if all(each[2] == 200 for each in batch_results): 151 | num_success += 1 152 | else: 153 | num_failed += 1 154 | latency = -1 155 | throughput = total_tokens / duration 156 | fail_rate = f"{num_failed / (num_failed + num_success) * 100: .4f}%" 157 | else: 158 | sleep_sec = 0 159 | tasks = [] 160 | for start_idx in range(0, len(payloads), batch_size): 161 | end_idx = start_idx + batch_size 162 | sleep_sec += (0 if throughput_oriented else 1) 163 | tasks += [ 164 | asyncio.create_task(request_one(start_idx + idx, payload, sleep_sec)) 165 | for idx, payload in enumerate(payloads[start_idx: end_idx]) 166 | ] 167 | results = await asyncio.gather(*tasks) 168 | results = [each for each in results if each[2] == 200] 169 | latency = pd.Series([each[1] for each in results]).describe().to_dict() 170 | throughput = -1 171 | fail_rate = f"{(len(payloads) - len(results)) / len(payloads) * 100: .4f}%" 172 | 173 | return { 174 | "throughput_oriented": throughput_oriented, 175 | "batch_size": batch_size, 176 | "latency(s)": latency, 177 | "throughput(generated_tokens/s)": throughput, 178 | "fail_rate": fail_rate 179 | } 180 | 181 | 182 | def main(): 183 | parser = ArgumentParser() 184 | parser.add_argument("--host", type=str, default="http://127.0.0.1") 185 | parser.add_argument("--port", type=int, default=8000) 186 | parser.add_argument("--model_id", type=str, default="") # default to not specify if request to server directly 187 | parser.add_argument("--data_path", type=str, default="load_test_data.json") 188 | parser.add_argument("--save_path", type=str, default="load_test_report.json") 189 | parser.add_argument("--tokenizer_name_or_path", type=str, default="gpt2") 190 | parser.add_argument("--use_fast_tokenizer", action="store_true") 191 | parser.add_argument("--throughput_oriented", action="store_true") 192 | # when throughput_oriented is True, num_beams will be treated as batch size 193 | # which means how many requests will be sent to inference server at the same time, 194 | # where each request's num_beams is 1 195 | parser.add_argument("--num_beams", type=int, default=1) 196 | parser.add_argument("--prompt_len", type=int, default=512) 197 | parser.add_argument("--response_len_mean", type=int, default=512) 198 | parser.add_argument("--response_len_range", type=int, default=0) 199 | parser.add_argument("--distribution", type=str, choices=["uniform", "exponential", "capped_exponential"], default="uniform") 200 | parser.add_argument("--num_emit", type=int, default=60) # num batches or num seconds continues 201 | parser.add_argument("--qps", type=int, default=1) # used only when throughput_oriented is not True 202 | parser.add_argument("--server_type", type=str, default="cb", choices=["cb", "sb"]) 203 | args = parser.parse_args() 204 | 205 | if args.distribution == "uniform": 206 | assert args.response_len_mean > args.response_len_range 207 | else: 208 | assert args.response_len_mean <= args.response_len_range 209 | 210 | tokenizer = load_tokenizer(args.tokenizer_name_or_path, args.use_fast_tokenizer, args.prompt_len) 211 | 212 | if args.throughput_oriented: 213 | num_requests = args.num_beams * args.num_emit 214 | response_lens = [args.response_len_mean for _ in range(num_requests)] 215 | args.num_beams = 1 216 | else: 217 | num_requests = args.qps * args.num_emit 218 | response_lens = gen_random_lens(args.distribution, args.response_len_mean, args.response_len_range, num_requests) 219 | batch_size = num_requests // args.num_emit 220 | 221 | payloads = prepare_payloads(args.model_id, args.data_path, tokenizer, args.num_beams, args.prompt_len, response_lens) 222 | 223 | print( 224 | f"Load Test :: {args.throughput_oriented=}, {num_requests=}, {batch_size=}, " 225 | f"{args.response_len_mean=}, {args.response_len_range=}, {args.distribution=}" 226 | ) 227 | 228 | route = ROUTE_POST_CONTINUOUS_BATCHING_COMPLETION if args.server_type == "cb" else ROUTE_POST_STATIC_BATCHING_COMPLETION 229 | report = asyncio.run( 230 | load_test( 231 | payloads, args.throughput_oriented, batch_size, f"{args.host}:{args.port}{route}" 232 | ) 233 | ) 234 | report["response_len_mean"] = args.response_len_mean 235 | report["response_len_range"] = args.response_len_range 236 | report["distribution"] = args.distribution 237 | 238 | print("REPORT ::") 239 | for k, v in report.items(): 240 | print(f" - {k}: {v}") 241 | 242 | with open(args.save_path, "w", encoding="utf-8") as f: 243 | json.dump(report, f) 244 | 245 | 246 | if __name__ == "__main__": 247 | main() 248 | -------------------------------------------------------------------------------- /code/client/__init__.py: -------------------------------------------------------------------------------- 1 | from .client import get_client, Client, ClientConfig, ServerType 2 | 3 | __all__ = ["Client", "ClientConfig", "ServerType", "get_client"] 4 | -------------------------------------------------------------------------------- /code/client/client.py: -------------------------------------------------------------------------------- 1 | import json 2 | import time 3 | from collections import defaultdict 4 | from enum import Enum 5 | from itertools import chain 6 | from logging import getLogger, Logger 7 | from typing import Dict, List, Optional, Tuple 8 | from threading import Thread 9 | 10 | import aiohttp 11 | import requests 12 | from fastapi import HTTPException 13 | from pydantic import BaseModel, Field 14 | 15 | from .openai_jumper import * 16 | from protocol.completion_task import * 17 | from protocol.routes import * 18 | 19 | 20 | class ServerType(Enum): 21 | SB = "static_batching_server" 22 | CB = "continuous_batching_server" 23 | 24 | 25 | class ServerURL: 26 | def __init__(self, url: str): 27 | self.url = url 28 | self.model_id = None 29 | self.workload = 0 30 | self.available = True 31 | 32 | @staticmethod 33 | def _calculate_workload(request_inputs: HuggingFaceCompletionInputs): 34 | # A simple heuristic method to calculate workload: 35 | # - 1.35 here means we assume 1 word ≈ 1.35 tokens 36 | # - 20 here means we assume the time consumed to decode 1 token ≈ prefill 20 tokens 37 | # You can also change the calculation logic here by yourself 38 | 39 | num_prompt_tokens = int(len(request_inputs.prompt.split()) * 1.35) 40 | num_beams = request_inputs.generation_config.num_beams 41 | max_new_tokens = request_inputs.generation_config.max_new_tokens 42 | 43 | return num_beams * max_new_tokens + num_prompt_tokens // 20 44 | 45 | def increase_workload(self, request_inputs: HuggingFaceCompletionInputs): 46 | self.workload += self._calculate_workload(request_inputs) 47 | 48 | def decrease_workload(self, request_inputs: HuggingFaceCompletionInputs): 49 | self.workload -= self._calculate_workload(request_inputs) 50 | 51 | def __eq__(self, other: "ServerURL"): 52 | return self.workload == other.workload 53 | 54 | def __gt__(self, other: "ServerURL"): 55 | return self.workload > other.workload 56 | 57 | def __ge__(self, other: "ServerURL"): 58 | return self.workload >= other.workload 59 | 60 | def __lt__(self, other: "ServerURL"): 61 | return self.workload < other.workload 62 | 63 | def __le__(self, other: "ServerURL"): 64 | return self.workload <= other.workload 65 | 66 | def __hash__(self): 67 | return hash(self.url) 68 | 69 | 70 | CLIENT_SINGLETON = None 71 | 72 | 73 | class ClientNotInitializedError(Exception): 74 | def __repr__(self): 75 | return "client is not initialized, please initialize a client object first." 76 | 77 | def __str__(self): 78 | return self.__repr__() 79 | 80 | 81 | class ClientDoubleInitializeError(Exception): 82 | def __repr__(self): 83 | return "client is initialized, do not initialize again, please use get_client() instead." 84 | 85 | def __str__(self): 86 | return self.__repr__() 87 | 88 | 89 | class ClientConfig(BaseModel): 90 | static_batching_server_urls: Optional[List[str]] = Field(default=None) 91 | continuous_batching_server_urls: Optional[List[str]] = Field(default=None) 92 | openai_jumper_configs: Optional[List[OpenAIJumperConfig]] = Field(default=None) 93 | heart_beat_interval_seconds: int = Field(default=600) 94 | 95 | 96 | class Client: 97 | def __init__(self, config: ClientConfig, logger: Optional[Logger] = None): 98 | global CLIENT_SINGLETON 99 | 100 | if CLIENT_SINGLETON is not None: 101 | raise ClientDoubleInitializeError() 102 | 103 | self.config = config 104 | self.logger = logger if logger else getLogger(__name__) 105 | 106 | # containers 107 | self.model_id2static_batching_server_urls: Dict[str, List[ServerURL]] = defaultdict(list) 108 | self.model_id2continuous_batching_server_urls: Dict[str, List[ServerURL]] = defaultdict(list) 109 | self.openai_jumpers: List[OpenAIJumper] = [] 110 | 111 | self._update_containers_once() 112 | 113 | Thread(target=self._heart_beat_loop, daemon=True).start() 114 | 115 | # set singleton to self 116 | CLIENT_SINGLETON = self 117 | 118 | def _update_containers_once(self): 119 | self._update_server_urls_map( 120 | static_batching_server_urls=self.config.static_batching_server_urls, 121 | continuous_batching_server_urls=self.config.continuous_batching_server_urls 122 | ) 123 | self._update_openai_jumpers(openai_jumper_configs=self.config.openai_jumper_configs) 124 | 125 | def _heart_beat_loop(self): 126 | while True: 127 | time.sleep(self.config.heart_beat_interval_seconds) 128 | self._update_containers_once() 129 | 130 | def _update_server_urls_map( 131 | self, 132 | static_batching_server_urls: Optional[List[str]] = None, 133 | continuous_batching_server_urls: Optional[List[str]] = None 134 | ): 135 | def build_server_urls_map(old_server_url_objs: List[ServerURL], new_server_urls: List[str]): 136 | # TODO: parallelize the execution, the logic here for now may very slow 137 | 138 | server_urls_map = defaultdict(list) 139 | old_server_url_hash_values = [hash(url_obj) for url_obj in old_server_url_objs] 140 | for url in new_server_urls: 141 | hash_value = hash(url) 142 | if hash_value in old_server_url_hash_values: 143 | url_obj = old_server_url_objs[old_server_url_hash_values.index(hash_value)] 144 | else: 145 | url_obj = ServerURL(url=url) 146 | self.get_model_id(url_obj) 147 | if url_obj.model_id: 148 | server_urls_map[url_obj.model_id].append(url_obj) 149 | 150 | return server_urls_map 151 | 152 | if static_batching_server_urls is not None: 153 | self.model_id2static_batching_server_urls = build_server_urls_map( 154 | old_server_url_objs=list(chain(*self.model_id2static_batching_server_urls.values())), 155 | new_server_urls=static_batching_server_urls 156 | ) 157 | 158 | if continuous_batching_server_urls is not None: 159 | self.model_id2continuous_batching_server_urls = build_server_urls_map( 160 | old_server_url_objs=list(chain(*self.model_id2continuous_batching_server_urls)), 161 | new_server_urls=continuous_batching_server_urls 162 | ) 163 | 164 | def _update_openai_jumpers(self, openai_jumper_configs: Optional[List[OpenAIJumperConfig]] = None): 165 | if openai_jumper_configs is None: 166 | return 167 | 168 | old_jumpers = self.openai_jumpers 169 | old_jumper_hash_values = [hash(jumper) for jumper in old_jumpers] 170 | new_jumpers = [] 171 | for config in openai_jumper_configs: 172 | hash_value = hash(config.api_key) 173 | if hash_value in old_jumper_hash_values: 174 | jumper = old_jumpers[old_jumper_hash_values.index(hash_value)] 175 | else: 176 | jumper = OpenAIJumper(config=config, logger=self.logger) 177 | new_jumpers.append(jumper) 178 | 179 | self.openai_jumpers = new_jumpers 180 | 181 | for jumper in [jumper for jumper in old_jumpers if jumper not in new_jumpers]: 182 | destroy_openai_jumper(jumper) 183 | 184 | def update_config(self, config: ClientConfig): 185 | # One can use this method to implement hot reload logic to update client's behavior on the fly. 186 | # For example: 187 | # 1. manually update locally saved config file to update config's parameters; 188 | # 2. receive an event that a server is startup or shutdown, and automatically update server_urls. 189 | 190 | if set([c.api_key for c in config.openai_jumper_configs]) != \ 191 | set([c.api_key for c in self.config.openai_jumper_configs]): 192 | self._update_openai_jumpers(openai_jumper_configs=config.openai_jumper_configs) 193 | new_static_batching_server_urls = None 194 | new_continuous_batching_server_urls = None 195 | if set(config.static_batching_server_urls) != set(self.config.static_batching_server_urls): 196 | new_static_batching_server_urls = config.static_batching_server_urls 197 | if set(config.continuous_batching_server_urls) != set(self.config.continuous_batching_server_urls): 198 | new_continuous_batching_server_urls = config.continuous_batching_server_urls 199 | if new_static_batching_server_urls is not None or new_continuous_batching_server_urls is not None: 200 | self._update_server_urls_map( 201 | static_batching_server_urls=new_static_batching_server_urls, 202 | continuous_batching_server_urls=new_continuous_batching_server_urls 203 | ) 204 | self.config = config 205 | 206 | def save_config(self, save_path: str): 207 | with open(save_path, "w", encoding="utf-8") as f: 208 | json.dump(self.config.dict(by_alias=True), f) 209 | 210 | async def openai_chat_completion( 211 | self, 212 | request_inputs: OpenAIChatCompletionInputs, 213 | max_retries: int = 3, 214 | raise_on_error: bool = True 215 | ) -> OpenAIChatCompletionOutputs: 216 | request_inputs.verify_and_preprocess() 217 | 218 | available_jumpers = [jumper for jumper in self.openai_jumpers if jumper.available] 219 | if not available_jumpers: 220 | if raise_on_error: 221 | raise HTTPException( 222 | status_code=404, 223 | detail="LookupError: none of openai jumper is available for now." 224 | ) 225 | return OpenAIChatCompletionOutputs() 226 | 227 | jumper = min(available_jumpers) 228 | request_outputs, error, status_code = await jumper.chat_completion( 229 | inputs=request_inputs, 230 | max_retries=max_retries 231 | ) 232 | if status_code != 200 and raise_on_error: 233 | raise HTTPException(status_code=status_code, detail=str(error)) 234 | return request_outputs 235 | 236 | async def huggingface_completion( 237 | self, 238 | request_inputs: HuggingFaceCompletionInputs, 239 | max_retries: int = 3, 240 | raise_on_error: bool = True, 241 | server_type: ServerType = ServerType.CB, 242 | timeout: int = 100 243 | ) -> HuggingFaceCompletionOutputs: 244 | request_inputs.verify_and_preprocess() 245 | 246 | async def request( 247 | payload: dict, 248 | url_obj: ServerURL, 249 | route: str, 250 | ) -> Tuple[HuggingFaceCompletionOutputs, Optional[str], int]: 251 | async with aiohttp.request( 252 | method="post", 253 | url=f"{url_obj.url}{route}", 254 | json=payload, 255 | headers={}, 256 | timeout=aiohttp.ClientTimeout(timeout) 257 | ) as resp: 258 | if resp.status == 200: 259 | return HuggingFaceCompletionOutputs(**(await resp.json())), None, resp.status 260 | else: 261 | return HuggingFaceCompletionOutputs(), resp.reason, resp.status 262 | 263 | # check if the requested model_id is available 264 | model_id = request_inputs.model 265 | server_urls_map = ( 266 | self.model_id2static_batching_server_urls if server_type == ServerType.SB 267 | else self.model_id2continuous_batching_server_urls 268 | ) 269 | url_objs = [url_obj for url_obj in server_urls_map.get(model_id, []) if url_obj.available] 270 | if not url_objs: 271 | if raise_on_error: 272 | raise HTTPException( 273 | status_code=404, 274 | detail=f"LookupError: requested model [{model_id}] is not available for now." 275 | ) 276 | return HuggingFaceCompletionOutputs() 277 | 278 | # get the url_obj whose workload is smallest 279 | url_obj = min(url_objs) 280 | 281 | # request 282 | url_obj.increase_workload(request_inputs) 283 | try: 284 | request_outputs, error, status_code = await request( 285 | payload=request_inputs.dict(by_alias=True), 286 | url_obj=url_obj, 287 | route=ROUTE_POST_STATIC_BATCHING_COMPLETION if server_type == ServerType.SB 288 | else ROUTE_POST_CONTINUOUS_BATCHING_COMPLETION 289 | ) 290 | except Exception as e: 291 | request_outputs = HuggingFaceCompletionOutputs() 292 | self.logger.error(msg=f"running request method failed with error type [{e.__class__.__name__}]", exc_info=e) 293 | error = "ServerRequestError: Unknown error occurred when request to server." 294 | status_code = 500 295 | url_obj.decrease_workload(request_inputs) 296 | 297 | if status_code == 200: 298 | return request_outputs 299 | else: 300 | if max_retries > 0: 301 | return await self.huggingface_completion( 302 | request_inputs=request_inputs, 303 | max_retries=max_retries-1, 304 | raise_on_error=raise_on_error, 305 | server_type=server_type, 306 | timeout=timeout 307 | ) 308 | elif raise_on_error: 309 | raise HTTPException(status_code=status_code, detail=error) 310 | else: 311 | return request_outputs 312 | 313 | def get_model_id(self, url: ServerURL, max_retries: int = 3) -> Optional[str]: 314 | res = requests.get(f"{url.url}{ROUTE_GET_MODEL_ID}", timeout=(1, 1), verify=False) 315 | if res.status_code == 200: 316 | url.available = True 317 | url.model_id = res.json() 318 | return url.model_id 319 | else: 320 | if max_retries > 0: 321 | time.sleep(1) 322 | return self.get_model_id(url, max_retries - 1) 323 | self.logger.error( 324 | msg=f"request to {url.url} to get model_id failed with status [{res.status_code}]" 325 | ) 326 | url.available = False 327 | return None 328 | 329 | 330 | def get_client(): 331 | if CLIENT_SINGLETON is None: 332 | raise ClientNotInitializedError() 333 | return CLIENT_SINGLETON 334 | 335 | 336 | __all__ = [ 337 | "ServerType", 338 | "Client", 339 | "ClientConfig", 340 | "get_client" 341 | ] 342 | -------------------------------------------------------------------------------- /code/client/openai_jumper.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | from logging import getLogger, Logger 4 | from typing import Optional, Tuple 5 | from threading import Thread 6 | 7 | import openai 8 | from openai.util import convert_to_dict 9 | from pydantic import BaseModel, Field, Required 10 | 11 | from protocol.completion_task import ( 12 | TokenUsage, 13 | OpenAIChatCompletionMessage, 14 | OpenAIChatCompletionChoice, 15 | OpenAIChatCompletionInputs, 16 | OpenAIChatCompletionOutputs 17 | ) 18 | from protocol.error import Error 19 | 20 | 21 | AVAILABLE_OPENAI_CHAT_COMPLETION_MODELS = [ 22 | "gpt-3.5-turbo", 23 | # You can extend available chat completion models here by yourself 24 | ] 25 | 26 | 27 | class OpenAIJumperConfig(BaseModel): 28 | api_key: str = Field(default=Required) 29 | org_id: Optional[str] = Field(default=None) 30 | 31 | 32 | class OpenAIJumper: 33 | def __init__(self, config: OpenAIJumperConfig, logger: Optional[Logger] = None): 34 | self.config = config 35 | self.logger = logger if logger else getLogger(__name__) 36 | 37 | self.workload = 0 38 | self.referenced = 0 39 | self._available = True 40 | 41 | async def chat_completion( 42 | self, 43 | inputs: OpenAIChatCompletionInputs, 44 | max_retries: int = 3 45 | ) -> Tuple[ 46 | OpenAIChatCompletionOutputs, 47 | Optional[Error], 48 | int 49 | ]: 50 | if inputs.model not in AVAILABLE_OPENAI_CHAT_COMPLETION_MODELS: 51 | error_body = Error( 52 | type="ValueError", 53 | detail=f"LookupError: Required model [{inputs.model}] is not available, " 54 | f"available models are {AVAILABLE_OPENAI_CHAT_COMPLETION_MODELS}" 55 | ) 56 | return OpenAIChatCompletionOutputs(), error_body, 404 57 | 58 | request_dict = inputs.dict(exclude_none=True) 59 | request_dict.update({"api_key": self.config.api_key}) 60 | if self.config.org_id: 61 | request_dict.update({"organization": self.config.org_id}) 62 | 63 | try: 64 | self.referenced += 1 65 | resp = await openai.ChatCompletion.acreate(**request_dict) 66 | except openai.error.Timeout as e: 67 | self.referenced -= 1 68 | if max_retries > 0: 69 | max_retries -= 1 70 | self.logger.warning( 71 | msg=f"Request to openai chat completion api timeout, will retry again (chance_left={max_retries})" 72 | ) 73 | return await self.chat_completion(inputs, max_retries=max_retries) 74 | error_body = Error( 75 | type=e.__class__.__name__, 76 | detail=e.user_message 77 | ) 78 | self.logger.error(msg=str(error_body)) 79 | return OpenAIChatCompletionOutputs(), error_body, e.http_status 80 | except openai.error.OpenAIError as e: 81 | self.referenced -= 1 82 | error_body = Error( 83 | type=e.__class__.__name__, 84 | detail=e.user_message 85 | ) 86 | self.logger.error(msg=str(error_body)) 87 | return OpenAIChatCompletionOutputs(), error_body, e.http_status 88 | except Exception as e: 89 | self.referenced -= 1 90 | error_body = Error( 91 | type=e.__class__.__name__, 92 | detail=str(e) 93 | ) 94 | self.logger.error(msg=str(error_body)) 95 | return OpenAIChatCompletionOutputs(), error_body, 500 96 | else: 97 | self.referenced -= 1 98 | resp = convert_to_dict(resp) 99 | outputs = OpenAIChatCompletionOutputs( 100 | choices=[ 101 | OpenAIChatCompletionChoice( 102 | message=OpenAIChatCompletionMessage( 103 | role=choice["message"]["role"], 104 | content=choice["message"]["content"] 105 | ), 106 | index=choice["index"], 107 | finish_reason=choice["finish_reason"] 108 | ) for choice in resp["choices"] 109 | ], 110 | usage=TokenUsage(**resp["usage"]) 111 | ) 112 | 113 | self.workload += outputs.usage.total_tokens 114 | 115 | return outputs, None, 200 116 | 117 | def reset_workload(self): 118 | self.workload = 0 119 | 120 | def freeze(self): 121 | self._available = False 122 | 123 | @property 124 | def available(self): 125 | return self._available 126 | 127 | def __eq__(self, other: "OpenAIJumper"): 128 | return self.workload == other.workload 129 | 130 | def __gt__(self, other: "OpenAIJumper"): 131 | return self.workload > other.workload 132 | 133 | def __ge__(self, other: "OpenAIJumper"): 134 | return self.workload >= other.workload 135 | 136 | def __lt__(self, other: "OpenAIJumper"): 137 | return self.workload < other.workload 138 | 139 | def __le__(self, other: "OpenAIJumper"): 140 | return self.workload <= other.workload 141 | 142 | def __hash__(self): 143 | return hash(self.config.api_key) 144 | 145 | 146 | def destroy_openai_jumper(openai_jumper: OpenAIJumper): 147 | def destroy(): 148 | openai_jumper.freeze() 149 | while True: 150 | if openai_jumper.referenced == 0: 151 | del openai_jumper 152 | return 153 | time.sleep(1) 154 | 155 | Thread(target=destroy, daemon=True).start() 156 | 157 | 158 | __all__ = [ 159 | "AVAILABLE_OPENAI_CHAT_COMPLETION_MODELS", 160 | "OpenAIJumper", 161 | "OpenAIJumperConfig", 162 | "destroy_openai_jumper", 163 | ] 164 | -------------------------------------------------------------------------------- /code/client_app.py: -------------------------------------------------------------------------------- 1 | import json 2 | import time 3 | import os 4 | from logging import getLogger, DEBUG, INFO, Logger 5 | from typing import Optional 6 | from threading import Thread 7 | 8 | 9 | from fastapi import FastAPI 10 | from fastapi.responses import JSONResponse 11 | from pydantic import BaseModel, Field 12 | 13 | from client import get_client, Client, ClientConfig, ServerType 14 | from protocol.completion_task import * 15 | from protocol.error import Error 16 | from protocol.routes import ( 17 | ROUTE_POST_OPENAI_CHAT_COMPLETION, 18 | ROUTE_POST_CONTINUOUS_BATCHING_COMPLETION, 19 | ROUTE_POST_STATIC_BATCHING_COMPLETION, 20 | ROUTE_CLIENT_ONLY_POST_SERVER_STARTUP_EVENT, 21 | ROUTE_CLIENT_ONLY_POST_SERVER_SHUTDOWN_EVENT 22 | ) 23 | from utils.log_util import RequestLoggingMiddleware 24 | 25 | 26 | logger = getLogger("gunicorn.logger") # by default, we use gunicorn to wrap the app 27 | 28 | 29 | class AppConfig(BaseModel): 30 | client_config_file_path: str = Field(default=...) 31 | client_config_hot_update_interval_minutes: int = Field(default=1) 32 | debug: bool = Field(default=False) 33 | 34 | 35 | APP_NAME = "LLM-Inference-Client" 36 | 37 | app = FastAPI(title=APP_NAME, version="0.1.0") 38 | app_config: Optional[AppConfig] = None 39 | 40 | 41 | def build_app( 42 | client_config_file_path: str = "client_config.json", 43 | client_config_hot_update_interval_minutes: int = 1, 44 | debug: bool = False 45 | ): 46 | global app, app_config 47 | 48 | logger.setLevel(DEBUG if debug else INFO) 49 | 50 | app_config = AppConfig( 51 | client_config_file_path=client_config_file_path, 52 | client_config_hot_update_interval_minutes=client_config_hot_update_interval_minutes, 53 | debug=debug 54 | ) 55 | 56 | app.add_middleware(RequestLoggingMiddleware, logger=logger) 57 | 58 | return app 59 | 60 | 61 | def hot_update_client_config_loop(): 62 | while True: 63 | time.sleep(app_config.client_config_hot_update_interval_minutes * 60) 64 | client = get_client() 65 | 66 | fp = app_config.client_config_file_path 67 | if not os.path.exists(fp): 68 | logger.warning( 69 | msg=f"Client config file path [{fp}] not exists, skip hot update this time," 70 | f"and will try to save a client config snapshot." 71 | ) 72 | try: 73 | client.save_config(fp) 74 | except: 75 | pass 76 | continue 77 | new_client_config = Client(**json.load(open(fp, "r", encoding="utf-8"))) 78 | client.update_config(new_client_config) 79 | 80 | 81 | @app.on_event("startup") 82 | def startup(): 83 | # init client 84 | Client( 85 | config=ClientConfig( 86 | **json.load(open(app_config.client_config_file_path, "r", encoding="utf-8")) 87 | ), 88 | logger=logger 89 | ) 90 | 91 | # start client config hot update loop 92 | Thread(target=hot_update_client_config_loop, daemon=True).start() 93 | 94 | 95 | # === Routes that request to LLM servers === # 96 | 97 | @app.post(ROUTE_POST_OPENAI_CHAT_COMPLETION, response_model=OpenAIChatCompletionOutputs) 98 | async def request_openai_chat_completion(request_inputs: OpenAIChatCompletionInputs): 99 | client = get_client() 100 | return await client.openai_chat_completion(request_inputs) 101 | 102 | 103 | @app.post(ROUTE_POST_CONTINUOUS_BATCHING_COMPLETION, response_model=HuggingFaceCompletionOutputs) 104 | async def request_continuous_batching_server(request_inputs: HuggingFaceCompletionInputs): 105 | client = get_client() 106 | return await client.huggingface_completion(request_inputs, server_type=ServerType.CB) 107 | 108 | 109 | @app.post(ROUTE_POST_STATIC_BATCHING_COMPLETION, response_model=HuggingFaceCompletionOutputs) 110 | async def request_static_batching_server(request_inputs: HuggingFaceCompletionInputs): 111 | client = get_client() 112 | return await client.huggingface_completion(request_inputs, server_type=ServerType.SB) 113 | 114 | 115 | # === Routes that provide some meta information === # 116 | 117 | @app.get("/cb_server/available_models") 118 | async def get_cb_server_available_models(): 119 | client = get_client() 120 | available_models = [] 121 | for model_id, server_urls in client.model_id2continuous_batching_server_urls: 122 | if any(url_obj.available for url_obj in server_urls): 123 | available_models.append(model_id) 124 | return JSONResponse(content=available_models) 125 | 126 | 127 | @app.get("/sb_server/available_models") 128 | async def get_sb_server_available_models(): 129 | client = get_client() 130 | available_models = [] 131 | for model_id, server_urls in client.model_id2static_batching_server_urls: 132 | if any(url_obj.available for url_obj in server_urls): 133 | available_models.append(model_id) 134 | return JSONResponse(content=available_models) 135 | 136 | 137 | @app.get("/openai_jumper/is_available") 138 | async def get_openai_jumper_is_available(): 139 | client = get_client() 140 | if any(jumper.available for jumper in client.openai_jumpers): 141 | return JSONResponse(content="1") 142 | return JSONResponse("0") 143 | 144 | 145 | # === Routes that receive server event and update client config === # 146 | # TODO: Implement routes 147 | 148 | 149 | if __name__ == "__main__": 150 | import uvicorn 151 | from argparse import ArgumentParser 152 | from logging import basicConfig 153 | 154 | parser = ArgumentParser() 155 | parser.add_argument("--client_config_file_path", type=str, default="client_config.json") 156 | parser.add_argument("--client_config_hot_update_interval_minutes", type=int, default=1) 157 | parser.add_argument("--debug", action="store_true") 158 | parser.add_argument("--port", type=int, default=8000) 159 | args = parser.parse_args() 160 | 161 | logger = getLogger(__name__) # override gunicorn logger if we use uvicorn directly 162 | basicConfig( 163 | format="%(asctime)s %(levelname)s [%(name)s] %(message)s", 164 | datefmt="%Y-%m-%d %H:%M:%S" 165 | ) 166 | 167 | uvicorn.run( 168 | build_app( 169 | client_config_file_path=args.client_config_file_path, 170 | client_config_hot_update_interval_minutes=args.client_config_hot_update_interval_minutes, 171 | debug=args.debug, 172 | ), 173 | host="0.0.0.0", 174 | port=args.port 175 | ) 176 | -------------------------------------------------------------------------------- /code/client_requirements.txt: -------------------------------------------------------------------------------- 1 | aiohttp 2 | requests 3 | openai 4 | pydantic<2.0.0 5 | fastapi[all]==0.96.0 6 | setproctitle 7 | uvicorn[standard] 8 | gunicorn -------------------------------------------------------------------------------- /code/continuous_batching_server_app.py: -------------------------------------------------------------------------------- 1 | import json 2 | from logging import getLogger, DEBUG, INFO 3 | from typing import Optional 4 | 5 | from fastapi import HTTPException, FastAPI 6 | from fastapi.responses import JSONResponse 7 | from pydantic import BaseModel, Field 8 | 9 | from server.continuous_batching_server import get_server, Server, ServerConfig 10 | from protocol.completion_task import ( 11 | HuggingFaceCompletionInputs, 12 | HuggingFaceCompletionOutputs 13 | ) 14 | from protocol.error import Error 15 | from protocol.routes import ( 16 | ROUTE_GET_MODEL_ID, 17 | ROUTE_POST_CONTINUOUS_BATCHING_COMPLETION, 18 | ROUTE_CLIENT_ONLY_POST_SERVER_STARTUP_EVENT, 19 | ROUTE_CLIENT_ONLY_POST_SERVER_SHUTDOWN_EVENT 20 | ) 21 | from utils.log_util import RequestLoggingMiddleware 22 | 23 | 24 | logger = getLogger("gunicorn.logger") # by default, we use gunicorn to wrap the app 25 | 26 | 27 | class AppConfig(BaseModel): 28 | model_id: str = Field(default=...) 29 | server_config_file_path: str = Field(default=...) 30 | client_url: Optional[str] = Field(default=None) 31 | debug: bool = Field(default=False) 32 | 33 | 34 | APP_NAME = "LLM-Inference-CB-Server" 35 | 36 | app = FastAPI(title=APP_NAME, version="0.1.0") 37 | app_config: Optional[AppConfig] = None 38 | 39 | 40 | def build_app( 41 | model_id: str = None, 42 | server_config_file_path: str = "cb_server_config.json", 43 | client_url: Optional[str] = None, 44 | debug: bool = False 45 | ): 46 | global app, app_config 47 | 48 | if model_id is None: 49 | raise ValueError("You must specify a real value to model_id.") 50 | 51 | logger.setLevel(DEBUG if debug else INFO) 52 | 53 | app_config = AppConfig( 54 | model_id=model_id, 55 | server_config_file_path=server_config_file_path, 56 | client_url=client_url, 57 | debug=debug 58 | ) 59 | 60 | app.add_middleware(RequestLoggingMiddleware, logger=logger) 61 | 62 | return app 63 | 64 | 65 | @app.on_event("startup") 66 | def startup(): 67 | # initialize server 68 | Server( 69 | config=ServerConfig(**json.load(open(app_config.server_config_file_path, "r", encoding="utf-8"))), 70 | logger=logger 71 | ) 72 | # TODO: implement logic to inform client that server is startup 73 | 74 | 75 | @app.on_event("shutdown") 76 | def shutdown(): 77 | pass # TODO: implement logic to inform client that server is shutdown 78 | 79 | 80 | @app.get(ROUTE_GET_MODEL_ID) 81 | async def get_model_id(): 82 | return JSONResponse(content=app_config.model_id) 83 | 84 | 85 | @app.post(ROUTE_POST_CONTINUOUS_BATCHING_COMPLETION, response_model=HuggingFaceCompletionOutputs) 86 | async def execute_completion(request_inputs: HuggingFaceCompletionInputs): 87 | server = get_server() 88 | outputs, error, status_code, wall_time = await server.wait_task_done(request_inputs) 89 | if status_code != 200: 90 | logger.error(msg=str(error)) 91 | raise HTTPException( 92 | status_code=status_code, 93 | detail=str(error) 94 | ) 95 | return outputs 96 | 97 | 98 | if __name__ == "__main__": 99 | import uvicorn 100 | from argparse import ArgumentParser 101 | from logging import basicConfig 102 | 103 | parser = ArgumentParser() 104 | parser.add_argument("--model_id", type=str) 105 | parser.add_argument("--server_config_file_path", type=str, default="cb_server_config.json") 106 | parser.add_argument("--client_url", type=str, default=None) 107 | parser.add_argument("--debug", action="store_true") 108 | parser.add_argument("--port", type=int, default=8001) 109 | args = parser.parse_args() 110 | 111 | logger = getLogger(__name__) # override gunicorn logger if we use uvicorn directly 112 | basicConfig( 113 | format="%(asctime)s %(levelname)s [%(name)s] %(message)s", 114 | datefmt="%Y-%m-%d %H:%M:%S" 115 | ) 116 | 117 | uvicorn.run( 118 | build_app( 119 | model_id=args.model_id, 120 | server_config_file_path=args.server_config_file_path, 121 | client_url=args.client_url, 122 | debug=args.debug 123 | ), 124 | host="0.0.0.0", 125 | port=args.port 126 | ) 127 | -------------------------------------------------------------------------------- /code/gunicorn_config.py: -------------------------------------------------------------------------------- 1 | worker_class = "uvicorn.workers.UvicornWorker" 2 | 3 | # Number of worker process 4 | workers = 1 5 | # Number of threads each worker process, 6 | # (workers * threads) means how many requests the app can process simultaneously 7 | threads = 32 8 | # Number of requests that can be waiting to be served, 9 | # requests exceed this number will be rejected and receive an error 10 | backlog = 64 11 | # Workers silent for more than this many seconds are killed and restarted 12 | timeout = 300 13 | 14 | # log format for access log, error log can't set format 15 | access_log_format = '%(h)s %(l)s %(t)s "%(r)s" %(m)s %(s)s %(b)s "%(f)s" "%(a)s"' 16 | 17 | """ 18 | Below are description for each format option: 19 | h remote address 20 | l '-' 21 | u currently '-', may be user name in future releases 22 | t date of the request 23 | r status line (e.g. ``GET / HTTP/1.1``) 24 | s status 25 | b response length or '-' 26 | f referer 27 | a user agent 28 | T request time in seconds 29 | D request time in microseconds 30 | L request time in decimal seconds 31 | p process ID 32 | """ 33 | -------------------------------------------------------------------------------- /code/protocol/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/modelize-ai/LLM-Inference-Deployment-Tutorial/e3264eacc4752a7d829241fb614a0b81892cdc8f/code/protocol/__init__.py -------------------------------------------------------------------------------- /code/protocol/completion_task.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Union 2 | 3 | from fastapi import status, HTTPException 4 | from pydantic import BaseModel, Field 5 | 6 | 7 | class TokenUsage(BaseModel): 8 | prompt_tokens: int = Field(default=0) 9 | completion_tokens: int = Field(default=0) 10 | total_tokens: int = Field(default=0) 11 | 12 | 13 | class HuggingFaceGenerationConfig(BaseModel): 14 | do_sample: bool = Field(default=False) 15 | early_stopping: bool = Field(default=True) 16 | num_beams: int = Field(default=1) 17 | num_return_sequences: int = Field(default=1) 18 | max_new_tokens: int = Field(default=32) 19 | min_new_tokens: int = Field(default=1) 20 | temperature: float = Field(default=1) 21 | top_p: float = Field(default=1, ge=0, le=1) 22 | top_k: int = Field(default=0) 23 | typical_p: float = Field(default=1, ge=0, le=1) 24 | repetition_penalty: float = Field(default=1) 25 | eos_token_id: Optional[int] = Field(default=None) 26 | pad_token_id: Optional[int] = Field(default=None) 27 | seed: int = Field(default=1024) 28 | 29 | def __hash__(self): 30 | return hash( 31 | str(self.do_sample) + 32 | str(self.early_stopping) + 33 | str(self.num_beams) + 34 | str(self.num_return_sequences) + 35 | str(self.max_new_tokens) + 36 | str(self.min_new_tokens) + 37 | str(self.temperature) + 38 | str(self.top_p) + 39 | str(self.top_k) + 40 | str(self.typical_p) + 41 | str(self.repetition_penalty) + 42 | str(self.eos_token_id) + 43 | str(self.pad_token_id) + 44 | str(self.seed) 45 | ) 46 | 47 | 48 | class HuggingFaceCompletionInputs(BaseModel): 49 | model: str = Field(default=...) 50 | prompt: str = Field(default=...) 51 | generation_config: HuggingFaceGenerationConfig = Field(default=HuggingFaceGenerationConfig()) 52 | 53 | def verify_and_preprocess(self): 54 | # verify 55 | # here we only do the simplest verification for some parameters 56 | # we should also check with other information again in the server for: 57 | # - whether num_prompt_tokens exceeds model's max_seq_len, if yse we should abort request directly. 58 | if self.generation_config.num_return_sequences > self.generation_config.num_beams: 59 | raise HTTPException( 60 | status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, 61 | detail=( 62 | f"num_return_sequences(get value of {self.generation_config.num_return_sequences}) " 63 | f"has to less than or equal to num_beams(get value of {self.generation_config.num_beams})" 64 | ) 65 | ) 66 | if self.generation_config.min_new_tokens > self.generation_config.max_new_tokens: 67 | raise HTTPException( 68 | status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, 69 | detail=( 70 | f"min_new_tokens(get value of {self.generation_config.min_new_tokens}) " 71 | f"has to less than or equal to max_new_tokens(get value of {self.generation_config.max_new_tokens})" 72 | ) 73 | ) 74 | 75 | # make sure some parameters in generation_config is specified 76 | # to default values so that do_sample will not be triggered 77 | if not self.generation_config.do_sample: 78 | self.generation_config.temperature = 1.0 79 | self.generation_config.top_p = 1.0 80 | self.generation_config.top_k = 0 81 | self.generation_config.typical_p = 1.0 82 | 83 | 84 | class HuggingFaceCompletionChoice(BaseModel): 85 | text: str = Field(default=...) 86 | index: int = Field(default=...) 87 | finish_reason: str = Field(default=...) 88 | 89 | 90 | class HuggingFaceCompletionOutputs(BaseModel): 91 | choices: Optional[List[HuggingFaceCompletionChoice]] = Field(default=None) 92 | usage: TokenUsage = Field(default=TokenUsage()) 93 | 94 | 95 | class OpenAIChatCompletionMessage(BaseModel): 96 | # Note: we removed parameters relevant to function call here for the simplest use case 97 | role: str = Field(default=..., regex=r"(system|user|assistant)") 98 | content: str = Field(default=...) 99 | 100 | 101 | class OpenAIChatCompletionInputs(BaseModel): 102 | model: str = Field(default=...) 103 | messages: List[OpenAIChatCompletionMessage] = Field(default=...) 104 | n: int = Field(default=1) 105 | max_tokens: int = Field(default=32) 106 | temperature: float = Field(default=1) 107 | top_p: float = Field(default=1, ge=0, le=1) 108 | stream: bool = Field(default=False) 109 | stop: Optional[Union[str, List[str]]] = Field(default=None) 110 | presence_penalty: float = Field(default=0) 111 | frequency_penalty: float = Field(default=0) 112 | 113 | def verify_and_preprocess(self): 114 | # verify 115 | # Not do anything, you can add logit here 116 | pass 117 | 118 | 119 | class OpenAIChatCompletionChoice(BaseModel): 120 | message: OpenAIChatCompletionMessage = Field(default=...) 121 | index: int = Field(default=...) 122 | finish_reason: str = Field(default=...) 123 | 124 | 125 | class OpenAIChatCompletionOutputs(BaseModel): 126 | choices: Optional[List[OpenAIChatCompletionChoice]] = Field(default=None) 127 | usage: TokenUsage = Field(default=TokenUsage()) 128 | 129 | 130 | __all__ = [ 131 | "TokenUsage", 132 | "HuggingFaceGenerationConfig", 133 | "HuggingFaceCompletionChoice", 134 | "HuggingFaceCompletionInputs", 135 | "HuggingFaceCompletionOutputs", 136 | "OpenAIChatCompletionMessage", 137 | "OpenAIChatCompletionChoice", 138 | "OpenAIChatCompletionInputs", 139 | "OpenAIChatCompletionOutputs" 140 | ] 141 | -------------------------------------------------------------------------------- /code/protocol/error.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from pydantic import BaseModel, Field 4 | 5 | 6 | class Error(BaseModel): 7 | type: str = Field(default=...) 8 | detail: Optional[str] = Field(default=None) 9 | 10 | def __repr__(self): 11 | return f"{self.type}: {self.detail}" 12 | 13 | def __str__(self): 14 | return self.__repr__() 15 | 16 | 17 | __all__ = [ 18 | "Error" 19 | ] 20 | -------------------------------------------------------------------------------- /code/protocol/routes.py: -------------------------------------------------------------------------------- 1 | ROUTE_POST_STATIC_BATCHING_COMPLETION = "/completion/v1" 2 | ROUTE_POST_CONTINUOUS_BATCHING_COMPLETION = "/completion/v2" 3 | ROUTE_POST_OPENAI_CHAT_COMPLETION = "/chat_completion/v1" 4 | ROUTE_GET_MODEL_ID = "/model_id" 5 | 6 | ROUTE_CLIENT_ONLY_POST_SERVER_STARTUP_EVENT = "/event/server_startup" 7 | ROUTE_CLIENT_ONLY_POST_SERVER_SHUTDOWN_EVENT = "/event/server_shutdown" 8 | 9 | 10 | __all__ = [ 11 | "ROUTE_POST_STATIC_BATCHING_COMPLETION", 12 | "ROUTE_POST_CONTINUOUS_BATCHING_COMPLETION", 13 | "ROUTE_POST_OPENAI_CHAT_COMPLETION", 14 | "ROUTE_GET_MODEL_ID", 15 | "ROUTE_CLIENT_ONLY_POST_SERVER_STARTUP_EVENT", 16 | "ROUTE_CLIENT_ONLY_POST_SERVER_SHUTDOWN_EVENT", 17 | ] 18 | -------------------------------------------------------------------------------- /code/server/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/modelize-ai/LLM-Inference-Deployment-Tutorial/e3264eacc4752a7d829241fb614a0b81892cdc8f/code/server/__init__.py -------------------------------------------------------------------------------- /code/server/continuous_batching_server/__init__.py: -------------------------------------------------------------------------------- 1 | from .config import * 2 | from .server import * 3 | 4 | DEFAULT_BLOCK_SIZE_A100 = 16 5 | DEFAULT_NUM_BLOCKS_A100 = 2500 6 | DEFAULT_BATCH_MAX_TOKENS_A100 = 56000 7 | DEFAULT_BATCH_MAX_BEAMS_A100 = 32 8 | 9 | 10 | __all__ = [ 11 | "get_server", 12 | "Server", 13 | "BatcherConfig", 14 | "CacheConfig", 15 | "ModelLoadingConfig", 16 | "ParallelConfig", 17 | "ServerConfig", 18 | "DEFAULT_BLOCK_SIZE_A100", 19 | "DEFAULT_NUM_BLOCKS_A100", 20 | "DEFAULT_BATCH_MAX_TOKENS_A100", 21 | "DEFAULT_BATCH_MAX_BEAMS_A100" 22 | ] 23 | -------------------------------------------------------------------------------- /code/server/continuous_batching_server/batcher.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from collections import defaultdict 3 | from logging import getLogger, Logger 4 | from typing import * 5 | from uuid import UUID 6 | 7 | import torch 8 | 9 | from .beam import Beam, BeamGroup, BeamStatus 10 | from .config import BatcherConfig, CacheConfig 11 | from .cache.cache_manager import CacheBlockManager 12 | from .generation_utils import HeterogeneousNextTokensChooser 13 | 14 | 15 | class Batch: 16 | def __init__(self): 17 | self.prefill_beams: List[Beam] = [] 18 | self.generation_beams: List[Beam] = [] 19 | self.block_tables: Dict[UUID, List[int]] = {} 20 | self.prefill_logits: Optional[torch.Tensor] = None 21 | self.generation_logits: Optional[torch.Tensor] = None 22 | 23 | @property 24 | def num_beams(self) -> int: 25 | return len(self.prefill_beams) + len(self.generation_beams) 26 | 27 | @property 28 | def request_ids(self): 29 | return list(set([beam.request_id for beam in self.prefill_beams + self.generation_beams])) 30 | 31 | def add(self, beam: Beam, block_ids: List[int]) -> None: 32 | if beam.num_generated_tokens == 0: 33 | self.prefill_beams.append(beam) 34 | else: 35 | self.generation_beams.append(beam) 36 | self.block_tables[beam.beam_id] = block_ids 37 | 38 | 39 | class Batcher: 40 | def __init__( 41 | self, 42 | batcher_config: BatcherConfig, 43 | cache_config: CacheConfig, 44 | logger: Optional[Logger] = None 45 | ): 46 | self.batcher_config = batcher_config 47 | self.cache_config = cache_config 48 | 49 | self.logger = logger if logger else getLogger(__name__) 50 | 51 | self.waiting: List[BeamGroup] = [] 52 | self.running: List[BeamGroup] = [] 53 | self.preempting: List[BeamGroup] = [] 54 | 55 | self.req_id2beam_group: Dict[UUID, BeamGroup] = {} 56 | 57 | self.cache_manager = CacheBlockManager( 58 | cache_config.block_size, 59 | cache_config.num_blocks, 60 | cache_config.num_blocks_cpu, 61 | cache_config.watermark 62 | ) 63 | 64 | def add_request(self, request: BeamGroup): 65 | self.waiting.append(request) 66 | self.req_id2beam_group[request.request_id] = request 67 | 68 | def _allocate(self, beam: Beam) -> bool: 69 | if not self.cache_manager.can_allocate(beam): 70 | return False 71 | self.cache_manager.allocate(beam) 72 | return True 73 | 74 | def _append_slot(self, beam: Beam, blocks_to_copy: Dict[int, List[int]]): 75 | ret = self.cache_manager.append_slot(beam) 76 | if ret is not None: 77 | src_block, dst_block = ret 78 | blocks_to_copy[src_block].append(dst_block) 79 | 80 | def _free_finished_beams(self): 81 | for beam in [beam for beam_group in self.running for beam in beam_group.get_beams(BeamStatus.FINISHED)]: 82 | self.logger.debug(f"free blocks of beam-{beam.beam_id} for it is finished.") 83 | self.cache_manager.free(beam.beam_id) 84 | 85 | def schedule(self) -> Tuple[ 86 | Batch, 87 | Dict[int, List[int]], 88 | Dict[int, int], 89 | Dict[int, int], 90 | List[BeamGroup] 91 | ]: 92 | """ 93 | 大致的执行流程如下: 94 | 1. 从运行队列中移除已经完成生成的请求(beam_group) 95 | 2. 对运行队列中仍需继续生成的请求进行缓存空间的分配 96 | a. 统计运行队列中所有请求所需增量分配的 GPU 缓存空间的总块数 97 | b. 如果空闲的 GPU 缓存空间总块数少于增量分配所需的总块数,则从运行队列尾端起依次将每个请求所占用的 GPU 缓存空间交互至 CPU,直至 98 | 剩余空间足够用于分配给运行队列前端的其他请求 99 | c. 如果不发生 GPU->CPU 交互且存在被交换至 CPU 的等待继续被处理的请求时,将这些请求按优先级依次交换回 GPU 直至无法被换回 100 | d. 为运行队列中剩余的请求进行缓存空间的分配 101 | 3. 在不发生 GPU->CPU 交互的情况下,或运行队列无请求时,尝试将等待队列中的请求移至运行队列 102 | 103 | :return: ( 104 | Batch, 105 | blocks_to_copy: Dict[int, List[int]], 106 | blocks_to_swap_in: Dict[int, int], 107 | blocks_to_swap_out: Dict[int, int], 108 | finished_requests: List[BeamGroup] 109 | ) 110 | """ 111 | batch = Batch() 112 | 113 | # step1: 从运行队列中移除已经完成生成的请求(beam_group) 114 | self._free_finished_beams() 115 | running = [] 116 | finishing = [] 117 | while self.running: 118 | beam_group = self.running.pop(0) 119 | if beam_group.is_finished: 120 | finishing.append(beam_group) 121 | self.logger.debug(f"BeamGroup-{beam_group.request_id} finished, put into finishing queue.") 122 | else: 123 | running.append(beam_group) 124 | self.logger.debug(f"BeamGroup-{beam_group.request_id} not finish, put back to running queue.") 125 | self.running = running 126 | 127 | # step2: 对运行队列中仍需继续生成的请求进行缓存空间的分配 128 | running = [] 129 | swapping_out = [] 130 | swapping_in = [] 131 | blocks_to_copy: Dict[int, List[int]] = defaultdict(list) 132 | blocks_to_swap_in: Dict[int, int] = {} 133 | blocks_to_swap_out: Dict[int, int] = {} 134 | # step2.a: 统计运行队列中所有请求所需增量分配的 GPU 缓存空间的总块数 135 | run_request2num_append_blocks = defaultdict(int) 136 | for beam_group in self.running: 137 | run_request2num_append_blocks[beam_group.request_id] = 0 138 | for beam in beam_group.get_beams(BeamStatus.RUNNING): 139 | if self.cache_manager.is_need_to_append_slot(beam): 140 | run_request2num_append_blocks[beam_group.request_id] += 1 141 | # step2.b: 如果空闲的 GPU 缓存空间总块数少于增量分配所需的总块数,则从运行队列尾端起 142 | # 依次将每个请求所占用的 GPU 缓存空间交互至 CPU, 143 | # 直至剩余空间足够用于分配给运行队列前端的其他请求 144 | while self.cache_manager.allocator.num_free_blocks < sum(run_request2num_append_blocks.values()): 145 | beam_group = self.running.pop(-1) 146 | num_append_blocks = run_request2num_append_blocks.pop(beam_group.request_id) 147 | if num_append_blocks == 0: 148 | running.insert(0, beam_group) 149 | continue 150 | if not self.cache_manager.can_swap_out(beam_group): 151 | # FIXME: do not raise error, abort this beam group, mark as finished with an abortion reason, free cache space 152 | raise RuntimeError("No enough CPU RAM to swap out") 153 | else: 154 | blocks_to_swap_out.update(self.cache_manager.swap_out(beam_group)) 155 | for beam in beam_group.get_beams(BeamStatus.RUNNING): 156 | beam_group.update_beam_status(beam, BeamStatus.SWAPPED) 157 | swapping_out.insert(0, beam_group) 158 | self.logger.debug( 159 | f"one request swapped out, " 160 | f"free_gpu_blocks={self.cache_manager.allocator.num_free_blocks}, " 161 | f"free_cpu_blocks={self.cache_manager.cpu_allocator.num_free_blocks}" 162 | ) 163 | self.running += running 164 | self.preempting += swapping_out 165 | # step2.c: 如果不发生 GPU->CPU 交互且存在被交换至 CPU 的等待继续被处理的请求时, 166 | # 将这些请求按优先级依次交换回 GPU 直至无法被换回 167 | if not swapping_out: 168 | preserved_num_blocks = sum(run_request2num_append_blocks.values()) 169 | while self.preempting: 170 | beam_group = self.preempting[0] 171 | if not self.cache_manager.can_swap_in(beam_group, preserved_num_blocks): 172 | self.logger.debug( 173 | f"attempt to swap in beam_group-{beam_group.request_id} but not have enough free gpu blocks." 174 | ) 175 | if not self.running: 176 | raise RuntimeError( 177 | "running queue is empty but still can't swap in request, " 178 | "please consider increase num_blocks or decrease max tokens number" 179 | ) 180 | else: 181 | break # exceed num available free gpu blocks if swap in this beam_group, break 182 | beam_group = self.preempting.pop(0) 183 | blocks_to_swap_in.update(self.cache_manager.swap_in(beam_group)) 184 | for beam in beam_group.get_beams(BeamStatus.SWAPPED): 185 | beam_group.update_beam_status(beam, BeamStatus.RUNNING) 186 | swapping_in.append(beam_group) 187 | preserved_num_blocks += sum( 188 | [ 189 | self.cache_manager.is_need_to_append_slot(beam) 190 | for beam in beam_group.get_beams(BeamStatus.RUNNING) 191 | ] 192 | ) 193 | self.logger.debug( 194 | f"one request swapped in, " 195 | f"free_gpu_blocks={self.cache_manager.allocator.num_free_blocks}, " 196 | f"free_cpu_blocks={self.cache_manager.cpu_allocator.num_free_blocks}" 197 | ) 198 | self.running += swapping_in 199 | # step2.d: 为运行队列中剩余的请求进行缓存空间的分配 200 | for beam_group in self.running: 201 | self.logger.debug( 202 | f"beam_group-{beam_group.request_id}'s beams' status: " 203 | f"{[beam.status.name for beam in beam_group.get_beams()]}" 204 | ) 205 | beams = beam_group.get_beams(BeamStatus.RUNNING) 206 | for beam in beams: 207 | self._append_slot(beam, blocks_to_copy) 208 | block_ids = self.cache_manager.get_block_table(beam.beam_id) 209 | batch.add(beam, block_ids) 210 | beam_group.beams.pop(beam.beam_id) 211 | 212 | # step3. 在不发生 GPU->CPU 交互的情况下,尝试将等待队列中的请求移至运行队列 213 | batch_tokens = batch.num_beams 214 | if (not swapping_out or not self.running) and not self.preempting: 215 | while self.waiting: 216 | beam_group = self.waiting[0] 217 | beam = beam_group.get_beams()[0] 218 | if batch_tokens + beam.num_tokens > self.batcher_config.batch_max_tokens: 219 | self.logger.debug( 220 | f"reach batch_max_tokens {self.batcher_config.batch_max_tokens}, " 221 | f"current batch_tokens {batch_tokens}" 222 | ) 223 | break 224 | if batch.num_beams + 1 > self.batcher_config.batch_max_beams: 225 | self.logger.debug( 226 | f"reach batch_max_beams {self.batcher_config.batch_max_beams}, " 227 | f"current batch_beams {batch.num_beams}" 228 | ) 229 | break 230 | has_cache_space = self._allocate(beam) 231 | if not has_cache_space: 232 | self.logger.debug("hasn't cache space to allocate") 233 | break 234 | beam_group = self.waiting.pop(0) 235 | beam = beam_group.get_beams()[0] 236 | batch_tokens += beam.num_tokens 237 | beam_group.update_beam_status(beam, BeamStatus.RUNNING) 238 | self.running.append(beam_group) 239 | block_ids = self.cache_manager.get_block_table(beam.beam_id) 240 | batch.add(beam, block_ids) 241 | beam_group.beams.pop(beam.beam_id) 242 | 243 | return batch, blocks_to_copy, blocks_to_swap_in, blocks_to_swap_out, finishing 244 | 245 | def _generate(self, old_beams: List[Beam], logits: torch.Tensor): 246 | next_tokens_chooser = HeterogeneousNextTokensChooser( 247 | dtype=logits.dtype, 248 | device=logits.device, 249 | temperature=[self.req_id2beam_group[beam.request_id].generation_config.temperature for beam in old_beams], 250 | repetition_penalty=[self.req_id2beam_group[beam.request_id].generation_config.repetition_penalty for beam in old_beams], 251 | top_k=[self.req_id2beam_group[beam.request_id].generation_config.top_k for beam in old_beams], 252 | top_p=[self.req_id2beam_group[beam.request_id].generation_config.top_p for beam in old_beams], 253 | typical_p=[self.req_id2beam_group[beam.request_id].generation_config.typical_p for beam in old_beams], 254 | do_sample=[self.req_id2beam_group[beam.request_id].generation_config.do_sample for beam in old_beams], 255 | num_beams=[self.req_id2beam_group[beam.request_id].generation_config.num_beams for beam in old_beams], 256 | seeds=[self.req_id2beam_group[beam.request_id].generation_config.seed for beam in old_beams] 257 | ) 258 | all_input_ids = [beam.token_ids for beam in old_beams] 259 | max_input_len = max([len(input_ids) for input_ids in all_input_ids]) 260 | all_input_ids = [input_ids + [0] * (max_input_len - len(input_ids)) for input_ids in all_input_ids] 261 | all_input_ids_tensor = torch.tensor(all_input_ids, dtype=torch.long, device=logits.device) 262 | outputs = next_tokens_chooser( 263 | input_ids=all_input_ids_tensor, 264 | scores=logits 265 | ) 266 | for beam, output in zip(old_beams, outputs): 267 | new_beams = [] 268 | generation_config = self.req_id2beam_group[beam.request_id].generation_config 269 | for nxt_token_id, nxt_prob in zip(output.next_token_ids, output.next_probs): 270 | new_beam = beam.copy() 271 | new_beam.append_token_id(nxt_token_id, nxt_prob) 272 | new_beam.check_finished( 273 | eos_token_id=generation_config.eos_token_id, 274 | max_new_tokens=generation_config.max_new_tokens 275 | ) 276 | new_beams.append(new_beam) 277 | self.req_id2beam_group[beam.request_id].cache_new_beams(new_beams) 278 | 279 | req_ids = list(set([beam.request_id for beam in old_beams])) 280 | for req_id in req_ids: 281 | beam_group = self.req_id2beam_group[req_id] 282 | if not beam_group.get_beams(BeamStatus.RUNNING): 283 | new_beams = beam_group.new_beams 284 | beam_group.clear_new_beams() 285 | new_beams = sorted(new_beams, reverse=True)[:beam_group.generation_config.num_beams] 286 | beam_group.add_beams(new_beams) 287 | for new_beam in new_beams: 288 | if not new_beam.is_finished: 289 | self.cache_manager.copy(new_beam.parent_beam_id, new_beam.beam_id) 290 | for old_beam in old_beams: 291 | if old_beam.request_id == req_id: 292 | self.cache_manager.free(old_beam.beam_id) 293 | 294 | def batch_generate(self, batch: Batch): 295 | if batch.prefill_beams: 296 | self._generate(batch.prefill_beams, batch.prefill_logits) 297 | if batch.generation_beams: 298 | self._generate(batch.generation_beams, batch.generation_logits) 299 | 300 | 301 | __all__ = ["Batch", "Batcher"] 302 | -------------------------------------------------------------------------------- /code/server/continuous_batching_server/beam.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import enum 3 | from typing import Dict, List, Optional 4 | from uuid import uuid4, UUID 5 | 6 | from .cache.cache import LogicalCacheBlock 7 | from protocol.completion_task import HuggingFaceGenerationConfig 8 | 9 | 10 | class BeamStatus(enum.Enum): 11 | WAITING = 0 12 | RUNNING = 1 13 | FINISHED = 2 14 | SWAPPED = 3 15 | 16 | 17 | class BeamFinishReason(enum.Enum): 18 | LENGTH = "length" 19 | STOP = "stop" 20 | ABORT = "abort" 21 | NOT_FINISHED = "not_finished" 22 | 23 | 24 | class Beam: 25 | def __init__( 26 | self, 27 | request_id: UUID, 28 | prompt: str, 29 | prompt_token_ids: List[int], 30 | block_size: int, 31 | ): 32 | self.request_id = request_id 33 | self.parent_beam_id = None 34 | self.beam_id = uuid4() 35 | self.prompt = prompt 36 | self.prompt_token_ids = prompt_token_ids 37 | self.block_size = block_size 38 | self.generated_token_ids = [] 39 | self.cumulative_logprob = 0.0 40 | 41 | self.cache_blocks: List[LogicalCacheBlock] = [] 42 | self._append_tokens_to_blocks(prompt_token_ids) 43 | 44 | self.status = BeamStatus.WAITING 45 | self._finish_reason = BeamFinishReason.NOT_FINISHED 46 | 47 | def _append_cache_block(self): 48 | block = LogicalCacheBlock(block_id=len(self.cache_blocks), block_size=self.block_size) 49 | self.cache_blocks.append(block) 50 | 51 | def _append_tokens_to_blocks(self, token_ids: List[int]): 52 | offset = 0 53 | while offset < len(token_ids): 54 | if not self.cache_blocks: 55 | self._append_cache_block() 56 | 57 | last_block = self.cache_blocks[-1] 58 | if last_block.is_full: 59 | self._append_cache_block() 60 | last_block = self.cache_blocks[-1] 61 | 62 | num_empty_slots = last_block.num_empty_slots 63 | last_block.append_tokens(token_ids[offset: offset + num_empty_slots]) 64 | 65 | offset += num_empty_slots 66 | 67 | def append_token_id(self, token_id: int, prob: float): 68 | self._append_tokens_to_blocks([token_id]) 69 | self.generated_token_ids.append(token_id) 70 | self.cumulative_logprob += prob 71 | 72 | @property 73 | def last_token_id(self): 74 | if not self.generated_token_ids: 75 | return self.prompt_token_ids[-1] 76 | return self.generated_token_ids[-1] 77 | 78 | @property 79 | def token_ids(self): 80 | return self.prompt_token_ids + self.generated_token_ids 81 | 82 | @property 83 | def num_tokens(self): 84 | return len(self.prompt_token_ids) + len(self.generated_token_ids) 85 | 86 | @property 87 | def num_generated_tokens(self): 88 | return len(self.generated_token_ids) 89 | 90 | def update_status(self, status: BeamStatus): 91 | self.status = status 92 | 93 | def copy(self) -> "Beam": 94 | beam = Beam( 95 | request_id=self.request_id, 96 | prompt=self.prompt, 97 | prompt_token_ids=copy.deepcopy(self.prompt_token_ids), 98 | block_size=self.block_size 99 | ) 100 | beam.parent_beam_id = self.beam_id 101 | beam.generated_token_ids = copy.deepcopy(self.generated_token_ids) 102 | beam.cumulative_logprob = self.cumulative_logprob 103 | 104 | beam.cache_blocks = copy.deepcopy(self.cache_blocks) 105 | beam.status = copy.deepcopy(self.status) 106 | beam._finish_reason = copy.deepcopy(self._finish_reason) 107 | 108 | return beam 109 | 110 | def check_finished( 111 | self, 112 | eos_token_id: int, 113 | max_new_tokens: int 114 | ) -> None: 115 | if not self.generated_token_ids: 116 | return 117 | if eos_token_id == int(self.generated_token_ids[-1]): 118 | self.status = BeamStatus.FINISHED 119 | self.finish_reason = BeamFinishReason.STOP 120 | if len(self.generated_token_ids) == max_new_tokens: 121 | self.status = BeamStatus.FINISHED 122 | self.finish_reason = BeamFinishReason.LENGTH 123 | 124 | def __eq__(self, other: "Beam"): 125 | return self.cumulative_logprob == other.cumulative_logprob 126 | 127 | def __gt__(self, other: "Beam"): 128 | return self.cumulative_logprob > other.cumulative_logprob 129 | 130 | def __ge__(self, other: "Beam"): 131 | return self.cumulative_logprob >= other.cumulative_logprob 132 | 133 | def __lt__(self, other: "Beam"): 134 | return self.cumulative_logprob < other.cumulative_logprob 135 | 136 | def __le__(self, other: "Beam"): 137 | return self.cumulative_logprob <= other.cumulative_logprob 138 | 139 | @property 140 | def finish_reason(self) -> str: 141 | return self._finish_reason.value 142 | 143 | @finish_reason.setter 144 | def finish_reason(self, finish_reason: BeamFinishReason): 145 | self._finish_reason = finish_reason 146 | 147 | @property 148 | def is_finished(self) -> bool: 149 | return self.status == BeamStatus.FINISHED 150 | 151 | 152 | class BeamGroup: 153 | """ A group of beams that are generated from the same prompt""" 154 | def __init__( 155 | self, 156 | request_id: UUID, 157 | arrival_time: float, 158 | beams: List[Beam], 159 | generation_config: HuggingFaceGenerationConfig 160 | ): 161 | self.request_id = request_id 162 | self.arrival_time = arrival_time 163 | self.beams: Dict[UUID, Beam] = {beam.beam_id: beam for beam in beams} 164 | self.generation_config = generation_config 165 | 166 | self._new_beams: List[Beam] = [] 167 | 168 | def add_beams(self, beams: List[Beam]): 169 | for beam in beams: 170 | self.beams[beam.beam_id] = beam 171 | 172 | def get_beams(self, status: Optional[BeamStatus] = None) -> List[Beam]: 173 | if status is None: 174 | return list(self.beams.values()) 175 | else: 176 | return [beam for beam in self.beams.values() if beam.status == status] 177 | 178 | def num_beams(self, status: Optional[BeamStatus] = None) -> int: 179 | return len(self.get_beams(status)) 180 | 181 | def cache_new_beams(self, new_beams: List[Beam]): 182 | self._new_beams += new_beams 183 | 184 | def clear_new_beams(self): 185 | self._new_beams = [] 186 | 187 | @property 188 | def new_beams(self): 189 | return self._new_beams 190 | 191 | @staticmethod 192 | def update_beam_status(beam: Beam, status: BeamStatus): 193 | beam.update_status(status) 194 | 195 | def find(self, beam_id: UUID) -> Beam: 196 | if beam_id not in self.beams: 197 | raise LookupError(f"Beam {beam_id} not found.") 198 | return self.beams[beam_id] 199 | 200 | @property 201 | def is_finished(self) -> bool: 202 | if not self.generation_config.early_stopping: 203 | return all(beam.is_finished for beam in self.beams.values()) 204 | else: 205 | num_finished_beams = len(self.get_beams(BeamStatus.FINISHED)) 206 | return num_finished_beams >= self.generation_config.num_return_sequences 207 | 208 | def get_final_beams(self) -> List[Beam]: 209 | if not self.is_finished: 210 | raise AttributeError("Can't get final beams for they are not finished.") 211 | return sorted(self.get_beams(BeamStatus.FINISHED), reverse=True)[:self.generation_config.num_return_sequences] 212 | 213 | 214 | __all__ = [ 215 | "Beam", 216 | "BeamGroup", 217 | "BeamStatus", 218 | "BeamFinishReason", 219 | ] 220 | -------------------------------------------------------------------------------- /code/server/continuous_batching_server/cache/README.md: -------------------------------------------------------------------------------- 1 | > Codes in this package are copied directly from vLLM with some modifications. 2 | -------------------------------------------------------------------------------- /code/server/continuous_batching_server/cache/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/modelize-ai/LLM-Inference-Deployment-Tutorial/e3264eacc4752a7d829241fb614a0b81892cdc8f/code/server/continuous_batching_server/cache/__init__.py -------------------------------------------------------------------------------- /code/server/continuous_batching_server/cache/cache.py: -------------------------------------------------------------------------------- 1 | # adopt from vllm 2 | 3 | from typing import * 4 | 5 | import torch 6 | import vllm.cache_ops as vllm_cache_ops 7 | 8 | from ..config import CacheConfig, ModelConfig, ParallelConfig 9 | 10 | 11 | KVCache = Tuple[torch.Tensor, torch.Tensor] 12 | 13 | 14 | class LogicalCacheBlock: 15 | def __init__(self, block_id: int, block_size: int): 16 | self.block_id = block_id 17 | self.block_size = block_size 18 | 19 | self._token_ids = [-1] * block_size 20 | self.num_tokens = 0 21 | 22 | @property 23 | def is_empty(self): 24 | return self.num_tokens == 0 25 | 26 | @property 27 | def num_empty_slots(self): 28 | return self.block_size - self.num_tokens 29 | 30 | @property 31 | def is_full(self): 32 | return self.num_tokens == self.block_size 33 | 34 | def append_tokens(self, token_ids: List[int]): 35 | assert len(token_ids) <= self.num_empty_slots 36 | offset = self.num_tokens 37 | self._token_ids[offset: offset + len(token_ids)] = token_ids 38 | self.num_tokens += len(token_ids) 39 | 40 | @property 41 | def token_ids(self): 42 | return self._token_ids[:self.num_tokens] 43 | 44 | @property 45 | def last_token_id(self): 46 | assert self.num_tokens > 0 47 | return self.token_ids[self.num_tokens - 1] 48 | 49 | 50 | class PhysicalCacheBlock: 51 | def __init__(self, block_id: int, block_size: int): 52 | self.block_id = block_id 53 | self.block_size = block_size 54 | 55 | self.ref_count = 0 56 | 57 | def __repr__(self): 58 | return f"[block_id={self.block_id} block_size={self.block_size} ref_count={self.ref_count}]" 59 | 60 | def __str__(self): 61 | return self.__repr__() 62 | 63 | 64 | class Cache: 65 | def __init__( 66 | self, 67 | cache_config: CacheConfig, 68 | model_config: ModelConfig, 69 | parallel_config: ParallelConfig, 70 | dtype: torch.dtype, 71 | device: torch.device 72 | ): 73 | self.cache_config = cache_config 74 | self.model_config = model_config 75 | self.parallel_config = parallel_config 76 | 77 | self.head_size = model_config.get_head_size() 78 | self.num_layers = model_config.get_num_layers() 79 | self.num_heads = model_config.get_num_heads() 80 | 81 | self.dtype = dtype 82 | self.device = device 83 | 84 | self.block_size = cache_config.block_size 85 | self.num_blocks = cache_config.num_blocks 86 | self.num_blocks_cpu = cache_config.num_blocks_cpu 87 | 88 | self.cache = [ 89 | ( 90 | torch.empty( 91 | size=(self.num_blocks, *self.key_block_shape), 92 | dtype=self.dtype, 93 | device=self.device 94 | ), 95 | torch.empty( 96 | size=(self.num_blocks, *self.value_block_shape), 97 | dtype=self.dtype, 98 | device=self.device 99 | ) 100 | ) for _ in range(self.num_layers) 101 | ] 102 | 103 | self.cpu_cache = [ 104 | ( 105 | torch.empty( 106 | size=(self.num_blocks_cpu, *self.key_block_shape), 107 | dtype=self.dtype, 108 | pin_memory=cache_config.pin_memory_on_cpu, 109 | ), 110 | torch.empty( 111 | size=(self.num_blocks_cpu, *self.value_block_shape), 112 | dtype=self.dtype, 113 | pin_memory=cache_config.pin_memory_on_cpu, 114 | ) 115 | ) for _ in range(self.num_layers) 116 | ] 117 | 118 | self.cache_stream = torch.cuda.Stream() 119 | assert self.cache_stream != torch.cuda.current_stream() 120 | 121 | self.events = [torch.cuda.Event() for _ in range(self.num_layers)] 122 | 123 | @property 124 | def key_block_shape(self) -> Tuple[int, int, int, int]: 125 | element_size = torch.tensor([], dtype=self.dtype).element_size() 126 | x = 16 // element_size 127 | return ( 128 | self.num_heads, 129 | self.head_size // x, 130 | self.block_size, 131 | x, 132 | ) 133 | 134 | @property 135 | def value_block_shape(self) -> Tuple[int, int, int]: 136 | return ( 137 | self.num_heads, 138 | self.head_size, 139 | self.block_size, 140 | ) 141 | 142 | def _swap( 143 | self, 144 | src: List[KVCache], 145 | dst: List[KVCache], 146 | src_to_dst: Dict[int, int], 147 | ) -> None: 148 | with torch.cuda.stream(self.cache_stream): 149 | for i in range(self.num_layers): 150 | src_key_cache, src_value_cache = src[i] 151 | dst_key_cache, dst_value_cache = dst[i] 152 | # Copy the key blocks. 153 | vllm_cache_ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst) 154 | # Copy the value blocks. 155 | vllm_cache_ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst) 156 | event = self.events[i] 157 | event.record(stream=self.cache_stream) 158 | 159 | def swap_in(self, src_to_dst: Dict[int, int]) -> None: 160 | self._swap(self.cpu_cache, self.cache, src_to_dst) 161 | 162 | def swap_out(self, src_to_dst: Dict[int, int]) -> None: 163 | self._swap(self.cache, self.cpu_cache, src_to_dst) 164 | 165 | def copy(self, src_to_dsts: Dict[int, List[int]]) -> None: 166 | key_caches = [key_cache for key_cache, _ in self.cache] 167 | value_caches = [value_cache for _, value_cache in self.cache] 168 | # This operation implicitly synchronizes the CPU and GPU. 169 | vllm_cache_ops.copy_blocks(key_caches, value_caches, src_to_dsts) 170 | 171 | @staticmethod 172 | def get_cache_block_size( 173 | block_size: int, 174 | model_config: ModelConfig, 175 | dtype: torch.dtype 176 | ) -> int: 177 | head_size = model_config.get_head_size() 178 | num_heads = model_config.get_num_heads() 179 | num_layers = model_config.get_num_layers() 180 | 181 | key_cache_block = block_size * num_heads * head_size 182 | value_cache_block = key_cache_block 183 | total = num_layers * (key_cache_block + value_cache_block) 184 | dtype_size = torch.tensor([], dtype=dtype).element_size() 185 | return dtype_size * total 186 | 187 | 188 | __all__ = [ 189 | "LogicalCacheBlock", 190 | "PhysicalCacheBlock", 191 | "Cache" 192 | ] 193 | -------------------------------------------------------------------------------- /code/server/continuous_batching_server/cache/cache_manager.py: -------------------------------------------------------------------------------- 1 | # adopt from vllm 2 | 3 | from logging import getLogger 4 | from typing import * 5 | from uuid import UUID 6 | 7 | from .cache import LogicalCacheBlock, PhysicalCacheBlock 8 | from ..beam import Beam, BeamGroup, BeamStatus 9 | 10 | 11 | logger = getLogger(__name__) 12 | 13 | 14 | class OOMError(Exception): 15 | pass 16 | 17 | 18 | class DoubleFreeBlockError(Exception): 19 | pass 20 | 21 | 22 | class CacheBlockAllocator: 23 | def __init__(self, block_size: int, num_blocks: int): 24 | self.block_size = block_size 25 | self.num_blocks = num_blocks 26 | 27 | self.free_blocks: List[PhysicalCacheBlock] = [ 28 | PhysicalCacheBlock(block_id=i, block_size=block_size) for i in range(num_blocks) 29 | ] 30 | 31 | def allocate(self) -> PhysicalCacheBlock: 32 | if not self.free_blocks: 33 | raise OOMError("No free blocks are available.") 34 | block = self.free_blocks.pop(0) 35 | block.ref_count = 1 36 | return block 37 | 38 | def free(self, block: PhysicalCacheBlock) -> None: 39 | if block.ref_count == 0: 40 | raise DoubleFreeBlockError(f"Double free! block-{block.block_id} is already freed.") 41 | block.ref_count -= 1 42 | if block.ref_count == 0: 43 | self.free_blocks.append(block) 44 | # self.free_blocks = sorted(self.free_blocks, key=lambda block: block.block_id) 45 | 46 | @property 47 | def num_free_blocks(self) -> int: 48 | return len(self.free_blocks) 49 | 50 | 51 | BlockTable = List[PhysicalCacheBlock] 52 | 53 | 54 | class CacheBlockManager: 55 | def __init__(self, block_size: int, num_blocks: int, num_blocks_cpu: int, watermark: float = 0.01): 56 | self.block_size = block_size 57 | self.num_blocks = num_blocks 58 | self.watermark = watermark 59 | assert self.watermark >= 0.0 60 | 61 | self.watermark_blocks = int(watermark * num_blocks) 62 | self.allocator = CacheBlockAllocator(block_size, num_blocks) 63 | self.cpu_allocator = CacheBlockAllocator(block_size, num_blocks_cpu) 64 | 65 | self.block_tables: Dict[UUID, BlockTable] = {} 66 | 67 | def can_allocate(self, beam: Beam): 68 | num_required_blocks = len(beam.cache_blocks) 69 | num_free_blocks = self.allocator.num_free_blocks 70 | # Use watermark to avoid frequent cache eviction. 71 | return num_free_blocks - num_required_blocks >= self.watermark_blocks 72 | 73 | def allocate(self, beam: Beam): 74 | # NOTE: only do to beam that is 'init' beam. 75 | block_table: BlockTable = [] 76 | 77 | # Allocate new physical cache blocks that will store the prompt tokens. 78 | for _ in range(len(beam.cache_blocks)): 79 | block = self.allocator.allocate() 80 | block_table.append(block) 81 | 82 | self.block_tables[beam.beam_id] = block_table.copy() 83 | logger.debug(f"beam-{beam.beam_id} allocate block_table: {[block.block_id for block in block_table]}") 84 | 85 | def is_need_to_append_slot(self, beam: Beam): 86 | logical_blocks = beam.cache_blocks 87 | block_table: BlockTable = self.block_tables[beam.beam_id] 88 | 89 | if len(block_table) < len(logical_blocks): 90 | return True 91 | if block_table[-1].ref_count > 1: 92 | # The last block is shared with other beams, which means should copy on write 93 | return True 94 | return False 95 | 96 | def can_append_slot(self): 97 | return self.allocator.num_free_blocks >= 1 98 | 99 | def append_slot(self, beam: Beam) -> Optional[Tuple[int, int]]: 100 | logical_blocks = beam.cache_blocks 101 | block_table: BlockTable = self.block_tables[beam.beam_id] 102 | 103 | if len(block_table) < len(logical_blocks): 104 | block = self.allocator.allocate() 105 | block_table.append(block) 106 | logger.debug(f"beam-{beam.beam_id} add one block-{block.block_id}") 107 | return 108 | 109 | last_block = block_table[-1] 110 | if last_block.ref_count == 1: 111 | # Not shared with other sequences. Appendable. 112 | return 113 | else: 114 | # The last block is shared with other sequences. 115 | # Copy on Write: Allocate a new block and copy the tokens. 116 | new_block = self.allocator.allocate() 117 | block_table[-1] = new_block 118 | self.allocator.free(last_block) 119 | logger.debug(f"beam-{beam.beam_id} replace block-{last_block.block_id} with block-{new_block.block_id}") 120 | return last_block.block_id, new_block.block_id 121 | 122 | def copy(self, parent: UUID, child: UUID): 123 | src_block_table = self.block_tables[parent] 124 | self.block_tables[child] = src_block_table.copy() 125 | for block in src_block_table: 126 | block.ref_count += 1 127 | logger.debug(f"beam-{parent} copy block_table: {[block.block_id for block in src_block_table]} to beam-{child}") 128 | 129 | def _get_physical_blocks(self, beam_group: BeamGroup) -> List[PhysicalCacheBlock]: 130 | blocks: Set[PhysicalCacheBlock] = set() 131 | for beam in beam_group.get_beams(): 132 | if beam.is_finished: 133 | continue 134 | block_table = self.block_tables[beam.beam_id] 135 | for block in block_table: 136 | blocks.add(block) 137 | return list(blocks) 138 | 139 | def can_swap_in(self, beam_group: BeamGroup, preserved_num_blocks: int = 0) -> bool: 140 | blocks = self._get_physical_blocks(beam_group) 141 | num_swapped_seqs = beam_group.num_beams(status=BeamStatus.SWAPPED) 142 | num_free_blocks = self.allocator.num_free_blocks 143 | # NOTE: Conservatively, we assume that every sequence will allocate 144 | # at least one free block right after the swap-in. 145 | # NOTE: This should match the logic in can_append_slot(). 146 | num_required_blocks = len(blocks) + num_swapped_seqs 147 | return num_free_blocks - num_required_blocks - preserved_num_blocks >= self.watermark_blocks 148 | 149 | def swap_in(self, beam_group: BeamGroup) -> Dict[int, int]: 150 | # CPU block -> GPU block. 151 | mapping: Dict[PhysicalCacheBlock, PhysicalCacheBlock] = {} 152 | for beam in beam_group.get_beams(): 153 | if beam.is_finished: 154 | continue 155 | new_block_table: BlockTable = [] 156 | block_table = self.block_tables[beam.beam_id] 157 | 158 | for cpu_block in block_table: 159 | if cpu_block in mapping: 160 | gpu_block = mapping[cpu_block] 161 | gpu_block.ref_count += 1 162 | else: 163 | gpu_block = self.allocator.allocate() 164 | mapping[cpu_block] = gpu_block 165 | new_block_table.append(gpu_block) 166 | # Free the CPU block swapped in to GPU. 167 | self.cpu_allocator.free(cpu_block) 168 | self.block_tables[beam.beam_id] = new_block_table 169 | 170 | block_ids_map = { 171 | cpu_block.block_id: gpu_block.block_id 172 | for cpu_block, gpu_block in mapping.items() 173 | } 174 | logger.debug( 175 | f"swap in beam_group-{beam_group.request_id} from CPU to GPU." 176 | ) 177 | return block_ids_map 178 | 179 | def can_swap_out(self, beam_group: BeamGroup) -> bool: 180 | blocks = self._get_physical_blocks(beam_group) 181 | return len(blocks) <= self.cpu_allocator.num_free_blocks 182 | 183 | def swap_out(self, beam_group: BeamGroup) -> Dict[int, int]: 184 | # GPU block -> CPU block. 185 | mapping: Dict[PhysicalCacheBlock, PhysicalCacheBlock] = {} 186 | for beam in beam_group.get_beams(): 187 | if beam.is_finished: 188 | continue 189 | new_block_table: BlockTable = [] 190 | block_table = self.block_tables[beam.beam_id] 191 | 192 | for gpu_block in block_table: 193 | if gpu_block in mapping: 194 | cpu_block = mapping[gpu_block] 195 | cpu_block.ref_count += 1 196 | else: 197 | cpu_block = self.cpu_allocator.allocate() 198 | mapping[gpu_block] = cpu_block 199 | new_block_table.append(cpu_block) 200 | # Free the GPU block swapped out to CPU. 201 | self.allocator.free(gpu_block) 202 | self.block_tables[beam.beam_id] = new_block_table 203 | 204 | block_ids_map = { 205 | gpu_block.block_id: cpu_block.block_id 206 | for gpu_block, cpu_block in mapping.items() 207 | } 208 | logger.debug( 209 | f"swap out beam_group-{beam_group.request_id} from GPU to CPU." 210 | ) 211 | return block_ids_map 212 | 213 | def _free_block_table(self, block_table: BlockTable): 214 | for block in block_table: 215 | self.allocator.free(block) 216 | 217 | def free(self, beam_id: UUID): 218 | if beam_id not in self.block_tables: 219 | return 220 | block_table = self.block_tables.pop(beam_id) 221 | self._free_block_table(block_table) 222 | logger.debug(f"beam-{beam_id} free blocks: {[block.block_id for block in block_table]}") 223 | 224 | def reset(self): 225 | for block_table in self.block_tables.values(): 226 | self._free_block_table(block_table) 227 | self.block_tables.clear() 228 | 229 | def get_block_table(self, beam_id: UUID) -> List[int]: 230 | block_table = self.block_tables[beam_id] 231 | return [block.block_id for block in block_table] 232 | 233 | @property 234 | def num_free_blocks(self): 235 | return self.allocator.num_free_blocks 236 | 237 | @property 238 | def num_free_cpu_blocks(self): 239 | return self.cpu_allocator.num_free_blocks 240 | -------------------------------------------------------------------------------- /code/server/continuous_batching_server/config.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | from typing import Optional 3 | 4 | import torch 5 | from pydantic import BaseModel, Field 6 | from transformers import PretrainedConfig 7 | 8 | from .modeling.llama import LlamaForCausalLM, LlamaConfig 9 | 10 | 11 | ModelFactory = namedtuple( 12 | "ModelFactory", 13 | [ 14 | "model_cls", 15 | "model_config_cls" 16 | ] 17 | ) 18 | 19 | 20 | TORCH_FLOAT_DTYPE_MAP = { 21 | "float": torch.float32, 22 | "float32": torch.float32, 23 | "float16": torch.float16, 24 | "bfloat16": torch.bfloat16, 25 | "int32": torch.int32, 26 | "int64": torch.int64 27 | } 28 | 29 | MODEL_AUTO_TABLE = { 30 | "llama": ModelFactory(model_cls=LlamaForCausalLM, model_config_cls=LlamaConfig) 31 | } 32 | 33 | 34 | class BatcherConfig(BaseModel): 35 | # default value is suitable for 7B model in A100 GPU, this controls prefill rate of each step 36 | batch_max_tokens: int = Field(default=56000) 37 | # default value is suitable for 7B model in A100 GPU, this controls batch size of each step 38 | batch_max_beams: int = Field(default=32) 39 | 40 | 41 | class CacheConfig(BaseModel): 42 | num_blocks: Optional[int] = Field(default=2500) # default value is suitable for 7B model in A100 GPU 43 | num_blocks_cpu: Optional[int] = Field(default=1024) 44 | block_size: int = Field(default=16) 45 | gpu_memory_utilization: float = Field(default=0.98) 46 | watermark: float = Field(default=0.01) 47 | pin_memory_on_cpu: bool = Field(default=True) 48 | 49 | 50 | class ModelLoadingConfig(BaseModel): 51 | model_type: str = Field(default="llama", regex="(" + "|".join(list(MODEL_AUTO_TABLE.keys())) + ")") 52 | model_name_or_path: str = Field(default="dummy_path_to_llama_model") 53 | torch_dtype: str = Field(default="float16", regex="(" + "|".join(list(TORCH_FLOAT_DTYPE_MAP.keys())) + ")") 54 | tokenizer_name_or_path: Optional[str] = Field(default=None) 55 | use_fast_tokenizer: bool = Field(default=False) 56 | trust_remote_code: bool = Field(default=False) 57 | quantize_method: Optional[str] = Field(default=None, regex="(gptq|)") 58 | model_max_length: int = Field(default=2048) 59 | device: int = Field(default=0) 60 | gptq_model_base_name: Optional[str] = Field(default=None) 61 | gptq_config_base_name: Optional[str] = Field(default=None) 62 | 63 | 64 | class ParallelConfig(BaseModel): 65 | tp_size: int = Field(default=1) 66 | 67 | 68 | class ModelConfig: 69 | def __init__(self, model_config: PretrainedConfig, parallel_config: ParallelConfig): 70 | self.model_config = model_config 71 | self.parallel_config = parallel_config 72 | 73 | def get_hidden_size(self): 74 | return self.model_config.hidden_size 75 | 76 | def get_head_size(self): 77 | return self.model_config.hidden_size // self.model_config.num_attention_heads 78 | 79 | def get_num_heads(self): 80 | # For GPTBigCode: 81 | if getattr(self.model_config, "multi_query", False): 82 | # Multi-query attention, only one KV head. 83 | return 1 84 | # For Falcon: 85 | if getattr(self.model_config, "n_head_kv", None) is not None: 86 | return self.model_config.n_head_kv 87 | 88 | return self.model_config.num_attention_heads // self.parallel_config.tp_size 89 | 90 | def get_num_layers(self): 91 | return self.model_config.num_hidden_layers 92 | 93 | 94 | class ServerConfig(BaseModel): 95 | model_loading_config: ModelLoadingConfig = Field(default=ModelLoadingConfig()) 96 | batcher_config: BatcherConfig = Field(default=BatcherConfig()) 97 | cache_config: CacheConfig = Field(default=CacheConfig()) 98 | parallel_config: ParallelConfig = Field(default=ParallelConfig()) 99 | 100 | 101 | __all__ = [ 102 | "ModelFactory", 103 | "TORCH_FLOAT_DTYPE_MAP", 104 | "MODEL_AUTO_TABLE", 105 | "BatcherConfig", 106 | "CacheConfig", 107 | "ModelLoadingConfig", 108 | "ParallelConfig", 109 | "ModelConfig", 110 | "ServerConfig" 111 | ] 112 | -------------------------------------------------------------------------------- /code/server/continuous_batching_server/generation_utils/README.md: -------------------------------------------------------------------------------- 1 | > Codes in this package are copied directly from TGI with some modifications. 2 | -------------------------------------------------------------------------------- /code/server/continuous_batching_server/generation_utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .tokens import HeterogeneousNextTokensChooser 2 | -------------------------------------------------------------------------------- /code/server/continuous_batching_server/generation_utils/logits_process.py: -------------------------------------------------------------------------------- 1 | # Adopt from Hugging Face text-generation-inference 2 | 3 | import math 4 | import torch 5 | 6 | from typing import List, Dict, Union 7 | 8 | from transformers import LogitsWarper, LogitsProcessor 9 | 10 | mempool = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None 11 | 12 | 13 | class HeterogeneousRepetitionPenaltyLogitsProcessor(LogitsProcessor): 14 | r""" 15 | [`LogitsProcessor`] enforcing an exponential penalty on repeated sequences. 16 | This version allows for a separate value for each sample and runs inplace when possible. 17 | It doesn't validate inputs. 18 | 19 | Args: 20 | repetition_penalty (`List[float]`): 21 | The parameter for repetition penalty. 1.0 means no penalty. See [this 22 | paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. 23 | """ 24 | 25 | def __init__(self, penalty: List[float], dtype: torch.dtype, device: torch.device): 26 | self.penalty = penalty 27 | self.penalty_tensor = torch.tensor( 28 | penalty, dtype=dtype, device=device 29 | ).unsqueeze(1) 30 | 31 | def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor: 32 | score = torch.gather(scores, 1, input_ids) 33 | 34 | # if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability 35 | score = torch.where( 36 | score < 0, score * self.penalty_tensor, score / self.penalty_tensor 37 | ) 38 | 39 | scores.scatter_(1, input_ids, score) 40 | return scores 41 | 42 | 43 | class HeterogeneousTemperatureLogitsWarper: 44 | r""" 45 | [`LogitsWarper`] for temperature (exponential scaling output probability distribution). 46 | This version allows for a separate value for each sample and runs inplace when possible. 47 | It doesn't validate inputs. 48 | 49 | Args: 50 | temperature (`float`): 51 | The value used to module the logits distribution. 52 | """ 53 | 54 | def __init__( 55 | self, temperature: List[float], dtype: torch.dtype, device: torch.device 56 | ): 57 | self.temperature = temperature 58 | self.temperature_tensor = torch.tensor( 59 | temperature, dtype=dtype, device=device 60 | ).unsqueeze(1) 61 | 62 | def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor: 63 | scores.div_(self.temperature_tensor) 64 | return scores 65 | 66 | 67 | class HeterogeneousTopPLogitsWarper(LogitsWarper): 68 | """ 69 | [`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off. 70 | This version allows for a separate value for each sample and runs inplace when possible. 71 | It doesn't validate inputs. 72 | 73 | Args: 74 | top_p (`float`): 75 | If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or 76 | higher are kept for generation. 77 | filter_value (`float`, *optional*, defaults to `-float("Inf")`): 78 | All filtered values will be set to this float value. 79 | min_tokens_to_keep (`int`, *optional*, defaults to 1): 80 | Minimum number of tokens that cannot be filtered. 81 | """ 82 | 83 | def __init__( 84 | self, 85 | top_p: List[float], 86 | dtype: torch.dtype, 87 | device: torch.device, 88 | filter_value: float = -math.inf, 89 | min_tokens_to_keep: int = 1, 90 | ): 91 | self.top_p = top_p 92 | self.top_p_opposite = 1 - torch.tensor( 93 | top_p, dtype=dtype, device=device 94 | ).unsqueeze(1) 95 | self.filter_value = filter_value 96 | self.min_tokens_to_keep = min_tokens_to_keep 97 | 98 | def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor: 99 | sorted_logits, sorted_indices = torch.sort(scores, descending=False) 100 | probs = sorted_logits.softmax(dim=-1) 101 | # This is way faster for some reason 102 | for i in range(probs.shape[0]): 103 | probs[i] = probs[i].cumsum(dim=-1) 104 | 105 | # Remove tokens with cumulative top_p above the threshold (token with 0 are kept) 106 | sorted_indices_to_remove = probs <= self.top_p_opposite 107 | # Keep at least min_tokens_to_keep 108 | sorted_indices_to_remove[..., -self.min_tokens_to_keep:] = 0 109 | 110 | # scatter sorted tensors to original indexing 111 | indices_to_remove = sorted_indices_to_remove.scatter( 112 | 1, sorted_indices, sorted_indices_to_remove 113 | ) 114 | warped_scores = scores.masked_fill_(indices_to_remove, self.filter_value) 115 | 116 | return warped_scores 117 | 118 | 119 | class HeterogeneousTopKLogitsWarper(LogitsWarper): 120 | r""" 121 | [`LogitsWarper`] that performs top-k, i.e. restricting to the k highest probability elements. 122 | This version allows for a separate value for each sample and runs inplace when possible. 123 | It doesn't validate inputs. 124 | 125 | Args: 126 | top_k (`int`): 127 | The number of highest probability vocabulary tokens to keep for top-k-filtering. 128 | filter_value (`float`, *optional*, defaults to `-float("Inf")`): 129 | All filtered values will be set to this float value. 130 | min_tokens_to_keep (`int`, *optional*, defaults to 1): 131 | Minimum number of tokens that cannot be filtered. 132 | """ 133 | 134 | def __init__( 135 | self, 136 | top_k: List[int], 137 | device: torch.device, 138 | filter_value: float = -math.inf, 139 | min_tokens_to_keep: int = 1, 140 | ): 141 | self.top_k = top_k 142 | self.max_top_k = max(top_k) 143 | # value - 1 as we will use top_k to index and python uses 0 based numbering 144 | self.top_k_tensor = torch.tensor( 145 | [max(x - 1, min_tokens_to_keep - 1) for x in top_k], 146 | dtype=torch.int64, 147 | device=device, 148 | ).unsqueeze(1) 149 | 150 | # 0 is a special value that disables top_k warping for this member of the batch 151 | disabled = [x == 0 for x in top_k] 152 | 153 | if any(disabled): 154 | self.top_k_disabled_mask = torch.tensor( 155 | disabled, dtype=torch.bool, device=device 156 | ).view(-1, 1) 157 | else: 158 | self.top_k_disabled_mask = None 159 | 160 | self.filter_value = filter_value 161 | 162 | def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor: 163 | # If max_top_k is superior to the vocab, we need to clamp or the warper will fail 164 | if scores.size(-1) < self.max_top_k: 165 | max_top_k = scores.size(-1) 166 | top_k = torch.clamp_max(self.top_k_tensor, max_top_k) 167 | else: 168 | max_top_k = self.max_top_k 169 | top_k = self.top_k_tensor 170 | 171 | # Get the kth score for each member of the batch 172 | kth_scores = torch.gather(torch.topk(scores, max_top_k)[0], 1, top_k) 173 | 174 | # Mask member of kth_scores that do not want to use top_k warping 175 | if self.top_k_disabled_mask is not None: 176 | kth_scores.masked_fill_(self.top_k_disabled_mask, self.filter_value) 177 | 178 | # Remove all tokens with a probability less than the last token of the top-k 179 | indices_to_remove = scores < kth_scores 180 | scores.masked_fill_(indices_to_remove, self.filter_value) 181 | return scores 182 | 183 | 184 | class HeterogeneousTypicalLogitsWarper(LogitsWarper): 185 | r""" 186 | [`LogitsWarper`] that performs typical decoding. See [Typical Decoding for Natural Language 187 | Generation](https://arxiv.org/abs/2202.00666) for more information. 188 | This version allows for a separate value for each sample and runs inplace when possible. 189 | It doesn't validate inputs. 190 | 191 | Args: 192 | mass (`float`): 193 | Value of typical_p between 0 and 1 inclusive, defaults to 0.9. 194 | filter_value (`float`, *optional*, defaults to `-float("Inf")`): 195 | All filtered values will be set to this float value. 196 | min_tokens_to_keep (`int`, *optional*, defaults to 1): 197 | Minimum number of tokens that cannot be filtered. 198 | """ 199 | 200 | def __init__( 201 | self, 202 | mass: List[float], 203 | dtype: torch.dtype, 204 | device: torch.device, 205 | filter_value: float = -math.inf, 206 | min_tokens_to_keep: int = 1, 207 | ): 208 | self.mass = mass 209 | self.mass_tensor = torch.tensor(mass, dtype=dtype, device=device).unsqueeze(1) 210 | 211 | # 1 is a special value that disables typical_p warping for this member of the batch 212 | disabled = [x == 1.0 for x in mass] 213 | 214 | if any(disabled): 215 | self.disabled_mask = torch.tensor(disabled, dtype=torch.bool, device=device) 216 | else: 217 | self.disabled_mask = None 218 | 219 | self.filter_value = filter_value 220 | self.min_tokens_to_keep = min_tokens_to_keep 221 | 222 | def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor: 223 | # calculate entropy 224 | normalized = torch.nn.functional.log_softmax(scores, dim=-1) 225 | p = torch.exp(normalized) 226 | ent = -(normalized * p).nansum(-1, keepdim=True) 227 | 228 | # shift and sort 229 | shifted_scores = torch.abs((-normalized) - ent) 230 | sorted_scores, sorted_indices = torch.sort(shifted_scores, descending=False) 231 | sorted_logits = scores.gather(-1, sorted_indices) 232 | probs = sorted_logits.softmax(dim=-1) 233 | # This is way faster for some reason 234 | for i in range(probs.shape[0]): 235 | probs[i] = probs[i].cumsum(dim=-1) 236 | 237 | # Remove tokens with cumulative mass above the threshold 238 | last_ind = (probs < self.mass_tensor).sum(dim=1) 239 | last_ind[last_ind < 0] = 0 240 | 241 | if self.disabled_mask is not None: 242 | last_ind.masked_fill_(self.disabled_mask, scores.shape[-1] - 1) 243 | 244 | sorted_indices_to_remove = sorted_scores > sorted_scores.gather( 245 | 1, last_ind.view(-1, 1) 246 | ) 247 | if self.min_tokens_to_keep > 1: 248 | # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) 249 | sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0 250 | indices_to_remove = sorted_indices_to_remove.scatter( 251 | 1, sorted_indices, sorted_indices_to_remove 252 | ) 253 | 254 | warped_scores = scores.masked_fill_(indices_to_remove, self.filter_value) 255 | 256 | return warped_scores 257 | 258 | 259 | class HeterogeneousProcessorWrapper(LogitsProcessor): 260 | r""" 261 | A wrapper for logit warpers or processors without heterogeneous parameter support. 262 | Args: 263 | processors (`Dict[int, Union[LogitsProcessor, LogitsWarper]]`): 264 | A mapping of sample indices to logit warpers or processors, to be run sequentially. 265 | """ 266 | 267 | def __init__( 268 | self, 269 | processors: Dict[int, Union[LogitsProcessor, LogitsWarper]], 270 | ): 271 | self.processors = processors 272 | 273 | def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor: 274 | for i, processor in self.processors.items(): 275 | scores[i: i + 1] = processor(input_ids[i: i + 1], scores[i: i + 1]) 276 | return scores 277 | -------------------------------------------------------------------------------- /code/server/continuous_batching_server/generation_utils/tokens.py: -------------------------------------------------------------------------------- 1 | # Adopt from Hugging Face text-generation-inference 2 | 3 | from typing import * 4 | 5 | import torch 6 | from pydantic import BaseModel, Field, Required 7 | 8 | from .logits_process import ( 9 | HeterogeneousRepetitionPenaltyLogitsProcessor, 10 | HeterogeneousTemperatureLogitsWarper, 11 | HeterogeneousTopKLogitsWarper, 12 | HeterogeneousTopPLogitsWarper, 13 | HeterogeneousTypicalLogitsWarper, 14 | HeterogeneousProcessorWrapper, 15 | ) 16 | 17 | 18 | class NextTokensChooserOutput(BaseModel): 19 | next_probs: List[float] = Field(default=Required) 20 | next_token_ids: List[int] = Field(default=Required) 21 | 22 | 23 | class HeterogeneousNextTokensChooser: 24 | def __init__( 25 | self, 26 | dtype: torch.dtype, 27 | device: torch.device, 28 | temperature: List[float], 29 | repetition_penalty: List[float], 30 | top_k: List[int], 31 | top_p: List[float], 32 | typical_p: List[float], 33 | do_sample: List[bool], 34 | num_beams: List[int], 35 | seeds: List[int], 36 | ): 37 | warpers = [] 38 | 39 | self.repetition_processor = ( 40 | HeterogeneousRepetitionPenaltyLogitsProcessor( 41 | repetition_penalty, dtype, device 42 | ) 43 | if any([x != 1.0 for x in repetition_penalty]) 44 | else None 45 | ) 46 | 47 | if any([x != 1.0 for x in temperature]): 48 | do_sample = [ 49 | sample or x != 1.0 for x, sample in zip(temperature, do_sample) 50 | ] 51 | warpers.append( 52 | HeterogeneousTemperatureLogitsWarper(temperature, dtype, device) 53 | ) 54 | 55 | if any([x != 0 for x in top_k]): 56 | do_sample = [sample or x != 0 for x, sample in zip(top_k, do_sample)] 57 | warpers.append(HeterogeneousTopKLogitsWarper(top_k, device)) 58 | 59 | if any([x < 1.0 for x in top_p]): 60 | do_sample = [sample or x < 1.0 for x, sample in zip(top_p, do_sample)] 61 | warpers.append(HeterogeneousTopPLogitsWarper(top_p, dtype, device)) 62 | 63 | if any([x < 1.0 for x in typical_p]): 64 | do_sample = [sample or x < 1.0 for x, sample in zip(typical_p, do_sample)] 65 | warpers.append(HeterogeneousTypicalLogitsWarper(typical_p, dtype, device)) 66 | 67 | self.warpers = warpers 68 | 69 | if any(do_sample): 70 | self.choice = HeterogeneousSampling(do_sample, num_beams, seeds, device) 71 | else: 72 | self.choice = Greedy(num_beams) 73 | 74 | self.seeds = seeds 75 | self.do_sample = do_sample 76 | self.dtype = dtype 77 | self.device = device 78 | 79 | def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> List[NextTokensChooserOutput]: 80 | if self.repetition_processor is not None: 81 | scores = self.repetition_processor(input_ids, scores) 82 | 83 | for warper in self.warpers: 84 | scores = warper(input_ids, scores) 85 | 86 | log_scores = torch.log_softmax(scores, dim=-1) 87 | 88 | token_ids = self.choice(scores) 89 | log_probs = [ 90 | [log_scores[i, token_id].item() for token_id in beam_nxt_token_ids] 91 | for i, beam_nxt_token_ids in enumerate(token_ids) 92 | ] 93 | 94 | return [ 95 | NextTokensChooserOutput(next_token_ids=nxt_token_ids, next_probs=nxt_probs) 96 | for nxt_token_ids, nxt_probs in zip(token_ids, log_probs) 97 | ] 98 | 99 | 100 | class Sampling: 101 | def __init__(self, num_beams: int, seed: int, device: str = "cpu"): 102 | self.num_beams = num_beams 103 | self.generator = torch.Generator(device) 104 | self.generator.manual_seed(seed) 105 | self.seed = seed 106 | 107 | def __call__(self, logits) -> List[int]: 108 | probs = torch.nn.functional.softmax(logits, -1) 109 | # Avoid GPU<->CPU sync done by torch multinomial 110 | # See: https://github.com/pytorch/pytorch/blob/925a3788ec5c06db62ca732a0e9425a26a00916f/aten/src/ATen/native/Distributions.cpp#L631-L637 111 | q = torch.empty_like(probs).exponential_(1, generator=self.generator) 112 | return torch.topk(probs.div_(q), k=self.num_beams + 1).indices.tolist() 113 | 114 | 115 | class Greedy: 116 | def __init__(self, num_beams: List[int]): 117 | self.num_beams = num_beams 118 | 119 | def __call__(self, logits) -> List[List[int]]: 120 | return torch.topk(logits, k=max(self.num_beams) + 1, dim=-1).indices.tolist() 121 | 122 | 123 | class HeterogeneousSampling: 124 | r""" 125 | Mixed greedy and probabilistic sampling. Compute both and pick the right one for each sample. 126 | """ 127 | 128 | def __init__(self, do_sample: List[bool], num_beams: List[int], seeds: List[int], device: torch.device): 129 | self.num_beams = num_beams 130 | self.seeds = seeds 131 | 132 | greedy_indices = [] 133 | self.sampling_mapping = {} 134 | for i, (sample, seed) in enumerate(zip(do_sample, seeds)): 135 | if sample: 136 | self.sampling_mapping[i] = Sampling(num_beams[i], seed, device) 137 | else: 138 | greedy_indices.append(i) 139 | 140 | self.greedy_indices = greedy_indices 141 | 142 | def __call__(self, logits) -> List[List[int]]: 143 | out = [None for _ in range(logits.shape[0])] 144 | if self.greedy_indices: 145 | # Computing for all indices is faster than slicing 146 | greedy = Greedy(self.num_beams) 147 | out = greedy(logits) 148 | 149 | for i, sampling in self.sampling_mapping.items(): 150 | out[i] = sampling(logits[i]) 151 | return out 152 | -------------------------------------------------------------------------------- /code/server/continuous_batching_server/modeling/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/modelize-ai/LLM-Inference-Deployment-Tutorial/e3264eacc4752a7d829241fb614a0b81892cdc8f/code/server/continuous_batching_server/modeling/__init__.py -------------------------------------------------------------------------------- /code/server/continuous_batching_server/modeling/llama.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. 3 | # 4 | # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX 5 | # and OPT implementations in this library. It has been modified from its 6 | # original forms to accommodate minor architectural differences compared 7 | # to GPT-NeoX and OPT used by the Meta AI team that trained the model. 8 | # 9 | # Licensed under the Apache License, Version 2.0 (the "License"); 10 | # you may not use this file except in compliance with the License. 11 | # You may obtain a copy of the License at 12 | # 13 | # http://www.apache.org/licenses/LICENSE-2.0 14 | # 15 | # Unless required by applicable law or agreed to in writing, software 16 | # distributed under the License is distributed on an "AS IS" BASIS, 17 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 18 | # See the License for the specific language governing permissions and 19 | # limitations under the License. 20 | 21 | 22 | from typing import * 23 | 24 | import torch 25 | import torch.nn as nn 26 | import xformers.ops as xops 27 | from accelerate import init_empty_weights 28 | from auto_gptq import BaseQuantizeConfig as GPTQConfig 29 | from transformers.activations import ACT2FN 30 | from transformers.configuration_utils import PretrainedConfig 31 | 32 | 33 | from .utils.attention import VarLenAttentionWithRoPE 34 | from .utils.linear import DynamicLinear 35 | from .utils.weights import Weights 36 | 37 | 38 | class LlamaConfig(PretrainedConfig): 39 | def __init__( 40 | self, 41 | vocab_size=32000, 42 | hidden_size=4096, 43 | intermediate_size=11008, 44 | num_hidden_layers=32, 45 | num_attention_heads=32, 46 | num_key_value_heads=None, 47 | hidden_act="silu", 48 | max_position_embeddings=2048, 49 | initializer_range=0.02, 50 | rms_norm_eps=1e-6, 51 | use_cache=True, 52 | pad_token_id=0, 53 | bos_token_id=1, 54 | eos_token_id=2, 55 | pretraining_tp=1, 56 | tie_word_embeddings=False, 57 | rope_scaling=None, 58 | **kwargs, 59 | ): 60 | self.vocab_size = vocab_size 61 | self.max_position_embeddings = max_position_embeddings 62 | self.hidden_size = hidden_size 63 | self.intermediate_size = intermediate_size 64 | self.num_hidden_layers = num_hidden_layers 65 | self.num_attention_heads = num_attention_heads 66 | 67 | # for backward compatibility 68 | if num_key_value_heads is None: 69 | num_key_value_heads = num_attention_heads 70 | 71 | self.num_key_value_heads = num_key_value_heads 72 | self.hidden_act = hidden_act 73 | self.initializer_range = initializer_range 74 | self.rms_norm_eps = rms_norm_eps 75 | self.pretraining_tp = pretraining_tp 76 | self.use_cache = use_cache 77 | self.rope_scaling = rope_scaling 78 | 79 | super().__init__( 80 | pad_token_id=pad_token_id, 81 | bos_token_id=bos_token_id, 82 | eos_token_id=eos_token_id, 83 | tie_word_embeddings=tie_word_embeddings, 84 | **kwargs, 85 | ) 86 | 87 | 88 | class LlamaRMSNorm(nn.Module): 89 | def __init__(self, hidden_size, eps=1e-6): 90 | """ 91 | LlamaRMSNorm is equivalent to T5LayerNorm 92 | """ 93 | super().__init__() 94 | self.weight = nn.Parameter(torch.ones(hidden_size)) 95 | self.variance_epsilon = eps 96 | 97 | def forward(self, hidden_states): 98 | input_dtype = hidden_states.dtype 99 | hidden_states = hidden_states.to(torch.float32) 100 | variance = hidden_states.pow(2).mean(-1, keepdim=True) 101 | hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) 102 | return self.weight * hidden_states.to(input_dtype) 103 | 104 | @classmethod 105 | def load(cls, prefix: str, weights: Weights, eps: float = 1e-6): 106 | weight = weights.get_tensor(f"{prefix}.weight") 107 | with init_empty_weights(): 108 | ln = cls(weight.shape[0], eps) 109 | ln.weight = nn.Parameter(weight) 110 | return ln 111 | 112 | 113 | def _load_gqa(config, prefix: str, weights: Weights, gptq_config: Optional[GPTQConfig] = None): 114 | w = [ 115 | weights.get_tensor(f"{prefix}.q_proj.weight"), 116 | weights.get_tensor(f"{prefix}.k_proj.weight"), 117 | weights.get_tensor(f"{prefix}.v_proj.weight") 118 | ] 119 | weight = torch.cat(w, dim=0) 120 | weight = weight.to(dtype=weights.dtype).to(device=weights.device) 121 | 122 | assert config.hidden_size % config.num_attention_heads == 0 123 | head_size = config.hidden_size // config.num_attention_heads 124 | num_heads = config.num_attention_heads 125 | num_key_value_heads = config.num_key_value_heads 126 | assert list(weight.shape) == [ 127 | (num_heads + 2 * num_key_value_heads) * head_size, 128 | config.hidden_size, 129 | ], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}" 130 | 131 | return DynamicLinear.load(config, prefix, weights, False, gptq_config) 132 | 133 | 134 | class LlamaMLP(nn.Module): 135 | def __init__(self, prefix: str, config: LlamaConfig, weights: Weights, gptq_config: Optional[GPTQConfig] = None): 136 | super().__init__() 137 | act = config.hidden_act 138 | self.act = ( 139 | ACT2FN[act] 140 | if "gelu" not in act 141 | else lambda x: torch.nn.functional.gelu( 142 | x, 143 | approximate="tanh" 144 | if act in ["gelu_fast", "gelu_pytorch_tanh"] 145 | else "none", 146 | ) 147 | ) 148 | # Fuse gate and up proj 149 | self.gate_up_proj = DynamicLinear.load_multi( 150 | config, 151 | prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"], 152 | weights=weights, 153 | dim=0, 154 | bias=False, 155 | gptq_config=gptq_config 156 | ) 157 | self.down_proj = DynamicLinear.load( 158 | config, 159 | prefix=f"{prefix}.down_proj", 160 | weights=weights, 161 | bias=False, 162 | gptq_config=gptq_config 163 | ) 164 | self.intermediate_size = config.intermediate_size 165 | 166 | def forward(self, hidden_states): 167 | gate_up_states = self.gate_up_proj(hidden_states) 168 | gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) 169 | return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1]) 170 | 171 | 172 | class LlamaLayer(nn.Module): 173 | def __init__(self, layer_id: int, config: LlamaConfig, weights: Weights, gptq_config: Optional[GPTQConfig] = None): 174 | super().__init__() 175 | 176 | prefix = f"model.layers.{layer_id}" 177 | 178 | self.config = config 179 | self.gptq_config = gptq_config 180 | 181 | self.self_attn = self._init_attention_module(config, f"{prefix}.self_attn", weights) 182 | self.mlp = LlamaMLP(prefix=f"{prefix}.mlp", config=config, weights=weights, gptq_config=gptq_config) 183 | 184 | self.input_layernorm = LlamaRMSNorm.load( 185 | prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps 186 | ) 187 | self.post_attention_layernorm = LlamaRMSNorm.load( 188 | prefix=f"{prefix}.post_attention_layernorm", weights=weights, eps=config.rms_norm_eps 189 | ) 190 | 191 | def _init_attention_module(self, config: LlamaConfig, prefix: str, weights: Weights) -> nn.Module: 192 | qkv_proj = DynamicLinear.load_multi( 193 | config=config, 194 | prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], 195 | weights=weights, 196 | bias=False, 197 | dim=0, 198 | gptq_config=self.gptq_config 199 | ) 200 | o_proj = DynamicLinear.load( 201 | config=config, 202 | prefix=f"{prefix}.o_proj", 203 | weights=weights, 204 | bias=False, 205 | gptq_config=self.gptq_config 206 | ) 207 | 208 | cos_sin_cache = VarLenAttentionWithRoPE.build_rope_cache( 209 | rotary_dim=config.hidden_size // config.num_attention_heads, 210 | max_position=config.max_position_embeddings, 211 | base=10000, 212 | device=o_proj.weight.device if not self.gptq_config else o_proj.scales.device, 213 | dtype=o_proj.weight.dtype if not self.gptq_config else o_proj.scales.dtype 214 | ) 215 | 216 | head_dim = config.hidden_size // config.num_attention_heads 217 | attn_fw_op = xops.fmha.flash.FwOp if head_dim <= 128 else xops.fmha.cutlass.FwOp 218 | return VarLenAttentionWithRoPE( 219 | qkv_proj=qkv_proj, 220 | out_proj=o_proj, 221 | cos_sin_cache=cos_sin_cache, 222 | num_query_heads=config.num_attention_heads, 223 | num_key_heads=config.num_key_value_heads, 224 | num_value_heads=config.num_key_value_heads, 225 | dropout=0.0, 226 | scale=head_dim ** -0.5, 227 | attention_ops=(attn_fw_op, None) 228 | ) 229 | 230 | def forward( 231 | self, 232 | hidden_states: torch.Tensor, 233 | position_ids: torch.Tensor, 234 | kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], 235 | prefill: bool, 236 | block_tables: torch.Tensor, 237 | slots: torch.Tensor, 238 | context_lengths: torch.Tensor, 239 | cache_event: Optional[torch.cuda.Event] = None, 240 | ): 241 | residual = hidden_states 242 | 243 | hidden_states = self.input_layernorm(hidden_states) 244 | 245 | # Self Attention 246 | hidden_states = self.self_attn( 247 | hidden_states, 248 | position_ids, 249 | kv_cache, 250 | prefill, 251 | block_tables, 252 | slots, 253 | context_lengths, 254 | cache_event 255 | ) 256 | hidden_states = residual + hidden_states 257 | 258 | # Fully Connected 259 | residual = hidden_states 260 | hidden_states = self.post_attention_layernorm(hidden_states) 261 | hidden_states = self.mlp(hidden_states) 262 | hidden_states = residual + hidden_states 263 | 264 | return hidden_states 265 | 266 | 267 | class LlamaModel(torch.nn.Module): 268 | def __init__(self, config: LlamaConfig, weights: Weights, gptq_config: Optional[GPTQConfig] = None): 269 | super().__init__() 270 | self.config = config 271 | self.gptq_config = gptq_config 272 | 273 | self.embed_tokens = nn.Embedding.from_pretrained( 274 | embeddings=weights.get_tensor("model.embed_tokens.weight"), 275 | freeze=True, 276 | padding_idx=config.pad_token_id 277 | ) 278 | self.layers = nn.ModuleList( 279 | [ 280 | LlamaLayer( 281 | layer_id, 282 | config, 283 | weights, 284 | gptq_config 285 | ) 286 | for layer_id in range(config.num_hidden_layers) 287 | ] 288 | ) 289 | self.norm = LlamaRMSNorm.load( 290 | prefix="model.norm", weights=weights, eps=config.rms_norm_eps 291 | ) 292 | 293 | self.gradient_checkpointing = False 294 | 295 | def forward( 296 | self, 297 | input_ids: torch.Tensor, 298 | position_ids: torch.Tensor, 299 | kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], 300 | prefill: bool, 301 | block_tables: torch.Tensor, 302 | slots: torch.Tensor, 303 | context_lengths: torch.Tensor, 304 | cache_events: Optional[List[torch.cuda.Event]] = None, 305 | ) -> torch.Tensor: 306 | hidden_states = self.embed_tokens(input_ids) 307 | 308 | for i, layer in enumerate(self.layers): 309 | hidden_states = layer( 310 | hidden_states, 311 | position_ids, 312 | kv_cache[i], 313 | prefill, 314 | block_tables, 315 | slots, 316 | context_lengths, 317 | cache_events[i] if cache_events is not None else None 318 | ) 319 | 320 | hidden_states = self.norm(hidden_states) 321 | 322 | return hidden_states 323 | 324 | 325 | class LlamaForCausalLM(torch.nn.Module): 326 | def __init__(self, config: LlamaConfig, weights: Weights, gptq_config: Optional[GPTQConfig] = None): 327 | super().__init__() 328 | 329 | self.config = config 330 | self.gptq_config = gptq_config 331 | self.model = LlamaModel(config, weights, gptq_config) 332 | self.lm_head = DynamicLinear.load( 333 | config, 334 | prefix="lm_head", 335 | weights=weights, 336 | bias=False, 337 | gptq_config=None 338 | ) 339 | 340 | def forward( 341 | self, 342 | input_ids: torch.Tensor, 343 | position_ids: torch.Tensor, 344 | kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], 345 | prefill: bool, 346 | block_tables: torch.Tensor, 347 | slots: torch.Tensor, 348 | context_lengths: torch.Tensor, 349 | cache_events: Optional[List[torch.cuda.Event]] = None, 350 | lm_head_indices: Optional[torch.Tensor] = None, 351 | ) -> torch.Tensor: 352 | hidden_states = self.model( 353 | input_ids, 354 | position_ids, 355 | kv_cache, 356 | prefill, 357 | block_tables, 358 | slots, 359 | context_lengths, 360 | cache_events, 361 | ) 362 | if lm_head_indices is not None: 363 | hidden_states = hidden_states[lm_head_indices] 364 | logits = self.lm_head(hidden_states) 365 | return logits 366 | -------------------------------------------------------------------------------- /code/server/continuous_batching_server/modeling/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/modelize-ai/LLM-Inference-Deployment-Tutorial/e3264eacc4752a7d829241fb614a0b81892cdc8f/code/server/continuous_batching_server/modeling/utils/__init__.py -------------------------------------------------------------------------------- /code/server/continuous_batching_server/modeling/utils/attention.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | 3 | import torch 4 | import torch.nn as nn 5 | import xformers.ops as xops 6 | from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask, LowerTriangularMask 7 | 8 | import vllm.attention_ops as vllm_attention_ops 9 | import vllm.cache_ops as vllm_cache_ops 10 | import vllm.pos_encoding_ops as vllm_pos_encoding_ops 11 | 12 | 13 | def prefill_attention( 14 | query: torch.Tensor, 15 | key: torch.Tensor, 16 | value: torch.Tensor, 17 | attention_ops: Optional[xops.AttentionOp] = None, 18 | attention_bias: Optional[xops.AttentionBias] = None, 19 | dropout: float = 0.0, 20 | scale: Optional[float] = None 21 | ): 22 | for each in [query, key, value]: 23 | if len(each.shape) != 4: 24 | raise ValueError( 25 | "input tensor must have 4-dim shape which are [bsz, seq_len, num_heads, head_size] respectively," 26 | f"but get {each.shape}" 27 | ) 28 | 29 | bsz, seq_len = query.shape[:2] 30 | 31 | if value.shape[2] != query.shape[2]: 32 | # MQA expand 33 | if value.shape[2] == 1: 34 | pass # TODO 35 | # GQA reshape 36 | else: 37 | original_shape = value.shape 38 | pass # TODO 39 | 40 | return xops.memory_efficient_attention( 41 | query=query, 42 | key=key, 43 | value=value, 44 | attn_bias=attention_bias, 45 | p=dropout, 46 | scale=scale, 47 | op=attention_ops 48 | ).reshape(bsz, seq_len, -1) 49 | 50 | 51 | def decode_attention( 52 | query: torch.Tensor, 53 | key_cache: torch.Tensor, 54 | value_cache: torch.Tensor, 55 | kv_head_mapping: torch.Tensor, 56 | scale: float, 57 | block_tables: torch.Tensor, 58 | context_lengths: torch.Tensor, 59 | alibi_slopes: Optional[torch.Tensor] = None 60 | ): 61 | if len(query.shape) != 3: 62 | raise ValueError( 63 | "query must have 3-dim shape which are [seq_len, num_heads, head_size] respectively, " 64 | f"but get shape {query.shape}" 65 | ) 66 | 67 | attention_output = torch.empty_like(query) 68 | block_size = value_cache.shape[-1] 69 | vllm_attention_ops.single_query_cached_kv_attention( 70 | attention_output, 71 | query, 72 | key_cache, 73 | value_cache, 74 | kv_head_mapping, 75 | scale, 76 | block_tables, 77 | context_lengths, 78 | block_size, 79 | context_lengths.max().item(), 80 | alibi_slopes 81 | ) 82 | return attention_output 83 | 84 | 85 | class AttentionWithRoPE(nn.Module): 86 | def __init__( 87 | self, 88 | qkv_proj: nn.Module, 89 | out_proj: nn.Module, 90 | cos_sin_cache: torch.Tensor, 91 | num_query_heads: int, 92 | num_key_heads: int, 93 | num_value_heads: int, 94 | dropout: float = 0.0, 95 | scale: Optional[float] = None, 96 | attention_ops: Optional[xops.AttentionOp] = None 97 | ): 98 | super(AttentionWithRoPE, self).__init__() 99 | 100 | self.qkv_proj = qkv_proj 101 | self.out_proj = out_proj 102 | 103 | self.register_buffer("cos_sin_cache", cos_sin_cache, persistent=False) 104 | 105 | self.num_query_heads = num_query_heads 106 | self.num_key_heads = num_key_heads 107 | self.num_value_heads = num_value_heads 108 | 109 | self.dropout = dropout 110 | self.scale = scale 111 | 112 | self.attention_ops = attention_ops 113 | 114 | # TODO: for now only compatible with GQA, make it also compatible with MQA 115 | self.num_groups = self.num_query_heads // self.num_value_heads 116 | self.kv_head_mapping = torch.arange( 117 | 0, self.num_value_heads, dtype=torch.int32, device=cos_sin_cache.device 118 | ).repeat_interleave(self.num_groups) 119 | 120 | def _qkv_forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 121 | hidden_size = hidden_states.shape[-1] 122 | # for each out tensor, shape ==> [total_tokens, hidden_size] 123 | return self.qkv_proj(hidden_states.view(-1, hidden_size)).chunk(chunks=3, dim=-1) 124 | 125 | def _rope_forward(self, query: torch.Tensor, key: torch.Tensor, position_ids: Optional[torch.Tensor]) -> None: 126 | if position_ids is None: 127 | return 128 | 129 | hidden_size = query.shape[-1] 130 | position_ids = position_ids.view(-1) 131 | 132 | vllm_pos_encoding_ops.rotary_embedding_neox( 133 | position_ids, 134 | query, 135 | key, 136 | hidden_size // self.num_query_heads, 137 | self.cos_sin_cache 138 | ) 139 | 140 | def _out_forward(self, hidden_states: torch.Tensor, shape: tuple) -> torch.Tensor: 141 | return self.out_proj(hidden_states).view(shape) 142 | 143 | def forward( 144 | self, 145 | hidden_states: torch.Tensor, 146 | position_ids: Optional[torch.Tensor], 147 | kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]], 148 | prefill: bool, 149 | block_tables: Optional[torch.Tensor], 150 | slots: Optional[torch.Tensor], 151 | context_lengths: torch.Tensor, 152 | cache_event: Optional[torch.cuda.Event] = None 153 | ) -> torch.Tensor: 154 | # The shape of hidden_states and position_ids: 155 | # if is prefill ==> [bsz, max(context_lengths), hidden_size] 156 | # otherwise ==> [bsz, 1, hidden_size] 157 | if len(hidden_states.shape) != 3: 158 | raise ValueError("hidden_states must have 3-dim shape.") 159 | bsz, max_len, hidden_size = hidden_states.shape 160 | 161 | # QKV projection 162 | query, key, value = self._qkv_forward(hidden_states) # for each: shape ==> [total_tokens, hidden_size] 163 | 164 | # Add RoPE info 165 | self._rope_forward(query, key, position_ids) 166 | 167 | # Prefill Attention 168 | if prefill: 169 | attn_out = prefill_attention( 170 | query.view(bsz, max_len, self.num_query_heads, -1), 171 | key.view(bsz, max_len, self.num_key_heads, -1), 172 | value.view(bsz, max_len, self.num_value_heads, -1), 173 | self.attention_ops, 174 | LowerTriangularMask(), 175 | self.dropout, 176 | self.scale 177 | ) 178 | 179 | # Wait until the cache op is done 180 | if cache_event is not None: 181 | cache_event.wait() 182 | 183 | # Cache key and value 184 | if kv_cache is not None: 185 | if prefill: 186 | valid_token_indices = [] 187 | for i, start_idx in enumerate(range(0, bsz * max_len, max_len)): 188 | end_idx = start_idx + max_len 189 | indices = list(range(start_idx, end_idx))[-context_lengths[i]:] 190 | valid_token_indices += indices 191 | key_to_cache = key[valid_token_indices] 192 | value_to_cache = value[valid_token_indices] 193 | else: 194 | key_to_cache = key[:len(slots)] 195 | value_to_cache = value[:len(slots)] 196 | num_valid_tokens = key_to_cache.shape[0] 197 | key_to_cache = key_to_cache.reshape(num_valid_tokens, self.num_key_heads, -1) 198 | value_to_cache = value_to_cache.reshape(num_valid_tokens, self.num_value_heads, -1) 199 | vllm_cache_ops.reshape_and_cache( 200 | key_to_cache, value_to_cache, kv_cache[0], kv_cache[1], slots 201 | ) 202 | elif not prefill: 203 | raise ValueError("kv_cache can't be None when in decode stage.") 204 | 205 | # Decode Attention 206 | if not prefill: 207 | attn_out = decode_attention( 208 | query.view(bsz * max_len, self.num_query_heads, -1), 209 | kv_cache[0], 210 | kv_cache[1], 211 | self.kv_head_mapping, 212 | self.scale, 213 | block_tables, 214 | context_lengths, 215 | None 216 | ).view(bsz, max_len, -1) 217 | return self._out_forward(attn_out, (bsz, max_len, hidden_size)) 218 | 219 | @staticmethod 220 | def build_rope_cache( 221 | rotary_dim: int, 222 | max_position: int = 2048, 223 | base: int = 10000, 224 | device: torch.device = torch.device("cuda:0"), 225 | dtype: torch.dtype = torch.float16 226 | ): 227 | inv_freq = (1.0 / (base ** (torch.arange(0, rotary_dim, 2, device=device, dtype=dtype) / rotary_dim))) 228 | t = torch.arange(max_position, device=device, dtype=dtype) 229 | freqs = torch.einsum("i,j -> ij", t, inv_freq) 230 | cos = freqs.cos() 231 | sin = freqs.sin() 232 | cache = torch.cat((cos, sin), dim=-1) 233 | 234 | return cache 235 | 236 | 237 | class VarLenAttentionWithRoPE(AttentionWithRoPE): 238 | def forward( 239 | self, 240 | hidden_states: torch.Tensor, 241 | position_ids: Optional[torch.Tensor], 242 | kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]], 243 | prefill: bool, 244 | block_tables: Optional[torch.Tensor], 245 | slots: Optional[torch.Tensor], 246 | context_lengths: torch.Tensor, 247 | cache_event: Optional[torch.cuda.Event] = None 248 | ) -> torch.Tensor: 249 | # The shape of hidden_states and position_ids for both prefill and decode: 250 | # [total_tokens, hidden_size] 251 | 252 | # QKV projection 253 | query, key, value = self._qkv_forward(hidden_states) # for each: shape ==> [total_tokens, hidden_size] 254 | 255 | # Add RoPE info 256 | self._rope_forward(query, key, position_ids) 257 | 258 | total_tokens = query.shape[0] 259 | query = query.view(total_tokens, self.num_query_heads, -1) 260 | key = key.view(total_tokens, self.num_key_heads, -1) 261 | value = value.view(total_tokens, self.num_value_heads, -1) 262 | 263 | # Prefill Attention 264 | if prefill: 265 | attn_out = prefill_attention( 266 | query.unsqueeze(0), 267 | key.unsqueeze(0), 268 | value.unsqueeze(0), 269 | self.attention_ops, 270 | BlockDiagonalCausalMask.from_seqlens(context_lengths.tolist()), 271 | self.dropout, 272 | self.scale 273 | ).squeeze(0) 274 | 275 | # Wait until the cache op is done 276 | if cache_event is not None: 277 | cache_event.wait() 278 | 279 | # Cache key and value 280 | if kv_cache is not None: 281 | vllm_cache_ops.reshape_and_cache( 282 | key[:len(slots)], value[:len(slots)], kv_cache[0], kv_cache[1], slots 283 | ) 284 | elif not prefill: 285 | raise ValueError("kv_cache can't be None when in decode stage.") 286 | 287 | # Decode Attention 288 | if not prefill: 289 | attn_out = decode_attention( 290 | query, 291 | kv_cache[0], 292 | kv_cache[1], 293 | self.kv_head_mapping, 294 | self.scale, 295 | block_tables, 296 | context_lengths, 297 | None 298 | ).view(total_tokens, -1) 299 | 300 | return self._out_forward(attn_out, hidden_states.shape) 301 | 302 | 303 | __all__ = ["AttentionWithRoPE", "VarLenAttentionWithRoPE"] 304 | -------------------------------------------------------------------------------- /code/server/continuous_batching_server/modeling/utils/linear.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | 3 | import torch 4 | import torch.nn as nn 5 | from accelerate import init_empty_weights 6 | from auto_gptq import BaseQuantizeConfig as GPTQConfig 7 | from auto_gptq.utils.import_utils import dynamically_import_QuantLinear 8 | from torch.nn import functional as F 9 | 10 | from .weights import Weights 11 | 12 | 13 | class FastLinear(nn.Module): 14 | def __init__( 15 | self, 16 | weight: torch.Tensor, 17 | bias: torch.Tensor, 18 | ) -> None: 19 | super().__init__() 20 | self.weight = nn.Parameter(weight, requires_grad=False) 21 | if bias is not None: 22 | self.bias = nn.Parameter(bias, requires_grad=False) 23 | else: 24 | self.bias = None 25 | 26 | @classmethod 27 | def load( 28 | cls, 29 | config, 30 | prefix: str, 31 | weights: Weights, 32 | bias: bool 33 | ): 34 | weight = weights.get_tensor(f"{prefix}.weight") 35 | if bias: 36 | bias = weights.get_tensor(f"{prefix}.bias") 37 | else: 38 | bias = None 39 | return cls(weight, bias) 40 | 41 | @classmethod 42 | def load_multi( 43 | cls, 44 | config, 45 | prefixes: List[str], 46 | weights: Weights, 47 | bias: bool, 48 | dim: int 49 | ): 50 | w = [ 51 | weights.get_tensor(f"{prefix}.weight") for prefix in prefixes 52 | ] 53 | weight = torch.cat(w, dim=dim) 54 | 55 | if bias: 56 | b = [weights.get_tensor(f"{p}.bias") for p in prefixes] 57 | bias = torch.cat(b, dim=dim) 58 | else: 59 | bias = None 60 | return cls(weight, bias) 61 | 62 | def forward(self, input: torch.Tensor) -> torch.Tensor: 63 | return F.linear(input, self.weight, self.bias) 64 | 65 | 66 | class DynamicLinear: 67 | @classmethod 68 | def load( 69 | cls, 70 | config, 71 | prefix: str, 72 | weights: Weights, 73 | bias: bool, 74 | gptq_config: Optional[GPTQConfig] = None 75 | ): 76 | if not gptq_config: 77 | return FastLinear.load(config, prefix, weights, bias) 78 | 79 | disable_exllama = False 80 | if gptq_config.bits != 4 or gptq_config.desc_act: # for the later condition, can be removed once auto-gptq fixed it 81 | disable_exllama = True 82 | QuantLinear = dynamically_import_QuantLinear( 83 | use_triton=False, 84 | desc_act=gptq_config.desc_act, 85 | group_size=gptq_config.group_size, 86 | bits=gptq_config.bits, 87 | disable_exllama=disable_exllama 88 | ) 89 | 90 | qweight, qzeros, scales, g_idx, bias = weights.get_gptq_weight(prefix) 91 | 92 | init_args = ( 93 | gptq_config.bits, 94 | gptq_config.group_size, 95 | qweight.shape[0] * 32 // gptq_config.bits, 96 | qweight.shape[1], 97 | bias is not None 98 | ) 99 | with init_empty_weights(include_buffers=True): 100 | quant_linear = QuantLinear(*init_args, trainable=False) 101 | quant_linear.qweight = qweight 102 | quant_linear.qzeros = qzeros 103 | quant_linear.scales = scales 104 | quant_linear.g_idx = g_idx 105 | quant_linear.bias = bias 106 | 107 | return quant_linear 108 | 109 | @classmethod 110 | def load_multi( 111 | cls, 112 | config, 113 | prefixes: List[str], 114 | weights: Weights, 115 | bias: bool, 116 | dim: int, 117 | gptq_config: Optional[GPTQConfig] = None 118 | ): 119 | if not gptq_config: 120 | return FastLinear.load_multi(config, prefixes, weights, bias, dim) 121 | 122 | disable_exllama = False 123 | if gptq_config.bits != 4 or gptq_config.desc_act: # for the later condition, can be removed once auto-gptq fixed it 124 | disable_exllama = True 125 | QuantLinear = dynamically_import_QuantLinear( 126 | use_triton=False, 127 | desc_act=gptq_config.desc_act, 128 | group_size=gptq_config.group_size, 129 | bits=gptq_config.bits, 130 | disable_exllama=disable_exllama 131 | ) 132 | 133 | qweight_li, qzeros_li, scales_li, g_idx_li, bias_li = [], [], [], [], [] 134 | outfeatures = 0 135 | for prefix in prefixes: 136 | qweight, qzeros, scales, g_idx, bias = weights.get_gptq_weight(prefix) 137 | qweight_li.append(qweight) 138 | qzeros_li.append(qzeros) 139 | scales_li.append(scales) 140 | g_idx_li.append(g_idx) 141 | bias_li.append(bias) 142 | outfeatures += qweight.shape[1] 143 | 144 | qweight = torch.cat(qweight_li, dim=1) 145 | qzeros = torch.cat(qzeros_li, dim=1) 146 | scales = torch.cat(scales_li, dim=1) 147 | g_idx = torch.cat(g_idx_li, dim=0) 148 | if bias_li[0] is not None: 149 | bias = torch.cat(bias_li, dim=0) 150 | else: 151 | bias = None 152 | 153 | init_args = ( 154 | gptq_config.bits, 155 | gptq_config.group_size, 156 | qweight.shape[0] * 32 // gptq_config.bits, 157 | qweight.shape[1], 158 | bias is not None 159 | ) 160 | 161 | with init_empty_weights(include_buffers=True): 162 | quant_linear = QuantLinear(*init_args, trainable=False) 163 | quant_linear.qweight = qweight 164 | quant_linear.qzeros = qzeros 165 | quant_linear.scales = scales 166 | quant_linear.g_idx = g_idx 167 | quant_linear.bias = bias 168 | 169 | return quant_linear 170 | -------------------------------------------------------------------------------- /code/server/continuous_batching_server/modeling/utils/weights.py: -------------------------------------------------------------------------------- 1 | from glob import glob 2 | from os import path 3 | from typing import * 4 | 5 | import torch 6 | from safetensors import safe_open 7 | 8 | 9 | class Weights: 10 | def __init__( 11 | self, 12 | model_name_or_path: str, 13 | device: torch.device, 14 | dtype: torch.dtype, 15 | quantize_method: Optional[str] = None, 16 | gptq_model_base_name: Optional[str] = None 17 | ): 18 | if not path.isdir(model_name_or_path): 19 | raise NotADirectoryError(f"{model_name_or_path} not exists.") 20 | routing = {} 21 | file_pattern = "*.safetensors" 22 | if quantize_method == "gptq": 23 | file_pattern = "gptq_model*.safetensors" 24 | if gptq_model_base_name: 25 | file_pattern = f"{gptq_model_base_name}*.safetensors" 26 | for model_file in glob(path.join(model_name_or_path, file_pattern)): 27 | with safe_open(model_file, framework="pt") as f: 28 | for k in f.keys(): 29 | if k in routing: 30 | raise RuntimeError( 31 | f"Key {k} was found in multiple files: {model_file} and {routing[k]}" 32 | ) 33 | routing[k] = model_file 34 | self.routing = routing 35 | self.device = device 36 | self.dtype = dtype 37 | self._handles = {} 38 | 39 | def _get_handle(self, filename: str): 40 | if filename not in self._handles: 41 | f = safe_open(filename, framework="pt") 42 | self._handles[filename] = f 43 | 44 | return self._handles[filename] 45 | 46 | def get_filename(self, tensor_name: str) -> (str, str): 47 | filename = self.routing.get(tensor_name, None) 48 | if filename is None: 49 | raise RuntimeError(f"weight {tensor_name} does not exist") 50 | return str(filename), tensor_name 51 | 52 | def _get_slice(self, tensor_name: str): 53 | filename, tensor_name = self.get_filename(tensor_name) 54 | f = self._get_handle(filename) 55 | slice_ = f.get_slice(tensor_name) 56 | return slice_ 57 | 58 | def get_shape(self, tensor_name: str): 59 | return self._get_slice(tensor_name).get_shape() 60 | 61 | def get_tensor(self, tensor_name: str): 62 | filename, tensor_name = self.get_filename(tensor_name) 63 | f = self._get_handle(filename) 64 | tensor = f.get_tensor(tensor_name) 65 | # Special case for gptq which shouldn't convert 66 | # u4 which are disguised as int32 67 | if tensor.dtype not in [torch.int32, torch.int64]: 68 | tensor = tensor.to(dtype=self.dtype) 69 | tensor = tensor.to(device=self.device) 70 | return tensor 71 | 72 | def get_gptq_weight(self, prefix: str): 73 | try: 74 | qweight = self.get_tensor(f"{prefix}.qweight") 75 | except RuntimeError: 76 | raise RuntimeError( 77 | "Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`" 78 | ) 79 | qzeros = self.get_tensor(f"{prefix}.qzeros") 80 | scales = self.get_tensor(f"{prefix}.scales") 81 | g_idx = self.get_tensor(f"{prefix}.g_idx") 82 | try: 83 | bias = self.get_tensor(f"{prefix}.bias") 84 | except: 85 | bias = None 86 | 87 | return qweight, qzeros, scales, g_idx, bias 88 | -------------------------------------------------------------------------------- /code/server/continuous_batching_server/server.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import time 3 | import threading 4 | from logging import getLogger, Logger 5 | from typing import * 6 | from uuid import uuid4, UUID 7 | 8 | from transformers import AutoTokenizer 9 | 10 | from .batcher import Batcher 11 | from .beam import Beam, BeamGroup, BeamStatus 12 | from .worker import Worker 13 | from .config import ServerConfig 14 | from protocol.completion_task import ( 15 | TokenUsage, 16 | HuggingFaceCompletionChoice, 17 | HuggingFaceGenerationConfig, 18 | HuggingFaceCompletionInputs, 19 | HuggingFaceCompletionOutputs 20 | ) 21 | from protocol.error import Error 22 | 23 | 24 | SERVER_SINGLETON = None 25 | 26 | 27 | class ServerNotInitializedError(Exception): 28 | def __repr__(self): 29 | return "server is not initialized, please initialize a server object first." 30 | 31 | def __str__(self): 32 | return self.__repr__() 33 | 34 | 35 | class ServerDoubleInitializeError(Exception): 36 | def __repr__(self): 37 | return "server is initialized, do not initialize again, please use get_server() instead." 38 | 39 | def __str__(self): 40 | return self.__repr__() 41 | 42 | 43 | class Server: 44 | def __init__( 45 | self, 46 | config: ServerConfig, 47 | logger: Optional[Logger] = None 48 | ): 49 | global SERVER_SINGLETON 50 | if SERVER_SINGLETON is not None: 51 | raise ServerDoubleInitializeError() 52 | 53 | self.config = config 54 | 55 | batcher_config = config.batcher_config 56 | cache_config = config.cache_config 57 | model_loading_config = config.model_loading_config 58 | parallel_config = config.parallel_config 59 | 60 | assert parallel_config.tp_size == 1, "we don't provide model parallelism support for now." 61 | 62 | self.logger = logger if logger else getLogger(__name__) 63 | 64 | self.worker = Worker(cache_config, model_loading_config, parallel_config, logger) 65 | self.batcher = Batcher(batcher_config, cache_config, logger) 66 | 67 | self.tokenizer_max_length = model_loading_config.model_max_length 68 | self.tokenizer = AutoTokenizer.from_pretrained( 69 | model_loading_config.tokenizer_name_or_path or model_loading_config.model_name_or_path, 70 | use_fast=model_loading_config.use_fast_tokenizer, 71 | trust_remote_code=model_loading_config.trust_remote_code, 72 | truncation_side="left", 73 | ) 74 | if not self.tokenizer.pad_token_id: 75 | self.tokenizer.pad_token_id = self.tokenizer.eos_token_id 76 | self.tokenizer.pad_token = self.tokenizer.eos_token 77 | 78 | self.finished_table: Dict[UUID, Tuple[HuggingFaceCompletionOutputs, Optional[Error], int]] = dict() 79 | 80 | threading.Thread(target=self._run, daemon=True).start() 81 | 82 | SERVER_SINGLETON = self 83 | 84 | def _encode(self, text: str) -> List[int]: 85 | return self.tokenizer(text, truncation=True, max_length=self.tokenizer_max_length)["input_ids"] 86 | 87 | def _decode(self, token_ids: List[int]) -> str: 88 | return self.tokenizer.decode(token_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) 89 | 90 | def _construct_generation_result( 91 | self, 92 | request: BeamGroup 93 | ) -> Tuple[HuggingFaceCompletionOutputs, Optional[Error], int]: 94 | final_beams = request.get_final_beams() 95 | choices = [ 96 | HuggingFaceCompletionChoice( 97 | text=self._decode(beam.generated_token_ids), 98 | index=idx, 99 | finish_reason=beam.finish_reason 100 | ) 101 | for idx, beam in enumerate(final_beams) 102 | ] 103 | usage = TokenUsage( 104 | prompt_tokens=len(request.get_final_beams()[0].prompt_token_ids), 105 | completion_tokens=sum([beam.num_generated_tokens for beam in request.get_beams(BeamStatus.FINISHED)]) 106 | ) 107 | usage.total_tokens = usage.prompt_tokens + usage.completion_tokens 108 | return HuggingFaceCompletionOutputs(choices=choices, usage=usage), None, 200 109 | 110 | async def wait_task_done( 111 | self, 112 | inp: HuggingFaceCompletionInputs 113 | ) -> Tuple[HuggingFaceCompletionOutputs, Optional[Error], int, float]: 114 | request_id = uuid4() 115 | start = time.time() 116 | 117 | inp.generation_config.eos_token_id = self.tokenizer.eos_token_id 118 | inp.generation_config.pad_token_id = self.tokenizer.pad_token_id 119 | 120 | request = BeamGroup( 121 | request_id=request_id, 122 | arrival_time=time.time(), 123 | beams=[ 124 | Beam( 125 | request_id, 126 | prompt=inp.prompt, 127 | prompt_token_ids=self._encode(inp.prompt), 128 | block_size=self.batcher.cache_config.block_size 129 | ) 130 | ], 131 | generation_config=inp.generation_config 132 | ) 133 | self.logger.info(msg=f"Task-{request_id} is added.") 134 | self.batcher.add_request(request) 135 | 136 | while True: 137 | await asyncio.sleep(0.1) 138 | if request_id in self.finished_table: 139 | end = time.time() 140 | outputs, error, status_code = self.finished_table.pop(request_id) 141 | wall_time = end - start 142 | self.logger.info(msg=f"Task-{request_id} is finished, {wall_time=: .4f}s") 143 | return outputs, error, status_code, wall_time 144 | 145 | def _run(self) -> None: 146 | steps = 0 147 | while True: 148 | steps += 1 149 | batch, blocks_to_copy, blocks_to_swap_in, blocks_to_swap_out, finished_requests = self.batcher.schedule() 150 | for req in finished_requests: 151 | self.finished_table[req.request_id] = self._construct_generation_result(req) 152 | self.worker.forward(batch, blocks_to_copy, blocks_to_swap_in, blocks_to_swap_out) 153 | self.batcher.batch_generate(batch) 154 | time.sleep(0.001) 155 | 156 | 157 | def get_server(): 158 | if SERVER_SINGLETON is None: 159 | raise ServerNotInitializedError() 160 | return SERVER_SINGLETON 161 | 162 | 163 | __all__ = ["Server", "get_server"] 164 | -------------------------------------------------------------------------------- /code/server/continuous_batching_server/worker.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from logging import getLogger, Logger 4 | from typing import * 5 | 6 | import torch 7 | from auto_gptq import BaseQuantizeConfig as GPTQConfig 8 | from auto_gptq.modeling._utils import autogptq_post_init 9 | 10 | from .batcher import Batch 11 | from .cache.cache import Cache 12 | from .modeling.utils.weights import Weights 13 | from .config import ( 14 | CacheConfig, 15 | ModelConfig, 16 | ModelLoadingConfig, 17 | MODEL_AUTO_TABLE, 18 | ParallelConfig, 19 | TORCH_FLOAT_DTYPE_MAP 20 | ) 21 | 22 | 23 | def _get_gptq_config(config_dir: str, config_file_base_name: Optional[str] = None) -> GPTQConfig: 24 | config_path = os.path.join(config_dir, "quantize_config.json") 25 | if config_file_base_name: 26 | config_path = os.path.join(config_dir, f"{config_file_base_name}.json") 27 | return GPTQConfig(**json.load(open(config_path, "r", encoding="utf-8"))) 28 | 29 | 30 | def _pad_to_alignment(x: List[int], multiple_of: int) -> List[int]: 31 | return x + [0] * ((-len(x)) % multiple_of) 32 | 33 | 34 | def _pad_to_max(x: List[int], max_len: int) -> List[int]: 35 | return x + [0] * (max_len - len(x)) 36 | 37 | 38 | class Worker: 39 | def __init__( 40 | self, 41 | cache_config: CacheConfig, 42 | model_loading_config: ModelLoadingConfig, 43 | parallel_config: ParallelConfig, 44 | logger: Optional[Logger] = None 45 | ): 46 | self.cache_config = cache_config 47 | self.model_loading_config = model_loading_config 48 | self.parallel_config = parallel_config 49 | 50 | self.logger = logger if logger else getLogger(__name__) 51 | 52 | # load model 53 | self.device = torch.device(self.model_loading_config.device) 54 | self.dtype = TORCH_FLOAT_DTYPE_MAP[self.model_loading_config.torch_dtype] 55 | 56 | torch.cuda.set_device(self.device) 57 | 58 | factory = MODEL_AUTO_TABLE[self.model_loading_config.model_type] 59 | model_config_cls = factory.model_config_cls 60 | model_cls = factory.model_cls 61 | model_config = model_config_cls.from_pretrained( 62 | self.model_loading_config.model_name_or_path, 63 | trust_remote_code=self.model_loading_config.trust_remote_code 64 | ) 65 | model_weights = Weights( 66 | self.model_loading_config.model_name_or_path, 67 | device=self.device, 68 | dtype=self.dtype, 69 | quantize_method=self.model_loading_config.quantize_method, 70 | gptq_model_base_name=self.model_loading_config.gptq_model_base_name 71 | ) 72 | if self.model_loading_config.quantize_method == "gptq": 73 | gptq_config = _get_gptq_config( 74 | self.model_loading_config.model_name_or_path, 75 | self.model_loading_config.gptq_config_base_name 76 | ) 77 | else: 78 | gptq_config = None 79 | self.model = model_cls(config=model_config, weights=model_weights, gptq_config=gptq_config) 80 | if self.model_loading_config.quantize_method == "gptq": 81 | self.model = autogptq_post_init(self.model, gptq_config.desc_act) 82 | self.model.eval() 83 | self.model_config = ModelConfig(self.model.config, self.parallel_config) 84 | 85 | self.cache = Cache( 86 | cache_config=self.cache_config, 87 | model_config=self.model_config, 88 | parallel_config=self.parallel_config, 89 | dtype=self.dtype, 90 | device=self.device 91 | ) 92 | 93 | def _prepare_inputs(self, batch: Batch, prefill: bool) -> Optional[Dict[str, Optional[torch.Tensor]]]: 94 | input_ids: List[int] = [] 95 | position_ids: List[int] = [] 96 | block_tables: Optional[List[List[int]]] = [] 97 | slots: List[int] = [] 98 | context_lengths: List[int] = [] 99 | lm_head_indices: Optional[List[int]] = [] 100 | 101 | beams = batch.prefill_beams if prefill else batch.generation_beams 102 | if not beams: 103 | return 104 | if prefill: 105 | block_tables = None 106 | for beam in beams: 107 | input_ids += beam.prompt_token_ids 108 | position_ids += list(range(beam.num_tokens)) 109 | context_lengths.append(beam.num_tokens) 110 | lm_head_indices.append(sum(context_lengths) - 1) 111 | 112 | block_ids = batch.block_tables[beam.beam_id] 113 | if block_ids is None: 114 | slots += [0] * beam.num_tokens 115 | else: 116 | for i in range(beam.num_tokens): 117 | block_id = block_ids[i // self.cache.block_size] 118 | block_offset = i % self.cache.block_size 119 | slot = block_id * self.cache.block_size + block_offset 120 | slots.append(slot) 121 | else: 122 | lm_head_indices = None 123 | for beam in beams: 124 | input_ids.append(beam.last_token_id) 125 | position_ids.append(beam.num_tokens - 1) 126 | context_lengths.append(beam.num_tokens) 127 | 128 | block_ids = batch.block_tables[beam.beam_id] 129 | block_tables.append(block_ids) 130 | 131 | block_id = block_ids[position_ids[-1] // self.cache.block_size] 132 | block_offset = position_ids[-1] % self.cache_config.block_size 133 | slot = block_id * self.cache_config.block_size + block_offset 134 | slots.append(slot) 135 | 136 | # Optimization: Pad the input length to be a multiple of 8. 137 | # This is required for utilizing the Tensor Cores in NVIDIA GPUs. 138 | # FIXME: after padding, model execution will fail if bsz is not a multiple of 8 139 | # input_ids = _pad_to_alignment(input_ids, multiple_of=8) 140 | # position_ids = _pad_to_alignment(position_ids, multiple_of=8) 141 | 142 | # Convert to tensors. 143 | input_ids = torch.tensor(input_ids, dtype=torch.long, device=self.device) 144 | position_ids = torch.tensor(position_ids, dtype=torch.long, device=self.device) 145 | context_lengths = torch.tensor(context_lengths, dtype=torch.int32, device=self.device) 146 | if block_tables is not None: 147 | max_num_blocks = max([len(block_table) for block_table in block_tables]) 148 | block_tables = torch.IntTensor( 149 | [_pad_to_max(block_table, max_num_blocks) for block_table in block_tables] 150 | ).to(self.device) 151 | slots = torch.IntTensor(slots).to(self.device) 152 | 153 | return { 154 | "input_ids": input_ids, 155 | "position_ids": position_ids, 156 | "kv_cache": self.cache.cache, 157 | "prefill": prefill, 158 | "block_tables": block_tables, 159 | "slots": slots, 160 | "context_lengths": context_lengths, 161 | "lm_head_indices": lm_head_indices 162 | } 163 | 164 | @torch.no_grad() 165 | @torch.cuda.amp.autocast() 166 | def _forward(self, batch: Batch, cache_events: Optional[List[torch.cuda.Event]] = None): 167 | prefill_inputs = self._prepare_inputs(batch, prefill=True) 168 | generation_inputs = self._prepare_inputs(batch, prefill=False) 169 | 170 | if prefill_inputs: 171 | self.logger.debug("executing model for prefilling.") 172 | batch.prefill_logits = self.model(cache_events=cache_events, **prefill_inputs) 173 | self.logger.debug("executed model for prefilling.") 174 | if generation_inputs: 175 | self.logger.debug("executing model for decoding.") 176 | batch.generation_logits = self.model(cache_events=cache_events, **generation_inputs) 177 | self.logger.debug("executed model for decoding.") 178 | 179 | def forward( 180 | self, 181 | batch: Batch, 182 | blocks_to_copy: Dict[int, List[int]], 183 | blocks_to_swap_in: Dict[int, int], 184 | blocks_to_swap_out: Dict[int, int] 185 | ): 186 | # Issue cache operations. 187 | issued_cache_op = False 188 | if blocks_to_swap_in: 189 | self.logger.debug("executing cache swap in operation.") 190 | self.cache.swap_in(blocks_to_swap_in) 191 | issued_cache_op = True 192 | self.logger.debug("executed cache swap in operation.") 193 | if blocks_to_swap_out: 194 | self.logger.debug("executing cache swap out operation.") 195 | self.cache.swap_out(blocks_to_swap_out) 196 | issued_cache_op = True 197 | self.logger.debug("executed cache swap out operation.") 198 | if blocks_to_copy: 199 | self.logger.debug("execution cache copy operation.") 200 | self.cache.copy(blocks_to_copy) 201 | issued_cache_op = True 202 | self.logger.debug("executed cache copy operation") 203 | 204 | if issued_cache_op: 205 | cache_events = self.cache.events 206 | else: 207 | cache_events = None 208 | 209 | if batch.num_beams == 0: 210 | if cache_events is not None: 211 | for event in cache_events: 212 | event.wait() 213 | self.logger.debug("no beams need to be processed, return directly.") 214 | return 215 | 216 | self._forward(batch, cache_events=cache_events) 217 | 218 | 219 | __all__ = ["Worker"] 220 | -------------------------------------------------------------------------------- /code/server/static_batching_server/__init__.py: -------------------------------------------------------------------------------- 1 | from .server import get_server, Server 2 | from .config import ServerConfig 3 | 4 | 5 | __all__ = ["get_server", "Server", "ServerConfig"] 6 | -------------------------------------------------------------------------------- /code/server/static_batching_server/batcher.py: -------------------------------------------------------------------------------- 1 | from logging import getLogger, Logger 2 | from typing import List, Optional, Tuple 3 | from uuid import UUID 4 | 5 | from pydantic import BaseModel, Field 6 | 7 | from .config import BatcherConfig 8 | from protocol.completion_task import HuggingFaceCompletionInputs, HuggingFaceGenerationConfig 9 | 10 | 11 | class Package(BaseModel): 12 | ids: List[UUID] = Field(default=...) 13 | prompts: List[str] = Field(default=...) 14 | generation_config: HuggingFaceGenerationConfig = Field(default=...) 15 | 16 | def __hash__(self): 17 | return hash(self.generation_config) 18 | 19 | @property 20 | def workload(self): 21 | return len(self.prompts) * self.generation_config.num_beams 22 | 23 | def add(self, prompt: str, uid: UUID): 24 | self.prompts.append(prompt) 25 | self.ids.append(uid) 26 | 27 | def __repr__(self): 28 | return f"Package(workload={self.workload})" 29 | 30 | def __str__(self): 31 | return self.__repr__() 32 | 33 | 34 | class Batcher: 35 | def __init__(self, config: BatcherConfig, logger: Optional[Logger] = None): 36 | self.config = config 37 | self.logger = logger if logger else getLogger(__name__) 38 | 39 | self.inputs: List[Tuple[HuggingFaceCompletionInputs, UUID]] = [] 40 | 41 | def pack(self) -> Optional[Package]: 42 | if not self.inputs: 43 | return None 44 | 45 | # ============================= 46 | # Strategy 1: 47 | # pack first input, then select input who has the same generation_config util package is full or 48 | # there is no other input left 49 | # ============================= 50 | 51 | inp, inp_id = self.inputs.pop(0) 52 | package = Package(ids=[inp_id], prompts=[inp.prompt], generation_config=inp.generation_config) 53 | inputs = [] 54 | while self.inputs: 55 | if package.workload > self.config.package_max_workload: # package is full, put back and return 56 | self.inputs = inputs + self.inputs 57 | self.logger.debug(msg=str(package)) 58 | return package 59 | inp, inp_id = self.inputs.pop(0) 60 | if hash(inp.generation_config) != hash(package): # gen_config is different, put back 61 | inputs.append((inp, inp_id)) 62 | else: 63 | package.add(inp.prompt, inp_id) 64 | self.logger.debug(msg=str(package)) 65 | return package 66 | 67 | # ============================= 68 | # Strategy 2: 69 | # pack input one by one, return immediately when package is full or generation_config is different 70 | # ============================= 71 | 72 | # package = None 73 | # while self.inputs: 74 | # inp, inp_id = self.inputs.pop(0) 75 | # if package is None: # the first input, initialize package 76 | # package = Package(ids=[inp_id], prompts=[inp.prompt], generation_config=inp.generation_config) 77 | # else: # if gen_config not the same or package is full, return package, otherwise add prompt into package 78 | # hash_value = hash(inp.generation_config) 79 | # if hash_value != hash(package) or package.workload > self.config.package_max_workload: 80 | # self.inputs.insert(0, (inp, inp_id)) 81 | # self.logger.debug(msg=str(package)) 82 | # return package 83 | # package.add(inp.prompt, inp_id) 84 | # self.logger.debug(msg=str(package)) 85 | # return package 86 | 87 | def add(self, inp: HuggingFaceCompletionInputs, inp_id: UUID): 88 | self.inputs.append((inp, inp_id)) 89 | 90 | 91 | __all__ = ["Batcher"] 92 | -------------------------------------------------------------------------------- /code/server/static_batching_server/config.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Optional, Union 2 | 3 | from pydantic import BaseModel, Field 4 | 5 | 6 | class BatcherConfig(BaseModel): 7 | package_max_workload: int = Field(default=1) 8 | packaging_interval_seconds: int = Field(default=2) 9 | 10 | 11 | class WorkerConfig(BaseModel): 12 | model_name_or_path: str = Field(default="dummy_model_name_or_path") 13 | tokenizer_name_or_path: Optional[str] = Field(default=None) 14 | revision: str = Field(default="main") 15 | low_cpu_mem_usage: bool = Field(default=True) 16 | torch_dtype: Union[str] = Field(default="float16", regex="(float16|bfloat16)") 17 | device: Optional[Union[int, str]] = Field(default=None) 18 | max_memory: Optional[Dict[Union[str, int], str]] = Field(default=None) 19 | device_map: Union[str, Dict[str, Union[int, str]]] = Field(default="auto") 20 | use_fast_tokenizer: bool = Field(default=False) 21 | trust_remote_code: bool = Field(default=False) 22 | use_safetensors: bool = Field(default=False) 23 | batch_size: int = Field(default=-1) # -1 means execute all inputs together no matter how many they are 24 | is_gptq_quantized: bool = Field(default=False) 25 | 26 | 27 | class ServerConfig(BatcherConfig): 28 | batcher_config: BatcherConfig = Field(default=BatcherConfig()) 29 | worker_config: WorkerConfig = Field(default=WorkerConfig()) 30 | 31 | 32 | __all__ = [ 33 | "BatcherConfig", 34 | "WorkerConfig", 35 | "ServerConfig" 36 | ] 37 | -------------------------------------------------------------------------------- /code/server/static_batching_server/server.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import time 3 | from logging import getLogger, Logger 4 | from typing import Dict, Optional, Tuple 5 | from threading import Thread 6 | from uuid import uuid4, UUID 7 | 8 | from .batcher import Batcher 9 | from .config import ServerConfig 10 | from .worker import Worker 11 | from protocol.completion_task import HuggingFaceCompletionInputs, HuggingFaceCompletionOutputs 12 | from protocol.error import Error 13 | 14 | 15 | SERVER_SINGLETON = None 16 | 17 | 18 | class ServerNotInitializedError(Exception): 19 | def __repr__(self): 20 | return "server is not initialized, please initialize a server object first." 21 | 22 | def __str__(self): 23 | return self.__repr__() 24 | 25 | 26 | class ServerDoubleInitializeError(Exception): 27 | def __repr__(self): 28 | return "server is initialized, do not initialize again, please use get_server() instead." 29 | 30 | def __str__(self): 31 | return self.__repr__() 32 | 33 | 34 | class Server: 35 | def __init__(self, config: ServerConfig, logger: Optional[Logger] = None): 36 | global SERVER_SINGLETON 37 | if SERVER_SINGLETON is not None: 38 | raise ServerDoubleInitializeError() 39 | 40 | self.config = config 41 | self.batcher_config = config.batcher_config 42 | self.worker_config = config.worker_config 43 | self.logger = logger if logger else getLogger(__name__) 44 | 45 | self.batcher = Batcher(config=self.batcher_config, logger=logger) 46 | self.worker = Worker(config=self.worker_config, logger=logger) 47 | 48 | self.outputs: Dict[UUID, Tuple[HuggingFaceCompletionOutputs, Optional[Error], int, float]] = dict() 49 | 50 | Thread(target=self._run, daemon=True).start() 51 | 52 | SERVER_SINGLETON = self 53 | 54 | def _run(self): 55 | while True: 56 | package = self.batcher.pack() 57 | if package is not None: 58 | self.outputs.update(self.worker.execute(package.prompts, package.ids, package.generation_config)) 59 | else: 60 | time.sleep(self.batcher_config.packaging_interval_seconds) 61 | 62 | async def wait_task_done( 63 | self, 64 | inp: HuggingFaceCompletionInputs 65 | ) -> Tuple[HuggingFaceCompletionOutputs, Optional[Error], int, float, float]: 66 | uid = uuid4() 67 | start = time.time() 68 | 69 | self.logger.info(msg=f"Task-{uid} is added.") 70 | self.batcher.add(inp, uid) 71 | 72 | while True: 73 | await asyncio.sleep(0.1) 74 | if uid in self.outputs: 75 | end = time.time() 76 | outputs, error, status_code, cpu_time = self.outputs.pop(uid) 77 | wall_time = end - start 78 | self.logger.info(msg=f"Task-{uid} is finished, {cpu_time=: .4f}s, {wall_time=: .4f}s") 79 | return outputs, error, status_code, cpu_time, wall_time 80 | 81 | 82 | def get_server(): 83 | if SERVER_SINGLETON is None: 84 | raise ServerNotInitializedError() 85 | return SERVER_SINGLETON 86 | 87 | 88 | __all__ = ["Server", "get_server"] 89 | -------------------------------------------------------------------------------- /code/server/static_batching_server/worker.py: -------------------------------------------------------------------------------- 1 | import time 2 | from logging import getLogger, Logger 3 | from typing import Dict, List, Optional, Tuple 4 | from uuid import UUID 5 | 6 | import torch 7 | from auto_gptq import AutoGPTQForCausalLM 8 | from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase 9 | 10 | from .config import WorkerConfig 11 | from protocol.completion_task import ( 12 | TokenUsage, 13 | HuggingFaceGenerationConfig, 14 | HuggingFaceCompletionChoice, 15 | HuggingFaceCompletionOutputs 16 | ) 17 | from protocol.error import Error 18 | 19 | 20 | class TextGenerationPipeline: 21 | """A simplified pipeline to show what HF's TextGenerationPipeline mainly do under the hood""" 22 | def __init__(self, model: torch.nn.Module, tokenizer: PreTrainedTokenizerBase, batch_size: int = -1): 23 | self.model = model 24 | self.tokenizer = tokenizer 25 | device = model.device 26 | if model.hf_device_map: 27 | device = model.hf_device_map[next(iter(model.hf_device_map))] 28 | self.device = torch.device(device) if not isinstance(device, torch.device) else device 29 | self.batch_size = batch_size 30 | 31 | def _preprocess( 32 | self, 33 | prompt_texts: List[str], 34 | generation_config: HuggingFaceGenerationConfig, 35 | handle_long_generation=None, 36 | ) -> "BatchEncoding": 37 | inputs = self.tokenizer( 38 | prompt_texts, 39 | padding=True, 40 | truncation=True, 41 | return_tensors="pt" 42 | ).to(self.device) 43 | 44 | if handle_long_generation == "hole": 45 | cur_len = inputs["input_ids"].shape[-1] 46 | new_tokens = generation_config.max_new_tokens 47 | if cur_len + new_tokens > self.tokenizer.model_max_length: 48 | keep_length = self.tokenizer.model_max_length - new_tokens 49 | if keep_length <= 0: 50 | raise ValueError( 51 | "We cannot use `hole` to handle this generation the number of desired tokens exceeds the" 52 | " models max length" 53 | ) 54 | 55 | inputs["input_ids"] = inputs["input_ids"][:, -keep_length:] 56 | if "attention_mask" in inputs: 57 | inputs["attention_mask"] = inputs["attention_mask"][:, -keep_length:] 58 | 59 | return inputs 60 | 61 | def _model_generate( 62 | self, 63 | input_ids: torch.Tensor, 64 | attention_mask: torch.Tensor, 65 | generation_config: HuggingFaceGenerationConfig 66 | ) -> torch.Tensor: 67 | generation_config.eos_token_id = self.tokenizer.eos_token_id 68 | generation_config.pad_token_id = self.tokenizer.pad_token_id 69 | decode_dict = generation_config.dict(by_alias=True) 70 | decode_dict.pop("seed") 71 | batch_gen_sequences = self.model.generate( 72 | input_ids=input_ids, 73 | attention_mask=attention_mask, 74 | **decode_dict 75 | ) 76 | return batch_gen_sequences 77 | 78 | def _postprocess( 79 | self, 80 | input_ids: torch.Tensor, 81 | generated_sequences: torch.Tensor, 82 | generation_config: HuggingFaceGenerationConfig, 83 | clean_up_tokenization_spaces=True 84 | ) -> List[HuggingFaceCompletionOutputs]: 85 | input_ids = input_ids.cpu() 86 | generated_sequences = generated_sequences.cpu() 87 | 88 | num_return_sequences = len(generated_sequences) // len(input_ids) 89 | batch_outputs = [] 90 | for idx, start in enumerate(range(0, len(generated_sequences), num_return_sequences)): 91 | inp = input_ids[idx].tolist() 92 | if self.tokenizer.pad_token_id in inp: 93 | inp = inp[:inp.index(self.tokenizer.pad_token_id)] 94 | 95 | sequences = generated_sequences[start: start + num_return_sequences] 96 | sequences = sequences[..., input_ids[idx].size(0):].tolist() 97 | for i, seq in enumerate(sequences): 98 | if self.tokenizer.pad_token_id in seq: 99 | sequences[i] = seq[: seq.index(self.tokenizer.pad_token_id)] 100 | sequences_num_tokens = [len(seq) for seq in sequences] 101 | 102 | usage = TokenUsage(prompt_tokens=len(inp), completion_tokens=sum(sequences_num_tokens)) 103 | usage.total_tokens = usage.prompt_tokens + usage.completion_tokens 104 | 105 | generated_texts = self.tokenizer.batch_decode( 106 | sequences, 107 | clean_up_tokenization_spaces=clean_up_tokenization_spaces, 108 | skip_special_tokens=True 109 | ) 110 | choices = [ 111 | HuggingFaceCompletionChoice( 112 | text=text, 113 | index=index, 114 | finish_reason="stop" if num_new_tokens < generation_config.max_new_tokens else "length" 115 | ) 116 | for index, (text, num_new_tokens) in enumerate(zip(generated_texts, sequences_num_tokens)) 117 | ] 118 | 119 | batch_outputs.append(HuggingFaceCompletionOutputs(choices=choices, usage=usage)) 120 | 121 | return batch_outputs 122 | 123 | def __call__( 124 | self, 125 | text_inputs, 126 | generation_config: HuggingFaceGenerationConfig, 127 | clean_up_tokenization_spaces=False, 128 | handle_long_generation="hole", 129 | ) -> List[HuggingFaceCompletionOutputs]: 130 | if isinstance(text_inputs, str): 131 | text_inputs = [text_inputs] 132 | 133 | outputs = [] 134 | batch_size = self.batch_size 135 | if batch_size == -1: 136 | batch_size = len(text_inputs) 137 | for start in range(0, len(text_inputs), batch_size): 138 | batch_input_texts = text_inputs[start: start + batch_size] 139 | 140 | batch_inputs = self._preprocess(batch_input_texts, generation_config, handle_long_generation) 141 | batch_input_ids = batch_inputs.input_ids 142 | batch_attention_mask = batch_inputs.attention_mask 143 | 144 | batch_gen_sequences = self._model_generate(batch_input_ids, batch_attention_mask, generation_config) 145 | 146 | outputs += self._postprocess( 147 | batch_input_ids, 148 | batch_gen_sequences, 149 | generation_config, 150 | clean_up_tokenization_spaces 151 | ) 152 | return outputs 153 | 154 | 155 | class Worker: 156 | def __init__(self, config: WorkerConfig, logger: Optional[Logger] = None): 157 | self.config = config 158 | self.logger = logger if logger else getLogger(__name__) 159 | 160 | self.model, self.tokenizer = self._load_model_tokenizer() 161 | self.pipeline = TextGenerationPipeline(self.model, self.tokenizer, batch_size=self.config.batch_size) 162 | 163 | def _load_model_tokenizer(self): 164 | tokenizer = AutoTokenizer.from_pretrained( 165 | self.config.tokenizer_name_or_path or self.config.model_name_or_path, 166 | use_fast=self.config.use_fast_tokenizer, 167 | padding_side="left", 168 | truncation_side="left", 169 | trust_remote_code=self.config.trust_remote_code 170 | ) 171 | if not tokenizer.pad_token_id: 172 | tokenizer.pad_token_id = tokenizer.eos_token_id 173 | 174 | max_memory = self.config.max_memory 175 | if max_memory: 176 | max_memory = {(eval(k) if isinstance(k, str) else k): v for k, v in max_memory.items()} 177 | 178 | if self.config.is_gptq_quantized: 179 | model = AutoGPTQForCausalLM.from_quantized( 180 | model_name_or_path=self.config.model_name_or_path, 181 | device_map=self.config.device_map, 182 | max_memory=max_memory, 183 | low_cpu_mem_usage=self.config.low_cpu_mem_usage, 184 | trust_remote_code=self.config.trust_remote_code, 185 | use_safetensors=self.config.use_safetensors 186 | ) 187 | else: 188 | model = AutoModelForCausalLM.from_pretrained( 189 | pretrained_model_name_or_path=self.config.model_name_or_path, 190 | torch_dtype=getattr(torch, self.config.torch_dtype), 191 | device_map=self.config.device_map, 192 | max_memory=max_memory, 193 | low_cpu_mem_usage=self.config.low_cpu_mem_usage, 194 | revision=self.config.revision, 195 | trust_remote_code=self.config.trust_remote_code 196 | ) 197 | 198 | return model, tokenizer 199 | 200 | def execute( 201 | self, 202 | prompts: List[str], 203 | uids: List[UUID], 204 | generation_config: HuggingFaceGenerationConfig 205 | ) -> Dict[UUID, Tuple[HuggingFaceCompletionOutputs, Optional[Error], int, float]]: 206 | start = time.time() 207 | try: 208 | pipeline_results = self.pipeline(prompts, generation_config) 209 | end = time.time() 210 | return {uid: (outputs, None, 200, end - start) for uid, outputs in zip(uids, pipeline_results)} 211 | except Exception as e: 212 | end = time.time() 213 | error = Error(type=e.__class__.__name__, detail=str(e)) 214 | self.logger.error(msg=str(error), exc_info=e) 215 | return {uid: (HuggingFaceCompletionOutputs(), error, 500, end - start) for uid in uids} 216 | 217 | 218 | __all__ = ["Worker"] 219 | -------------------------------------------------------------------------------- /code/server_requirements.txt: -------------------------------------------------------------------------------- 1 | aiohttp 2 | requests 3 | openai 4 | pydantic<2.0.0 5 | fastapi[all]==0.96.0 6 | setproctitle 7 | uvicorn[standard] 8 | gunicorn 9 | accelerate 10 | sentencepiece 11 | transformers 12 | xformers -------------------------------------------------------------------------------- /code/start_app.py: -------------------------------------------------------------------------------- 1 | import os 2 | from argparse import ArgumentParser 3 | from os.path import abspath, dirname, join 4 | 5 | 6 | CURRENT_DIR = dirname(abspath(__file__)) 7 | 8 | 9 | def build_gunicorn_cmd_str( 10 | gunicorn_config_file_path: str, 11 | proc_name: str, 12 | log_dir: str, 13 | port: int 14 | ): 15 | access_log_file = join(log_dir, f"{proc_name}_access.log") 16 | error_log_file = join(log_dir, f"{proc_name}_error.log") 17 | cmd_str = ( 18 | f'-c "{gunicorn_config_file_path}" ' 19 | f'--bind "0.0.0.0:{port}" ' 20 | f'--name "{proc_name}" ' 21 | f'--access-logfile "{access_log_file}" ' 22 | f'--error-logfile "{error_log_file}" ' 23 | f'-D' # running in background 24 | ) 25 | return cmd_str 26 | 27 | 28 | def main(): 29 | parser = ArgumentParser() 30 | 31 | # args to start client app 32 | parser.add_argument( 33 | "--start_client_app", action="store_true", 34 | help="whether to start client app" 35 | ) 36 | parser.add_argument( 37 | "--client_config_file_path", type=str, default="client_config.json", 38 | help="local path to read client config file" 39 | ) 40 | parser.add_argument( 41 | "--client_config_hot_update_interval_minutes", type=int, default=1, 42 | help="specify the interval minutes to hot update client config" 43 | ) 44 | parser.add_argument( 45 | "--client_debug", action="store_true", 46 | help="whether to start client app in debug mode" 47 | ) 48 | parser.add_argument( 49 | "--client_port", type=int, default=8000, 50 | help="the port the client app will use" 51 | ) 52 | 53 | # args to start CB server app 54 | parser.add_argument( 55 | "--start_cb_server_app", action="store_true", 56 | help="whether to start CB server app" 57 | ) 58 | parser.add_argument( 59 | "--cb_server_model_id", type=str, 60 | help="model id for the CB server that will be started" 61 | ) 62 | parser.add_argument( 63 | "--cb_server_config_file_path", type=str, default="cb_server_config.json", 64 | help="local path to read CB server config file" 65 | ) 66 | parser.add_argument( 67 | "--cb_server_debug", action="store_true", 68 | help="whether to start CB server app in debug mode" 69 | ) 70 | parser.add_argument( 71 | "--cb_server_port", type=int, default=8001, 72 | help="the part the CB server app will use" 73 | ) 74 | 75 | # args to start SB server app 76 | parser.add_argument( 77 | "--start_sb_server_app", action="store_true", 78 | help="whether to start SB server app" 79 | ) 80 | parser.add_argument( 81 | "--sb_server_model_id", type=str, 82 | help="model id for the SB server that will be started" 83 | ) 84 | parser.add_argument( 85 | "--sb_server_config_file_path", type=str, default="sb_server_config.json", 86 | help="local path to read SB server config file" 87 | ) 88 | parser.add_argument( 89 | "--sb_server_debug", action="store_true", 90 | help="whether to start SB server app in debug mode" 91 | ) 92 | parser.add_argument( 93 | "--sb_server_port", type=int, default=8002, 94 | help="the part the SB server app will use" 95 | ) 96 | 97 | # args that are shared among apps 98 | parser.add_argument( 99 | "--client_url", type=str, default=None, 100 | help="URL of client, only be used by servers, has no effect for now" 101 | ) 102 | parser.add_argument( 103 | "--gunicorn_config_file_path", type=str, default="gunicorn_config.py", 104 | help="local path to read a python script that stores some common gunicorn settings for APPs" 105 | ) 106 | 107 | args = parser.parse_args() 108 | 109 | # start client app if triggered 110 | if args.start_client_app: 111 | cmd_str = ( 112 | f"""gunicorn 'client_app:build_app(""" 113 | f"""client_config_file_path="{args.client_config_file_path}",""" 114 | f"""client_config_hot_update_interval_minutes="{args.client_config_hot_update_interval_minutes}",""" 115 | f"""debug={args.client_debug})' """ 116 | ) 117 | cmd_str += build_gunicorn_cmd_str( 118 | gunicorn_config_file_path=args.gunicorn_config_file_path, 119 | proc_name="llm_inference_client", 120 | log_dir=join(CURRENT_DIR, "logs"), 121 | port=args.client_port 122 | ) 123 | print(f"start client app, command being executed is:\n{cmd_str}") 124 | os.system(cmd_str) 125 | 126 | # start CB server app if triggered 127 | if args.start_cb_server_app: 128 | cmd_str = ( 129 | f"""gunicorn 'continuous_batching_server_app:build_app(""" 130 | f"""model_id="{args.cb_server_model_id}",""" 131 | f"""server_config_file_path="{args.cb_server_config_file_path}",""" 132 | # f"""client_url={args.client_url}""" 133 | f"""debug={args.cb_server_debug})' """ 134 | ) 135 | cmd_str += build_gunicorn_cmd_str( 136 | gunicorn_config_file_path=args.gunicorn_config_file_path, 137 | proc_name="llm_inference_cb_server", 138 | log_dir=join(CURRENT_DIR, "logs"), 139 | port=args.cb_server_port 140 | ) 141 | print(f"start CB server app, command being executed is:\n{cmd_str}") 142 | os.system(cmd_str) 143 | 144 | # start SB server app if triggered 145 | if args.start_sb_server_app: 146 | cmd_str = ( 147 | f"""gunicorn 'static_batching_server_app:build_app(""" 148 | f"""model_id="{args.sb_server_model_id}",""" 149 | f"""server_config_file_path="{args.sb_server_config_file_path}",""" 150 | # f"""client_url={args.client_url}""" 151 | f"""debug={args.sb_server_debug})' """ 152 | ) 153 | cmd_str += build_gunicorn_cmd_str( 154 | gunicorn_config_file_path=args.gunicorn_config_file_path, 155 | proc_name="llm_inference_sb_server", 156 | log_dir=join(CURRENT_DIR, "logs"), 157 | port=args.sb_server_port 158 | ) 159 | print(f"start SB server app, command being executed is:\n{cmd_str}") 160 | os.system(cmd_str) 161 | 162 | 163 | if __name__ == "__main__": 164 | main() 165 | -------------------------------------------------------------------------------- /code/static_batching_server_app.py: -------------------------------------------------------------------------------- 1 | import json 2 | from logging import getLogger, DEBUG, INFO 3 | from typing import Optional 4 | 5 | from fastapi import HTTPException, FastAPI 6 | from fastapi.responses import JSONResponse 7 | from pydantic import BaseModel, Field 8 | 9 | from server.static_batching_server import get_server, Server, ServerConfig 10 | from protocol.completion_task import ( 11 | HuggingFaceCompletionInputs, 12 | HuggingFaceCompletionOutputs 13 | ) 14 | from protocol.error import Error 15 | from protocol.routes import ( 16 | ROUTE_GET_MODEL_ID, 17 | ROUTE_POST_CONTINUOUS_BATCHING_COMPLETION, 18 | ROUTE_CLIENT_ONLY_POST_SERVER_STARTUP_EVENT, 19 | ROUTE_CLIENT_ONLY_POST_SERVER_SHUTDOWN_EVENT 20 | ) 21 | from utils.log_util import RequestLoggingMiddleware 22 | 23 | 24 | logger = getLogger("gunicorn.logger") # by default, we use gunicorn to wrap the app 25 | 26 | 27 | class AppConfig(BaseModel): 28 | model_id: str = Field(default=...) 29 | server_config_file_path: str = Field(default=...) 30 | client_url: Optional[str] = Field(default=None) 31 | debug: bool = Field(default=False) 32 | 33 | 34 | APP_NAME = "LLM-Inference-SB-Server" 35 | 36 | app = FastAPI(title=APP_NAME, version="0.1.0") 37 | app_config: Optional[AppConfig] = None 38 | 39 | 40 | def build_app( 41 | model_id: str = None, 42 | server_config_file_path: str = "sb_server_config.json", 43 | client_url: Optional[str] = None, 44 | debug: bool = False 45 | ): 46 | global app, app_config 47 | 48 | if model_id is None: 49 | raise ValueError("You must specify a real value to model_id.") 50 | 51 | logger.setLevel(DEBUG if debug else INFO) 52 | 53 | app_config = AppConfig( 54 | model_id=model_id, 55 | server_config_file_path=server_config_file_path, 56 | client_url=client_url, 57 | debug=debug 58 | ) 59 | 60 | app.add_middleware(RequestLoggingMiddleware, logger=logger) 61 | 62 | return app 63 | 64 | 65 | @app.on_event("startup") 66 | def startup(): 67 | # initialize server 68 | Server( 69 | config=ServerConfig(**json.load(open(app_config.server_config_file_path, "r", encoding="utf-8"))), 70 | logger=logger 71 | ) 72 | # TODO: implement logic to inform client that server is startup 73 | 74 | 75 | @app.on_event("shutdown") 76 | def shutdown(): 77 | pass # TODO: implement logic to inform client that server is shutdown 78 | 79 | 80 | @app.get(ROUTE_GET_MODEL_ID) 81 | async def get_model_id(): 82 | return JSONResponse(content=app_config.model_id) 83 | 84 | 85 | @app.post(ROUTE_POST_CONTINUOUS_BATCHING_COMPLETION, response_model=HuggingFaceCompletionOutputs) 86 | async def execute_completion(request_inputs: HuggingFaceCompletionInputs): 87 | server = get_server() 88 | outputs, error, status_code, cpu_time, wall_time = await server.wait_task_done(request_inputs) 89 | if status_code != 200: 90 | logger.error(msg=str(error)) 91 | raise HTTPException(status_code=status_code, detail=str(error)) 92 | return outputs 93 | 94 | 95 | if __name__ == "__main__": 96 | import uvicorn 97 | from argparse import ArgumentParser 98 | from logging import basicConfig 99 | 100 | parser = ArgumentParser() 101 | parser.add_argument("--model_id", type=str) 102 | parser.add_argument("--server_config_file_path", type=str, default="sb_server_config.json") 103 | parser.add_argument("--client_url", type=str, default=None) 104 | parser.add_argument("--debug", action="store_true") 105 | parser.add_argument("--port", type=int, default=8002) 106 | args = parser.parse_args() 107 | 108 | logger = getLogger(__name__) # override gunicorn logger if we use uvicorn directly 109 | basicConfig( 110 | format="%(asctime)s %(levelname)s [%(name)s] %(message)s", 111 | datefmt="%Y-%m-%d %H:%M:%S" 112 | ) 113 | 114 | uvicorn.run( 115 | build_app( 116 | model_id=args.model_id, 117 | server_config_file_path=args.server_config_file_path, 118 | client_url=args.client_url, 119 | debug=args.debug 120 | ), 121 | host="0.0.0.0", 122 | port=args.port 123 | ) 124 | -------------------------------------------------------------------------------- /code/utils/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /code/utils/log_util.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import re 3 | from datetime import datetime 4 | from typing import * 5 | 6 | from fastapi import FastAPI, Request, Response 7 | from starlette.middleware.base import BaseHTTPMiddleware 8 | 9 | 10 | # adopt from: 11 | # https://github.com/jeffsiver/fastapi-route-logger/blob/main/fastapi_route_logger_middleware/__init__.py 12 | class RequestLoggingMiddleware(BaseHTTPMiddleware): 13 | def __init__( 14 | self, 15 | app: FastAPI, 16 | *, 17 | logger: Optional[logging.Logger] = None, 18 | skip_routes: List[str] = None, 19 | skip_regexes: List[str] = None, 20 | skip_request_methods: List[str] = None, 21 | ): 22 | self._logger = logger if logger else logging.getLogger(__name__) 23 | self._skip_routes = skip_routes if skip_routes else [] 24 | self._skip_regexes = ( 25 | list(map(lambda regex: re.compile(regex), skip_regexes)) 26 | if skip_regexes 27 | else [] 28 | ) 29 | self._skip_request_methods = skip_request_methods if skip_request_methods else [] 30 | 31 | BaseHTTPMiddleware.__init__(self, app) 32 | 33 | async def dispatch(self, request: Request, call_next: Callable) -> Response: 34 | if self._should_route_be_skipped(request): 35 | return await call_next(request) 36 | 37 | return await self._execute_request_with_logging(request, call_next) 38 | 39 | def _should_route_be_skipped(self, request: Request) -> bool: 40 | return any( 41 | [True for path in self._skip_routes if request.url.path.startswith(path)] 42 | + [True for regex in self._skip_regexes if regex.match(request.url.path)] 43 | + [True for method in self._skip_request_methods if request.method.lower() == method.lower()] 44 | ) 45 | 46 | async def _execute_request_with_logging( 47 | self, request: Request, call_next: Callable 48 | ) -> Response: 49 | self._logging_before_execution(request) 50 | start_time = datetime.utcnow() 51 | try: 52 | response = await call_next(request) 53 | except Exception as e: 54 | self._logging_when_error_raised(e) 55 | raise 56 | 57 | finish_time = datetime.utcnow() 58 | self._logging_after_execution(response, (finish_time - start_time).total_seconds()) 59 | 60 | return response 61 | 62 | def _logging_before_execution(self, request: Request): 63 | content = ( 64 | f"receive a request from {request.client.host}:{request.client.port} " 65 | f"to {request.url.path}, method={request.method}" 66 | ) 67 | self._logger.info(content) 68 | 69 | def _logging_after_execution(self, response: Response, execution_time: float): 70 | overall_status = "successfully" if response.status_code < 400 else "failed" 71 | content = ( 72 | f"{overall_status} executed a request, duration={execution_time}s, " 73 | f"status_code={response.status_code}" 74 | ) 75 | self._logger.info(content) 76 | 77 | def _logging_when_error_raised(self, exception: Exception): 78 | content = ( 79 | f"error occurred when execute request, " 80 | f"error_type=[{exception.__class__.__name__}], error_msg=[{str(exception)}]" 81 | ) 82 | self._logger.error(content) 83 | 84 | 85 | __all__ = ["RequestLoggingMiddleware"] 86 | -------------------------------------------------------------------------------- /docs/tutorial/0_概述.md: -------------------------------------------------------------------------------- 1 |

概述

2 | 3 | 在本章节,我们会对整个教程进行一个概述,介绍后续各章节的核心内容,让读者对整个教程的知识脉络有一个比较清晰的把握,同时也方便读者有针对性的选择自己感兴趣或较陌生的话题来阅读,节省宝贵的时间。 4 | 5 | ## 话题范围 6 | 7 | 在当前版本,我们不会细致入微地涵盖大语言模型推理部署工程化落地的全方面内容。比如,我们不会讨论如何设计一个好的日志和监控告警系统,如何设计和实施分布式的容器化部署与管理,如何实现真正的负载均衡等等。在本教程,我们重点关注:大语言模型推理引擎设计与代码讲解,推理引擎的优化方向和实现思路,推理引擎的性能测试方法,以及硬件设备的选择。 8 | 9 | > 对于本教程不会涉及的部分,若读者朋友们有强烈的学习意愿和需要,欢迎在 issue 中提出,我们将根据社区的反馈考虑是否增加相关章节。 10 | 11 | ## 各章节介绍 12 | 13 | ### 第一章 大语言模型推理引擎设计与代码讲解 14 | 15 | 在第一章中,我们首先对大语言模型推理引擎进行整体性介绍;然后,在余下小节里,我们将会更详细地对大语言模型推理引擎中的各模块进行讲解,并对相关代码进行说明,让读者最终能够自己动手实现一个简单的高性能推理引擎。 16 | 17 | 以下展示第一章所含的各小节: 18 | 0. 概述 19 | 1. 请求打包模块 20 | 2. 缓存管理模块 21 | 3. 模型执行模块 22 | 4. 模型工具模块 23 | 5. 生成工具模块 24 | 6. Server 封装和 HTTP 服务化 25 | 26 | ### 第二章 推理引擎的优化方向和实现思路 27 | 28 | 在第二章节中,我们结合现有流行的推理引擎和最新的研究进展来分析推理引擎在推理性能和生成质量等方面可进一步优化的方向,以及可能的实现思路。 29 | 30 | 以下展示第二章所含的各小节: 31 | 1. 模型和缓存压缩的优化 32 | 2. 生成策略的优化 33 | 3. 请求打包策略的优化 34 | 4. 其他可能的优化 35 | 36 | ### 第三章 推理引擎的性能测试方法 37 | 38 | 在第三章节中,我们将介绍如何有效地评估推理引擎的性能和根据实际业务流量的特征选择测试策略。 39 | 40 | ### 第四章 硬件设备的选择 41 | 42 | 在第四章节中,我们将介绍如何结合实际的业务场景来有针对性地选择 GPU。 43 | --------------------------------------------------------------------------------