├── .gitignore
├── LICENSE
├── README.md
├── code
├── benchmark
│ └── load_test.py
├── client
│ ├── __init__.py
│ ├── client.py
│ └── openai_jumper.py
├── client_app.py
├── client_requirements.txt
├── continuous_batching_server_app.py
├── gunicorn_config.py
├── protocol
│ ├── __init__.py
│ ├── completion_task.py
│ ├── error.py
│ └── routes.py
├── server
│ ├── __init__.py
│ ├── continuous_batching_server
│ │ ├── __init__.py
│ │ ├── batcher.py
│ │ ├── beam.py
│ │ ├── cache
│ │ │ ├── README.md
│ │ │ ├── __init__.py
│ │ │ ├── cache.py
│ │ │ └── cache_manager.py
│ │ ├── config.py
│ │ ├── generation_utils
│ │ │ ├── README.md
│ │ │ ├── __init__.py
│ │ │ ├── logits_process.py
│ │ │ └── tokens.py
│ │ ├── modeling
│ │ │ ├── __init__.py
│ │ │ ├── llama.py
│ │ │ └── utils
│ │ │ │ ├── __init__.py
│ │ │ │ ├── attention.py
│ │ │ │ ├── linear.py
│ │ │ │ └── weights.py
│ │ ├── server.py
│ │ └── worker.py
│ └── static_batching_server
│ │ ├── __init__.py
│ │ ├── batcher.py
│ │ ├── config.py
│ │ ├── server.py
│ │ └── worker.py
├── server_requirements.txt
├── start_app.py
├── static_batching_server_app.py
└── utils
│ ├── __init__.py
│ └── log_util.py
└── docs
└── tutorial
├── 0_概述.md
└── 1_大语言模型推理引擎设计与代码讲解.md
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | #.idea/
161 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
大语言模型推理部署教程
2 | 面向大语言模型开发工程师的,囊括推理引擎设计、服务部署、性能评测等方面的进阶教程,并开源了一套经过优化的单卡推理引擎代码供研究与学习
3 |
4 | 教程文档正在持续快速地更新中,watch 本项目以第一时间获取最新内容!
5 |
6 | ## 目录
7 |
8 | - [本项目文件结构](#本项目文件结构)
9 | - [推理引擎代码](#推理引擎代码)
10 | * [介绍](#介绍)
11 | * [环境配置](#环境配置)
12 | * [服务部署](#服务部署)
13 |
14 |
15 | ## 本项目文件结构
16 |
17 | 本项目主要包含两个目录:`code` 和 `docs`。
18 | - `code` 目录:存放着我们以研究和学习为主要目的的推理引擎代码,开发者们可以在该目录下通过简单几步来快速配置环境和部署服务,以对比和体验不同策略在不同流量特征下的差异。
19 | - `docs` 目录:存放着所有章节的教程文档,通过阅读相关内容,开发者们可以了解到在商业化的生产环境中部署大语言模型进行推理的多方面知识。
20 |
21 | ## 推理引擎代码
22 |
23 | ### 介绍
24 |
25 | 我们在本项目中开源了一套经过优化的大语言模型单卡推理引擎,采用 C/S 设计模式,向开发者们展示大语言模型推理引擎的内部运作机理,方便大家亲自上手部署以体验不同策略在不同流量特征下的性能差异。
26 |
27 | #### 优点
28 |
29 | - 加载 `int4` GPTQ 模型
30 | - (客户端)负载均衡
31 | - 多策略的服务端:**静态批处理** (Static Batching, SB) 策略和**持续批处理** (Continuous Batching, CB) 策略
32 | - (支持持续批处理的服务端)基于 xformers 和 PagedAttention 加速推理
33 | - (支持持续批处理的服务端)异构解码策略
34 | - GPU-CPU 内存交换
35 |
36 | #### 缺点
37 |
38 | - 多卡推理
39 | - 流式传输
40 | - (支持持续批处理的服务端)只支持 llama v1 和 llama v2(除 70B)模型
41 | - (支持持续批处理的服务端)未对注意力层之外的其他网络模块和计算操作进行推理性能的优化
42 | - (支持持续批处理的服务端)仅支持 `safetensors` 格式的权重文件
43 | - (支持持续批处理的服务端)不支持拓展最大可处理的上下文长度
44 | - 不支持加载 adapter(s)
45 | - 不支持投机采样(又称辅助生成)
46 |
47 | #### 其他开源大模型推理框架
48 |
49 | 本项目提供的推理引擎代码旨在向开发者们展示大语言模型推理引擎的内部运作机理,因此未对与这一目的关系较弱或无关的方面做进一步优化。
50 |
51 | 通过使用我们提供的这套推理引擎代码,开发者们可以在单卡上轻松部署 20B 及以下的模型,但在面对更大参数量的模型时则捉襟见肘。为此,我们在这里列出开源社区上现已存在的功能较为完善的大语言模型推理框架并作简单介绍:
52 |
53 | - [TGI](https://github.com/huggingface/text-generation-inference): Hugging Face 在内部生产环境中使用的大语言模型推理框架,注意在 0.9.4 版本之后应用于商业目的需获得官方许可;
54 | - [vLLM](https://github.com/vllm-project/vllm): 高性能、易用的大语言模型推理框架,提出并实现了 PagedAttention;
55 | - [lmdeploy](https://github.com/InternLM/lmdeploy): 书生浦语团队研发的大语言模型推理部署框架;
56 | - [TransformerEngine](https://github.com/NVIDIA/TransformerEngine): 英伟达新一代 Transformers 架构模型推理框架,支持 `fp8` 格式;
57 | - [DeepSpeed-MII](https://github.com/microsoft/DeepSpeed-MII): 基于 DeepSpeed 的低延迟和高吞吐推理引擎,不仅仅只支持大语言模型。
58 |
59 | 此外,还有 [FasterTransformer](https://github.com/NVIDIA/FasterTransformer), [FlexGen](https://github.com/FMInference/FlexGen), [EnergonAI](https://github.com/hpcaitech/EnergonAI) 等针对大语言模型推理性能优化而设计的框架,但更新频率较低,故不在此作进一步介绍。
60 |
61 | 我们十分建议开发者们在学习完本项目后,亲自去阅读以上开源项目的代码,以进一步理解大模型推理引擎设计思路和优化技术。有关更多大语言模型推理引擎和部署框架对比的详细介绍,可参阅这篇[文章](https://mp.weixin.qq.com/s/xIbNSAI9cKGIA19yZhIEgg) 。
62 |
63 | > 本项目提供的推理引擎代码在权重加载、模型网络代码设计和文本生成策略实现上借鉴了 TGI 项目;在内存管理上借鉴了 vLLM 项目。
64 |
65 | ### 环境配置
66 |
67 | #### 客户端
68 |
69 | 在 `code` 目录下,执行 `pip install -r client_requirements.txt` 来安装客户端所需的第三方代码库。
70 |
71 | #### 服务端
72 |
73 | 首先,确保虚拟环境中已经安装了支持 CUDA 的 2.0.0 及以上版本的 PyTorch,若没有,可以在 [这里](https://pytorch.org/get-started/locally/) 选择并安装与你的软硬件信息相符的预编译安装包,也可以根据 [这里](https://github.com/pytorch/pytorch#from-source) 的指导从源码编译和安装。
74 |
75 | 其次,安装 vLLM,此举的目的是为了方便我们在代码中使用 paged-attention 算子和与内存管理相关的算子。
76 | - 快速安装:`pip install vllm`;
77 | - 从源码编译安装:遵循 [这里](https://vllm.readthedocs.io/en/latest/getting_started/installation.html#build-from-source) 的指导。
78 |
79 | 接着,根据 [这里](https://github.com/PanQiWei/AutoGPTQ#installation) 的指导安装 `auto-gptq` 代码库,我们提供的推理引擎支持加载和使用 GPTQ 模型。
80 |
81 | 最后,在 `code` 目录下,执行 `pip install -r server_requirements.txt` 安装服务端所需的其他第三方代码库。
82 |
83 | ### 服务部署
84 | 正如你所见,`code` 目录下包含了四个模块:
85 | - `protocol`:定义了客户端和服务端之间的通信规范;
86 | - `client`:客户端相关代码;
87 | - `server`:服务端相关代码;
88 | - `utils`:构建 HTTP 服务代码时的其他工具代码。
89 |
90 | 你可以组合这些模块并将它们封装成单独的 HTTP 服务,如:
91 | - `client_app.py`:使用 `protocol`, `client` 和 `utils` 模块构建的客户端 HTTP 服务;
92 | - `continuous_batching_server_app.py` 和 `static_batching_server_app.py`:使用 `protocol`, `server` 和 `utils` 模块构建的服务端 HTTP 服务。
93 |
94 | 也可以将这些模块直接拷贝到你当前工作的项目中进行集成应用。
95 |
96 | 我们在 `code` 目录中提供了三个以 "_app.py" 为后缀的 HTTP 服务脚本,你可以直接执行这些脚本来启动服务,如:
97 | ```shell
98 | CUDA_VISIBLE_DEVICES=0 python client_app.py --port 8000 # 查询更多命令行参数可使用 python client_app.py --help
99 | ```
100 |
101 | 此时服务进程将始终运行在前台,你可以直接使用 `ctrl`+`c` 来关闭进程。
102 |
103 | 此外,我们还提供了一个统一的服务部署脚本 `start_app.py`,通过运行这个脚本,你可以使用 `gunicorn` 来管理和在后台运行服务进程。
104 |
105 | > 注意:由于我们提供的推理引擎目前只支持单卡推理,因此在运行服务脚本时,强烈建议同时设置 CUDA_VISIBLE_DEVICES 环境变量以使用指定的单张显卡,避免预期外的行为发生。
106 |
107 | 部署每个 HTTP 服务都需要提供一个相应的配置文件,以下为各服务配置文件模板(使用时注意删去注释文本):
108 |
109 |
110 | client_config.json
111 |
112 | ```json
113 | {
114 | "continuous_batching_server_urls": ["http://127.0.0.1:8001"], # 支持持续批处理的服务端 HTTP 服务地址,请求会在此间负载均衡地分发
115 | "static_batching_server_urls": ["http://127.0.0.1:8002"], # 支持静态批处理的服务端 HTTP 服务地址,请求会在此间负载均衡地分发
116 | "openai_jumper_configs": [
117 | {
118 | "api_key": "YOUR_OPENAI_KEY",
119 | "org_id": null
120 | }
121 | ], # openai 账号列表,请求会在此间负载均衡地分发
122 | "heart_beat_interval_seconds": 600 # 对服务端的 HTTP 服务心跳检测间隔(以秒为单位)
123 | }
124 | ```
125 |
126 |
127 |
128 | cb_server_config.json
129 |
130 | ```json
131 | {
132 | "model_loading_config": {
133 | "model_type": "llama", # 模型架构类型,目前只支持 llama
134 | "model_name_or_path": "PATH_TO_MODEL_DIR", # 存放模型权重文件的目录路径,只支持 safetensors 格式
135 | "torch_dtype": "float16", # (非 GPTQ 模型时)模型权重和运算时使用的数值类型,可选项为 float16 和 bfloat16
136 | "tokenizer_name_or_path": null, # 存放分词器模型文件的目录路径,如果为空则使用存放模型权重文件的目录路径
137 | "use_fast_tokenizer": false, # 若为 true 则加载分词器时设置 use_fast=True
138 | "trust_remote_code": false, # 是否使用非 Hugging Face 官方提供的模型或分词器代码
139 | "quantize_method": null, # 量化方法,可选值为 gptq
140 | "model_max_length": 2048, # 模型能处理的最大上下文长度
141 | "gptq_model_base_name": null, # GPTQ 模型权重文件名称(不包含文件拓展名),若为空则使用默认的命名格式查找文件
142 | "gptq_config_base_name": null # GPTQ 配置文件名称(不包含文件拓展名),若为空则使用默认的命名格式查找文件
143 | },
144 | "batcher_config": {
145 | "batch_max_tokens": 56000, # 一个批次中同时处理的最大 tokens 数量,这里的值为 llama-7b fp16 模型在 A100-40G 上的一个合理值
146 | "batch_max_beams": 32 # 一个批次中同时处理的最大 beam(文本生成阶段的预测分支) 数量,这里的值为 llama-7b fp16 模型在 A100-40G 上的一个合理值
147 | },
148 | "cache_config": {
149 | "num_blocks": 2500, # GPU 内存块数量,这里的值为 llama-7b fp16 模型在 A100-40G 上的一个合理值
150 | "num_blocks_cpu": 1024, # CPU 内存块数量
151 | "block_size": 16, # 一个内存块的大小
152 | "watermark": 0.01, # 预留的 GPU 内存块比例,这是为了防止过分分配 GPU 内存块给 prompt 的 和从 CPU 内存换入的 KV Cache 而导致文本生成时 GPU 内存资源紧张
153 | }
154 | }
155 | ```
156 |
157 |
158 |
159 | sb_server_config.json
160 |
161 | ```json
162 | {
163 | "batcher_config": {
164 | "package_max_workload": "16", # 一个任务包的最大工作负载,单位为 beam
165 | "packaging_interval_seconds": 2 # 打包的间隔时间,TPS/QPS 越小间隔时间可以越长
166 | },
167 | "worker_config": {
168 | "model_name_or_path": "PATH_TO_MODEL_DIR", # 存放模型权重文件的目录路径
169 | "tokenizer_name_or_path": null, # 存放模型权重文件的目录路径
170 | "revision": "main", # 使用的模型仓库分支,仅在目录路径为 Hugging Face Hub 模型名或 github 仓库目录时生效
171 | "low_cpu_mem_usage": true, # 是否直接将模型权重加载到 GPU
172 | "torch_dtype": "float16", # (非 GPTQ 模型时)模型权重和运算时使用的数值类型,可选项为 float16 和 bfloat16
173 | "use_fast_tokenizer": false, # 若为 true 则加载分词器时设置 use_fast=True
174 | "trust_remote_code": false, # 是否使用非 Hugging Face 官方提供的模型或分词器代码
175 | "use_safetensors": false, # 是否加载 `safetensors` 格式的权重文件
176 | "batch_size": -1, # TextGenerationPipeline 执行时所使用的 batch_size,-1 表示同时处理所有的输入
177 | "is_gptq_quantized": false # 是否使用的是 GPTQ 模型
178 | }
179 | }
180 | ```
181 |
182 |
--------------------------------------------------------------------------------
/code/benchmark/load_test.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import json
3 | import random
4 | import sys
5 | import time
6 | from argparse import ArgumentParser
7 | from os.path import abspath, dirname
8 | from typing import List
9 |
10 | import aiohttp
11 | import numpy as np
12 | import pandas as pd
13 | from transformers import AutoTokenizer, PreTrainedTokenizerBase
14 |
15 | sys.path.insert(0, abspath(dirname(dirname(__file__))))
16 |
17 | from protocol.completion_task import (
18 | HuggingFaceGenerationConfig,
19 | HuggingFaceCompletionInputs,
20 | HuggingFaceCompletionOutputs,
21 | )
22 | from protocol.routes import ROUTE_POST_CONTINUOUS_BATCHING_COMPLETION, ROUTE_POST_STATIC_BATCHING_COMPLETION
23 |
24 |
25 | def load_tokenizer(tokenizer_name_or_path: str, use_fast: bool, max_length: int) -> PreTrainedTokenizerBase:
26 | tokenizer = AutoTokenizer.from_pretrained(
27 | tokenizer_name_or_path,
28 | use_fast=use_fast,
29 | model_max_length=max_length,
30 | padding_side="left",
31 | truncation_side="right",
32 | )
33 | if not tokenizer.pad_token:
34 | tokenizer.pad_token = tokenizer.eos_token
35 | tokenizer.pad_token_id = tokenizer.eos_token_id
36 |
37 | return tokenizer
38 |
39 |
40 | def gen_random_lens(distribution: str, len_mean: int, len_range: int, num_requests: int) -> List[int]:
41 | if distribution == 'uniform':
42 | if len_range == 0:
43 | return [len_mean for _ in range(num_requests)]
44 |
45 | low = len_mean - (len_range // 2)
46 | high = len_mean + (len_range // 2)
47 | num_to_generate = list(
48 | map(lambda _: random.randint(low, high), range(num_requests)))
49 | return num_to_generate
50 | elif distribution == 'exponential':
51 | np.random.seed(random.randint(0, 1e6))
52 | return [min(round(s), len_range) for s in np.random.exponential(scale=len_mean, size=num_requests)]
53 | elif distribution == 'capped_exponential':
54 | np.random.seed(random.randint(0, 1e6))
55 | response_lens = []
56 | while len(response_lens) < num_requests:
57 | sample = round(np.random.exponential(scale=len_mean))
58 | if sample <= len_range:
59 | response_lens.append(sample)
60 | return response_lens
61 | else:
62 | raise ValueError(f'unknown distribution {distribution=}')
63 |
64 |
65 | def prepare_payloads(
66 | model_id: str,
67 | data_path: str,
68 | tokenizer: PreTrainedTokenizerBase,
69 | num_beams: int,
70 | prompt_len: int,
71 | response_lens: List[int]
72 | ) -> List[dict]:
73 | with open(data_path, "r", encoding="utf-8") as f:
74 | prompts = json.load(f)
75 |
76 | assert len(prompts) >= len(response_lens)
77 |
78 | prompts = tokenizer.batch_decode(
79 | tokenizer(
80 | prompts,
81 | max_length=prompt_len,
82 | truncation=True,
83 | padding="max_length"
84 | )["input_ids"]
85 | )
86 | prompts = random.sample(prompts, len(response_lens))
87 | random.shuffle(prompts)
88 |
89 | payloads = []
90 | for idx, (prompt, response_len) in enumerate(zip(prompts, response_lens)):
91 | payload = HuggingFaceCompletionInputs(
92 | model=model_id,
93 | prompt=prompt,
94 | generation_config=HuggingFaceGenerationConfig(
95 | max_new_tokens=response_len,
96 | min_new_tokens=response_len,
97 | num_beams=num_beams,
98 | num_return_sequences=1
99 | )
100 | ).dict(by_alias=True)
101 | payloads.append(payload)
102 |
103 | return payloads
104 |
105 |
106 | async def load_test(
107 | payloads: List[dict],
108 | throughput_oriented: bool,
109 | batch_size: int,
110 | url: str
111 | ) -> dict:
112 | async def request_one(request_id, payload, sleep_seconds):
113 | await asyncio.sleep(sleep_seconds)
114 | request_start = time.time()
115 | async with aiohttp.request(
116 | method="post",
117 | url=url,
118 | json=payload,
119 | timeout=aiohttp.ClientTimeout(total=60*60)
120 | ) as resp:
121 | if resp.status:
122 | outputs = HuggingFaceCompletionOutputs(**(await resp.json()))
123 | else:
124 | outputs = HuggingFaceCompletionOutputs()
125 | num_gen_tokens = outputs.usage.completion_tokens
126 | status = resp.status
127 | request_end = time.time()
128 | wall_time = request_end - request_start
129 | print(f" - request_id={request_id} :: {status=}, {wall_time=: .4f}s, {num_gen_tokens=}")
130 |
131 | return num_gen_tokens, wall_time, status
132 |
133 | if throughput_oriented:
134 | total_tokens = 0
135 | duration = 0
136 | num_success = 0
137 | num_failed = 0
138 | for start_idx in range(0, len(payloads), batch_size):
139 | end_idx = start_idx + batch_size
140 | start = time.time()
141 | batch_results = await asyncio.gather(
142 | *[
143 | request_one(start_idx + idx, payload, 0)
144 | for idx, payload in enumerate(payloads[start_idx: end_idx])
145 | ]
146 | )
147 | end = time.time()
148 | duration += (end - start)
149 | total_tokens += sum([each[0] for each in batch_results if each[2] == 200])
150 | if all(each[2] == 200 for each in batch_results):
151 | num_success += 1
152 | else:
153 | num_failed += 1
154 | latency = -1
155 | throughput = total_tokens / duration
156 | fail_rate = f"{num_failed / (num_failed + num_success) * 100: .4f}%"
157 | else:
158 | sleep_sec = 0
159 | tasks = []
160 | for start_idx in range(0, len(payloads), batch_size):
161 | end_idx = start_idx + batch_size
162 | sleep_sec += (0 if throughput_oriented else 1)
163 | tasks += [
164 | asyncio.create_task(request_one(start_idx + idx, payload, sleep_sec))
165 | for idx, payload in enumerate(payloads[start_idx: end_idx])
166 | ]
167 | results = await asyncio.gather(*tasks)
168 | results = [each for each in results if each[2] == 200]
169 | latency = pd.Series([each[1] for each in results]).describe().to_dict()
170 | throughput = -1
171 | fail_rate = f"{(len(payloads) - len(results)) / len(payloads) * 100: .4f}%"
172 |
173 | return {
174 | "throughput_oriented": throughput_oriented,
175 | "batch_size": batch_size,
176 | "latency(s)": latency,
177 | "throughput(generated_tokens/s)": throughput,
178 | "fail_rate": fail_rate
179 | }
180 |
181 |
182 | def main():
183 | parser = ArgumentParser()
184 | parser.add_argument("--host", type=str, default="http://127.0.0.1")
185 | parser.add_argument("--port", type=int, default=8000)
186 | parser.add_argument("--model_id", type=str, default="") # default to not specify if request to server directly
187 | parser.add_argument("--data_path", type=str, default="load_test_data.json")
188 | parser.add_argument("--save_path", type=str, default="load_test_report.json")
189 | parser.add_argument("--tokenizer_name_or_path", type=str, default="gpt2")
190 | parser.add_argument("--use_fast_tokenizer", action="store_true")
191 | parser.add_argument("--throughput_oriented", action="store_true")
192 | # when throughput_oriented is True, num_beams will be treated as batch size
193 | # which means how many requests will be sent to inference server at the same time,
194 | # where each request's num_beams is 1
195 | parser.add_argument("--num_beams", type=int, default=1)
196 | parser.add_argument("--prompt_len", type=int, default=512)
197 | parser.add_argument("--response_len_mean", type=int, default=512)
198 | parser.add_argument("--response_len_range", type=int, default=0)
199 | parser.add_argument("--distribution", type=str, choices=["uniform", "exponential", "capped_exponential"], default="uniform")
200 | parser.add_argument("--num_emit", type=int, default=60) # num batches or num seconds continues
201 | parser.add_argument("--qps", type=int, default=1) # used only when throughput_oriented is not True
202 | parser.add_argument("--server_type", type=str, default="cb", choices=["cb", "sb"])
203 | args = parser.parse_args()
204 |
205 | if args.distribution == "uniform":
206 | assert args.response_len_mean > args.response_len_range
207 | else:
208 | assert args.response_len_mean <= args.response_len_range
209 |
210 | tokenizer = load_tokenizer(args.tokenizer_name_or_path, args.use_fast_tokenizer, args.prompt_len)
211 |
212 | if args.throughput_oriented:
213 | num_requests = args.num_beams * args.num_emit
214 | response_lens = [args.response_len_mean for _ in range(num_requests)]
215 | args.num_beams = 1
216 | else:
217 | num_requests = args.qps * args.num_emit
218 | response_lens = gen_random_lens(args.distribution, args.response_len_mean, args.response_len_range, num_requests)
219 | batch_size = num_requests // args.num_emit
220 |
221 | payloads = prepare_payloads(args.model_id, args.data_path, tokenizer, args.num_beams, args.prompt_len, response_lens)
222 |
223 | print(
224 | f"Load Test :: {args.throughput_oriented=}, {num_requests=}, {batch_size=}, "
225 | f"{args.response_len_mean=}, {args.response_len_range=}, {args.distribution=}"
226 | )
227 |
228 | route = ROUTE_POST_CONTINUOUS_BATCHING_COMPLETION if args.server_type == "cb" else ROUTE_POST_STATIC_BATCHING_COMPLETION
229 | report = asyncio.run(
230 | load_test(
231 | payloads, args.throughput_oriented, batch_size, f"{args.host}:{args.port}{route}"
232 | )
233 | )
234 | report["response_len_mean"] = args.response_len_mean
235 | report["response_len_range"] = args.response_len_range
236 | report["distribution"] = args.distribution
237 |
238 | print("REPORT ::")
239 | for k, v in report.items():
240 | print(f" - {k}: {v}")
241 |
242 | with open(args.save_path, "w", encoding="utf-8") as f:
243 | json.dump(report, f)
244 |
245 |
246 | if __name__ == "__main__":
247 | main()
248 |
--------------------------------------------------------------------------------
/code/client/__init__.py:
--------------------------------------------------------------------------------
1 | from .client import get_client, Client, ClientConfig, ServerType
2 |
3 | __all__ = ["Client", "ClientConfig", "ServerType", "get_client"]
4 |
--------------------------------------------------------------------------------
/code/client/client.py:
--------------------------------------------------------------------------------
1 | import json
2 | import time
3 | from collections import defaultdict
4 | from enum import Enum
5 | from itertools import chain
6 | from logging import getLogger, Logger
7 | from typing import Dict, List, Optional, Tuple
8 | from threading import Thread
9 |
10 | import aiohttp
11 | import requests
12 | from fastapi import HTTPException
13 | from pydantic import BaseModel, Field
14 |
15 | from .openai_jumper import *
16 | from protocol.completion_task import *
17 | from protocol.routes import *
18 |
19 |
20 | class ServerType(Enum):
21 | SB = "static_batching_server"
22 | CB = "continuous_batching_server"
23 |
24 |
25 | class ServerURL:
26 | def __init__(self, url: str):
27 | self.url = url
28 | self.model_id = None
29 | self.workload = 0
30 | self.available = True
31 |
32 | @staticmethod
33 | def _calculate_workload(request_inputs: HuggingFaceCompletionInputs):
34 | # A simple heuristic method to calculate workload:
35 | # - 1.35 here means we assume 1 word ≈ 1.35 tokens
36 | # - 20 here means we assume the time consumed to decode 1 token ≈ prefill 20 tokens
37 | # You can also change the calculation logic here by yourself
38 |
39 | num_prompt_tokens = int(len(request_inputs.prompt.split()) * 1.35)
40 | num_beams = request_inputs.generation_config.num_beams
41 | max_new_tokens = request_inputs.generation_config.max_new_tokens
42 |
43 | return num_beams * max_new_tokens + num_prompt_tokens // 20
44 |
45 | def increase_workload(self, request_inputs: HuggingFaceCompletionInputs):
46 | self.workload += self._calculate_workload(request_inputs)
47 |
48 | def decrease_workload(self, request_inputs: HuggingFaceCompletionInputs):
49 | self.workload -= self._calculate_workload(request_inputs)
50 |
51 | def __eq__(self, other: "ServerURL"):
52 | return self.workload == other.workload
53 |
54 | def __gt__(self, other: "ServerURL"):
55 | return self.workload > other.workload
56 |
57 | def __ge__(self, other: "ServerURL"):
58 | return self.workload >= other.workload
59 |
60 | def __lt__(self, other: "ServerURL"):
61 | return self.workload < other.workload
62 |
63 | def __le__(self, other: "ServerURL"):
64 | return self.workload <= other.workload
65 |
66 | def __hash__(self):
67 | return hash(self.url)
68 |
69 |
70 | CLIENT_SINGLETON = None
71 |
72 |
73 | class ClientNotInitializedError(Exception):
74 | def __repr__(self):
75 | return "client is not initialized, please initialize a client object first."
76 |
77 | def __str__(self):
78 | return self.__repr__()
79 |
80 |
81 | class ClientDoubleInitializeError(Exception):
82 | def __repr__(self):
83 | return "client is initialized, do not initialize again, please use get_client() instead."
84 |
85 | def __str__(self):
86 | return self.__repr__()
87 |
88 |
89 | class ClientConfig(BaseModel):
90 | static_batching_server_urls: Optional[List[str]] = Field(default=None)
91 | continuous_batching_server_urls: Optional[List[str]] = Field(default=None)
92 | openai_jumper_configs: Optional[List[OpenAIJumperConfig]] = Field(default=None)
93 | heart_beat_interval_seconds: int = Field(default=600)
94 |
95 |
96 | class Client:
97 | def __init__(self, config: ClientConfig, logger: Optional[Logger] = None):
98 | global CLIENT_SINGLETON
99 |
100 | if CLIENT_SINGLETON is not None:
101 | raise ClientDoubleInitializeError()
102 |
103 | self.config = config
104 | self.logger = logger if logger else getLogger(__name__)
105 |
106 | # containers
107 | self.model_id2static_batching_server_urls: Dict[str, List[ServerURL]] = defaultdict(list)
108 | self.model_id2continuous_batching_server_urls: Dict[str, List[ServerURL]] = defaultdict(list)
109 | self.openai_jumpers: List[OpenAIJumper] = []
110 |
111 | self._update_containers_once()
112 |
113 | Thread(target=self._heart_beat_loop, daemon=True).start()
114 |
115 | # set singleton to self
116 | CLIENT_SINGLETON = self
117 |
118 | def _update_containers_once(self):
119 | self._update_server_urls_map(
120 | static_batching_server_urls=self.config.static_batching_server_urls,
121 | continuous_batching_server_urls=self.config.continuous_batching_server_urls
122 | )
123 | self._update_openai_jumpers(openai_jumper_configs=self.config.openai_jumper_configs)
124 |
125 | def _heart_beat_loop(self):
126 | while True:
127 | time.sleep(self.config.heart_beat_interval_seconds)
128 | self._update_containers_once()
129 |
130 | def _update_server_urls_map(
131 | self,
132 | static_batching_server_urls: Optional[List[str]] = None,
133 | continuous_batching_server_urls: Optional[List[str]] = None
134 | ):
135 | def build_server_urls_map(old_server_url_objs: List[ServerURL], new_server_urls: List[str]):
136 | # TODO: parallelize the execution, the logic here for now may very slow
137 |
138 | server_urls_map = defaultdict(list)
139 | old_server_url_hash_values = [hash(url_obj) for url_obj in old_server_url_objs]
140 | for url in new_server_urls:
141 | hash_value = hash(url)
142 | if hash_value in old_server_url_hash_values:
143 | url_obj = old_server_url_objs[old_server_url_hash_values.index(hash_value)]
144 | else:
145 | url_obj = ServerURL(url=url)
146 | self.get_model_id(url_obj)
147 | if url_obj.model_id:
148 | server_urls_map[url_obj.model_id].append(url_obj)
149 |
150 | return server_urls_map
151 |
152 | if static_batching_server_urls is not None:
153 | self.model_id2static_batching_server_urls = build_server_urls_map(
154 | old_server_url_objs=list(chain(*self.model_id2static_batching_server_urls.values())),
155 | new_server_urls=static_batching_server_urls
156 | )
157 |
158 | if continuous_batching_server_urls is not None:
159 | self.model_id2continuous_batching_server_urls = build_server_urls_map(
160 | old_server_url_objs=list(chain(*self.model_id2continuous_batching_server_urls)),
161 | new_server_urls=continuous_batching_server_urls
162 | )
163 |
164 | def _update_openai_jumpers(self, openai_jumper_configs: Optional[List[OpenAIJumperConfig]] = None):
165 | if openai_jumper_configs is None:
166 | return
167 |
168 | old_jumpers = self.openai_jumpers
169 | old_jumper_hash_values = [hash(jumper) for jumper in old_jumpers]
170 | new_jumpers = []
171 | for config in openai_jumper_configs:
172 | hash_value = hash(config.api_key)
173 | if hash_value in old_jumper_hash_values:
174 | jumper = old_jumpers[old_jumper_hash_values.index(hash_value)]
175 | else:
176 | jumper = OpenAIJumper(config=config, logger=self.logger)
177 | new_jumpers.append(jumper)
178 |
179 | self.openai_jumpers = new_jumpers
180 |
181 | for jumper in [jumper for jumper in old_jumpers if jumper not in new_jumpers]:
182 | destroy_openai_jumper(jumper)
183 |
184 | def update_config(self, config: ClientConfig):
185 | # One can use this method to implement hot reload logic to update client's behavior on the fly.
186 | # For example:
187 | # 1. manually update locally saved config file to update config's parameters;
188 | # 2. receive an event that a server is startup or shutdown, and automatically update server_urls.
189 |
190 | if set([c.api_key for c in config.openai_jumper_configs]) != \
191 | set([c.api_key for c in self.config.openai_jumper_configs]):
192 | self._update_openai_jumpers(openai_jumper_configs=config.openai_jumper_configs)
193 | new_static_batching_server_urls = None
194 | new_continuous_batching_server_urls = None
195 | if set(config.static_batching_server_urls) != set(self.config.static_batching_server_urls):
196 | new_static_batching_server_urls = config.static_batching_server_urls
197 | if set(config.continuous_batching_server_urls) != set(self.config.continuous_batching_server_urls):
198 | new_continuous_batching_server_urls = config.continuous_batching_server_urls
199 | if new_static_batching_server_urls is not None or new_continuous_batching_server_urls is not None:
200 | self._update_server_urls_map(
201 | static_batching_server_urls=new_static_batching_server_urls,
202 | continuous_batching_server_urls=new_continuous_batching_server_urls
203 | )
204 | self.config = config
205 |
206 | def save_config(self, save_path: str):
207 | with open(save_path, "w", encoding="utf-8") as f:
208 | json.dump(self.config.dict(by_alias=True), f)
209 |
210 | async def openai_chat_completion(
211 | self,
212 | request_inputs: OpenAIChatCompletionInputs,
213 | max_retries: int = 3,
214 | raise_on_error: bool = True
215 | ) -> OpenAIChatCompletionOutputs:
216 | request_inputs.verify_and_preprocess()
217 |
218 | available_jumpers = [jumper for jumper in self.openai_jumpers if jumper.available]
219 | if not available_jumpers:
220 | if raise_on_error:
221 | raise HTTPException(
222 | status_code=404,
223 | detail="LookupError: none of openai jumper is available for now."
224 | )
225 | return OpenAIChatCompletionOutputs()
226 |
227 | jumper = min(available_jumpers)
228 | request_outputs, error, status_code = await jumper.chat_completion(
229 | inputs=request_inputs,
230 | max_retries=max_retries
231 | )
232 | if status_code != 200 and raise_on_error:
233 | raise HTTPException(status_code=status_code, detail=str(error))
234 | return request_outputs
235 |
236 | async def huggingface_completion(
237 | self,
238 | request_inputs: HuggingFaceCompletionInputs,
239 | max_retries: int = 3,
240 | raise_on_error: bool = True,
241 | server_type: ServerType = ServerType.CB,
242 | timeout: int = 100
243 | ) -> HuggingFaceCompletionOutputs:
244 | request_inputs.verify_and_preprocess()
245 |
246 | async def request(
247 | payload: dict,
248 | url_obj: ServerURL,
249 | route: str,
250 | ) -> Tuple[HuggingFaceCompletionOutputs, Optional[str], int]:
251 | async with aiohttp.request(
252 | method="post",
253 | url=f"{url_obj.url}{route}",
254 | json=payload,
255 | headers={},
256 | timeout=aiohttp.ClientTimeout(timeout)
257 | ) as resp:
258 | if resp.status == 200:
259 | return HuggingFaceCompletionOutputs(**(await resp.json())), None, resp.status
260 | else:
261 | return HuggingFaceCompletionOutputs(), resp.reason, resp.status
262 |
263 | # check if the requested model_id is available
264 | model_id = request_inputs.model
265 | server_urls_map = (
266 | self.model_id2static_batching_server_urls if server_type == ServerType.SB
267 | else self.model_id2continuous_batching_server_urls
268 | )
269 | url_objs = [url_obj for url_obj in server_urls_map.get(model_id, []) if url_obj.available]
270 | if not url_objs:
271 | if raise_on_error:
272 | raise HTTPException(
273 | status_code=404,
274 | detail=f"LookupError: requested model [{model_id}] is not available for now."
275 | )
276 | return HuggingFaceCompletionOutputs()
277 |
278 | # get the url_obj whose workload is smallest
279 | url_obj = min(url_objs)
280 |
281 | # request
282 | url_obj.increase_workload(request_inputs)
283 | try:
284 | request_outputs, error, status_code = await request(
285 | payload=request_inputs.dict(by_alias=True),
286 | url_obj=url_obj,
287 | route=ROUTE_POST_STATIC_BATCHING_COMPLETION if server_type == ServerType.SB
288 | else ROUTE_POST_CONTINUOUS_BATCHING_COMPLETION
289 | )
290 | except Exception as e:
291 | request_outputs = HuggingFaceCompletionOutputs()
292 | self.logger.error(msg=f"running request method failed with error type [{e.__class__.__name__}]", exc_info=e)
293 | error = "ServerRequestError: Unknown error occurred when request to server."
294 | status_code = 500
295 | url_obj.decrease_workload(request_inputs)
296 |
297 | if status_code == 200:
298 | return request_outputs
299 | else:
300 | if max_retries > 0:
301 | return await self.huggingface_completion(
302 | request_inputs=request_inputs,
303 | max_retries=max_retries-1,
304 | raise_on_error=raise_on_error,
305 | server_type=server_type,
306 | timeout=timeout
307 | )
308 | elif raise_on_error:
309 | raise HTTPException(status_code=status_code, detail=error)
310 | else:
311 | return request_outputs
312 |
313 | def get_model_id(self, url: ServerURL, max_retries: int = 3) -> Optional[str]:
314 | res = requests.get(f"{url.url}{ROUTE_GET_MODEL_ID}", timeout=(1, 1), verify=False)
315 | if res.status_code == 200:
316 | url.available = True
317 | url.model_id = res.json()
318 | return url.model_id
319 | else:
320 | if max_retries > 0:
321 | time.sleep(1)
322 | return self.get_model_id(url, max_retries - 1)
323 | self.logger.error(
324 | msg=f"request to {url.url} to get model_id failed with status [{res.status_code}]"
325 | )
326 | url.available = False
327 | return None
328 |
329 |
330 | def get_client():
331 | if CLIENT_SINGLETON is None:
332 | raise ClientNotInitializedError()
333 | return CLIENT_SINGLETON
334 |
335 |
336 | __all__ = [
337 | "ServerType",
338 | "Client",
339 | "ClientConfig",
340 | "get_client"
341 | ]
342 |
--------------------------------------------------------------------------------
/code/client/openai_jumper.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | from logging import getLogger, Logger
4 | from typing import Optional, Tuple
5 | from threading import Thread
6 |
7 | import openai
8 | from openai.util import convert_to_dict
9 | from pydantic import BaseModel, Field, Required
10 |
11 | from protocol.completion_task import (
12 | TokenUsage,
13 | OpenAIChatCompletionMessage,
14 | OpenAIChatCompletionChoice,
15 | OpenAIChatCompletionInputs,
16 | OpenAIChatCompletionOutputs
17 | )
18 | from protocol.error import Error
19 |
20 |
21 | AVAILABLE_OPENAI_CHAT_COMPLETION_MODELS = [
22 | "gpt-3.5-turbo",
23 | # You can extend available chat completion models here by yourself
24 | ]
25 |
26 |
27 | class OpenAIJumperConfig(BaseModel):
28 | api_key: str = Field(default=Required)
29 | org_id: Optional[str] = Field(default=None)
30 |
31 |
32 | class OpenAIJumper:
33 | def __init__(self, config: OpenAIJumperConfig, logger: Optional[Logger] = None):
34 | self.config = config
35 | self.logger = logger if logger else getLogger(__name__)
36 |
37 | self.workload = 0
38 | self.referenced = 0
39 | self._available = True
40 |
41 | async def chat_completion(
42 | self,
43 | inputs: OpenAIChatCompletionInputs,
44 | max_retries: int = 3
45 | ) -> Tuple[
46 | OpenAIChatCompletionOutputs,
47 | Optional[Error],
48 | int
49 | ]:
50 | if inputs.model not in AVAILABLE_OPENAI_CHAT_COMPLETION_MODELS:
51 | error_body = Error(
52 | type="ValueError",
53 | detail=f"LookupError: Required model [{inputs.model}] is not available, "
54 | f"available models are {AVAILABLE_OPENAI_CHAT_COMPLETION_MODELS}"
55 | )
56 | return OpenAIChatCompletionOutputs(), error_body, 404
57 |
58 | request_dict = inputs.dict(exclude_none=True)
59 | request_dict.update({"api_key": self.config.api_key})
60 | if self.config.org_id:
61 | request_dict.update({"organization": self.config.org_id})
62 |
63 | try:
64 | self.referenced += 1
65 | resp = await openai.ChatCompletion.acreate(**request_dict)
66 | except openai.error.Timeout as e:
67 | self.referenced -= 1
68 | if max_retries > 0:
69 | max_retries -= 1
70 | self.logger.warning(
71 | msg=f"Request to openai chat completion api timeout, will retry again (chance_left={max_retries})"
72 | )
73 | return await self.chat_completion(inputs, max_retries=max_retries)
74 | error_body = Error(
75 | type=e.__class__.__name__,
76 | detail=e.user_message
77 | )
78 | self.logger.error(msg=str(error_body))
79 | return OpenAIChatCompletionOutputs(), error_body, e.http_status
80 | except openai.error.OpenAIError as e:
81 | self.referenced -= 1
82 | error_body = Error(
83 | type=e.__class__.__name__,
84 | detail=e.user_message
85 | )
86 | self.logger.error(msg=str(error_body))
87 | return OpenAIChatCompletionOutputs(), error_body, e.http_status
88 | except Exception as e:
89 | self.referenced -= 1
90 | error_body = Error(
91 | type=e.__class__.__name__,
92 | detail=str(e)
93 | )
94 | self.logger.error(msg=str(error_body))
95 | return OpenAIChatCompletionOutputs(), error_body, 500
96 | else:
97 | self.referenced -= 1
98 | resp = convert_to_dict(resp)
99 | outputs = OpenAIChatCompletionOutputs(
100 | choices=[
101 | OpenAIChatCompletionChoice(
102 | message=OpenAIChatCompletionMessage(
103 | role=choice["message"]["role"],
104 | content=choice["message"]["content"]
105 | ),
106 | index=choice["index"],
107 | finish_reason=choice["finish_reason"]
108 | ) for choice in resp["choices"]
109 | ],
110 | usage=TokenUsage(**resp["usage"])
111 | )
112 |
113 | self.workload += outputs.usage.total_tokens
114 |
115 | return outputs, None, 200
116 |
117 | def reset_workload(self):
118 | self.workload = 0
119 |
120 | def freeze(self):
121 | self._available = False
122 |
123 | @property
124 | def available(self):
125 | return self._available
126 |
127 | def __eq__(self, other: "OpenAIJumper"):
128 | return self.workload == other.workload
129 |
130 | def __gt__(self, other: "OpenAIJumper"):
131 | return self.workload > other.workload
132 |
133 | def __ge__(self, other: "OpenAIJumper"):
134 | return self.workload >= other.workload
135 |
136 | def __lt__(self, other: "OpenAIJumper"):
137 | return self.workload < other.workload
138 |
139 | def __le__(self, other: "OpenAIJumper"):
140 | return self.workload <= other.workload
141 |
142 | def __hash__(self):
143 | return hash(self.config.api_key)
144 |
145 |
146 | def destroy_openai_jumper(openai_jumper: OpenAIJumper):
147 | def destroy():
148 | openai_jumper.freeze()
149 | while True:
150 | if openai_jumper.referenced == 0:
151 | del openai_jumper
152 | return
153 | time.sleep(1)
154 |
155 | Thread(target=destroy, daemon=True).start()
156 |
157 |
158 | __all__ = [
159 | "AVAILABLE_OPENAI_CHAT_COMPLETION_MODELS",
160 | "OpenAIJumper",
161 | "OpenAIJumperConfig",
162 | "destroy_openai_jumper",
163 | ]
164 |
--------------------------------------------------------------------------------
/code/client_app.py:
--------------------------------------------------------------------------------
1 | import json
2 | import time
3 | import os
4 | from logging import getLogger, DEBUG, INFO, Logger
5 | from typing import Optional
6 | from threading import Thread
7 |
8 |
9 | from fastapi import FastAPI
10 | from fastapi.responses import JSONResponse
11 | from pydantic import BaseModel, Field
12 |
13 | from client import get_client, Client, ClientConfig, ServerType
14 | from protocol.completion_task import *
15 | from protocol.error import Error
16 | from protocol.routes import (
17 | ROUTE_POST_OPENAI_CHAT_COMPLETION,
18 | ROUTE_POST_CONTINUOUS_BATCHING_COMPLETION,
19 | ROUTE_POST_STATIC_BATCHING_COMPLETION,
20 | ROUTE_CLIENT_ONLY_POST_SERVER_STARTUP_EVENT,
21 | ROUTE_CLIENT_ONLY_POST_SERVER_SHUTDOWN_EVENT
22 | )
23 | from utils.log_util import RequestLoggingMiddleware
24 |
25 |
26 | logger = getLogger("gunicorn.logger") # by default, we use gunicorn to wrap the app
27 |
28 |
29 | class AppConfig(BaseModel):
30 | client_config_file_path: str = Field(default=...)
31 | client_config_hot_update_interval_minutes: int = Field(default=1)
32 | debug: bool = Field(default=False)
33 |
34 |
35 | APP_NAME = "LLM-Inference-Client"
36 |
37 | app = FastAPI(title=APP_NAME, version="0.1.0")
38 | app_config: Optional[AppConfig] = None
39 |
40 |
41 | def build_app(
42 | client_config_file_path: str = "client_config.json",
43 | client_config_hot_update_interval_minutes: int = 1,
44 | debug: bool = False
45 | ):
46 | global app, app_config
47 |
48 | logger.setLevel(DEBUG if debug else INFO)
49 |
50 | app_config = AppConfig(
51 | client_config_file_path=client_config_file_path,
52 | client_config_hot_update_interval_minutes=client_config_hot_update_interval_minutes,
53 | debug=debug
54 | )
55 |
56 | app.add_middleware(RequestLoggingMiddleware, logger=logger)
57 |
58 | return app
59 |
60 |
61 | def hot_update_client_config_loop():
62 | while True:
63 | time.sleep(app_config.client_config_hot_update_interval_minutes * 60)
64 | client = get_client()
65 |
66 | fp = app_config.client_config_file_path
67 | if not os.path.exists(fp):
68 | logger.warning(
69 | msg=f"Client config file path [{fp}] not exists, skip hot update this time,"
70 | f"and will try to save a client config snapshot."
71 | )
72 | try:
73 | client.save_config(fp)
74 | except:
75 | pass
76 | continue
77 | new_client_config = Client(**json.load(open(fp, "r", encoding="utf-8")))
78 | client.update_config(new_client_config)
79 |
80 |
81 | @app.on_event("startup")
82 | def startup():
83 | # init client
84 | Client(
85 | config=ClientConfig(
86 | **json.load(open(app_config.client_config_file_path, "r", encoding="utf-8"))
87 | ),
88 | logger=logger
89 | )
90 |
91 | # start client config hot update loop
92 | Thread(target=hot_update_client_config_loop, daemon=True).start()
93 |
94 |
95 | # === Routes that request to LLM servers === #
96 |
97 | @app.post(ROUTE_POST_OPENAI_CHAT_COMPLETION, response_model=OpenAIChatCompletionOutputs)
98 | async def request_openai_chat_completion(request_inputs: OpenAIChatCompletionInputs):
99 | client = get_client()
100 | return await client.openai_chat_completion(request_inputs)
101 |
102 |
103 | @app.post(ROUTE_POST_CONTINUOUS_BATCHING_COMPLETION, response_model=HuggingFaceCompletionOutputs)
104 | async def request_continuous_batching_server(request_inputs: HuggingFaceCompletionInputs):
105 | client = get_client()
106 | return await client.huggingface_completion(request_inputs, server_type=ServerType.CB)
107 |
108 |
109 | @app.post(ROUTE_POST_STATIC_BATCHING_COMPLETION, response_model=HuggingFaceCompletionOutputs)
110 | async def request_static_batching_server(request_inputs: HuggingFaceCompletionInputs):
111 | client = get_client()
112 | return await client.huggingface_completion(request_inputs, server_type=ServerType.SB)
113 |
114 |
115 | # === Routes that provide some meta information === #
116 |
117 | @app.get("/cb_server/available_models")
118 | async def get_cb_server_available_models():
119 | client = get_client()
120 | available_models = []
121 | for model_id, server_urls in client.model_id2continuous_batching_server_urls:
122 | if any(url_obj.available for url_obj in server_urls):
123 | available_models.append(model_id)
124 | return JSONResponse(content=available_models)
125 |
126 |
127 | @app.get("/sb_server/available_models")
128 | async def get_sb_server_available_models():
129 | client = get_client()
130 | available_models = []
131 | for model_id, server_urls in client.model_id2static_batching_server_urls:
132 | if any(url_obj.available for url_obj in server_urls):
133 | available_models.append(model_id)
134 | return JSONResponse(content=available_models)
135 |
136 |
137 | @app.get("/openai_jumper/is_available")
138 | async def get_openai_jumper_is_available():
139 | client = get_client()
140 | if any(jumper.available for jumper in client.openai_jumpers):
141 | return JSONResponse(content="1")
142 | return JSONResponse("0")
143 |
144 |
145 | # === Routes that receive server event and update client config === #
146 | # TODO: Implement routes
147 |
148 |
149 | if __name__ == "__main__":
150 | import uvicorn
151 | from argparse import ArgumentParser
152 | from logging import basicConfig
153 |
154 | parser = ArgumentParser()
155 | parser.add_argument("--client_config_file_path", type=str, default="client_config.json")
156 | parser.add_argument("--client_config_hot_update_interval_minutes", type=int, default=1)
157 | parser.add_argument("--debug", action="store_true")
158 | parser.add_argument("--port", type=int, default=8000)
159 | args = parser.parse_args()
160 |
161 | logger = getLogger(__name__) # override gunicorn logger if we use uvicorn directly
162 | basicConfig(
163 | format="%(asctime)s %(levelname)s [%(name)s] %(message)s",
164 | datefmt="%Y-%m-%d %H:%M:%S"
165 | )
166 |
167 | uvicorn.run(
168 | build_app(
169 | client_config_file_path=args.client_config_file_path,
170 | client_config_hot_update_interval_minutes=args.client_config_hot_update_interval_minutes,
171 | debug=args.debug,
172 | ),
173 | host="0.0.0.0",
174 | port=args.port
175 | )
176 |
--------------------------------------------------------------------------------
/code/client_requirements.txt:
--------------------------------------------------------------------------------
1 | aiohttp
2 | requests
3 | openai
4 | pydantic<2.0.0
5 | fastapi[all]==0.96.0
6 | setproctitle
7 | uvicorn[standard]
8 | gunicorn
--------------------------------------------------------------------------------
/code/continuous_batching_server_app.py:
--------------------------------------------------------------------------------
1 | import json
2 | from logging import getLogger, DEBUG, INFO
3 | from typing import Optional
4 |
5 | from fastapi import HTTPException, FastAPI
6 | from fastapi.responses import JSONResponse
7 | from pydantic import BaseModel, Field
8 |
9 | from server.continuous_batching_server import get_server, Server, ServerConfig
10 | from protocol.completion_task import (
11 | HuggingFaceCompletionInputs,
12 | HuggingFaceCompletionOutputs
13 | )
14 | from protocol.error import Error
15 | from protocol.routes import (
16 | ROUTE_GET_MODEL_ID,
17 | ROUTE_POST_CONTINUOUS_BATCHING_COMPLETION,
18 | ROUTE_CLIENT_ONLY_POST_SERVER_STARTUP_EVENT,
19 | ROUTE_CLIENT_ONLY_POST_SERVER_SHUTDOWN_EVENT
20 | )
21 | from utils.log_util import RequestLoggingMiddleware
22 |
23 |
24 | logger = getLogger("gunicorn.logger") # by default, we use gunicorn to wrap the app
25 |
26 |
27 | class AppConfig(BaseModel):
28 | model_id: str = Field(default=...)
29 | server_config_file_path: str = Field(default=...)
30 | client_url: Optional[str] = Field(default=None)
31 | debug: bool = Field(default=False)
32 |
33 |
34 | APP_NAME = "LLM-Inference-CB-Server"
35 |
36 | app = FastAPI(title=APP_NAME, version="0.1.0")
37 | app_config: Optional[AppConfig] = None
38 |
39 |
40 | def build_app(
41 | model_id: str = None,
42 | server_config_file_path: str = "cb_server_config.json",
43 | client_url: Optional[str] = None,
44 | debug: bool = False
45 | ):
46 | global app, app_config
47 |
48 | if model_id is None:
49 | raise ValueError("You must specify a real value to model_id.")
50 |
51 | logger.setLevel(DEBUG if debug else INFO)
52 |
53 | app_config = AppConfig(
54 | model_id=model_id,
55 | server_config_file_path=server_config_file_path,
56 | client_url=client_url,
57 | debug=debug
58 | )
59 |
60 | app.add_middleware(RequestLoggingMiddleware, logger=logger)
61 |
62 | return app
63 |
64 |
65 | @app.on_event("startup")
66 | def startup():
67 | # initialize server
68 | Server(
69 | config=ServerConfig(**json.load(open(app_config.server_config_file_path, "r", encoding="utf-8"))),
70 | logger=logger
71 | )
72 | # TODO: implement logic to inform client that server is startup
73 |
74 |
75 | @app.on_event("shutdown")
76 | def shutdown():
77 | pass # TODO: implement logic to inform client that server is shutdown
78 |
79 |
80 | @app.get(ROUTE_GET_MODEL_ID)
81 | async def get_model_id():
82 | return JSONResponse(content=app_config.model_id)
83 |
84 |
85 | @app.post(ROUTE_POST_CONTINUOUS_BATCHING_COMPLETION, response_model=HuggingFaceCompletionOutputs)
86 | async def execute_completion(request_inputs: HuggingFaceCompletionInputs):
87 | server = get_server()
88 | outputs, error, status_code, wall_time = await server.wait_task_done(request_inputs)
89 | if status_code != 200:
90 | logger.error(msg=str(error))
91 | raise HTTPException(
92 | status_code=status_code,
93 | detail=str(error)
94 | )
95 | return outputs
96 |
97 |
98 | if __name__ == "__main__":
99 | import uvicorn
100 | from argparse import ArgumentParser
101 | from logging import basicConfig
102 |
103 | parser = ArgumentParser()
104 | parser.add_argument("--model_id", type=str)
105 | parser.add_argument("--server_config_file_path", type=str, default="cb_server_config.json")
106 | parser.add_argument("--client_url", type=str, default=None)
107 | parser.add_argument("--debug", action="store_true")
108 | parser.add_argument("--port", type=int, default=8001)
109 | args = parser.parse_args()
110 |
111 | logger = getLogger(__name__) # override gunicorn logger if we use uvicorn directly
112 | basicConfig(
113 | format="%(asctime)s %(levelname)s [%(name)s] %(message)s",
114 | datefmt="%Y-%m-%d %H:%M:%S"
115 | )
116 |
117 | uvicorn.run(
118 | build_app(
119 | model_id=args.model_id,
120 | server_config_file_path=args.server_config_file_path,
121 | client_url=args.client_url,
122 | debug=args.debug
123 | ),
124 | host="0.0.0.0",
125 | port=args.port
126 | )
127 |
--------------------------------------------------------------------------------
/code/gunicorn_config.py:
--------------------------------------------------------------------------------
1 | worker_class = "uvicorn.workers.UvicornWorker"
2 |
3 | # Number of worker process
4 | workers = 1
5 | # Number of threads each worker process,
6 | # (workers * threads) means how many requests the app can process simultaneously
7 | threads = 32
8 | # Number of requests that can be waiting to be served,
9 | # requests exceed this number will be rejected and receive an error
10 | backlog = 64
11 | # Workers silent for more than this many seconds are killed and restarted
12 | timeout = 300
13 |
14 | # log format for access log, error log can't set format
15 | access_log_format = '%(h)s %(l)s %(t)s "%(r)s" %(m)s %(s)s %(b)s "%(f)s" "%(a)s"'
16 |
17 | """
18 | Below are description for each format option:
19 | h remote address
20 | l '-'
21 | u currently '-', may be user name in future releases
22 | t date of the request
23 | r status line (e.g. ``GET / HTTP/1.1``)
24 | s status
25 | b response length or '-'
26 | f referer
27 | a user agent
28 | T request time in seconds
29 | D request time in microseconds
30 | L request time in decimal seconds
31 | p process ID
32 | """
33 |
--------------------------------------------------------------------------------
/code/protocol/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/modelize-ai/LLM-Inference-Deployment-Tutorial/e3264eacc4752a7d829241fb614a0b81892cdc8f/code/protocol/__init__.py
--------------------------------------------------------------------------------
/code/protocol/completion_task.py:
--------------------------------------------------------------------------------
1 | from typing import List, Optional, Union
2 |
3 | from fastapi import status, HTTPException
4 | from pydantic import BaseModel, Field
5 |
6 |
7 | class TokenUsage(BaseModel):
8 | prompt_tokens: int = Field(default=0)
9 | completion_tokens: int = Field(default=0)
10 | total_tokens: int = Field(default=0)
11 |
12 |
13 | class HuggingFaceGenerationConfig(BaseModel):
14 | do_sample: bool = Field(default=False)
15 | early_stopping: bool = Field(default=True)
16 | num_beams: int = Field(default=1)
17 | num_return_sequences: int = Field(default=1)
18 | max_new_tokens: int = Field(default=32)
19 | min_new_tokens: int = Field(default=1)
20 | temperature: float = Field(default=1)
21 | top_p: float = Field(default=1, ge=0, le=1)
22 | top_k: int = Field(default=0)
23 | typical_p: float = Field(default=1, ge=0, le=1)
24 | repetition_penalty: float = Field(default=1)
25 | eos_token_id: Optional[int] = Field(default=None)
26 | pad_token_id: Optional[int] = Field(default=None)
27 | seed: int = Field(default=1024)
28 |
29 | def __hash__(self):
30 | return hash(
31 | str(self.do_sample) +
32 | str(self.early_stopping) +
33 | str(self.num_beams) +
34 | str(self.num_return_sequences) +
35 | str(self.max_new_tokens) +
36 | str(self.min_new_tokens) +
37 | str(self.temperature) +
38 | str(self.top_p) +
39 | str(self.top_k) +
40 | str(self.typical_p) +
41 | str(self.repetition_penalty) +
42 | str(self.eos_token_id) +
43 | str(self.pad_token_id) +
44 | str(self.seed)
45 | )
46 |
47 |
48 | class HuggingFaceCompletionInputs(BaseModel):
49 | model: str = Field(default=...)
50 | prompt: str = Field(default=...)
51 | generation_config: HuggingFaceGenerationConfig = Field(default=HuggingFaceGenerationConfig())
52 |
53 | def verify_and_preprocess(self):
54 | # verify
55 | # here we only do the simplest verification for some parameters
56 | # we should also check with other information again in the server for:
57 | # - whether num_prompt_tokens exceeds model's max_seq_len, if yse we should abort request directly.
58 | if self.generation_config.num_return_sequences > self.generation_config.num_beams:
59 | raise HTTPException(
60 | status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
61 | detail=(
62 | f"num_return_sequences(get value of {self.generation_config.num_return_sequences}) "
63 | f"has to less than or equal to num_beams(get value of {self.generation_config.num_beams})"
64 | )
65 | )
66 | if self.generation_config.min_new_tokens > self.generation_config.max_new_tokens:
67 | raise HTTPException(
68 | status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
69 | detail=(
70 | f"min_new_tokens(get value of {self.generation_config.min_new_tokens}) "
71 | f"has to less than or equal to max_new_tokens(get value of {self.generation_config.max_new_tokens})"
72 | )
73 | )
74 |
75 | # make sure some parameters in generation_config is specified
76 | # to default values so that do_sample will not be triggered
77 | if not self.generation_config.do_sample:
78 | self.generation_config.temperature = 1.0
79 | self.generation_config.top_p = 1.0
80 | self.generation_config.top_k = 0
81 | self.generation_config.typical_p = 1.0
82 |
83 |
84 | class HuggingFaceCompletionChoice(BaseModel):
85 | text: str = Field(default=...)
86 | index: int = Field(default=...)
87 | finish_reason: str = Field(default=...)
88 |
89 |
90 | class HuggingFaceCompletionOutputs(BaseModel):
91 | choices: Optional[List[HuggingFaceCompletionChoice]] = Field(default=None)
92 | usage: TokenUsage = Field(default=TokenUsage())
93 |
94 |
95 | class OpenAIChatCompletionMessage(BaseModel):
96 | # Note: we removed parameters relevant to function call here for the simplest use case
97 | role: str = Field(default=..., regex=r"(system|user|assistant)")
98 | content: str = Field(default=...)
99 |
100 |
101 | class OpenAIChatCompletionInputs(BaseModel):
102 | model: str = Field(default=...)
103 | messages: List[OpenAIChatCompletionMessage] = Field(default=...)
104 | n: int = Field(default=1)
105 | max_tokens: int = Field(default=32)
106 | temperature: float = Field(default=1)
107 | top_p: float = Field(default=1, ge=0, le=1)
108 | stream: bool = Field(default=False)
109 | stop: Optional[Union[str, List[str]]] = Field(default=None)
110 | presence_penalty: float = Field(default=0)
111 | frequency_penalty: float = Field(default=0)
112 |
113 | def verify_and_preprocess(self):
114 | # verify
115 | # Not do anything, you can add logit here
116 | pass
117 |
118 |
119 | class OpenAIChatCompletionChoice(BaseModel):
120 | message: OpenAIChatCompletionMessage = Field(default=...)
121 | index: int = Field(default=...)
122 | finish_reason: str = Field(default=...)
123 |
124 |
125 | class OpenAIChatCompletionOutputs(BaseModel):
126 | choices: Optional[List[OpenAIChatCompletionChoice]] = Field(default=None)
127 | usage: TokenUsage = Field(default=TokenUsage())
128 |
129 |
130 | __all__ = [
131 | "TokenUsage",
132 | "HuggingFaceGenerationConfig",
133 | "HuggingFaceCompletionChoice",
134 | "HuggingFaceCompletionInputs",
135 | "HuggingFaceCompletionOutputs",
136 | "OpenAIChatCompletionMessage",
137 | "OpenAIChatCompletionChoice",
138 | "OpenAIChatCompletionInputs",
139 | "OpenAIChatCompletionOutputs"
140 | ]
141 |
--------------------------------------------------------------------------------
/code/protocol/error.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | from pydantic import BaseModel, Field
4 |
5 |
6 | class Error(BaseModel):
7 | type: str = Field(default=...)
8 | detail: Optional[str] = Field(default=None)
9 |
10 | def __repr__(self):
11 | return f"{self.type}: {self.detail}"
12 |
13 | def __str__(self):
14 | return self.__repr__()
15 |
16 |
17 | __all__ = [
18 | "Error"
19 | ]
20 |
--------------------------------------------------------------------------------
/code/protocol/routes.py:
--------------------------------------------------------------------------------
1 | ROUTE_POST_STATIC_BATCHING_COMPLETION = "/completion/v1"
2 | ROUTE_POST_CONTINUOUS_BATCHING_COMPLETION = "/completion/v2"
3 | ROUTE_POST_OPENAI_CHAT_COMPLETION = "/chat_completion/v1"
4 | ROUTE_GET_MODEL_ID = "/model_id"
5 |
6 | ROUTE_CLIENT_ONLY_POST_SERVER_STARTUP_EVENT = "/event/server_startup"
7 | ROUTE_CLIENT_ONLY_POST_SERVER_SHUTDOWN_EVENT = "/event/server_shutdown"
8 |
9 |
10 | __all__ = [
11 | "ROUTE_POST_STATIC_BATCHING_COMPLETION",
12 | "ROUTE_POST_CONTINUOUS_BATCHING_COMPLETION",
13 | "ROUTE_POST_OPENAI_CHAT_COMPLETION",
14 | "ROUTE_GET_MODEL_ID",
15 | "ROUTE_CLIENT_ONLY_POST_SERVER_STARTUP_EVENT",
16 | "ROUTE_CLIENT_ONLY_POST_SERVER_SHUTDOWN_EVENT",
17 | ]
18 |
--------------------------------------------------------------------------------
/code/server/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/modelize-ai/LLM-Inference-Deployment-Tutorial/e3264eacc4752a7d829241fb614a0b81892cdc8f/code/server/__init__.py
--------------------------------------------------------------------------------
/code/server/continuous_batching_server/__init__.py:
--------------------------------------------------------------------------------
1 | from .config import *
2 | from .server import *
3 |
4 | DEFAULT_BLOCK_SIZE_A100 = 16
5 | DEFAULT_NUM_BLOCKS_A100 = 2500
6 | DEFAULT_BATCH_MAX_TOKENS_A100 = 56000
7 | DEFAULT_BATCH_MAX_BEAMS_A100 = 32
8 |
9 |
10 | __all__ = [
11 | "get_server",
12 | "Server",
13 | "BatcherConfig",
14 | "CacheConfig",
15 | "ModelLoadingConfig",
16 | "ParallelConfig",
17 | "ServerConfig",
18 | "DEFAULT_BLOCK_SIZE_A100",
19 | "DEFAULT_NUM_BLOCKS_A100",
20 | "DEFAULT_BATCH_MAX_TOKENS_A100",
21 | "DEFAULT_BATCH_MAX_BEAMS_A100"
22 | ]
23 |
--------------------------------------------------------------------------------
/code/server/continuous_batching_server/batcher.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from collections import defaultdict
3 | from logging import getLogger, Logger
4 | from typing import *
5 | from uuid import UUID
6 |
7 | import torch
8 |
9 | from .beam import Beam, BeamGroup, BeamStatus
10 | from .config import BatcherConfig, CacheConfig
11 | from .cache.cache_manager import CacheBlockManager
12 | from .generation_utils import HeterogeneousNextTokensChooser
13 |
14 |
15 | class Batch:
16 | def __init__(self):
17 | self.prefill_beams: List[Beam] = []
18 | self.generation_beams: List[Beam] = []
19 | self.block_tables: Dict[UUID, List[int]] = {}
20 | self.prefill_logits: Optional[torch.Tensor] = None
21 | self.generation_logits: Optional[torch.Tensor] = None
22 |
23 | @property
24 | def num_beams(self) -> int:
25 | return len(self.prefill_beams) + len(self.generation_beams)
26 |
27 | @property
28 | def request_ids(self):
29 | return list(set([beam.request_id for beam in self.prefill_beams + self.generation_beams]))
30 |
31 | def add(self, beam: Beam, block_ids: List[int]) -> None:
32 | if beam.num_generated_tokens == 0:
33 | self.prefill_beams.append(beam)
34 | else:
35 | self.generation_beams.append(beam)
36 | self.block_tables[beam.beam_id] = block_ids
37 |
38 |
39 | class Batcher:
40 | def __init__(
41 | self,
42 | batcher_config: BatcherConfig,
43 | cache_config: CacheConfig,
44 | logger: Optional[Logger] = None
45 | ):
46 | self.batcher_config = batcher_config
47 | self.cache_config = cache_config
48 |
49 | self.logger = logger if logger else getLogger(__name__)
50 |
51 | self.waiting: List[BeamGroup] = []
52 | self.running: List[BeamGroup] = []
53 | self.preempting: List[BeamGroup] = []
54 |
55 | self.req_id2beam_group: Dict[UUID, BeamGroup] = {}
56 |
57 | self.cache_manager = CacheBlockManager(
58 | cache_config.block_size,
59 | cache_config.num_blocks,
60 | cache_config.num_blocks_cpu,
61 | cache_config.watermark
62 | )
63 |
64 | def add_request(self, request: BeamGroup):
65 | self.waiting.append(request)
66 | self.req_id2beam_group[request.request_id] = request
67 |
68 | def _allocate(self, beam: Beam) -> bool:
69 | if not self.cache_manager.can_allocate(beam):
70 | return False
71 | self.cache_manager.allocate(beam)
72 | return True
73 |
74 | def _append_slot(self, beam: Beam, blocks_to_copy: Dict[int, List[int]]):
75 | ret = self.cache_manager.append_slot(beam)
76 | if ret is not None:
77 | src_block, dst_block = ret
78 | blocks_to_copy[src_block].append(dst_block)
79 |
80 | def _free_finished_beams(self):
81 | for beam in [beam for beam_group in self.running for beam in beam_group.get_beams(BeamStatus.FINISHED)]:
82 | self.logger.debug(f"free blocks of beam-{beam.beam_id} for it is finished.")
83 | self.cache_manager.free(beam.beam_id)
84 |
85 | def schedule(self) -> Tuple[
86 | Batch,
87 | Dict[int, List[int]],
88 | Dict[int, int],
89 | Dict[int, int],
90 | List[BeamGroup]
91 | ]:
92 | """
93 | 大致的执行流程如下:
94 | 1. 从运行队列中移除已经完成生成的请求(beam_group)
95 | 2. 对运行队列中仍需继续生成的请求进行缓存空间的分配
96 | a. 统计运行队列中所有请求所需增量分配的 GPU 缓存空间的总块数
97 | b. 如果空闲的 GPU 缓存空间总块数少于增量分配所需的总块数,则从运行队列尾端起依次将每个请求所占用的 GPU 缓存空间交互至 CPU,直至
98 | 剩余空间足够用于分配给运行队列前端的其他请求
99 | c. 如果不发生 GPU->CPU 交互且存在被交换至 CPU 的等待继续被处理的请求时,将这些请求按优先级依次交换回 GPU 直至无法被换回
100 | d. 为运行队列中剩余的请求进行缓存空间的分配
101 | 3. 在不发生 GPU->CPU 交互的情况下,或运行队列无请求时,尝试将等待队列中的请求移至运行队列
102 |
103 | :return: (
104 | Batch,
105 | blocks_to_copy: Dict[int, List[int]],
106 | blocks_to_swap_in: Dict[int, int],
107 | blocks_to_swap_out: Dict[int, int],
108 | finished_requests: List[BeamGroup]
109 | )
110 | """
111 | batch = Batch()
112 |
113 | # step1: 从运行队列中移除已经完成生成的请求(beam_group)
114 | self._free_finished_beams()
115 | running = []
116 | finishing = []
117 | while self.running:
118 | beam_group = self.running.pop(0)
119 | if beam_group.is_finished:
120 | finishing.append(beam_group)
121 | self.logger.debug(f"BeamGroup-{beam_group.request_id} finished, put into finishing queue.")
122 | else:
123 | running.append(beam_group)
124 | self.logger.debug(f"BeamGroup-{beam_group.request_id} not finish, put back to running queue.")
125 | self.running = running
126 |
127 | # step2: 对运行队列中仍需继续生成的请求进行缓存空间的分配
128 | running = []
129 | swapping_out = []
130 | swapping_in = []
131 | blocks_to_copy: Dict[int, List[int]] = defaultdict(list)
132 | blocks_to_swap_in: Dict[int, int] = {}
133 | blocks_to_swap_out: Dict[int, int] = {}
134 | # step2.a: 统计运行队列中所有请求所需增量分配的 GPU 缓存空间的总块数
135 | run_request2num_append_blocks = defaultdict(int)
136 | for beam_group in self.running:
137 | run_request2num_append_blocks[beam_group.request_id] = 0
138 | for beam in beam_group.get_beams(BeamStatus.RUNNING):
139 | if self.cache_manager.is_need_to_append_slot(beam):
140 | run_request2num_append_blocks[beam_group.request_id] += 1
141 | # step2.b: 如果空闲的 GPU 缓存空间总块数少于增量分配所需的总块数,则从运行队列尾端起
142 | # 依次将每个请求所占用的 GPU 缓存空间交互至 CPU,
143 | # 直至剩余空间足够用于分配给运行队列前端的其他请求
144 | while self.cache_manager.allocator.num_free_blocks < sum(run_request2num_append_blocks.values()):
145 | beam_group = self.running.pop(-1)
146 | num_append_blocks = run_request2num_append_blocks.pop(beam_group.request_id)
147 | if num_append_blocks == 0:
148 | running.insert(0, beam_group)
149 | continue
150 | if not self.cache_manager.can_swap_out(beam_group):
151 | # FIXME: do not raise error, abort this beam group, mark as finished with an abortion reason, free cache space
152 | raise RuntimeError("No enough CPU RAM to swap out")
153 | else:
154 | blocks_to_swap_out.update(self.cache_manager.swap_out(beam_group))
155 | for beam in beam_group.get_beams(BeamStatus.RUNNING):
156 | beam_group.update_beam_status(beam, BeamStatus.SWAPPED)
157 | swapping_out.insert(0, beam_group)
158 | self.logger.debug(
159 | f"one request swapped out, "
160 | f"free_gpu_blocks={self.cache_manager.allocator.num_free_blocks}, "
161 | f"free_cpu_blocks={self.cache_manager.cpu_allocator.num_free_blocks}"
162 | )
163 | self.running += running
164 | self.preempting += swapping_out
165 | # step2.c: 如果不发生 GPU->CPU 交互且存在被交换至 CPU 的等待继续被处理的请求时,
166 | # 将这些请求按优先级依次交换回 GPU 直至无法被换回
167 | if not swapping_out:
168 | preserved_num_blocks = sum(run_request2num_append_blocks.values())
169 | while self.preempting:
170 | beam_group = self.preempting[0]
171 | if not self.cache_manager.can_swap_in(beam_group, preserved_num_blocks):
172 | self.logger.debug(
173 | f"attempt to swap in beam_group-{beam_group.request_id} but not have enough free gpu blocks."
174 | )
175 | if not self.running:
176 | raise RuntimeError(
177 | "running queue is empty but still can't swap in request, "
178 | "please consider increase num_blocks or decrease max tokens number"
179 | )
180 | else:
181 | break # exceed num available free gpu blocks if swap in this beam_group, break
182 | beam_group = self.preempting.pop(0)
183 | blocks_to_swap_in.update(self.cache_manager.swap_in(beam_group))
184 | for beam in beam_group.get_beams(BeamStatus.SWAPPED):
185 | beam_group.update_beam_status(beam, BeamStatus.RUNNING)
186 | swapping_in.append(beam_group)
187 | preserved_num_blocks += sum(
188 | [
189 | self.cache_manager.is_need_to_append_slot(beam)
190 | for beam in beam_group.get_beams(BeamStatus.RUNNING)
191 | ]
192 | )
193 | self.logger.debug(
194 | f"one request swapped in, "
195 | f"free_gpu_blocks={self.cache_manager.allocator.num_free_blocks}, "
196 | f"free_cpu_blocks={self.cache_manager.cpu_allocator.num_free_blocks}"
197 | )
198 | self.running += swapping_in
199 | # step2.d: 为运行队列中剩余的请求进行缓存空间的分配
200 | for beam_group in self.running:
201 | self.logger.debug(
202 | f"beam_group-{beam_group.request_id}'s beams' status: "
203 | f"{[beam.status.name for beam in beam_group.get_beams()]}"
204 | )
205 | beams = beam_group.get_beams(BeamStatus.RUNNING)
206 | for beam in beams:
207 | self._append_slot(beam, blocks_to_copy)
208 | block_ids = self.cache_manager.get_block_table(beam.beam_id)
209 | batch.add(beam, block_ids)
210 | beam_group.beams.pop(beam.beam_id)
211 |
212 | # step3. 在不发生 GPU->CPU 交互的情况下,尝试将等待队列中的请求移至运行队列
213 | batch_tokens = batch.num_beams
214 | if (not swapping_out or not self.running) and not self.preempting:
215 | while self.waiting:
216 | beam_group = self.waiting[0]
217 | beam = beam_group.get_beams()[0]
218 | if batch_tokens + beam.num_tokens > self.batcher_config.batch_max_tokens:
219 | self.logger.debug(
220 | f"reach batch_max_tokens {self.batcher_config.batch_max_tokens}, "
221 | f"current batch_tokens {batch_tokens}"
222 | )
223 | break
224 | if batch.num_beams + 1 > self.batcher_config.batch_max_beams:
225 | self.logger.debug(
226 | f"reach batch_max_beams {self.batcher_config.batch_max_beams}, "
227 | f"current batch_beams {batch.num_beams}"
228 | )
229 | break
230 | has_cache_space = self._allocate(beam)
231 | if not has_cache_space:
232 | self.logger.debug("hasn't cache space to allocate")
233 | break
234 | beam_group = self.waiting.pop(0)
235 | beam = beam_group.get_beams()[0]
236 | batch_tokens += beam.num_tokens
237 | beam_group.update_beam_status(beam, BeamStatus.RUNNING)
238 | self.running.append(beam_group)
239 | block_ids = self.cache_manager.get_block_table(beam.beam_id)
240 | batch.add(beam, block_ids)
241 | beam_group.beams.pop(beam.beam_id)
242 |
243 | return batch, blocks_to_copy, blocks_to_swap_in, blocks_to_swap_out, finishing
244 |
245 | def _generate(self, old_beams: List[Beam], logits: torch.Tensor):
246 | next_tokens_chooser = HeterogeneousNextTokensChooser(
247 | dtype=logits.dtype,
248 | device=logits.device,
249 | temperature=[self.req_id2beam_group[beam.request_id].generation_config.temperature for beam in old_beams],
250 | repetition_penalty=[self.req_id2beam_group[beam.request_id].generation_config.repetition_penalty for beam in old_beams],
251 | top_k=[self.req_id2beam_group[beam.request_id].generation_config.top_k for beam in old_beams],
252 | top_p=[self.req_id2beam_group[beam.request_id].generation_config.top_p for beam in old_beams],
253 | typical_p=[self.req_id2beam_group[beam.request_id].generation_config.typical_p for beam in old_beams],
254 | do_sample=[self.req_id2beam_group[beam.request_id].generation_config.do_sample for beam in old_beams],
255 | num_beams=[self.req_id2beam_group[beam.request_id].generation_config.num_beams for beam in old_beams],
256 | seeds=[self.req_id2beam_group[beam.request_id].generation_config.seed for beam in old_beams]
257 | )
258 | all_input_ids = [beam.token_ids for beam in old_beams]
259 | max_input_len = max([len(input_ids) for input_ids in all_input_ids])
260 | all_input_ids = [input_ids + [0] * (max_input_len - len(input_ids)) for input_ids in all_input_ids]
261 | all_input_ids_tensor = torch.tensor(all_input_ids, dtype=torch.long, device=logits.device)
262 | outputs = next_tokens_chooser(
263 | input_ids=all_input_ids_tensor,
264 | scores=logits
265 | )
266 | for beam, output in zip(old_beams, outputs):
267 | new_beams = []
268 | generation_config = self.req_id2beam_group[beam.request_id].generation_config
269 | for nxt_token_id, nxt_prob in zip(output.next_token_ids, output.next_probs):
270 | new_beam = beam.copy()
271 | new_beam.append_token_id(nxt_token_id, nxt_prob)
272 | new_beam.check_finished(
273 | eos_token_id=generation_config.eos_token_id,
274 | max_new_tokens=generation_config.max_new_tokens
275 | )
276 | new_beams.append(new_beam)
277 | self.req_id2beam_group[beam.request_id].cache_new_beams(new_beams)
278 |
279 | req_ids = list(set([beam.request_id for beam in old_beams]))
280 | for req_id in req_ids:
281 | beam_group = self.req_id2beam_group[req_id]
282 | if not beam_group.get_beams(BeamStatus.RUNNING):
283 | new_beams = beam_group.new_beams
284 | beam_group.clear_new_beams()
285 | new_beams = sorted(new_beams, reverse=True)[:beam_group.generation_config.num_beams]
286 | beam_group.add_beams(new_beams)
287 | for new_beam in new_beams:
288 | if not new_beam.is_finished:
289 | self.cache_manager.copy(new_beam.parent_beam_id, new_beam.beam_id)
290 | for old_beam in old_beams:
291 | if old_beam.request_id == req_id:
292 | self.cache_manager.free(old_beam.beam_id)
293 |
294 | def batch_generate(self, batch: Batch):
295 | if batch.prefill_beams:
296 | self._generate(batch.prefill_beams, batch.prefill_logits)
297 | if batch.generation_beams:
298 | self._generate(batch.generation_beams, batch.generation_logits)
299 |
300 |
301 | __all__ = ["Batch", "Batcher"]
302 |
--------------------------------------------------------------------------------
/code/server/continuous_batching_server/beam.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import enum
3 | from typing import Dict, List, Optional
4 | from uuid import uuid4, UUID
5 |
6 | from .cache.cache import LogicalCacheBlock
7 | from protocol.completion_task import HuggingFaceGenerationConfig
8 |
9 |
10 | class BeamStatus(enum.Enum):
11 | WAITING = 0
12 | RUNNING = 1
13 | FINISHED = 2
14 | SWAPPED = 3
15 |
16 |
17 | class BeamFinishReason(enum.Enum):
18 | LENGTH = "length"
19 | STOP = "stop"
20 | ABORT = "abort"
21 | NOT_FINISHED = "not_finished"
22 |
23 |
24 | class Beam:
25 | def __init__(
26 | self,
27 | request_id: UUID,
28 | prompt: str,
29 | prompt_token_ids: List[int],
30 | block_size: int,
31 | ):
32 | self.request_id = request_id
33 | self.parent_beam_id = None
34 | self.beam_id = uuid4()
35 | self.prompt = prompt
36 | self.prompt_token_ids = prompt_token_ids
37 | self.block_size = block_size
38 | self.generated_token_ids = []
39 | self.cumulative_logprob = 0.0
40 |
41 | self.cache_blocks: List[LogicalCacheBlock] = []
42 | self._append_tokens_to_blocks(prompt_token_ids)
43 |
44 | self.status = BeamStatus.WAITING
45 | self._finish_reason = BeamFinishReason.NOT_FINISHED
46 |
47 | def _append_cache_block(self):
48 | block = LogicalCacheBlock(block_id=len(self.cache_blocks), block_size=self.block_size)
49 | self.cache_blocks.append(block)
50 |
51 | def _append_tokens_to_blocks(self, token_ids: List[int]):
52 | offset = 0
53 | while offset < len(token_ids):
54 | if not self.cache_blocks:
55 | self._append_cache_block()
56 |
57 | last_block = self.cache_blocks[-1]
58 | if last_block.is_full:
59 | self._append_cache_block()
60 | last_block = self.cache_blocks[-1]
61 |
62 | num_empty_slots = last_block.num_empty_slots
63 | last_block.append_tokens(token_ids[offset: offset + num_empty_slots])
64 |
65 | offset += num_empty_slots
66 |
67 | def append_token_id(self, token_id: int, prob: float):
68 | self._append_tokens_to_blocks([token_id])
69 | self.generated_token_ids.append(token_id)
70 | self.cumulative_logprob += prob
71 |
72 | @property
73 | def last_token_id(self):
74 | if not self.generated_token_ids:
75 | return self.prompt_token_ids[-1]
76 | return self.generated_token_ids[-1]
77 |
78 | @property
79 | def token_ids(self):
80 | return self.prompt_token_ids + self.generated_token_ids
81 |
82 | @property
83 | def num_tokens(self):
84 | return len(self.prompt_token_ids) + len(self.generated_token_ids)
85 |
86 | @property
87 | def num_generated_tokens(self):
88 | return len(self.generated_token_ids)
89 |
90 | def update_status(self, status: BeamStatus):
91 | self.status = status
92 |
93 | def copy(self) -> "Beam":
94 | beam = Beam(
95 | request_id=self.request_id,
96 | prompt=self.prompt,
97 | prompt_token_ids=copy.deepcopy(self.prompt_token_ids),
98 | block_size=self.block_size
99 | )
100 | beam.parent_beam_id = self.beam_id
101 | beam.generated_token_ids = copy.deepcopy(self.generated_token_ids)
102 | beam.cumulative_logprob = self.cumulative_logprob
103 |
104 | beam.cache_blocks = copy.deepcopy(self.cache_blocks)
105 | beam.status = copy.deepcopy(self.status)
106 | beam._finish_reason = copy.deepcopy(self._finish_reason)
107 |
108 | return beam
109 |
110 | def check_finished(
111 | self,
112 | eos_token_id: int,
113 | max_new_tokens: int
114 | ) -> None:
115 | if not self.generated_token_ids:
116 | return
117 | if eos_token_id == int(self.generated_token_ids[-1]):
118 | self.status = BeamStatus.FINISHED
119 | self.finish_reason = BeamFinishReason.STOP
120 | if len(self.generated_token_ids) == max_new_tokens:
121 | self.status = BeamStatus.FINISHED
122 | self.finish_reason = BeamFinishReason.LENGTH
123 |
124 | def __eq__(self, other: "Beam"):
125 | return self.cumulative_logprob == other.cumulative_logprob
126 |
127 | def __gt__(self, other: "Beam"):
128 | return self.cumulative_logprob > other.cumulative_logprob
129 |
130 | def __ge__(self, other: "Beam"):
131 | return self.cumulative_logprob >= other.cumulative_logprob
132 |
133 | def __lt__(self, other: "Beam"):
134 | return self.cumulative_logprob < other.cumulative_logprob
135 |
136 | def __le__(self, other: "Beam"):
137 | return self.cumulative_logprob <= other.cumulative_logprob
138 |
139 | @property
140 | def finish_reason(self) -> str:
141 | return self._finish_reason.value
142 |
143 | @finish_reason.setter
144 | def finish_reason(self, finish_reason: BeamFinishReason):
145 | self._finish_reason = finish_reason
146 |
147 | @property
148 | def is_finished(self) -> bool:
149 | return self.status == BeamStatus.FINISHED
150 |
151 |
152 | class BeamGroup:
153 | """ A group of beams that are generated from the same prompt"""
154 | def __init__(
155 | self,
156 | request_id: UUID,
157 | arrival_time: float,
158 | beams: List[Beam],
159 | generation_config: HuggingFaceGenerationConfig
160 | ):
161 | self.request_id = request_id
162 | self.arrival_time = arrival_time
163 | self.beams: Dict[UUID, Beam] = {beam.beam_id: beam for beam in beams}
164 | self.generation_config = generation_config
165 |
166 | self._new_beams: List[Beam] = []
167 |
168 | def add_beams(self, beams: List[Beam]):
169 | for beam in beams:
170 | self.beams[beam.beam_id] = beam
171 |
172 | def get_beams(self, status: Optional[BeamStatus] = None) -> List[Beam]:
173 | if status is None:
174 | return list(self.beams.values())
175 | else:
176 | return [beam for beam in self.beams.values() if beam.status == status]
177 |
178 | def num_beams(self, status: Optional[BeamStatus] = None) -> int:
179 | return len(self.get_beams(status))
180 |
181 | def cache_new_beams(self, new_beams: List[Beam]):
182 | self._new_beams += new_beams
183 |
184 | def clear_new_beams(self):
185 | self._new_beams = []
186 |
187 | @property
188 | def new_beams(self):
189 | return self._new_beams
190 |
191 | @staticmethod
192 | def update_beam_status(beam: Beam, status: BeamStatus):
193 | beam.update_status(status)
194 |
195 | def find(self, beam_id: UUID) -> Beam:
196 | if beam_id not in self.beams:
197 | raise LookupError(f"Beam {beam_id} not found.")
198 | return self.beams[beam_id]
199 |
200 | @property
201 | def is_finished(self) -> bool:
202 | if not self.generation_config.early_stopping:
203 | return all(beam.is_finished for beam in self.beams.values())
204 | else:
205 | num_finished_beams = len(self.get_beams(BeamStatus.FINISHED))
206 | return num_finished_beams >= self.generation_config.num_return_sequences
207 |
208 | def get_final_beams(self) -> List[Beam]:
209 | if not self.is_finished:
210 | raise AttributeError("Can't get final beams for they are not finished.")
211 | return sorted(self.get_beams(BeamStatus.FINISHED), reverse=True)[:self.generation_config.num_return_sequences]
212 |
213 |
214 | __all__ = [
215 | "Beam",
216 | "BeamGroup",
217 | "BeamStatus",
218 | "BeamFinishReason",
219 | ]
220 |
--------------------------------------------------------------------------------
/code/server/continuous_batching_server/cache/README.md:
--------------------------------------------------------------------------------
1 | > Codes in this package are copied directly from vLLM with some modifications.
2 |
--------------------------------------------------------------------------------
/code/server/continuous_batching_server/cache/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/modelize-ai/LLM-Inference-Deployment-Tutorial/e3264eacc4752a7d829241fb614a0b81892cdc8f/code/server/continuous_batching_server/cache/__init__.py
--------------------------------------------------------------------------------
/code/server/continuous_batching_server/cache/cache.py:
--------------------------------------------------------------------------------
1 | # adopt from vllm
2 |
3 | from typing import *
4 |
5 | import torch
6 | import vllm.cache_ops as vllm_cache_ops
7 |
8 | from ..config import CacheConfig, ModelConfig, ParallelConfig
9 |
10 |
11 | KVCache = Tuple[torch.Tensor, torch.Tensor]
12 |
13 |
14 | class LogicalCacheBlock:
15 | def __init__(self, block_id: int, block_size: int):
16 | self.block_id = block_id
17 | self.block_size = block_size
18 |
19 | self._token_ids = [-1] * block_size
20 | self.num_tokens = 0
21 |
22 | @property
23 | def is_empty(self):
24 | return self.num_tokens == 0
25 |
26 | @property
27 | def num_empty_slots(self):
28 | return self.block_size - self.num_tokens
29 |
30 | @property
31 | def is_full(self):
32 | return self.num_tokens == self.block_size
33 |
34 | def append_tokens(self, token_ids: List[int]):
35 | assert len(token_ids) <= self.num_empty_slots
36 | offset = self.num_tokens
37 | self._token_ids[offset: offset + len(token_ids)] = token_ids
38 | self.num_tokens += len(token_ids)
39 |
40 | @property
41 | def token_ids(self):
42 | return self._token_ids[:self.num_tokens]
43 |
44 | @property
45 | def last_token_id(self):
46 | assert self.num_tokens > 0
47 | return self.token_ids[self.num_tokens - 1]
48 |
49 |
50 | class PhysicalCacheBlock:
51 | def __init__(self, block_id: int, block_size: int):
52 | self.block_id = block_id
53 | self.block_size = block_size
54 |
55 | self.ref_count = 0
56 |
57 | def __repr__(self):
58 | return f"[block_id={self.block_id} block_size={self.block_size} ref_count={self.ref_count}]"
59 |
60 | def __str__(self):
61 | return self.__repr__()
62 |
63 |
64 | class Cache:
65 | def __init__(
66 | self,
67 | cache_config: CacheConfig,
68 | model_config: ModelConfig,
69 | parallel_config: ParallelConfig,
70 | dtype: torch.dtype,
71 | device: torch.device
72 | ):
73 | self.cache_config = cache_config
74 | self.model_config = model_config
75 | self.parallel_config = parallel_config
76 |
77 | self.head_size = model_config.get_head_size()
78 | self.num_layers = model_config.get_num_layers()
79 | self.num_heads = model_config.get_num_heads()
80 |
81 | self.dtype = dtype
82 | self.device = device
83 |
84 | self.block_size = cache_config.block_size
85 | self.num_blocks = cache_config.num_blocks
86 | self.num_blocks_cpu = cache_config.num_blocks_cpu
87 |
88 | self.cache = [
89 | (
90 | torch.empty(
91 | size=(self.num_blocks, *self.key_block_shape),
92 | dtype=self.dtype,
93 | device=self.device
94 | ),
95 | torch.empty(
96 | size=(self.num_blocks, *self.value_block_shape),
97 | dtype=self.dtype,
98 | device=self.device
99 | )
100 | ) for _ in range(self.num_layers)
101 | ]
102 |
103 | self.cpu_cache = [
104 | (
105 | torch.empty(
106 | size=(self.num_blocks_cpu, *self.key_block_shape),
107 | dtype=self.dtype,
108 | pin_memory=cache_config.pin_memory_on_cpu,
109 | ),
110 | torch.empty(
111 | size=(self.num_blocks_cpu, *self.value_block_shape),
112 | dtype=self.dtype,
113 | pin_memory=cache_config.pin_memory_on_cpu,
114 | )
115 | ) for _ in range(self.num_layers)
116 | ]
117 |
118 | self.cache_stream = torch.cuda.Stream()
119 | assert self.cache_stream != torch.cuda.current_stream()
120 |
121 | self.events = [torch.cuda.Event() for _ in range(self.num_layers)]
122 |
123 | @property
124 | def key_block_shape(self) -> Tuple[int, int, int, int]:
125 | element_size = torch.tensor([], dtype=self.dtype).element_size()
126 | x = 16 // element_size
127 | return (
128 | self.num_heads,
129 | self.head_size // x,
130 | self.block_size,
131 | x,
132 | )
133 |
134 | @property
135 | def value_block_shape(self) -> Tuple[int, int, int]:
136 | return (
137 | self.num_heads,
138 | self.head_size,
139 | self.block_size,
140 | )
141 |
142 | def _swap(
143 | self,
144 | src: List[KVCache],
145 | dst: List[KVCache],
146 | src_to_dst: Dict[int, int],
147 | ) -> None:
148 | with torch.cuda.stream(self.cache_stream):
149 | for i in range(self.num_layers):
150 | src_key_cache, src_value_cache = src[i]
151 | dst_key_cache, dst_value_cache = dst[i]
152 | # Copy the key blocks.
153 | vllm_cache_ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst)
154 | # Copy the value blocks.
155 | vllm_cache_ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst)
156 | event = self.events[i]
157 | event.record(stream=self.cache_stream)
158 |
159 | def swap_in(self, src_to_dst: Dict[int, int]) -> None:
160 | self._swap(self.cpu_cache, self.cache, src_to_dst)
161 |
162 | def swap_out(self, src_to_dst: Dict[int, int]) -> None:
163 | self._swap(self.cache, self.cpu_cache, src_to_dst)
164 |
165 | def copy(self, src_to_dsts: Dict[int, List[int]]) -> None:
166 | key_caches = [key_cache for key_cache, _ in self.cache]
167 | value_caches = [value_cache for _, value_cache in self.cache]
168 | # This operation implicitly synchronizes the CPU and GPU.
169 | vllm_cache_ops.copy_blocks(key_caches, value_caches, src_to_dsts)
170 |
171 | @staticmethod
172 | def get_cache_block_size(
173 | block_size: int,
174 | model_config: ModelConfig,
175 | dtype: torch.dtype
176 | ) -> int:
177 | head_size = model_config.get_head_size()
178 | num_heads = model_config.get_num_heads()
179 | num_layers = model_config.get_num_layers()
180 |
181 | key_cache_block = block_size * num_heads * head_size
182 | value_cache_block = key_cache_block
183 | total = num_layers * (key_cache_block + value_cache_block)
184 | dtype_size = torch.tensor([], dtype=dtype).element_size()
185 | return dtype_size * total
186 |
187 |
188 | __all__ = [
189 | "LogicalCacheBlock",
190 | "PhysicalCacheBlock",
191 | "Cache"
192 | ]
193 |
--------------------------------------------------------------------------------
/code/server/continuous_batching_server/cache/cache_manager.py:
--------------------------------------------------------------------------------
1 | # adopt from vllm
2 |
3 | from logging import getLogger
4 | from typing import *
5 | from uuid import UUID
6 |
7 | from .cache import LogicalCacheBlock, PhysicalCacheBlock
8 | from ..beam import Beam, BeamGroup, BeamStatus
9 |
10 |
11 | logger = getLogger(__name__)
12 |
13 |
14 | class OOMError(Exception):
15 | pass
16 |
17 |
18 | class DoubleFreeBlockError(Exception):
19 | pass
20 |
21 |
22 | class CacheBlockAllocator:
23 | def __init__(self, block_size: int, num_blocks: int):
24 | self.block_size = block_size
25 | self.num_blocks = num_blocks
26 |
27 | self.free_blocks: List[PhysicalCacheBlock] = [
28 | PhysicalCacheBlock(block_id=i, block_size=block_size) for i in range(num_blocks)
29 | ]
30 |
31 | def allocate(self) -> PhysicalCacheBlock:
32 | if not self.free_blocks:
33 | raise OOMError("No free blocks are available.")
34 | block = self.free_blocks.pop(0)
35 | block.ref_count = 1
36 | return block
37 |
38 | def free(self, block: PhysicalCacheBlock) -> None:
39 | if block.ref_count == 0:
40 | raise DoubleFreeBlockError(f"Double free! block-{block.block_id} is already freed.")
41 | block.ref_count -= 1
42 | if block.ref_count == 0:
43 | self.free_blocks.append(block)
44 | # self.free_blocks = sorted(self.free_blocks, key=lambda block: block.block_id)
45 |
46 | @property
47 | def num_free_blocks(self) -> int:
48 | return len(self.free_blocks)
49 |
50 |
51 | BlockTable = List[PhysicalCacheBlock]
52 |
53 |
54 | class CacheBlockManager:
55 | def __init__(self, block_size: int, num_blocks: int, num_blocks_cpu: int, watermark: float = 0.01):
56 | self.block_size = block_size
57 | self.num_blocks = num_blocks
58 | self.watermark = watermark
59 | assert self.watermark >= 0.0
60 |
61 | self.watermark_blocks = int(watermark * num_blocks)
62 | self.allocator = CacheBlockAllocator(block_size, num_blocks)
63 | self.cpu_allocator = CacheBlockAllocator(block_size, num_blocks_cpu)
64 |
65 | self.block_tables: Dict[UUID, BlockTable] = {}
66 |
67 | def can_allocate(self, beam: Beam):
68 | num_required_blocks = len(beam.cache_blocks)
69 | num_free_blocks = self.allocator.num_free_blocks
70 | # Use watermark to avoid frequent cache eviction.
71 | return num_free_blocks - num_required_blocks >= self.watermark_blocks
72 |
73 | def allocate(self, beam: Beam):
74 | # NOTE: only do to beam that is 'init' beam.
75 | block_table: BlockTable = []
76 |
77 | # Allocate new physical cache blocks that will store the prompt tokens.
78 | for _ in range(len(beam.cache_blocks)):
79 | block = self.allocator.allocate()
80 | block_table.append(block)
81 |
82 | self.block_tables[beam.beam_id] = block_table.copy()
83 | logger.debug(f"beam-{beam.beam_id} allocate block_table: {[block.block_id for block in block_table]}")
84 |
85 | def is_need_to_append_slot(self, beam: Beam):
86 | logical_blocks = beam.cache_blocks
87 | block_table: BlockTable = self.block_tables[beam.beam_id]
88 |
89 | if len(block_table) < len(logical_blocks):
90 | return True
91 | if block_table[-1].ref_count > 1:
92 | # The last block is shared with other beams, which means should copy on write
93 | return True
94 | return False
95 |
96 | def can_append_slot(self):
97 | return self.allocator.num_free_blocks >= 1
98 |
99 | def append_slot(self, beam: Beam) -> Optional[Tuple[int, int]]:
100 | logical_blocks = beam.cache_blocks
101 | block_table: BlockTable = self.block_tables[beam.beam_id]
102 |
103 | if len(block_table) < len(logical_blocks):
104 | block = self.allocator.allocate()
105 | block_table.append(block)
106 | logger.debug(f"beam-{beam.beam_id} add one block-{block.block_id}")
107 | return
108 |
109 | last_block = block_table[-1]
110 | if last_block.ref_count == 1:
111 | # Not shared with other sequences. Appendable.
112 | return
113 | else:
114 | # The last block is shared with other sequences.
115 | # Copy on Write: Allocate a new block and copy the tokens.
116 | new_block = self.allocator.allocate()
117 | block_table[-1] = new_block
118 | self.allocator.free(last_block)
119 | logger.debug(f"beam-{beam.beam_id} replace block-{last_block.block_id} with block-{new_block.block_id}")
120 | return last_block.block_id, new_block.block_id
121 |
122 | def copy(self, parent: UUID, child: UUID):
123 | src_block_table = self.block_tables[parent]
124 | self.block_tables[child] = src_block_table.copy()
125 | for block in src_block_table:
126 | block.ref_count += 1
127 | logger.debug(f"beam-{parent} copy block_table: {[block.block_id for block in src_block_table]} to beam-{child}")
128 |
129 | def _get_physical_blocks(self, beam_group: BeamGroup) -> List[PhysicalCacheBlock]:
130 | blocks: Set[PhysicalCacheBlock] = set()
131 | for beam in beam_group.get_beams():
132 | if beam.is_finished:
133 | continue
134 | block_table = self.block_tables[beam.beam_id]
135 | for block in block_table:
136 | blocks.add(block)
137 | return list(blocks)
138 |
139 | def can_swap_in(self, beam_group: BeamGroup, preserved_num_blocks: int = 0) -> bool:
140 | blocks = self._get_physical_blocks(beam_group)
141 | num_swapped_seqs = beam_group.num_beams(status=BeamStatus.SWAPPED)
142 | num_free_blocks = self.allocator.num_free_blocks
143 | # NOTE: Conservatively, we assume that every sequence will allocate
144 | # at least one free block right after the swap-in.
145 | # NOTE: This should match the logic in can_append_slot().
146 | num_required_blocks = len(blocks) + num_swapped_seqs
147 | return num_free_blocks - num_required_blocks - preserved_num_blocks >= self.watermark_blocks
148 |
149 | def swap_in(self, beam_group: BeamGroup) -> Dict[int, int]:
150 | # CPU block -> GPU block.
151 | mapping: Dict[PhysicalCacheBlock, PhysicalCacheBlock] = {}
152 | for beam in beam_group.get_beams():
153 | if beam.is_finished:
154 | continue
155 | new_block_table: BlockTable = []
156 | block_table = self.block_tables[beam.beam_id]
157 |
158 | for cpu_block in block_table:
159 | if cpu_block in mapping:
160 | gpu_block = mapping[cpu_block]
161 | gpu_block.ref_count += 1
162 | else:
163 | gpu_block = self.allocator.allocate()
164 | mapping[cpu_block] = gpu_block
165 | new_block_table.append(gpu_block)
166 | # Free the CPU block swapped in to GPU.
167 | self.cpu_allocator.free(cpu_block)
168 | self.block_tables[beam.beam_id] = new_block_table
169 |
170 | block_ids_map = {
171 | cpu_block.block_id: gpu_block.block_id
172 | for cpu_block, gpu_block in mapping.items()
173 | }
174 | logger.debug(
175 | f"swap in beam_group-{beam_group.request_id} from CPU to GPU."
176 | )
177 | return block_ids_map
178 |
179 | def can_swap_out(self, beam_group: BeamGroup) -> bool:
180 | blocks = self._get_physical_blocks(beam_group)
181 | return len(blocks) <= self.cpu_allocator.num_free_blocks
182 |
183 | def swap_out(self, beam_group: BeamGroup) -> Dict[int, int]:
184 | # GPU block -> CPU block.
185 | mapping: Dict[PhysicalCacheBlock, PhysicalCacheBlock] = {}
186 | for beam in beam_group.get_beams():
187 | if beam.is_finished:
188 | continue
189 | new_block_table: BlockTable = []
190 | block_table = self.block_tables[beam.beam_id]
191 |
192 | for gpu_block in block_table:
193 | if gpu_block in mapping:
194 | cpu_block = mapping[gpu_block]
195 | cpu_block.ref_count += 1
196 | else:
197 | cpu_block = self.cpu_allocator.allocate()
198 | mapping[gpu_block] = cpu_block
199 | new_block_table.append(cpu_block)
200 | # Free the GPU block swapped out to CPU.
201 | self.allocator.free(gpu_block)
202 | self.block_tables[beam.beam_id] = new_block_table
203 |
204 | block_ids_map = {
205 | gpu_block.block_id: cpu_block.block_id
206 | for gpu_block, cpu_block in mapping.items()
207 | }
208 | logger.debug(
209 | f"swap out beam_group-{beam_group.request_id} from GPU to CPU."
210 | )
211 | return block_ids_map
212 |
213 | def _free_block_table(self, block_table: BlockTable):
214 | for block in block_table:
215 | self.allocator.free(block)
216 |
217 | def free(self, beam_id: UUID):
218 | if beam_id not in self.block_tables:
219 | return
220 | block_table = self.block_tables.pop(beam_id)
221 | self._free_block_table(block_table)
222 | logger.debug(f"beam-{beam_id} free blocks: {[block.block_id for block in block_table]}")
223 |
224 | def reset(self):
225 | for block_table in self.block_tables.values():
226 | self._free_block_table(block_table)
227 | self.block_tables.clear()
228 |
229 | def get_block_table(self, beam_id: UUID) -> List[int]:
230 | block_table = self.block_tables[beam_id]
231 | return [block.block_id for block in block_table]
232 |
233 | @property
234 | def num_free_blocks(self):
235 | return self.allocator.num_free_blocks
236 |
237 | @property
238 | def num_free_cpu_blocks(self):
239 | return self.cpu_allocator.num_free_blocks
240 |
--------------------------------------------------------------------------------
/code/server/continuous_batching_server/config.py:
--------------------------------------------------------------------------------
1 | from collections import namedtuple
2 | from typing import Optional
3 |
4 | import torch
5 | from pydantic import BaseModel, Field
6 | from transformers import PretrainedConfig
7 |
8 | from .modeling.llama import LlamaForCausalLM, LlamaConfig
9 |
10 |
11 | ModelFactory = namedtuple(
12 | "ModelFactory",
13 | [
14 | "model_cls",
15 | "model_config_cls"
16 | ]
17 | )
18 |
19 |
20 | TORCH_FLOAT_DTYPE_MAP = {
21 | "float": torch.float32,
22 | "float32": torch.float32,
23 | "float16": torch.float16,
24 | "bfloat16": torch.bfloat16,
25 | "int32": torch.int32,
26 | "int64": torch.int64
27 | }
28 |
29 | MODEL_AUTO_TABLE = {
30 | "llama": ModelFactory(model_cls=LlamaForCausalLM, model_config_cls=LlamaConfig)
31 | }
32 |
33 |
34 | class BatcherConfig(BaseModel):
35 | # default value is suitable for 7B model in A100 GPU, this controls prefill rate of each step
36 | batch_max_tokens: int = Field(default=56000)
37 | # default value is suitable for 7B model in A100 GPU, this controls batch size of each step
38 | batch_max_beams: int = Field(default=32)
39 |
40 |
41 | class CacheConfig(BaseModel):
42 | num_blocks: Optional[int] = Field(default=2500) # default value is suitable for 7B model in A100 GPU
43 | num_blocks_cpu: Optional[int] = Field(default=1024)
44 | block_size: int = Field(default=16)
45 | gpu_memory_utilization: float = Field(default=0.98)
46 | watermark: float = Field(default=0.01)
47 | pin_memory_on_cpu: bool = Field(default=True)
48 |
49 |
50 | class ModelLoadingConfig(BaseModel):
51 | model_type: str = Field(default="llama", regex="(" + "|".join(list(MODEL_AUTO_TABLE.keys())) + ")")
52 | model_name_or_path: str = Field(default="dummy_path_to_llama_model")
53 | torch_dtype: str = Field(default="float16", regex="(" + "|".join(list(TORCH_FLOAT_DTYPE_MAP.keys())) + ")")
54 | tokenizer_name_or_path: Optional[str] = Field(default=None)
55 | use_fast_tokenizer: bool = Field(default=False)
56 | trust_remote_code: bool = Field(default=False)
57 | quantize_method: Optional[str] = Field(default=None, regex="(gptq|)")
58 | model_max_length: int = Field(default=2048)
59 | device: int = Field(default=0)
60 | gptq_model_base_name: Optional[str] = Field(default=None)
61 | gptq_config_base_name: Optional[str] = Field(default=None)
62 |
63 |
64 | class ParallelConfig(BaseModel):
65 | tp_size: int = Field(default=1)
66 |
67 |
68 | class ModelConfig:
69 | def __init__(self, model_config: PretrainedConfig, parallel_config: ParallelConfig):
70 | self.model_config = model_config
71 | self.parallel_config = parallel_config
72 |
73 | def get_hidden_size(self):
74 | return self.model_config.hidden_size
75 |
76 | def get_head_size(self):
77 | return self.model_config.hidden_size // self.model_config.num_attention_heads
78 |
79 | def get_num_heads(self):
80 | # For GPTBigCode:
81 | if getattr(self.model_config, "multi_query", False):
82 | # Multi-query attention, only one KV head.
83 | return 1
84 | # For Falcon:
85 | if getattr(self.model_config, "n_head_kv", None) is not None:
86 | return self.model_config.n_head_kv
87 |
88 | return self.model_config.num_attention_heads // self.parallel_config.tp_size
89 |
90 | def get_num_layers(self):
91 | return self.model_config.num_hidden_layers
92 |
93 |
94 | class ServerConfig(BaseModel):
95 | model_loading_config: ModelLoadingConfig = Field(default=ModelLoadingConfig())
96 | batcher_config: BatcherConfig = Field(default=BatcherConfig())
97 | cache_config: CacheConfig = Field(default=CacheConfig())
98 | parallel_config: ParallelConfig = Field(default=ParallelConfig())
99 |
100 |
101 | __all__ = [
102 | "ModelFactory",
103 | "TORCH_FLOAT_DTYPE_MAP",
104 | "MODEL_AUTO_TABLE",
105 | "BatcherConfig",
106 | "CacheConfig",
107 | "ModelLoadingConfig",
108 | "ParallelConfig",
109 | "ModelConfig",
110 | "ServerConfig"
111 | ]
112 |
--------------------------------------------------------------------------------
/code/server/continuous_batching_server/generation_utils/README.md:
--------------------------------------------------------------------------------
1 | > Codes in this package are copied directly from TGI with some modifications.
2 |
--------------------------------------------------------------------------------
/code/server/continuous_batching_server/generation_utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .tokens import HeterogeneousNextTokensChooser
2 |
--------------------------------------------------------------------------------
/code/server/continuous_batching_server/generation_utils/logits_process.py:
--------------------------------------------------------------------------------
1 | # Adopt from Hugging Face text-generation-inference
2 |
3 | import math
4 | import torch
5 |
6 | from typing import List, Dict, Union
7 |
8 | from transformers import LogitsWarper, LogitsProcessor
9 |
10 | mempool = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
11 |
12 |
13 | class HeterogeneousRepetitionPenaltyLogitsProcessor(LogitsProcessor):
14 | r"""
15 | [`LogitsProcessor`] enforcing an exponential penalty on repeated sequences.
16 | This version allows for a separate value for each sample and runs inplace when possible.
17 | It doesn't validate inputs.
18 |
19 | Args:
20 | repetition_penalty (`List[float]`):
21 | The parameter for repetition penalty. 1.0 means no penalty. See [this
22 | paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
23 | """
24 |
25 | def __init__(self, penalty: List[float], dtype: torch.dtype, device: torch.device):
26 | self.penalty = penalty
27 | self.penalty_tensor = torch.tensor(
28 | penalty, dtype=dtype, device=device
29 | ).unsqueeze(1)
30 |
31 | def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
32 | score = torch.gather(scores, 1, input_ids)
33 |
34 | # if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
35 | score = torch.where(
36 | score < 0, score * self.penalty_tensor, score / self.penalty_tensor
37 | )
38 |
39 | scores.scatter_(1, input_ids, score)
40 | return scores
41 |
42 |
43 | class HeterogeneousTemperatureLogitsWarper:
44 | r"""
45 | [`LogitsWarper`] for temperature (exponential scaling output probability distribution).
46 | This version allows for a separate value for each sample and runs inplace when possible.
47 | It doesn't validate inputs.
48 |
49 | Args:
50 | temperature (`float`):
51 | The value used to module the logits distribution.
52 | """
53 |
54 | def __init__(
55 | self, temperature: List[float], dtype: torch.dtype, device: torch.device
56 | ):
57 | self.temperature = temperature
58 | self.temperature_tensor = torch.tensor(
59 | temperature, dtype=dtype, device=device
60 | ).unsqueeze(1)
61 |
62 | def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
63 | scores.div_(self.temperature_tensor)
64 | return scores
65 |
66 |
67 | class HeterogeneousTopPLogitsWarper(LogitsWarper):
68 | """
69 | [`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off.
70 | This version allows for a separate value for each sample and runs inplace when possible.
71 | It doesn't validate inputs.
72 |
73 | Args:
74 | top_p (`float`):
75 | If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
76 | higher are kept for generation.
77 | filter_value (`float`, *optional*, defaults to `-float("Inf")`):
78 | All filtered values will be set to this float value.
79 | min_tokens_to_keep (`int`, *optional*, defaults to 1):
80 | Minimum number of tokens that cannot be filtered.
81 | """
82 |
83 | def __init__(
84 | self,
85 | top_p: List[float],
86 | dtype: torch.dtype,
87 | device: torch.device,
88 | filter_value: float = -math.inf,
89 | min_tokens_to_keep: int = 1,
90 | ):
91 | self.top_p = top_p
92 | self.top_p_opposite = 1 - torch.tensor(
93 | top_p, dtype=dtype, device=device
94 | ).unsqueeze(1)
95 | self.filter_value = filter_value
96 | self.min_tokens_to_keep = min_tokens_to_keep
97 |
98 | def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
99 | sorted_logits, sorted_indices = torch.sort(scores, descending=False)
100 | probs = sorted_logits.softmax(dim=-1)
101 | # This is way faster for some reason
102 | for i in range(probs.shape[0]):
103 | probs[i] = probs[i].cumsum(dim=-1)
104 |
105 | # Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
106 | sorted_indices_to_remove = probs <= self.top_p_opposite
107 | # Keep at least min_tokens_to_keep
108 | sorted_indices_to_remove[..., -self.min_tokens_to_keep:] = 0
109 |
110 | # scatter sorted tensors to original indexing
111 | indices_to_remove = sorted_indices_to_remove.scatter(
112 | 1, sorted_indices, sorted_indices_to_remove
113 | )
114 | warped_scores = scores.masked_fill_(indices_to_remove, self.filter_value)
115 |
116 | return warped_scores
117 |
118 |
119 | class HeterogeneousTopKLogitsWarper(LogitsWarper):
120 | r"""
121 | [`LogitsWarper`] that performs top-k, i.e. restricting to the k highest probability elements.
122 | This version allows for a separate value for each sample and runs inplace when possible.
123 | It doesn't validate inputs.
124 |
125 | Args:
126 | top_k (`int`):
127 | The number of highest probability vocabulary tokens to keep for top-k-filtering.
128 | filter_value (`float`, *optional*, defaults to `-float("Inf")`):
129 | All filtered values will be set to this float value.
130 | min_tokens_to_keep (`int`, *optional*, defaults to 1):
131 | Minimum number of tokens that cannot be filtered.
132 | """
133 |
134 | def __init__(
135 | self,
136 | top_k: List[int],
137 | device: torch.device,
138 | filter_value: float = -math.inf,
139 | min_tokens_to_keep: int = 1,
140 | ):
141 | self.top_k = top_k
142 | self.max_top_k = max(top_k)
143 | # value - 1 as we will use top_k to index and python uses 0 based numbering
144 | self.top_k_tensor = torch.tensor(
145 | [max(x - 1, min_tokens_to_keep - 1) for x in top_k],
146 | dtype=torch.int64,
147 | device=device,
148 | ).unsqueeze(1)
149 |
150 | # 0 is a special value that disables top_k warping for this member of the batch
151 | disabled = [x == 0 for x in top_k]
152 |
153 | if any(disabled):
154 | self.top_k_disabled_mask = torch.tensor(
155 | disabled, dtype=torch.bool, device=device
156 | ).view(-1, 1)
157 | else:
158 | self.top_k_disabled_mask = None
159 |
160 | self.filter_value = filter_value
161 |
162 | def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
163 | # If max_top_k is superior to the vocab, we need to clamp or the warper will fail
164 | if scores.size(-1) < self.max_top_k:
165 | max_top_k = scores.size(-1)
166 | top_k = torch.clamp_max(self.top_k_tensor, max_top_k)
167 | else:
168 | max_top_k = self.max_top_k
169 | top_k = self.top_k_tensor
170 |
171 | # Get the kth score for each member of the batch
172 | kth_scores = torch.gather(torch.topk(scores, max_top_k)[0], 1, top_k)
173 |
174 | # Mask member of kth_scores that do not want to use top_k warping
175 | if self.top_k_disabled_mask is not None:
176 | kth_scores.masked_fill_(self.top_k_disabled_mask, self.filter_value)
177 |
178 | # Remove all tokens with a probability less than the last token of the top-k
179 | indices_to_remove = scores < kth_scores
180 | scores.masked_fill_(indices_to_remove, self.filter_value)
181 | return scores
182 |
183 |
184 | class HeterogeneousTypicalLogitsWarper(LogitsWarper):
185 | r"""
186 | [`LogitsWarper`] that performs typical decoding. See [Typical Decoding for Natural Language
187 | Generation](https://arxiv.org/abs/2202.00666) for more information.
188 | This version allows for a separate value for each sample and runs inplace when possible.
189 | It doesn't validate inputs.
190 |
191 | Args:
192 | mass (`float`):
193 | Value of typical_p between 0 and 1 inclusive, defaults to 0.9.
194 | filter_value (`float`, *optional*, defaults to `-float("Inf")`):
195 | All filtered values will be set to this float value.
196 | min_tokens_to_keep (`int`, *optional*, defaults to 1):
197 | Minimum number of tokens that cannot be filtered.
198 | """
199 |
200 | def __init__(
201 | self,
202 | mass: List[float],
203 | dtype: torch.dtype,
204 | device: torch.device,
205 | filter_value: float = -math.inf,
206 | min_tokens_to_keep: int = 1,
207 | ):
208 | self.mass = mass
209 | self.mass_tensor = torch.tensor(mass, dtype=dtype, device=device).unsqueeze(1)
210 |
211 | # 1 is a special value that disables typical_p warping for this member of the batch
212 | disabled = [x == 1.0 for x in mass]
213 |
214 | if any(disabled):
215 | self.disabled_mask = torch.tensor(disabled, dtype=torch.bool, device=device)
216 | else:
217 | self.disabled_mask = None
218 |
219 | self.filter_value = filter_value
220 | self.min_tokens_to_keep = min_tokens_to_keep
221 |
222 | def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
223 | # calculate entropy
224 | normalized = torch.nn.functional.log_softmax(scores, dim=-1)
225 | p = torch.exp(normalized)
226 | ent = -(normalized * p).nansum(-1, keepdim=True)
227 |
228 | # shift and sort
229 | shifted_scores = torch.abs((-normalized) - ent)
230 | sorted_scores, sorted_indices = torch.sort(shifted_scores, descending=False)
231 | sorted_logits = scores.gather(-1, sorted_indices)
232 | probs = sorted_logits.softmax(dim=-1)
233 | # This is way faster for some reason
234 | for i in range(probs.shape[0]):
235 | probs[i] = probs[i].cumsum(dim=-1)
236 |
237 | # Remove tokens with cumulative mass above the threshold
238 | last_ind = (probs < self.mass_tensor).sum(dim=1)
239 | last_ind[last_ind < 0] = 0
240 |
241 | if self.disabled_mask is not None:
242 | last_ind.masked_fill_(self.disabled_mask, scores.shape[-1] - 1)
243 |
244 | sorted_indices_to_remove = sorted_scores > sorted_scores.gather(
245 | 1, last_ind.view(-1, 1)
246 | )
247 | if self.min_tokens_to_keep > 1:
248 | # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
249 | sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0
250 | indices_to_remove = sorted_indices_to_remove.scatter(
251 | 1, sorted_indices, sorted_indices_to_remove
252 | )
253 |
254 | warped_scores = scores.masked_fill_(indices_to_remove, self.filter_value)
255 |
256 | return warped_scores
257 |
258 |
259 | class HeterogeneousProcessorWrapper(LogitsProcessor):
260 | r"""
261 | A wrapper for logit warpers or processors without heterogeneous parameter support.
262 | Args:
263 | processors (`Dict[int, Union[LogitsProcessor, LogitsWarper]]`):
264 | A mapping of sample indices to logit warpers or processors, to be run sequentially.
265 | """
266 |
267 | def __init__(
268 | self,
269 | processors: Dict[int, Union[LogitsProcessor, LogitsWarper]],
270 | ):
271 | self.processors = processors
272 |
273 | def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
274 | for i, processor in self.processors.items():
275 | scores[i: i + 1] = processor(input_ids[i: i + 1], scores[i: i + 1])
276 | return scores
277 |
--------------------------------------------------------------------------------
/code/server/continuous_batching_server/generation_utils/tokens.py:
--------------------------------------------------------------------------------
1 | # Adopt from Hugging Face text-generation-inference
2 |
3 | from typing import *
4 |
5 | import torch
6 | from pydantic import BaseModel, Field, Required
7 |
8 | from .logits_process import (
9 | HeterogeneousRepetitionPenaltyLogitsProcessor,
10 | HeterogeneousTemperatureLogitsWarper,
11 | HeterogeneousTopKLogitsWarper,
12 | HeterogeneousTopPLogitsWarper,
13 | HeterogeneousTypicalLogitsWarper,
14 | HeterogeneousProcessorWrapper,
15 | )
16 |
17 |
18 | class NextTokensChooserOutput(BaseModel):
19 | next_probs: List[float] = Field(default=Required)
20 | next_token_ids: List[int] = Field(default=Required)
21 |
22 |
23 | class HeterogeneousNextTokensChooser:
24 | def __init__(
25 | self,
26 | dtype: torch.dtype,
27 | device: torch.device,
28 | temperature: List[float],
29 | repetition_penalty: List[float],
30 | top_k: List[int],
31 | top_p: List[float],
32 | typical_p: List[float],
33 | do_sample: List[bool],
34 | num_beams: List[int],
35 | seeds: List[int],
36 | ):
37 | warpers = []
38 |
39 | self.repetition_processor = (
40 | HeterogeneousRepetitionPenaltyLogitsProcessor(
41 | repetition_penalty, dtype, device
42 | )
43 | if any([x != 1.0 for x in repetition_penalty])
44 | else None
45 | )
46 |
47 | if any([x != 1.0 for x in temperature]):
48 | do_sample = [
49 | sample or x != 1.0 for x, sample in zip(temperature, do_sample)
50 | ]
51 | warpers.append(
52 | HeterogeneousTemperatureLogitsWarper(temperature, dtype, device)
53 | )
54 |
55 | if any([x != 0 for x in top_k]):
56 | do_sample = [sample or x != 0 for x, sample in zip(top_k, do_sample)]
57 | warpers.append(HeterogeneousTopKLogitsWarper(top_k, device))
58 |
59 | if any([x < 1.0 for x in top_p]):
60 | do_sample = [sample or x < 1.0 for x, sample in zip(top_p, do_sample)]
61 | warpers.append(HeterogeneousTopPLogitsWarper(top_p, dtype, device))
62 |
63 | if any([x < 1.0 for x in typical_p]):
64 | do_sample = [sample or x < 1.0 for x, sample in zip(typical_p, do_sample)]
65 | warpers.append(HeterogeneousTypicalLogitsWarper(typical_p, dtype, device))
66 |
67 | self.warpers = warpers
68 |
69 | if any(do_sample):
70 | self.choice = HeterogeneousSampling(do_sample, num_beams, seeds, device)
71 | else:
72 | self.choice = Greedy(num_beams)
73 |
74 | self.seeds = seeds
75 | self.do_sample = do_sample
76 | self.dtype = dtype
77 | self.device = device
78 |
79 | def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> List[NextTokensChooserOutput]:
80 | if self.repetition_processor is not None:
81 | scores = self.repetition_processor(input_ids, scores)
82 |
83 | for warper in self.warpers:
84 | scores = warper(input_ids, scores)
85 |
86 | log_scores = torch.log_softmax(scores, dim=-1)
87 |
88 | token_ids = self.choice(scores)
89 | log_probs = [
90 | [log_scores[i, token_id].item() for token_id in beam_nxt_token_ids]
91 | for i, beam_nxt_token_ids in enumerate(token_ids)
92 | ]
93 |
94 | return [
95 | NextTokensChooserOutput(next_token_ids=nxt_token_ids, next_probs=nxt_probs)
96 | for nxt_token_ids, nxt_probs in zip(token_ids, log_probs)
97 | ]
98 |
99 |
100 | class Sampling:
101 | def __init__(self, num_beams: int, seed: int, device: str = "cpu"):
102 | self.num_beams = num_beams
103 | self.generator = torch.Generator(device)
104 | self.generator.manual_seed(seed)
105 | self.seed = seed
106 |
107 | def __call__(self, logits) -> List[int]:
108 | probs = torch.nn.functional.softmax(logits, -1)
109 | # Avoid GPU<->CPU sync done by torch multinomial
110 | # See: https://github.com/pytorch/pytorch/blob/925a3788ec5c06db62ca732a0e9425a26a00916f/aten/src/ATen/native/Distributions.cpp#L631-L637
111 | q = torch.empty_like(probs).exponential_(1, generator=self.generator)
112 | return torch.topk(probs.div_(q), k=self.num_beams + 1).indices.tolist()
113 |
114 |
115 | class Greedy:
116 | def __init__(self, num_beams: List[int]):
117 | self.num_beams = num_beams
118 |
119 | def __call__(self, logits) -> List[List[int]]:
120 | return torch.topk(logits, k=max(self.num_beams) + 1, dim=-1).indices.tolist()
121 |
122 |
123 | class HeterogeneousSampling:
124 | r"""
125 | Mixed greedy and probabilistic sampling. Compute both and pick the right one for each sample.
126 | """
127 |
128 | def __init__(self, do_sample: List[bool], num_beams: List[int], seeds: List[int], device: torch.device):
129 | self.num_beams = num_beams
130 | self.seeds = seeds
131 |
132 | greedy_indices = []
133 | self.sampling_mapping = {}
134 | for i, (sample, seed) in enumerate(zip(do_sample, seeds)):
135 | if sample:
136 | self.sampling_mapping[i] = Sampling(num_beams[i], seed, device)
137 | else:
138 | greedy_indices.append(i)
139 |
140 | self.greedy_indices = greedy_indices
141 |
142 | def __call__(self, logits) -> List[List[int]]:
143 | out = [None for _ in range(logits.shape[0])]
144 | if self.greedy_indices:
145 | # Computing for all indices is faster than slicing
146 | greedy = Greedy(self.num_beams)
147 | out = greedy(logits)
148 |
149 | for i, sampling in self.sampling_mapping.items():
150 | out[i] = sampling(logits[i])
151 | return out
152 |
--------------------------------------------------------------------------------
/code/server/continuous_batching_server/modeling/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/modelize-ai/LLM-Inference-Deployment-Tutorial/e3264eacc4752a7d829241fb614a0b81892cdc8f/code/server/continuous_batching_server/modeling/__init__.py
--------------------------------------------------------------------------------
/code/server/continuous_batching_server/modeling/llama.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3 | #
4 | # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5 | # and OPT implementations in this library. It has been modified from its
6 | # original forms to accommodate minor architectural differences compared
7 | # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8 | #
9 | # Licensed under the Apache License, Version 2.0 (the "License");
10 | # you may not use this file except in compliance with the License.
11 | # You may obtain a copy of the License at
12 | #
13 | # http://www.apache.org/licenses/LICENSE-2.0
14 | #
15 | # Unless required by applicable law or agreed to in writing, software
16 | # distributed under the License is distributed on an "AS IS" BASIS,
17 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18 | # See the License for the specific language governing permissions and
19 | # limitations under the License.
20 |
21 |
22 | from typing import *
23 |
24 | import torch
25 | import torch.nn as nn
26 | import xformers.ops as xops
27 | from accelerate import init_empty_weights
28 | from auto_gptq import BaseQuantizeConfig as GPTQConfig
29 | from transformers.activations import ACT2FN
30 | from transformers.configuration_utils import PretrainedConfig
31 |
32 |
33 | from .utils.attention import VarLenAttentionWithRoPE
34 | from .utils.linear import DynamicLinear
35 | from .utils.weights import Weights
36 |
37 |
38 | class LlamaConfig(PretrainedConfig):
39 | def __init__(
40 | self,
41 | vocab_size=32000,
42 | hidden_size=4096,
43 | intermediate_size=11008,
44 | num_hidden_layers=32,
45 | num_attention_heads=32,
46 | num_key_value_heads=None,
47 | hidden_act="silu",
48 | max_position_embeddings=2048,
49 | initializer_range=0.02,
50 | rms_norm_eps=1e-6,
51 | use_cache=True,
52 | pad_token_id=0,
53 | bos_token_id=1,
54 | eos_token_id=2,
55 | pretraining_tp=1,
56 | tie_word_embeddings=False,
57 | rope_scaling=None,
58 | **kwargs,
59 | ):
60 | self.vocab_size = vocab_size
61 | self.max_position_embeddings = max_position_embeddings
62 | self.hidden_size = hidden_size
63 | self.intermediate_size = intermediate_size
64 | self.num_hidden_layers = num_hidden_layers
65 | self.num_attention_heads = num_attention_heads
66 |
67 | # for backward compatibility
68 | if num_key_value_heads is None:
69 | num_key_value_heads = num_attention_heads
70 |
71 | self.num_key_value_heads = num_key_value_heads
72 | self.hidden_act = hidden_act
73 | self.initializer_range = initializer_range
74 | self.rms_norm_eps = rms_norm_eps
75 | self.pretraining_tp = pretraining_tp
76 | self.use_cache = use_cache
77 | self.rope_scaling = rope_scaling
78 |
79 | super().__init__(
80 | pad_token_id=pad_token_id,
81 | bos_token_id=bos_token_id,
82 | eos_token_id=eos_token_id,
83 | tie_word_embeddings=tie_word_embeddings,
84 | **kwargs,
85 | )
86 |
87 |
88 | class LlamaRMSNorm(nn.Module):
89 | def __init__(self, hidden_size, eps=1e-6):
90 | """
91 | LlamaRMSNorm is equivalent to T5LayerNorm
92 | """
93 | super().__init__()
94 | self.weight = nn.Parameter(torch.ones(hidden_size))
95 | self.variance_epsilon = eps
96 |
97 | def forward(self, hidden_states):
98 | input_dtype = hidden_states.dtype
99 | hidden_states = hidden_states.to(torch.float32)
100 | variance = hidden_states.pow(2).mean(-1, keepdim=True)
101 | hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
102 | return self.weight * hidden_states.to(input_dtype)
103 |
104 | @classmethod
105 | def load(cls, prefix: str, weights: Weights, eps: float = 1e-6):
106 | weight = weights.get_tensor(f"{prefix}.weight")
107 | with init_empty_weights():
108 | ln = cls(weight.shape[0], eps)
109 | ln.weight = nn.Parameter(weight)
110 | return ln
111 |
112 |
113 | def _load_gqa(config, prefix: str, weights: Weights, gptq_config: Optional[GPTQConfig] = None):
114 | w = [
115 | weights.get_tensor(f"{prefix}.q_proj.weight"),
116 | weights.get_tensor(f"{prefix}.k_proj.weight"),
117 | weights.get_tensor(f"{prefix}.v_proj.weight")
118 | ]
119 | weight = torch.cat(w, dim=0)
120 | weight = weight.to(dtype=weights.dtype).to(device=weights.device)
121 |
122 | assert config.hidden_size % config.num_attention_heads == 0
123 | head_size = config.hidden_size // config.num_attention_heads
124 | num_heads = config.num_attention_heads
125 | num_key_value_heads = config.num_key_value_heads
126 | assert list(weight.shape) == [
127 | (num_heads + 2 * num_key_value_heads) * head_size,
128 | config.hidden_size,
129 | ], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
130 |
131 | return DynamicLinear.load(config, prefix, weights, False, gptq_config)
132 |
133 |
134 | class LlamaMLP(nn.Module):
135 | def __init__(self, prefix: str, config: LlamaConfig, weights: Weights, gptq_config: Optional[GPTQConfig] = None):
136 | super().__init__()
137 | act = config.hidden_act
138 | self.act = (
139 | ACT2FN[act]
140 | if "gelu" not in act
141 | else lambda x: torch.nn.functional.gelu(
142 | x,
143 | approximate="tanh"
144 | if act in ["gelu_fast", "gelu_pytorch_tanh"]
145 | else "none",
146 | )
147 | )
148 | # Fuse gate and up proj
149 | self.gate_up_proj = DynamicLinear.load_multi(
150 | config,
151 | prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
152 | weights=weights,
153 | dim=0,
154 | bias=False,
155 | gptq_config=gptq_config
156 | )
157 | self.down_proj = DynamicLinear.load(
158 | config,
159 | prefix=f"{prefix}.down_proj",
160 | weights=weights,
161 | bias=False,
162 | gptq_config=gptq_config
163 | )
164 | self.intermediate_size = config.intermediate_size
165 |
166 | def forward(self, hidden_states):
167 | gate_up_states = self.gate_up_proj(hidden_states)
168 | gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
169 | return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1])
170 |
171 |
172 | class LlamaLayer(nn.Module):
173 | def __init__(self, layer_id: int, config: LlamaConfig, weights: Weights, gptq_config: Optional[GPTQConfig] = None):
174 | super().__init__()
175 |
176 | prefix = f"model.layers.{layer_id}"
177 |
178 | self.config = config
179 | self.gptq_config = gptq_config
180 |
181 | self.self_attn = self._init_attention_module(config, f"{prefix}.self_attn", weights)
182 | self.mlp = LlamaMLP(prefix=f"{prefix}.mlp", config=config, weights=weights, gptq_config=gptq_config)
183 |
184 | self.input_layernorm = LlamaRMSNorm.load(
185 | prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
186 | )
187 | self.post_attention_layernorm = LlamaRMSNorm.load(
188 | prefix=f"{prefix}.post_attention_layernorm", weights=weights, eps=config.rms_norm_eps
189 | )
190 |
191 | def _init_attention_module(self, config: LlamaConfig, prefix: str, weights: Weights) -> nn.Module:
192 | qkv_proj = DynamicLinear.load_multi(
193 | config=config,
194 | prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
195 | weights=weights,
196 | bias=False,
197 | dim=0,
198 | gptq_config=self.gptq_config
199 | )
200 | o_proj = DynamicLinear.load(
201 | config=config,
202 | prefix=f"{prefix}.o_proj",
203 | weights=weights,
204 | bias=False,
205 | gptq_config=self.gptq_config
206 | )
207 |
208 | cos_sin_cache = VarLenAttentionWithRoPE.build_rope_cache(
209 | rotary_dim=config.hidden_size // config.num_attention_heads,
210 | max_position=config.max_position_embeddings,
211 | base=10000,
212 | device=o_proj.weight.device if not self.gptq_config else o_proj.scales.device,
213 | dtype=o_proj.weight.dtype if not self.gptq_config else o_proj.scales.dtype
214 | )
215 |
216 | head_dim = config.hidden_size // config.num_attention_heads
217 | attn_fw_op = xops.fmha.flash.FwOp if head_dim <= 128 else xops.fmha.cutlass.FwOp
218 | return VarLenAttentionWithRoPE(
219 | qkv_proj=qkv_proj,
220 | out_proj=o_proj,
221 | cos_sin_cache=cos_sin_cache,
222 | num_query_heads=config.num_attention_heads,
223 | num_key_heads=config.num_key_value_heads,
224 | num_value_heads=config.num_key_value_heads,
225 | dropout=0.0,
226 | scale=head_dim ** -0.5,
227 | attention_ops=(attn_fw_op, None)
228 | )
229 |
230 | def forward(
231 | self,
232 | hidden_states: torch.Tensor,
233 | position_ids: torch.Tensor,
234 | kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
235 | prefill: bool,
236 | block_tables: torch.Tensor,
237 | slots: torch.Tensor,
238 | context_lengths: torch.Tensor,
239 | cache_event: Optional[torch.cuda.Event] = None,
240 | ):
241 | residual = hidden_states
242 |
243 | hidden_states = self.input_layernorm(hidden_states)
244 |
245 | # Self Attention
246 | hidden_states = self.self_attn(
247 | hidden_states,
248 | position_ids,
249 | kv_cache,
250 | prefill,
251 | block_tables,
252 | slots,
253 | context_lengths,
254 | cache_event
255 | )
256 | hidden_states = residual + hidden_states
257 |
258 | # Fully Connected
259 | residual = hidden_states
260 | hidden_states = self.post_attention_layernorm(hidden_states)
261 | hidden_states = self.mlp(hidden_states)
262 | hidden_states = residual + hidden_states
263 |
264 | return hidden_states
265 |
266 |
267 | class LlamaModel(torch.nn.Module):
268 | def __init__(self, config: LlamaConfig, weights: Weights, gptq_config: Optional[GPTQConfig] = None):
269 | super().__init__()
270 | self.config = config
271 | self.gptq_config = gptq_config
272 |
273 | self.embed_tokens = nn.Embedding.from_pretrained(
274 | embeddings=weights.get_tensor("model.embed_tokens.weight"),
275 | freeze=True,
276 | padding_idx=config.pad_token_id
277 | )
278 | self.layers = nn.ModuleList(
279 | [
280 | LlamaLayer(
281 | layer_id,
282 | config,
283 | weights,
284 | gptq_config
285 | )
286 | for layer_id in range(config.num_hidden_layers)
287 | ]
288 | )
289 | self.norm = LlamaRMSNorm.load(
290 | prefix="model.norm", weights=weights, eps=config.rms_norm_eps
291 | )
292 |
293 | self.gradient_checkpointing = False
294 |
295 | def forward(
296 | self,
297 | input_ids: torch.Tensor,
298 | position_ids: torch.Tensor,
299 | kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
300 | prefill: bool,
301 | block_tables: torch.Tensor,
302 | slots: torch.Tensor,
303 | context_lengths: torch.Tensor,
304 | cache_events: Optional[List[torch.cuda.Event]] = None,
305 | ) -> torch.Tensor:
306 | hidden_states = self.embed_tokens(input_ids)
307 |
308 | for i, layer in enumerate(self.layers):
309 | hidden_states = layer(
310 | hidden_states,
311 | position_ids,
312 | kv_cache[i],
313 | prefill,
314 | block_tables,
315 | slots,
316 | context_lengths,
317 | cache_events[i] if cache_events is not None else None
318 | )
319 |
320 | hidden_states = self.norm(hidden_states)
321 |
322 | return hidden_states
323 |
324 |
325 | class LlamaForCausalLM(torch.nn.Module):
326 | def __init__(self, config: LlamaConfig, weights: Weights, gptq_config: Optional[GPTQConfig] = None):
327 | super().__init__()
328 |
329 | self.config = config
330 | self.gptq_config = gptq_config
331 | self.model = LlamaModel(config, weights, gptq_config)
332 | self.lm_head = DynamicLinear.load(
333 | config,
334 | prefix="lm_head",
335 | weights=weights,
336 | bias=False,
337 | gptq_config=None
338 | )
339 |
340 | def forward(
341 | self,
342 | input_ids: torch.Tensor,
343 | position_ids: torch.Tensor,
344 | kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
345 | prefill: bool,
346 | block_tables: torch.Tensor,
347 | slots: torch.Tensor,
348 | context_lengths: torch.Tensor,
349 | cache_events: Optional[List[torch.cuda.Event]] = None,
350 | lm_head_indices: Optional[torch.Tensor] = None,
351 | ) -> torch.Tensor:
352 | hidden_states = self.model(
353 | input_ids,
354 | position_ids,
355 | kv_cache,
356 | prefill,
357 | block_tables,
358 | slots,
359 | context_lengths,
360 | cache_events,
361 | )
362 | if lm_head_indices is not None:
363 | hidden_states = hidden_states[lm_head_indices]
364 | logits = self.lm_head(hidden_states)
365 | return logits
366 |
--------------------------------------------------------------------------------
/code/server/continuous_batching_server/modeling/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/modelize-ai/LLM-Inference-Deployment-Tutorial/e3264eacc4752a7d829241fb614a0b81892cdc8f/code/server/continuous_batching_server/modeling/utils/__init__.py
--------------------------------------------------------------------------------
/code/server/continuous_batching_server/modeling/utils/attention.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Tuple
2 |
3 | import torch
4 | import torch.nn as nn
5 | import xformers.ops as xops
6 | from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask, LowerTriangularMask
7 |
8 | import vllm.attention_ops as vllm_attention_ops
9 | import vllm.cache_ops as vllm_cache_ops
10 | import vllm.pos_encoding_ops as vllm_pos_encoding_ops
11 |
12 |
13 | def prefill_attention(
14 | query: torch.Tensor,
15 | key: torch.Tensor,
16 | value: torch.Tensor,
17 | attention_ops: Optional[xops.AttentionOp] = None,
18 | attention_bias: Optional[xops.AttentionBias] = None,
19 | dropout: float = 0.0,
20 | scale: Optional[float] = None
21 | ):
22 | for each in [query, key, value]:
23 | if len(each.shape) != 4:
24 | raise ValueError(
25 | "input tensor must have 4-dim shape which are [bsz, seq_len, num_heads, head_size] respectively,"
26 | f"but get {each.shape}"
27 | )
28 |
29 | bsz, seq_len = query.shape[:2]
30 |
31 | if value.shape[2] != query.shape[2]:
32 | # MQA expand
33 | if value.shape[2] == 1:
34 | pass # TODO
35 | # GQA reshape
36 | else:
37 | original_shape = value.shape
38 | pass # TODO
39 |
40 | return xops.memory_efficient_attention(
41 | query=query,
42 | key=key,
43 | value=value,
44 | attn_bias=attention_bias,
45 | p=dropout,
46 | scale=scale,
47 | op=attention_ops
48 | ).reshape(bsz, seq_len, -1)
49 |
50 |
51 | def decode_attention(
52 | query: torch.Tensor,
53 | key_cache: torch.Tensor,
54 | value_cache: torch.Tensor,
55 | kv_head_mapping: torch.Tensor,
56 | scale: float,
57 | block_tables: torch.Tensor,
58 | context_lengths: torch.Tensor,
59 | alibi_slopes: Optional[torch.Tensor] = None
60 | ):
61 | if len(query.shape) != 3:
62 | raise ValueError(
63 | "query must have 3-dim shape which are [seq_len, num_heads, head_size] respectively, "
64 | f"but get shape {query.shape}"
65 | )
66 |
67 | attention_output = torch.empty_like(query)
68 | block_size = value_cache.shape[-1]
69 | vllm_attention_ops.single_query_cached_kv_attention(
70 | attention_output,
71 | query,
72 | key_cache,
73 | value_cache,
74 | kv_head_mapping,
75 | scale,
76 | block_tables,
77 | context_lengths,
78 | block_size,
79 | context_lengths.max().item(),
80 | alibi_slopes
81 | )
82 | return attention_output
83 |
84 |
85 | class AttentionWithRoPE(nn.Module):
86 | def __init__(
87 | self,
88 | qkv_proj: nn.Module,
89 | out_proj: nn.Module,
90 | cos_sin_cache: torch.Tensor,
91 | num_query_heads: int,
92 | num_key_heads: int,
93 | num_value_heads: int,
94 | dropout: float = 0.0,
95 | scale: Optional[float] = None,
96 | attention_ops: Optional[xops.AttentionOp] = None
97 | ):
98 | super(AttentionWithRoPE, self).__init__()
99 |
100 | self.qkv_proj = qkv_proj
101 | self.out_proj = out_proj
102 |
103 | self.register_buffer("cos_sin_cache", cos_sin_cache, persistent=False)
104 |
105 | self.num_query_heads = num_query_heads
106 | self.num_key_heads = num_key_heads
107 | self.num_value_heads = num_value_heads
108 |
109 | self.dropout = dropout
110 | self.scale = scale
111 |
112 | self.attention_ops = attention_ops
113 |
114 | # TODO: for now only compatible with GQA, make it also compatible with MQA
115 | self.num_groups = self.num_query_heads // self.num_value_heads
116 | self.kv_head_mapping = torch.arange(
117 | 0, self.num_value_heads, dtype=torch.int32, device=cos_sin_cache.device
118 | ).repeat_interleave(self.num_groups)
119 |
120 | def _qkv_forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
121 | hidden_size = hidden_states.shape[-1]
122 | # for each out tensor, shape ==> [total_tokens, hidden_size]
123 | return self.qkv_proj(hidden_states.view(-1, hidden_size)).chunk(chunks=3, dim=-1)
124 |
125 | def _rope_forward(self, query: torch.Tensor, key: torch.Tensor, position_ids: Optional[torch.Tensor]) -> None:
126 | if position_ids is None:
127 | return
128 |
129 | hidden_size = query.shape[-1]
130 | position_ids = position_ids.view(-1)
131 |
132 | vllm_pos_encoding_ops.rotary_embedding_neox(
133 | position_ids,
134 | query,
135 | key,
136 | hidden_size // self.num_query_heads,
137 | self.cos_sin_cache
138 | )
139 |
140 | def _out_forward(self, hidden_states: torch.Tensor, shape: tuple) -> torch.Tensor:
141 | return self.out_proj(hidden_states).view(shape)
142 |
143 | def forward(
144 | self,
145 | hidden_states: torch.Tensor,
146 | position_ids: Optional[torch.Tensor],
147 | kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]],
148 | prefill: bool,
149 | block_tables: Optional[torch.Tensor],
150 | slots: Optional[torch.Tensor],
151 | context_lengths: torch.Tensor,
152 | cache_event: Optional[torch.cuda.Event] = None
153 | ) -> torch.Tensor:
154 | # The shape of hidden_states and position_ids:
155 | # if is prefill ==> [bsz, max(context_lengths), hidden_size]
156 | # otherwise ==> [bsz, 1, hidden_size]
157 | if len(hidden_states.shape) != 3:
158 | raise ValueError("hidden_states must have 3-dim shape.")
159 | bsz, max_len, hidden_size = hidden_states.shape
160 |
161 | # QKV projection
162 | query, key, value = self._qkv_forward(hidden_states) # for each: shape ==> [total_tokens, hidden_size]
163 |
164 | # Add RoPE info
165 | self._rope_forward(query, key, position_ids)
166 |
167 | # Prefill Attention
168 | if prefill:
169 | attn_out = prefill_attention(
170 | query.view(bsz, max_len, self.num_query_heads, -1),
171 | key.view(bsz, max_len, self.num_key_heads, -1),
172 | value.view(bsz, max_len, self.num_value_heads, -1),
173 | self.attention_ops,
174 | LowerTriangularMask(),
175 | self.dropout,
176 | self.scale
177 | )
178 |
179 | # Wait until the cache op is done
180 | if cache_event is not None:
181 | cache_event.wait()
182 |
183 | # Cache key and value
184 | if kv_cache is not None:
185 | if prefill:
186 | valid_token_indices = []
187 | for i, start_idx in enumerate(range(0, bsz * max_len, max_len)):
188 | end_idx = start_idx + max_len
189 | indices = list(range(start_idx, end_idx))[-context_lengths[i]:]
190 | valid_token_indices += indices
191 | key_to_cache = key[valid_token_indices]
192 | value_to_cache = value[valid_token_indices]
193 | else:
194 | key_to_cache = key[:len(slots)]
195 | value_to_cache = value[:len(slots)]
196 | num_valid_tokens = key_to_cache.shape[0]
197 | key_to_cache = key_to_cache.reshape(num_valid_tokens, self.num_key_heads, -1)
198 | value_to_cache = value_to_cache.reshape(num_valid_tokens, self.num_value_heads, -1)
199 | vllm_cache_ops.reshape_and_cache(
200 | key_to_cache, value_to_cache, kv_cache[0], kv_cache[1], slots
201 | )
202 | elif not prefill:
203 | raise ValueError("kv_cache can't be None when in decode stage.")
204 |
205 | # Decode Attention
206 | if not prefill:
207 | attn_out = decode_attention(
208 | query.view(bsz * max_len, self.num_query_heads, -1),
209 | kv_cache[0],
210 | kv_cache[1],
211 | self.kv_head_mapping,
212 | self.scale,
213 | block_tables,
214 | context_lengths,
215 | None
216 | ).view(bsz, max_len, -1)
217 | return self._out_forward(attn_out, (bsz, max_len, hidden_size))
218 |
219 | @staticmethod
220 | def build_rope_cache(
221 | rotary_dim: int,
222 | max_position: int = 2048,
223 | base: int = 10000,
224 | device: torch.device = torch.device("cuda:0"),
225 | dtype: torch.dtype = torch.float16
226 | ):
227 | inv_freq = (1.0 / (base ** (torch.arange(0, rotary_dim, 2, device=device, dtype=dtype) / rotary_dim)))
228 | t = torch.arange(max_position, device=device, dtype=dtype)
229 | freqs = torch.einsum("i,j -> ij", t, inv_freq)
230 | cos = freqs.cos()
231 | sin = freqs.sin()
232 | cache = torch.cat((cos, sin), dim=-1)
233 |
234 | return cache
235 |
236 |
237 | class VarLenAttentionWithRoPE(AttentionWithRoPE):
238 | def forward(
239 | self,
240 | hidden_states: torch.Tensor,
241 | position_ids: Optional[torch.Tensor],
242 | kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]],
243 | prefill: bool,
244 | block_tables: Optional[torch.Tensor],
245 | slots: Optional[torch.Tensor],
246 | context_lengths: torch.Tensor,
247 | cache_event: Optional[torch.cuda.Event] = None
248 | ) -> torch.Tensor:
249 | # The shape of hidden_states and position_ids for both prefill and decode:
250 | # [total_tokens, hidden_size]
251 |
252 | # QKV projection
253 | query, key, value = self._qkv_forward(hidden_states) # for each: shape ==> [total_tokens, hidden_size]
254 |
255 | # Add RoPE info
256 | self._rope_forward(query, key, position_ids)
257 |
258 | total_tokens = query.shape[0]
259 | query = query.view(total_tokens, self.num_query_heads, -1)
260 | key = key.view(total_tokens, self.num_key_heads, -1)
261 | value = value.view(total_tokens, self.num_value_heads, -1)
262 |
263 | # Prefill Attention
264 | if prefill:
265 | attn_out = prefill_attention(
266 | query.unsqueeze(0),
267 | key.unsqueeze(0),
268 | value.unsqueeze(0),
269 | self.attention_ops,
270 | BlockDiagonalCausalMask.from_seqlens(context_lengths.tolist()),
271 | self.dropout,
272 | self.scale
273 | ).squeeze(0)
274 |
275 | # Wait until the cache op is done
276 | if cache_event is not None:
277 | cache_event.wait()
278 |
279 | # Cache key and value
280 | if kv_cache is not None:
281 | vllm_cache_ops.reshape_and_cache(
282 | key[:len(slots)], value[:len(slots)], kv_cache[0], kv_cache[1], slots
283 | )
284 | elif not prefill:
285 | raise ValueError("kv_cache can't be None when in decode stage.")
286 |
287 | # Decode Attention
288 | if not prefill:
289 | attn_out = decode_attention(
290 | query,
291 | kv_cache[0],
292 | kv_cache[1],
293 | self.kv_head_mapping,
294 | self.scale,
295 | block_tables,
296 | context_lengths,
297 | None
298 | ).view(total_tokens, -1)
299 |
300 | return self._out_forward(attn_out, hidden_states.shape)
301 |
302 |
303 | __all__ = ["AttentionWithRoPE", "VarLenAttentionWithRoPE"]
304 |
--------------------------------------------------------------------------------
/code/server/continuous_batching_server/modeling/utils/linear.py:
--------------------------------------------------------------------------------
1 | from typing import *
2 |
3 | import torch
4 | import torch.nn as nn
5 | from accelerate import init_empty_weights
6 | from auto_gptq import BaseQuantizeConfig as GPTQConfig
7 | from auto_gptq.utils.import_utils import dynamically_import_QuantLinear
8 | from torch.nn import functional as F
9 |
10 | from .weights import Weights
11 |
12 |
13 | class FastLinear(nn.Module):
14 | def __init__(
15 | self,
16 | weight: torch.Tensor,
17 | bias: torch.Tensor,
18 | ) -> None:
19 | super().__init__()
20 | self.weight = nn.Parameter(weight, requires_grad=False)
21 | if bias is not None:
22 | self.bias = nn.Parameter(bias, requires_grad=False)
23 | else:
24 | self.bias = None
25 |
26 | @classmethod
27 | def load(
28 | cls,
29 | config,
30 | prefix: str,
31 | weights: Weights,
32 | bias: bool
33 | ):
34 | weight = weights.get_tensor(f"{prefix}.weight")
35 | if bias:
36 | bias = weights.get_tensor(f"{prefix}.bias")
37 | else:
38 | bias = None
39 | return cls(weight, bias)
40 |
41 | @classmethod
42 | def load_multi(
43 | cls,
44 | config,
45 | prefixes: List[str],
46 | weights: Weights,
47 | bias: bool,
48 | dim: int
49 | ):
50 | w = [
51 | weights.get_tensor(f"{prefix}.weight") for prefix in prefixes
52 | ]
53 | weight = torch.cat(w, dim=dim)
54 |
55 | if bias:
56 | b = [weights.get_tensor(f"{p}.bias") for p in prefixes]
57 | bias = torch.cat(b, dim=dim)
58 | else:
59 | bias = None
60 | return cls(weight, bias)
61 |
62 | def forward(self, input: torch.Tensor) -> torch.Tensor:
63 | return F.linear(input, self.weight, self.bias)
64 |
65 |
66 | class DynamicLinear:
67 | @classmethod
68 | def load(
69 | cls,
70 | config,
71 | prefix: str,
72 | weights: Weights,
73 | bias: bool,
74 | gptq_config: Optional[GPTQConfig] = None
75 | ):
76 | if not gptq_config:
77 | return FastLinear.load(config, prefix, weights, bias)
78 |
79 | disable_exllama = False
80 | if gptq_config.bits != 4 or gptq_config.desc_act: # for the later condition, can be removed once auto-gptq fixed it
81 | disable_exllama = True
82 | QuantLinear = dynamically_import_QuantLinear(
83 | use_triton=False,
84 | desc_act=gptq_config.desc_act,
85 | group_size=gptq_config.group_size,
86 | bits=gptq_config.bits,
87 | disable_exllama=disable_exllama
88 | )
89 |
90 | qweight, qzeros, scales, g_idx, bias = weights.get_gptq_weight(prefix)
91 |
92 | init_args = (
93 | gptq_config.bits,
94 | gptq_config.group_size,
95 | qweight.shape[0] * 32 // gptq_config.bits,
96 | qweight.shape[1],
97 | bias is not None
98 | )
99 | with init_empty_weights(include_buffers=True):
100 | quant_linear = QuantLinear(*init_args, trainable=False)
101 | quant_linear.qweight = qweight
102 | quant_linear.qzeros = qzeros
103 | quant_linear.scales = scales
104 | quant_linear.g_idx = g_idx
105 | quant_linear.bias = bias
106 |
107 | return quant_linear
108 |
109 | @classmethod
110 | def load_multi(
111 | cls,
112 | config,
113 | prefixes: List[str],
114 | weights: Weights,
115 | bias: bool,
116 | dim: int,
117 | gptq_config: Optional[GPTQConfig] = None
118 | ):
119 | if not gptq_config:
120 | return FastLinear.load_multi(config, prefixes, weights, bias, dim)
121 |
122 | disable_exllama = False
123 | if gptq_config.bits != 4 or gptq_config.desc_act: # for the later condition, can be removed once auto-gptq fixed it
124 | disable_exllama = True
125 | QuantLinear = dynamically_import_QuantLinear(
126 | use_triton=False,
127 | desc_act=gptq_config.desc_act,
128 | group_size=gptq_config.group_size,
129 | bits=gptq_config.bits,
130 | disable_exllama=disable_exllama
131 | )
132 |
133 | qweight_li, qzeros_li, scales_li, g_idx_li, bias_li = [], [], [], [], []
134 | outfeatures = 0
135 | for prefix in prefixes:
136 | qweight, qzeros, scales, g_idx, bias = weights.get_gptq_weight(prefix)
137 | qweight_li.append(qweight)
138 | qzeros_li.append(qzeros)
139 | scales_li.append(scales)
140 | g_idx_li.append(g_idx)
141 | bias_li.append(bias)
142 | outfeatures += qweight.shape[1]
143 |
144 | qweight = torch.cat(qweight_li, dim=1)
145 | qzeros = torch.cat(qzeros_li, dim=1)
146 | scales = torch.cat(scales_li, dim=1)
147 | g_idx = torch.cat(g_idx_li, dim=0)
148 | if bias_li[0] is not None:
149 | bias = torch.cat(bias_li, dim=0)
150 | else:
151 | bias = None
152 |
153 | init_args = (
154 | gptq_config.bits,
155 | gptq_config.group_size,
156 | qweight.shape[0] * 32 // gptq_config.bits,
157 | qweight.shape[1],
158 | bias is not None
159 | )
160 |
161 | with init_empty_weights(include_buffers=True):
162 | quant_linear = QuantLinear(*init_args, trainable=False)
163 | quant_linear.qweight = qweight
164 | quant_linear.qzeros = qzeros
165 | quant_linear.scales = scales
166 | quant_linear.g_idx = g_idx
167 | quant_linear.bias = bias
168 |
169 | return quant_linear
170 |
--------------------------------------------------------------------------------
/code/server/continuous_batching_server/modeling/utils/weights.py:
--------------------------------------------------------------------------------
1 | from glob import glob
2 | from os import path
3 | from typing import *
4 |
5 | import torch
6 | from safetensors import safe_open
7 |
8 |
9 | class Weights:
10 | def __init__(
11 | self,
12 | model_name_or_path: str,
13 | device: torch.device,
14 | dtype: torch.dtype,
15 | quantize_method: Optional[str] = None,
16 | gptq_model_base_name: Optional[str] = None
17 | ):
18 | if not path.isdir(model_name_or_path):
19 | raise NotADirectoryError(f"{model_name_or_path} not exists.")
20 | routing = {}
21 | file_pattern = "*.safetensors"
22 | if quantize_method == "gptq":
23 | file_pattern = "gptq_model*.safetensors"
24 | if gptq_model_base_name:
25 | file_pattern = f"{gptq_model_base_name}*.safetensors"
26 | for model_file in glob(path.join(model_name_or_path, file_pattern)):
27 | with safe_open(model_file, framework="pt") as f:
28 | for k in f.keys():
29 | if k in routing:
30 | raise RuntimeError(
31 | f"Key {k} was found in multiple files: {model_file} and {routing[k]}"
32 | )
33 | routing[k] = model_file
34 | self.routing = routing
35 | self.device = device
36 | self.dtype = dtype
37 | self._handles = {}
38 |
39 | def _get_handle(self, filename: str):
40 | if filename not in self._handles:
41 | f = safe_open(filename, framework="pt")
42 | self._handles[filename] = f
43 |
44 | return self._handles[filename]
45 |
46 | def get_filename(self, tensor_name: str) -> (str, str):
47 | filename = self.routing.get(tensor_name, None)
48 | if filename is None:
49 | raise RuntimeError(f"weight {tensor_name} does not exist")
50 | return str(filename), tensor_name
51 |
52 | def _get_slice(self, tensor_name: str):
53 | filename, tensor_name = self.get_filename(tensor_name)
54 | f = self._get_handle(filename)
55 | slice_ = f.get_slice(tensor_name)
56 | return slice_
57 |
58 | def get_shape(self, tensor_name: str):
59 | return self._get_slice(tensor_name).get_shape()
60 |
61 | def get_tensor(self, tensor_name: str):
62 | filename, tensor_name = self.get_filename(tensor_name)
63 | f = self._get_handle(filename)
64 | tensor = f.get_tensor(tensor_name)
65 | # Special case for gptq which shouldn't convert
66 | # u4 which are disguised as int32
67 | if tensor.dtype not in [torch.int32, torch.int64]:
68 | tensor = tensor.to(dtype=self.dtype)
69 | tensor = tensor.to(device=self.device)
70 | return tensor
71 |
72 | def get_gptq_weight(self, prefix: str):
73 | try:
74 | qweight = self.get_tensor(f"{prefix}.qweight")
75 | except RuntimeError:
76 | raise RuntimeError(
77 | "Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
78 | )
79 | qzeros = self.get_tensor(f"{prefix}.qzeros")
80 | scales = self.get_tensor(f"{prefix}.scales")
81 | g_idx = self.get_tensor(f"{prefix}.g_idx")
82 | try:
83 | bias = self.get_tensor(f"{prefix}.bias")
84 | except:
85 | bias = None
86 |
87 | return qweight, qzeros, scales, g_idx, bias
88 |
--------------------------------------------------------------------------------
/code/server/continuous_batching_server/server.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import time
3 | import threading
4 | from logging import getLogger, Logger
5 | from typing import *
6 | from uuid import uuid4, UUID
7 |
8 | from transformers import AutoTokenizer
9 |
10 | from .batcher import Batcher
11 | from .beam import Beam, BeamGroup, BeamStatus
12 | from .worker import Worker
13 | from .config import ServerConfig
14 | from protocol.completion_task import (
15 | TokenUsage,
16 | HuggingFaceCompletionChoice,
17 | HuggingFaceGenerationConfig,
18 | HuggingFaceCompletionInputs,
19 | HuggingFaceCompletionOutputs
20 | )
21 | from protocol.error import Error
22 |
23 |
24 | SERVER_SINGLETON = None
25 |
26 |
27 | class ServerNotInitializedError(Exception):
28 | def __repr__(self):
29 | return "server is not initialized, please initialize a server object first."
30 |
31 | def __str__(self):
32 | return self.__repr__()
33 |
34 |
35 | class ServerDoubleInitializeError(Exception):
36 | def __repr__(self):
37 | return "server is initialized, do not initialize again, please use get_server() instead."
38 |
39 | def __str__(self):
40 | return self.__repr__()
41 |
42 |
43 | class Server:
44 | def __init__(
45 | self,
46 | config: ServerConfig,
47 | logger: Optional[Logger] = None
48 | ):
49 | global SERVER_SINGLETON
50 | if SERVER_SINGLETON is not None:
51 | raise ServerDoubleInitializeError()
52 |
53 | self.config = config
54 |
55 | batcher_config = config.batcher_config
56 | cache_config = config.cache_config
57 | model_loading_config = config.model_loading_config
58 | parallel_config = config.parallel_config
59 |
60 | assert parallel_config.tp_size == 1, "we don't provide model parallelism support for now."
61 |
62 | self.logger = logger if logger else getLogger(__name__)
63 |
64 | self.worker = Worker(cache_config, model_loading_config, parallel_config, logger)
65 | self.batcher = Batcher(batcher_config, cache_config, logger)
66 |
67 | self.tokenizer_max_length = model_loading_config.model_max_length
68 | self.tokenizer = AutoTokenizer.from_pretrained(
69 | model_loading_config.tokenizer_name_or_path or model_loading_config.model_name_or_path,
70 | use_fast=model_loading_config.use_fast_tokenizer,
71 | trust_remote_code=model_loading_config.trust_remote_code,
72 | truncation_side="left",
73 | )
74 | if not self.tokenizer.pad_token_id:
75 | self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
76 | self.tokenizer.pad_token = self.tokenizer.eos_token
77 |
78 | self.finished_table: Dict[UUID, Tuple[HuggingFaceCompletionOutputs, Optional[Error], int]] = dict()
79 |
80 | threading.Thread(target=self._run, daemon=True).start()
81 |
82 | SERVER_SINGLETON = self
83 |
84 | def _encode(self, text: str) -> List[int]:
85 | return self.tokenizer(text, truncation=True, max_length=self.tokenizer_max_length)["input_ids"]
86 |
87 | def _decode(self, token_ids: List[int]) -> str:
88 | return self.tokenizer.decode(token_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
89 |
90 | def _construct_generation_result(
91 | self,
92 | request: BeamGroup
93 | ) -> Tuple[HuggingFaceCompletionOutputs, Optional[Error], int]:
94 | final_beams = request.get_final_beams()
95 | choices = [
96 | HuggingFaceCompletionChoice(
97 | text=self._decode(beam.generated_token_ids),
98 | index=idx,
99 | finish_reason=beam.finish_reason
100 | )
101 | for idx, beam in enumerate(final_beams)
102 | ]
103 | usage = TokenUsage(
104 | prompt_tokens=len(request.get_final_beams()[0].prompt_token_ids),
105 | completion_tokens=sum([beam.num_generated_tokens for beam in request.get_beams(BeamStatus.FINISHED)])
106 | )
107 | usage.total_tokens = usage.prompt_tokens + usage.completion_tokens
108 | return HuggingFaceCompletionOutputs(choices=choices, usage=usage), None, 200
109 |
110 | async def wait_task_done(
111 | self,
112 | inp: HuggingFaceCompletionInputs
113 | ) -> Tuple[HuggingFaceCompletionOutputs, Optional[Error], int, float]:
114 | request_id = uuid4()
115 | start = time.time()
116 |
117 | inp.generation_config.eos_token_id = self.tokenizer.eos_token_id
118 | inp.generation_config.pad_token_id = self.tokenizer.pad_token_id
119 |
120 | request = BeamGroup(
121 | request_id=request_id,
122 | arrival_time=time.time(),
123 | beams=[
124 | Beam(
125 | request_id,
126 | prompt=inp.prompt,
127 | prompt_token_ids=self._encode(inp.prompt),
128 | block_size=self.batcher.cache_config.block_size
129 | )
130 | ],
131 | generation_config=inp.generation_config
132 | )
133 | self.logger.info(msg=f"Task-{request_id} is added.")
134 | self.batcher.add_request(request)
135 |
136 | while True:
137 | await asyncio.sleep(0.1)
138 | if request_id in self.finished_table:
139 | end = time.time()
140 | outputs, error, status_code = self.finished_table.pop(request_id)
141 | wall_time = end - start
142 | self.logger.info(msg=f"Task-{request_id} is finished, {wall_time=: .4f}s")
143 | return outputs, error, status_code, wall_time
144 |
145 | def _run(self) -> None:
146 | steps = 0
147 | while True:
148 | steps += 1
149 | batch, blocks_to_copy, blocks_to_swap_in, blocks_to_swap_out, finished_requests = self.batcher.schedule()
150 | for req in finished_requests:
151 | self.finished_table[req.request_id] = self._construct_generation_result(req)
152 | self.worker.forward(batch, blocks_to_copy, blocks_to_swap_in, blocks_to_swap_out)
153 | self.batcher.batch_generate(batch)
154 | time.sleep(0.001)
155 |
156 |
157 | def get_server():
158 | if SERVER_SINGLETON is None:
159 | raise ServerNotInitializedError()
160 | return SERVER_SINGLETON
161 |
162 |
163 | __all__ = ["Server", "get_server"]
164 |
--------------------------------------------------------------------------------
/code/server/continuous_batching_server/worker.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | from logging import getLogger, Logger
4 | from typing import *
5 |
6 | import torch
7 | from auto_gptq import BaseQuantizeConfig as GPTQConfig
8 | from auto_gptq.modeling._utils import autogptq_post_init
9 |
10 | from .batcher import Batch
11 | from .cache.cache import Cache
12 | from .modeling.utils.weights import Weights
13 | from .config import (
14 | CacheConfig,
15 | ModelConfig,
16 | ModelLoadingConfig,
17 | MODEL_AUTO_TABLE,
18 | ParallelConfig,
19 | TORCH_FLOAT_DTYPE_MAP
20 | )
21 |
22 |
23 | def _get_gptq_config(config_dir: str, config_file_base_name: Optional[str] = None) -> GPTQConfig:
24 | config_path = os.path.join(config_dir, "quantize_config.json")
25 | if config_file_base_name:
26 | config_path = os.path.join(config_dir, f"{config_file_base_name}.json")
27 | return GPTQConfig(**json.load(open(config_path, "r", encoding="utf-8")))
28 |
29 |
30 | def _pad_to_alignment(x: List[int], multiple_of: int) -> List[int]:
31 | return x + [0] * ((-len(x)) % multiple_of)
32 |
33 |
34 | def _pad_to_max(x: List[int], max_len: int) -> List[int]:
35 | return x + [0] * (max_len - len(x))
36 |
37 |
38 | class Worker:
39 | def __init__(
40 | self,
41 | cache_config: CacheConfig,
42 | model_loading_config: ModelLoadingConfig,
43 | parallel_config: ParallelConfig,
44 | logger: Optional[Logger] = None
45 | ):
46 | self.cache_config = cache_config
47 | self.model_loading_config = model_loading_config
48 | self.parallel_config = parallel_config
49 |
50 | self.logger = logger if logger else getLogger(__name__)
51 |
52 | # load model
53 | self.device = torch.device(self.model_loading_config.device)
54 | self.dtype = TORCH_FLOAT_DTYPE_MAP[self.model_loading_config.torch_dtype]
55 |
56 | torch.cuda.set_device(self.device)
57 |
58 | factory = MODEL_AUTO_TABLE[self.model_loading_config.model_type]
59 | model_config_cls = factory.model_config_cls
60 | model_cls = factory.model_cls
61 | model_config = model_config_cls.from_pretrained(
62 | self.model_loading_config.model_name_or_path,
63 | trust_remote_code=self.model_loading_config.trust_remote_code
64 | )
65 | model_weights = Weights(
66 | self.model_loading_config.model_name_or_path,
67 | device=self.device,
68 | dtype=self.dtype,
69 | quantize_method=self.model_loading_config.quantize_method,
70 | gptq_model_base_name=self.model_loading_config.gptq_model_base_name
71 | )
72 | if self.model_loading_config.quantize_method == "gptq":
73 | gptq_config = _get_gptq_config(
74 | self.model_loading_config.model_name_or_path,
75 | self.model_loading_config.gptq_config_base_name
76 | )
77 | else:
78 | gptq_config = None
79 | self.model = model_cls(config=model_config, weights=model_weights, gptq_config=gptq_config)
80 | if self.model_loading_config.quantize_method == "gptq":
81 | self.model = autogptq_post_init(self.model, gptq_config.desc_act)
82 | self.model.eval()
83 | self.model_config = ModelConfig(self.model.config, self.parallel_config)
84 |
85 | self.cache = Cache(
86 | cache_config=self.cache_config,
87 | model_config=self.model_config,
88 | parallel_config=self.parallel_config,
89 | dtype=self.dtype,
90 | device=self.device
91 | )
92 |
93 | def _prepare_inputs(self, batch: Batch, prefill: bool) -> Optional[Dict[str, Optional[torch.Tensor]]]:
94 | input_ids: List[int] = []
95 | position_ids: List[int] = []
96 | block_tables: Optional[List[List[int]]] = []
97 | slots: List[int] = []
98 | context_lengths: List[int] = []
99 | lm_head_indices: Optional[List[int]] = []
100 |
101 | beams = batch.prefill_beams if prefill else batch.generation_beams
102 | if not beams:
103 | return
104 | if prefill:
105 | block_tables = None
106 | for beam in beams:
107 | input_ids += beam.prompt_token_ids
108 | position_ids += list(range(beam.num_tokens))
109 | context_lengths.append(beam.num_tokens)
110 | lm_head_indices.append(sum(context_lengths) - 1)
111 |
112 | block_ids = batch.block_tables[beam.beam_id]
113 | if block_ids is None:
114 | slots += [0] * beam.num_tokens
115 | else:
116 | for i in range(beam.num_tokens):
117 | block_id = block_ids[i // self.cache.block_size]
118 | block_offset = i % self.cache.block_size
119 | slot = block_id * self.cache.block_size + block_offset
120 | slots.append(slot)
121 | else:
122 | lm_head_indices = None
123 | for beam in beams:
124 | input_ids.append(beam.last_token_id)
125 | position_ids.append(beam.num_tokens - 1)
126 | context_lengths.append(beam.num_tokens)
127 |
128 | block_ids = batch.block_tables[beam.beam_id]
129 | block_tables.append(block_ids)
130 |
131 | block_id = block_ids[position_ids[-1] // self.cache.block_size]
132 | block_offset = position_ids[-1] % self.cache_config.block_size
133 | slot = block_id * self.cache_config.block_size + block_offset
134 | slots.append(slot)
135 |
136 | # Optimization: Pad the input length to be a multiple of 8.
137 | # This is required for utilizing the Tensor Cores in NVIDIA GPUs.
138 | # FIXME: after padding, model execution will fail if bsz is not a multiple of 8
139 | # input_ids = _pad_to_alignment(input_ids, multiple_of=8)
140 | # position_ids = _pad_to_alignment(position_ids, multiple_of=8)
141 |
142 | # Convert to tensors.
143 | input_ids = torch.tensor(input_ids, dtype=torch.long, device=self.device)
144 | position_ids = torch.tensor(position_ids, dtype=torch.long, device=self.device)
145 | context_lengths = torch.tensor(context_lengths, dtype=torch.int32, device=self.device)
146 | if block_tables is not None:
147 | max_num_blocks = max([len(block_table) for block_table in block_tables])
148 | block_tables = torch.IntTensor(
149 | [_pad_to_max(block_table, max_num_blocks) for block_table in block_tables]
150 | ).to(self.device)
151 | slots = torch.IntTensor(slots).to(self.device)
152 |
153 | return {
154 | "input_ids": input_ids,
155 | "position_ids": position_ids,
156 | "kv_cache": self.cache.cache,
157 | "prefill": prefill,
158 | "block_tables": block_tables,
159 | "slots": slots,
160 | "context_lengths": context_lengths,
161 | "lm_head_indices": lm_head_indices
162 | }
163 |
164 | @torch.no_grad()
165 | @torch.cuda.amp.autocast()
166 | def _forward(self, batch: Batch, cache_events: Optional[List[torch.cuda.Event]] = None):
167 | prefill_inputs = self._prepare_inputs(batch, prefill=True)
168 | generation_inputs = self._prepare_inputs(batch, prefill=False)
169 |
170 | if prefill_inputs:
171 | self.logger.debug("executing model for prefilling.")
172 | batch.prefill_logits = self.model(cache_events=cache_events, **prefill_inputs)
173 | self.logger.debug("executed model for prefilling.")
174 | if generation_inputs:
175 | self.logger.debug("executing model for decoding.")
176 | batch.generation_logits = self.model(cache_events=cache_events, **generation_inputs)
177 | self.logger.debug("executed model for decoding.")
178 |
179 | def forward(
180 | self,
181 | batch: Batch,
182 | blocks_to_copy: Dict[int, List[int]],
183 | blocks_to_swap_in: Dict[int, int],
184 | blocks_to_swap_out: Dict[int, int]
185 | ):
186 | # Issue cache operations.
187 | issued_cache_op = False
188 | if blocks_to_swap_in:
189 | self.logger.debug("executing cache swap in operation.")
190 | self.cache.swap_in(blocks_to_swap_in)
191 | issued_cache_op = True
192 | self.logger.debug("executed cache swap in operation.")
193 | if blocks_to_swap_out:
194 | self.logger.debug("executing cache swap out operation.")
195 | self.cache.swap_out(blocks_to_swap_out)
196 | issued_cache_op = True
197 | self.logger.debug("executed cache swap out operation.")
198 | if blocks_to_copy:
199 | self.logger.debug("execution cache copy operation.")
200 | self.cache.copy(blocks_to_copy)
201 | issued_cache_op = True
202 | self.logger.debug("executed cache copy operation")
203 |
204 | if issued_cache_op:
205 | cache_events = self.cache.events
206 | else:
207 | cache_events = None
208 |
209 | if batch.num_beams == 0:
210 | if cache_events is not None:
211 | for event in cache_events:
212 | event.wait()
213 | self.logger.debug("no beams need to be processed, return directly.")
214 | return
215 |
216 | self._forward(batch, cache_events=cache_events)
217 |
218 |
219 | __all__ = ["Worker"]
220 |
--------------------------------------------------------------------------------
/code/server/static_batching_server/__init__.py:
--------------------------------------------------------------------------------
1 | from .server import get_server, Server
2 | from .config import ServerConfig
3 |
4 |
5 | __all__ = ["get_server", "Server", "ServerConfig"]
6 |
--------------------------------------------------------------------------------
/code/server/static_batching_server/batcher.py:
--------------------------------------------------------------------------------
1 | from logging import getLogger, Logger
2 | from typing import List, Optional, Tuple
3 | from uuid import UUID
4 |
5 | from pydantic import BaseModel, Field
6 |
7 | from .config import BatcherConfig
8 | from protocol.completion_task import HuggingFaceCompletionInputs, HuggingFaceGenerationConfig
9 |
10 |
11 | class Package(BaseModel):
12 | ids: List[UUID] = Field(default=...)
13 | prompts: List[str] = Field(default=...)
14 | generation_config: HuggingFaceGenerationConfig = Field(default=...)
15 |
16 | def __hash__(self):
17 | return hash(self.generation_config)
18 |
19 | @property
20 | def workload(self):
21 | return len(self.prompts) * self.generation_config.num_beams
22 |
23 | def add(self, prompt: str, uid: UUID):
24 | self.prompts.append(prompt)
25 | self.ids.append(uid)
26 |
27 | def __repr__(self):
28 | return f"Package(workload={self.workload})"
29 |
30 | def __str__(self):
31 | return self.__repr__()
32 |
33 |
34 | class Batcher:
35 | def __init__(self, config: BatcherConfig, logger: Optional[Logger] = None):
36 | self.config = config
37 | self.logger = logger if logger else getLogger(__name__)
38 |
39 | self.inputs: List[Tuple[HuggingFaceCompletionInputs, UUID]] = []
40 |
41 | def pack(self) -> Optional[Package]:
42 | if not self.inputs:
43 | return None
44 |
45 | # =============================
46 | # Strategy 1:
47 | # pack first input, then select input who has the same generation_config util package is full or
48 | # there is no other input left
49 | # =============================
50 |
51 | inp, inp_id = self.inputs.pop(0)
52 | package = Package(ids=[inp_id], prompts=[inp.prompt], generation_config=inp.generation_config)
53 | inputs = []
54 | while self.inputs:
55 | if package.workload > self.config.package_max_workload: # package is full, put back and return
56 | self.inputs = inputs + self.inputs
57 | self.logger.debug(msg=str(package))
58 | return package
59 | inp, inp_id = self.inputs.pop(0)
60 | if hash(inp.generation_config) != hash(package): # gen_config is different, put back
61 | inputs.append((inp, inp_id))
62 | else:
63 | package.add(inp.prompt, inp_id)
64 | self.logger.debug(msg=str(package))
65 | return package
66 |
67 | # =============================
68 | # Strategy 2:
69 | # pack input one by one, return immediately when package is full or generation_config is different
70 | # =============================
71 |
72 | # package = None
73 | # while self.inputs:
74 | # inp, inp_id = self.inputs.pop(0)
75 | # if package is None: # the first input, initialize package
76 | # package = Package(ids=[inp_id], prompts=[inp.prompt], generation_config=inp.generation_config)
77 | # else: # if gen_config not the same or package is full, return package, otherwise add prompt into package
78 | # hash_value = hash(inp.generation_config)
79 | # if hash_value != hash(package) or package.workload > self.config.package_max_workload:
80 | # self.inputs.insert(0, (inp, inp_id))
81 | # self.logger.debug(msg=str(package))
82 | # return package
83 | # package.add(inp.prompt, inp_id)
84 | # self.logger.debug(msg=str(package))
85 | # return package
86 |
87 | def add(self, inp: HuggingFaceCompletionInputs, inp_id: UUID):
88 | self.inputs.append((inp, inp_id))
89 |
90 |
91 | __all__ = ["Batcher"]
92 |
--------------------------------------------------------------------------------
/code/server/static_batching_server/config.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Optional, Union
2 |
3 | from pydantic import BaseModel, Field
4 |
5 |
6 | class BatcherConfig(BaseModel):
7 | package_max_workload: int = Field(default=1)
8 | packaging_interval_seconds: int = Field(default=2)
9 |
10 |
11 | class WorkerConfig(BaseModel):
12 | model_name_or_path: str = Field(default="dummy_model_name_or_path")
13 | tokenizer_name_or_path: Optional[str] = Field(default=None)
14 | revision: str = Field(default="main")
15 | low_cpu_mem_usage: bool = Field(default=True)
16 | torch_dtype: Union[str] = Field(default="float16", regex="(float16|bfloat16)")
17 | device: Optional[Union[int, str]] = Field(default=None)
18 | max_memory: Optional[Dict[Union[str, int], str]] = Field(default=None)
19 | device_map: Union[str, Dict[str, Union[int, str]]] = Field(default="auto")
20 | use_fast_tokenizer: bool = Field(default=False)
21 | trust_remote_code: bool = Field(default=False)
22 | use_safetensors: bool = Field(default=False)
23 | batch_size: int = Field(default=-1) # -1 means execute all inputs together no matter how many they are
24 | is_gptq_quantized: bool = Field(default=False)
25 |
26 |
27 | class ServerConfig(BatcherConfig):
28 | batcher_config: BatcherConfig = Field(default=BatcherConfig())
29 | worker_config: WorkerConfig = Field(default=WorkerConfig())
30 |
31 |
32 | __all__ = [
33 | "BatcherConfig",
34 | "WorkerConfig",
35 | "ServerConfig"
36 | ]
37 |
--------------------------------------------------------------------------------
/code/server/static_batching_server/server.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import time
3 | from logging import getLogger, Logger
4 | from typing import Dict, Optional, Tuple
5 | from threading import Thread
6 | from uuid import uuid4, UUID
7 |
8 | from .batcher import Batcher
9 | from .config import ServerConfig
10 | from .worker import Worker
11 | from protocol.completion_task import HuggingFaceCompletionInputs, HuggingFaceCompletionOutputs
12 | from protocol.error import Error
13 |
14 |
15 | SERVER_SINGLETON = None
16 |
17 |
18 | class ServerNotInitializedError(Exception):
19 | def __repr__(self):
20 | return "server is not initialized, please initialize a server object first."
21 |
22 | def __str__(self):
23 | return self.__repr__()
24 |
25 |
26 | class ServerDoubleInitializeError(Exception):
27 | def __repr__(self):
28 | return "server is initialized, do not initialize again, please use get_server() instead."
29 |
30 | def __str__(self):
31 | return self.__repr__()
32 |
33 |
34 | class Server:
35 | def __init__(self, config: ServerConfig, logger: Optional[Logger] = None):
36 | global SERVER_SINGLETON
37 | if SERVER_SINGLETON is not None:
38 | raise ServerDoubleInitializeError()
39 |
40 | self.config = config
41 | self.batcher_config = config.batcher_config
42 | self.worker_config = config.worker_config
43 | self.logger = logger if logger else getLogger(__name__)
44 |
45 | self.batcher = Batcher(config=self.batcher_config, logger=logger)
46 | self.worker = Worker(config=self.worker_config, logger=logger)
47 |
48 | self.outputs: Dict[UUID, Tuple[HuggingFaceCompletionOutputs, Optional[Error], int, float]] = dict()
49 |
50 | Thread(target=self._run, daemon=True).start()
51 |
52 | SERVER_SINGLETON = self
53 |
54 | def _run(self):
55 | while True:
56 | package = self.batcher.pack()
57 | if package is not None:
58 | self.outputs.update(self.worker.execute(package.prompts, package.ids, package.generation_config))
59 | else:
60 | time.sleep(self.batcher_config.packaging_interval_seconds)
61 |
62 | async def wait_task_done(
63 | self,
64 | inp: HuggingFaceCompletionInputs
65 | ) -> Tuple[HuggingFaceCompletionOutputs, Optional[Error], int, float, float]:
66 | uid = uuid4()
67 | start = time.time()
68 |
69 | self.logger.info(msg=f"Task-{uid} is added.")
70 | self.batcher.add(inp, uid)
71 |
72 | while True:
73 | await asyncio.sleep(0.1)
74 | if uid in self.outputs:
75 | end = time.time()
76 | outputs, error, status_code, cpu_time = self.outputs.pop(uid)
77 | wall_time = end - start
78 | self.logger.info(msg=f"Task-{uid} is finished, {cpu_time=: .4f}s, {wall_time=: .4f}s")
79 | return outputs, error, status_code, cpu_time, wall_time
80 |
81 |
82 | def get_server():
83 | if SERVER_SINGLETON is None:
84 | raise ServerNotInitializedError()
85 | return SERVER_SINGLETON
86 |
87 |
88 | __all__ = ["Server", "get_server"]
89 |
--------------------------------------------------------------------------------
/code/server/static_batching_server/worker.py:
--------------------------------------------------------------------------------
1 | import time
2 | from logging import getLogger, Logger
3 | from typing import Dict, List, Optional, Tuple
4 | from uuid import UUID
5 |
6 | import torch
7 | from auto_gptq import AutoGPTQForCausalLM
8 | from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase
9 |
10 | from .config import WorkerConfig
11 | from protocol.completion_task import (
12 | TokenUsage,
13 | HuggingFaceGenerationConfig,
14 | HuggingFaceCompletionChoice,
15 | HuggingFaceCompletionOutputs
16 | )
17 | from protocol.error import Error
18 |
19 |
20 | class TextGenerationPipeline:
21 | """A simplified pipeline to show what HF's TextGenerationPipeline mainly do under the hood"""
22 | def __init__(self, model: torch.nn.Module, tokenizer: PreTrainedTokenizerBase, batch_size: int = -1):
23 | self.model = model
24 | self.tokenizer = tokenizer
25 | device = model.device
26 | if model.hf_device_map:
27 | device = model.hf_device_map[next(iter(model.hf_device_map))]
28 | self.device = torch.device(device) if not isinstance(device, torch.device) else device
29 | self.batch_size = batch_size
30 |
31 | def _preprocess(
32 | self,
33 | prompt_texts: List[str],
34 | generation_config: HuggingFaceGenerationConfig,
35 | handle_long_generation=None,
36 | ) -> "BatchEncoding":
37 | inputs = self.tokenizer(
38 | prompt_texts,
39 | padding=True,
40 | truncation=True,
41 | return_tensors="pt"
42 | ).to(self.device)
43 |
44 | if handle_long_generation == "hole":
45 | cur_len = inputs["input_ids"].shape[-1]
46 | new_tokens = generation_config.max_new_tokens
47 | if cur_len + new_tokens > self.tokenizer.model_max_length:
48 | keep_length = self.tokenizer.model_max_length - new_tokens
49 | if keep_length <= 0:
50 | raise ValueError(
51 | "We cannot use `hole` to handle this generation the number of desired tokens exceeds the"
52 | " models max length"
53 | )
54 |
55 | inputs["input_ids"] = inputs["input_ids"][:, -keep_length:]
56 | if "attention_mask" in inputs:
57 | inputs["attention_mask"] = inputs["attention_mask"][:, -keep_length:]
58 |
59 | return inputs
60 |
61 | def _model_generate(
62 | self,
63 | input_ids: torch.Tensor,
64 | attention_mask: torch.Tensor,
65 | generation_config: HuggingFaceGenerationConfig
66 | ) -> torch.Tensor:
67 | generation_config.eos_token_id = self.tokenizer.eos_token_id
68 | generation_config.pad_token_id = self.tokenizer.pad_token_id
69 | decode_dict = generation_config.dict(by_alias=True)
70 | decode_dict.pop("seed")
71 | batch_gen_sequences = self.model.generate(
72 | input_ids=input_ids,
73 | attention_mask=attention_mask,
74 | **decode_dict
75 | )
76 | return batch_gen_sequences
77 |
78 | def _postprocess(
79 | self,
80 | input_ids: torch.Tensor,
81 | generated_sequences: torch.Tensor,
82 | generation_config: HuggingFaceGenerationConfig,
83 | clean_up_tokenization_spaces=True
84 | ) -> List[HuggingFaceCompletionOutputs]:
85 | input_ids = input_ids.cpu()
86 | generated_sequences = generated_sequences.cpu()
87 |
88 | num_return_sequences = len(generated_sequences) // len(input_ids)
89 | batch_outputs = []
90 | for idx, start in enumerate(range(0, len(generated_sequences), num_return_sequences)):
91 | inp = input_ids[idx].tolist()
92 | if self.tokenizer.pad_token_id in inp:
93 | inp = inp[:inp.index(self.tokenizer.pad_token_id)]
94 |
95 | sequences = generated_sequences[start: start + num_return_sequences]
96 | sequences = sequences[..., input_ids[idx].size(0):].tolist()
97 | for i, seq in enumerate(sequences):
98 | if self.tokenizer.pad_token_id in seq:
99 | sequences[i] = seq[: seq.index(self.tokenizer.pad_token_id)]
100 | sequences_num_tokens = [len(seq) for seq in sequences]
101 |
102 | usage = TokenUsage(prompt_tokens=len(inp), completion_tokens=sum(sequences_num_tokens))
103 | usage.total_tokens = usage.prompt_tokens + usage.completion_tokens
104 |
105 | generated_texts = self.tokenizer.batch_decode(
106 | sequences,
107 | clean_up_tokenization_spaces=clean_up_tokenization_spaces,
108 | skip_special_tokens=True
109 | )
110 | choices = [
111 | HuggingFaceCompletionChoice(
112 | text=text,
113 | index=index,
114 | finish_reason="stop" if num_new_tokens < generation_config.max_new_tokens else "length"
115 | )
116 | for index, (text, num_new_tokens) in enumerate(zip(generated_texts, sequences_num_tokens))
117 | ]
118 |
119 | batch_outputs.append(HuggingFaceCompletionOutputs(choices=choices, usage=usage))
120 |
121 | return batch_outputs
122 |
123 | def __call__(
124 | self,
125 | text_inputs,
126 | generation_config: HuggingFaceGenerationConfig,
127 | clean_up_tokenization_spaces=False,
128 | handle_long_generation="hole",
129 | ) -> List[HuggingFaceCompletionOutputs]:
130 | if isinstance(text_inputs, str):
131 | text_inputs = [text_inputs]
132 |
133 | outputs = []
134 | batch_size = self.batch_size
135 | if batch_size == -1:
136 | batch_size = len(text_inputs)
137 | for start in range(0, len(text_inputs), batch_size):
138 | batch_input_texts = text_inputs[start: start + batch_size]
139 |
140 | batch_inputs = self._preprocess(batch_input_texts, generation_config, handle_long_generation)
141 | batch_input_ids = batch_inputs.input_ids
142 | batch_attention_mask = batch_inputs.attention_mask
143 |
144 | batch_gen_sequences = self._model_generate(batch_input_ids, batch_attention_mask, generation_config)
145 |
146 | outputs += self._postprocess(
147 | batch_input_ids,
148 | batch_gen_sequences,
149 | generation_config,
150 | clean_up_tokenization_spaces
151 | )
152 | return outputs
153 |
154 |
155 | class Worker:
156 | def __init__(self, config: WorkerConfig, logger: Optional[Logger] = None):
157 | self.config = config
158 | self.logger = logger if logger else getLogger(__name__)
159 |
160 | self.model, self.tokenizer = self._load_model_tokenizer()
161 | self.pipeline = TextGenerationPipeline(self.model, self.tokenizer, batch_size=self.config.batch_size)
162 |
163 | def _load_model_tokenizer(self):
164 | tokenizer = AutoTokenizer.from_pretrained(
165 | self.config.tokenizer_name_or_path or self.config.model_name_or_path,
166 | use_fast=self.config.use_fast_tokenizer,
167 | padding_side="left",
168 | truncation_side="left",
169 | trust_remote_code=self.config.trust_remote_code
170 | )
171 | if not tokenizer.pad_token_id:
172 | tokenizer.pad_token_id = tokenizer.eos_token_id
173 |
174 | max_memory = self.config.max_memory
175 | if max_memory:
176 | max_memory = {(eval(k) if isinstance(k, str) else k): v for k, v in max_memory.items()}
177 |
178 | if self.config.is_gptq_quantized:
179 | model = AutoGPTQForCausalLM.from_quantized(
180 | model_name_or_path=self.config.model_name_or_path,
181 | device_map=self.config.device_map,
182 | max_memory=max_memory,
183 | low_cpu_mem_usage=self.config.low_cpu_mem_usage,
184 | trust_remote_code=self.config.trust_remote_code,
185 | use_safetensors=self.config.use_safetensors
186 | )
187 | else:
188 | model = AutoModelForCausalLM.from_pretrained(
189 | pretrained_model_name_or_path=self.config.model_name_or_path,
190 | torch_dtype=getattr(torch, self.config.torch_dtype),
191 | device_map=self.config.device_map,
192 | max_memory=max_memory,
193 | low_cpu_mem_usage=self.config.low_cpu_mem_usage,
194 | revision=self.config.revision,
195 | trust_remote_code=self.config.trust_remote_code
196 | )
197 |
198 | return model, tokenizer
199 |
200 | def execute(
201 | self,
202 | prompts: List[str],
203 | uids: List[UUID],
204 | generation_config: HuggingFaceGenerationConfig
205 | ) -> Dict[UUID, Tuple[HuggingFaceCompletionOutputs, Optional[Error], int, float]]:
206 | start = time.time()
207 | try:
208 | pipeline_results = self.pipeline(prompts, generation_config)
209 | end = time.time()
210 | return {uid: (outputs, None, 200, end - start) for uid, outputs in zip(uids, pipeline_results)}
211 | except Exception as e:
212 | end = time.time()
213 | error = Error(type=e.__class__.__name__, detail=str(e))
214 | self.logger.error(msg=str(error), exc_info=e)
215 | return {uid: (HuggingFaceCompletionOutputs(), error, 500, end - start) for uid in uids}
216 |
217 |
218 | __all__ = ["Worker"]
219 |
--------------------------------------------------------------------------------
/code/server_requirements.txt:
--------------------------------------------------------------------------------
1 | aiohttp
2 | requests
3 | openai
4 | pydantic<2.0.0
5 | fastapi[all]==0.96.0
6 | setproctitle
7 | uvicorn[standard]
8 | gunicorn
9 | accelerate
10 | sentencepiece
11 | transformers
12 | xformers
--------------------------------------------------------------------------------
/code/start_app.py:
--------------------------------------------------------------------------------
1 | import os
2 | from argparse import ArgumentParser
3 | from os.path import abspath, dirname, join
4 |
5 |
6 | CURRENT_DIR = dirname(abspath(__file__))
7 |
8 |
9 | def build_gunicorn_cmd_str(
10 | gunicorn_config_file_path: str,
11 | proc_name: str,
12 | log_dir: str,
13 | port: int
14 | ):
15 | access_log_file = join(log_dir, f"{proc_name}_access.log")
16 | error_log_file = join(log_dir, f"{proc_name}_error.log")
17 | cmd_str = (
18 | f'-c "{gunicorn_config_file_path}" '
19 | f'--bind "0.0.0.0:{port}" '
20 | f'--name "{proc_name}" '
21 | f'--access-logfile "{access_log_file}" '
22 | f'--error-logfile "{error_log_file}" '
23 | f'-D' # running in background
24 | )
25 | return cmd_str
26 |
27 |
28 | def main():
29 | parser = ArgumentParser()
30 |
31 | # args to start client app
32 | parser.add_argument(
33 | "--start_client_app", action="store_true",
34 | help="whether to start client app"
35 | )
36 | parser.add_argument(
37 | "--client_config_file_path", type=str, default="client_config.json",
38 | help="local path to read client config file"
39 | )
40 | parser.add_argument(
41 | "--client_config_hot_update_interval_minutes", type=int, default=1,
42 | help="specify the interval minutes to hot update client config"
43 | )
44 | parser.add_argument(
45 | "--client_debug", action="store_true",
46 | help="whether to start client app in debug mode"
47 | )
48 | parser.add_argument(
49 | "--client_port", type=int, default=8000,
50 | help="the port the client app will use"
51 | )
52 |
53 | # args to start CB server app
54 | parser.add_argument(
55 | "--start_cb_server_app", action="store_true",
56 | help="whether to start CB server app"
57 | )
58 | parser.add_argument(
59 | "--cb_server_model_id", type=str,
60 | help="model id for the CB server that will be started"
61 | )
62 | parser.add_argument(
63 | "--cb_server_config_file_path", type=str, default="cb_server_config.json",
64 | help="local path to read CB server config file"
65 | )
66 | parser.add_argument(
67 | "--cb_server_debug", action="store_true",
68 | help="whether to start CB server app in debug mode"
69 | )
70 | parser.add_argument(
71 | "--cb_server_port", type=int, default=8001,
72 | help="the part the CB server app will use"
73 | )
74 |
75 | # args to start SB server app
76 | parser.add_argument(
77 | "--start_sb_server_app", action="store_true",
78 | help="whether to start SB server app"
79 | )
80 | parser.add_argument(
81 | "--sb_server_model_id", type=str,
82 | help="model id for the SB server that will be started"
83 | )
84 | parser.add_argument(
85 | "--sb_server_config_file_path", type=str, default="sb_server_config.json",
86 | help="local path to read SB server config file"
87 | )
88 | parser.add_argument(
89 | "--sb_server_debug", action="store_true",
90 | help="whether to start SB server app in debug mode"
91 | )
92 | parser.add_argument(
93 | "--sb_server_port", type=int, default=8002,
94 | help="the part the SB server app will use"
95 | )
96 |
97 | # args that are shared among apps
98 | parser.add_argument(
99 | "--client_url", type=str, default=None,
100 | help="URL of client, only be used by servers, has no effect for now"
101 | )
102 | parser.add_argument(
103 | "--gunicorn_config_file_path", type=str, default="gunicorn_config.py",
104 | help="local path to read a python script that stores some common gunicorn settings for APPs"
105 | )
106 |
107 | args = parser.parse_args()
108 |
109 | # start client app if triggered
110 | if args.start_client_app:
111 | cmd_str = (
112 | f"""gunicorn 'client_app:build_app("""
113 | f"""client_config_file_path="{args.client_config_file_path}","""
114 | f"""client_config_hot_update_interval_minutes="{args.client_config_hot_update_interval_minutes}","""
115 | f"""debug={args.client_debug})' """
116 | )
117 | cmd_str += build_gunicorn_cmd_str(
118 | gunicorn_config_file_path=args.gunicorn_config_file_path,
119 | proc_name="llm_inference_client",
120 | log_dir=join(CURRENT_DIR, "logs"),
121 | port=args.client_port
122 | )
123 | print(f"start client app, command being executed is:\n{cmd_str}")
124 | os.system(cmd_str)
125 |
126 | # start CB server app if triggered
127 | if args.start_cb_server_app:
128 | cmd_str = (
129 | f"""gunicorn 'continuous_batching_server_app:build_app("""
130 | f"""model_id="{args.cb_server_model_id}","""
131 | f"""server_config_file_path="{args.cb_server_config_file_path}","""
132 | # f"""client_url={args.client_url}"""
133 | f"""debug={args.cb_server_debug})' """
134 | )
135 | cmd_str += build_gunicorn_cmd_str(
136 | gunicorn_config_file_path=args.gunicorn_config_file_path,
137 | proc_name="llm_inference_cb_server",
138 | log_dir=join(CURRENT_DIR, "logs"),
139 | port=args.cb_server_port
140 | )
141 | print(f"start CB server app, command being executed is:\n{cmd_str}")
142 | os.system(cmd_str)
143 |
144 | # start SB server app if triggered
145 | if args.start_sb_server_app:
146 | cmd_str = (
147 | f"""gunicorn 'static_batching_server_app:build_app("""
148 | f"""model_id="{args.sb_server_model_id}","""
149 | f"""server_config_file_path="{args.sb_server_config_file_path}","""
150 | # f"""client_url={args.client_url}"""
151 | f"""debug={args.sb_server_debug})' """
152 | )
153 | cmd_str += build_gunicorn_cmd_str(
154 | gunicorn_config_file_path=args.gunicorn_config_file_path,
155 | proc_name="llm_inference_sb_server",
156 | log_dir=join(CURRENT_DIR, "logs"),
157 | port=args.sb_server_port
158 | )
159 | print(f"start SB server app, command being executed is:\n{cmd_str}")
160 | os.system(cmd_str)
161 |
162 |
163 | if __name__ == "__main__":
164 | main()
165 |
--------------------------------------------------------------------------------
/code/static_batching_server_app.py:
--------------------------------------------------------------------------------
1 | import json
2 | from logging import getLogger, DEBUG, INFO
3 | from typing import Optional
4 |
5 | from fastapi import HTTPException, FastAPI
6 | from fastapi.responses import JSONResponse
7 | from pydantic import BaseModel, Field
8 |
9 | from server.static_batching_server import get_server, Server, ServerConfig
10 | from protocol.completion_task import (
11 | HuggingFaceCompletionInputs,
12 | HuggingFaceCompletionOutputs
13 | )
14 | from protocol.error import Error
15 | from protocol.routes import (
16 | ROUTE_GET_MODEL_ID,
17 | ROUTE_POST_CONTINUOUS_BATCHING_COMPLETION,
18 | ROUTE_CLIENT_ONLY_POST_SERVER_STARTUP_EVENT,
19 | ROUTE_CLIENT_ONLY_POST_SERVER_SHUTDOWN_EVENT
20 | )
21 | from utils.log_util import RequestLoggingMiddleware
22 |
23 |
24 | logger = getLogger("gunicorn.logger") # by default, we use gunicorn to wrap the app
25 |
26 |
27 | class AppConfig(BaseModel):
28 | model_id: str = Field(default=...)
29 | server_config_file_path: str = Field(default=...)
30 | client_url: Optional[str] = Field(default=None)
31 | debug: bool = Field(default=False)
32 |
33 |
34 | APP_NAME = "LLM-Inference-SB-Server"
35 |
36 | app = FastAPI(title=APP_NAME, version="0.1.0")
37 | app_config: Optional[AppConfig] = None
38 |
39 |
40 | def build_app(
41 | model_id: str = None,
42 | server_config_file_path: str = "sb_server_config.json",
43 | client_url: Optional[str] = None,
44 | debug: bool = False
45 | ):
46 | global app, app_config
47 |
48 | if model_id is None:
49 | raise ValueError("You must specify a real value to model_id.")
50 |
51 | logger.setLevel(DEBUG if debug else INFO)
52 |
53 | app_config = AppConfig(
54 | model_id=model_id,
55 | server_config_file_path=server_config_file_path,
56 | client_url=client_url,
57 | debug=debug
58 | )
59 |
60 | app.add_middleware(RequestLoggingMiddleware, logger=logger)
61 |
62 | return app
63 |
64 |
65 | @app.on_event("startup")
66 | def startup():
67 | # initialize server
68 | Server(
69 | config=ServerConfig(**json.load(open(app_config.server_config_file_path, "r", encoding="utf-8"))),
70 | logger=logger
71 | )
72 | # TODO: implement logic to inform client that server is startup
73 |
74 |
75 | @app.on_event("shutdown")
76 | def shutdown():
77 | pass # TODO: implement logic to inform client that server is shutdown
78 |
79 |
80 | @app.get(ROUTE_GET_MODEL_ID)
81 | async def get_model_id():
82 | return JSONResponse(content=app_config.model_id)
83 |
84 |
85 | @app.post(ROUTE_POST_CONTINUOUS_BATCHING_COMPLETION, response_model=HuggingFaceCompletionOutputs)
86 | async def execute_completion(request_inputs: HuggingFaceCompletionInputs):
87 | server = get_server()
88 | outputs, error, status_code, cpu_time, wall_time = await server.wait_task_done(request_inputs)
89 | if status_code != 200:
90 | logger.error(msg=str(error))
91 | raise HTTPException(status_code=status_code, detail=str(error))
92 | return outputs
93 |
94 |
95 | if __name__ == "__main__":
96 | import uvicorn
97 | from argparse import ArgumentParser
98 | from logging import basicConfig
99 |
100 | parser = ArgumentParser()
101 | parser.add_argument("--model_id", type=str)
102 | parser.add_argument("--server_config_file_path", type=str, default="sb_server_config.json")
103 | parser.add_argument("--client_url", type=str, default=None)
104 | parser.add_argument("--debug", action="store_true")
105 | parser.add_argument("--port", type=int, default=8002)
106 | args = parser.parse_args()
107 |
108 | logger = getLogger(__name__) # override gunicorn logger if we use uvicorn directly
109 | basicConfig(
110 | format="%(asctime)s %(levelname)s [%(name)s] %(message)s",
111 | datefmt="%Y-%m-%d %H:%M:%S"
112 | )
113 |
114 | uvicorn.run(
115 | build_app(
116 | model_id=args.model_id,
117 | server_config_file_path=args.server_config_file_path,
118 | client_url=args.client_url,
119 | debug=args.debug
120 | ),
121 | host="0.0.0.0",
122 | port=args.port
123 | )
124 |
--------------------------------------------------------------------------------
/code/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/code/utils/log_util.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import re
3 | from datetime import datetime
4 | from typing import *
5 |
6 | from fastapi import FastAPI, Request, Response
7 | from starlette.middleware.base import BaseHTTPMiddleware
8 |
9 |
10 | # adopt from:
11 | # https://github.com/jeffsiver/fastapi-route-logger/blob/main/fastapi_route_logger_middleware/__init__.py
12 | class RequestLoggingMiddleware(BaseHTTPMiddleware):
13 | def __init__(
14 | self,
15 | app: FastAPI,
16 | *,
17 | logger: Optional[logging.Logger] = None,
18 | skip_routes: List[str] = None,
19 | skip_regexes: List[str] = None,
20 | skip_request_methods: List[str] = None,
21 | ):
22 | self._logger = logger if logger else logging.getLogger(__name__)
23 | self._skip_routes = skip_routes if skip_routes else []
24 | self._skip_regexes = (
25 | list(map(lambda regex: re.compile(regex), skip_regexes))
26 | if skip_regexes
27 | else []
28 | )
29 | self._skip_request_methods = skip_request_methods if skip_request_methods else []
30 |
31 | BaseHTTPMiddleware.__init__(self, app)
32 |
33 | async def dispatch(self, request: Request, call_next: Callable) -> Response:
34 | if self._should_route_be_skipped(request):
35 | return await call_next(request)
36 |
37 | return await self._execute_request_with_logging(request, call_next)
38 |
39 | def _should_route_be_skipped(self, request: Request) -> bool:
40 | return any(
41 | [True for path in self._skip_routes if request.url.path.startswith(path)]
42 | + [True for regex in self._skip_regexes if regex.match(request.url.path)]
43 | + [True for method in self._skip_request_methods if request.method.lower() == method.lower()]
44 | )
45 |
46 | async def _execute_request_with_logging(
47 | self, request: Request, call_next: Callable
48 | ) -> Response:
49 | self._logging_before_execution(request)
50 | start_time = datetime.utcnow()
51 | try:
52 | response = await call_next(request)
53 | except Exception as e:
54 | self._logging_when_error_raised(e)
55 | raise
56 |
57 | finish_time = datetime.utcnow()
58 | self._logging_after_execution(response, (finish_time - start_time).total_seconds())
59 |
60 | return response
61 |
62 | def _logging_before_execution(self, request: Request):
63 | content = (
64 | f"receive a request from {request.client.host}:{request.client.port} "
65 | f"to {request.url.path}, method={request.method}"
66 | )
67 | self._logger.info(content)
68 |
69 | def _logging_after_execution(self, response: Response, execution_time: float):
70 | overall_status = "successfully" if response.status_code < 400 else "failed"
71 | content = (
72 | f"{overall_status} executed a request, duration={execution_time}s, "
73 | f"status_code={response.status_code}"
74 | )
75 | self._logger.info(content)
76 |
77 | def _logging_when_error_raised(self, exception: Exception):
78 | content = (
79 | f"error occurred when execute request, "
80 | f"error_type=[{exception.__class__.__name__}], error_msg=[{str(exception)}]"
81 | )
82 | self._logger.error(content)
83 |
84 |
85 | __all__ = ["RequestLoggingMiddleware"]
86 |
--------------------------------------------------------------------------------
/docs/tutorial/0_概述.md:
--------------------------------------------------------------------------------
1 | 概述
2 |
3 | 在本章节,我们会对整个教程进行一个概述,介绍后续各章节的核心内容,让读者对整个教程的知识脉络有一个比较清晰的把握,同时也方便读者有针对性的选择自己感兴趣或较陌生的话题来阅读,节省宝贵的时间。
4 |
5 | ## 话题范围
6 |
7 | 在当前版本,我们不会细致入微地涵盖大语言模型推理部署工程化落地的全方面内容。比如,我们不会讨论如何设计一个好的日志和监控告警系统,如何设计和实施分布式的容器化部署与管理,如何实现真正的负载均衡等等。在本教程,我们重点关注:大语言模型推理引擎设计与代码讲解,推理引擎的优化方向和实现思路,推理引擎的性能测试方法,以及硬件设备的选择。
8 |
9 | > 对于本教程不会涉及的部分,若读者朋友们有强烈的学习意愿和需要,欢迎在 issue 中提出,我们将根据社区的反馈考虑是否增加相关章节。
10 |
11 | ## 各章节介绍
12 |
13 | ### 第一章 大语言模型推理引擎设计与代码讲解
14 |
15 | 在第一章中,我们首先对大语言模型推理引擎进行整体性介绍;然后,在余下小节里,我们将会更详细地对大语言模型推理引擎中的各模块进行讲解,并对相关代码进行说明,让读者最终能够自己动手实现一个简单的高性能推理引擎。
16 |
17 | 以下展示第一章所含的各小节:
18 | 0. 概述
19 | 1. 请求打包模块
20 | 2. 缓存管理模块
21 | 3. 模型执行模块
22 | 4. 模型工具模块
23 | 5. 生成工具模块
24 | 6. Server 封装和 HTTP 服务化
25 |
26 | ### 第二章 推理引擎的优化方向和实现思路
27 |
28 | 在第二章节中,我们结合现有流行的推理引擎和最新的研究进展来分析推理引擎在推理性能和生成质量等方面可进一步优化的方向,以及可能的实现思路。
29 |
30 | 以下展示第二章所含的各小节:
31 | 1. 模型和缓存压缩的优化
32 | 2. 生成策略的优化
33 | 3. 请求打包策略的优化
34 | 4. 其他可能的优化
35 |
36 | ### 第三章 推理引擎的性能测试方法
37 |
38 | 在第三章节中,我们将介绍如何有效地评估推理引擎的性能和根据实际业务流量的特征选择测试策略。
39 |
40 | ### 第四章 硬件设备的选择
41 |
42 | 在第四章节中,我们将介绍如何结合实际的业务场景来有针对性地选择 GPU。
43 |
--------------------------------------------------------------------------------