├── .clang-format ├── .github └── ISSUE_TEMPLATE │ ├── config.yml │ └── issue.md ├── .gitignore ├── CMakeLists.txt ├── LICENSE ├── README.md ├── build.sh ├── cmake ├── deps.cmake ├── grpc_serving.cmake ├── install.cmake ├── llm.cmake ├── pplllmserving-config.cmake.in ├── sentencepiece.cmake └── xxhash.cmake ├── docs └── llama_guide.md ├── samples └── integration-cuda │ ├── CMakeLists.txt │ └── main.cc ├── src ├── backends │ └── cuda │ │ ├── post_processor.cc │ │ ├── post_processor.h │ │ ├── resource_manager.cc │ │ └── resource_manager.h ├── common │ ├── config.cc │ ├── config.h │ ├── connection.h │ ├── post_processor.h │ ├── profiler.cc │ ├── profiler.h │ ├── request.h │ ├── resource.h │ └── response.h ├── engine │ ├── llm_engine.cc │ └── llm_engine.h ├── generated │ └── onnx │ │ ├── v23.4 │ │ ├── llm.grpc.pb.cc │ │ ├── llm.grpc.pb.h │ │ ├── llm.pb.cc │ │ ├── llm.pb.h │ │ ├── sentencepiece.pb.cc │ │ ├── sentencepiece.pb.h │ │ ├── sentencepiece_model.pb.cc │ │ └── sentencepiece_model.pb.h │ │ └── v3.1.0 │ │ ├── sentencepiece.pb.cc │ │ ├── sentencepiece.pb.h │ │ ├── sentencepiece.proto │ │ ├── sentencepiece_model.pb.cc │ │ ├── sentencepiece_model.pb.h │ │ └── sentencepiece_model.proto ├── generator │ ├── llm_generator.cc │ └── llm_generator.h ├── onnx │ └── onnx.proto ├── serving │ └── grpc │ │ ├── grpc_server.cc │ │ ├── grpc_server.h │ │ └── proto │ │ └── llm.proto ├── tokenizer │ ├── models │ │ ├── baichuan │ │ │ └── baichuan_tokenizer.h │ │ ├── internlm │ │ │ └── internlm_tokenizer.h │ │ ├── llama │ │ │ └── llama_tokenizer.h │ │ └── llama3 │ │ │ ├── llama3_tokenizer.h │ │ │ └── tokenizer_config.json │ ├── tokenizer.h │ ├── tokenizer_factory.h │ ├── tokenizer_impl.h │ ├── tokenizer_impl_hf.h │ └── tokenizer_impl_sp.h └── utils │ ├── index_manager.h │ ├── mpsc_request_scheduler.h │ ├── prefix_cache_manager.h │ ├── utils.cc │ └── utils.h ├── test ├── CMakeLists.txt └── test_prefix_cache_mgr.cc └── tools ├── CMakeLists.txt ├── backtrace.h ├── benchmark_prefix_cache_offline.cc ├── client_pressure.cc ├── client_qps_measure.cc ├── client_qps_measure_token_in_out.cc ├── client_sample.cc ├── client_sample_token_in_out.cc ├── llm_server.cc ├── offline_inference.cc ├── samples_1024.json ├── samples_2048.json ├── samples_4096.json ├── samples_8192.json ├── simple_flags.cc └── simple_flags.h /.clang-format: -------------------------------------------------------------------------------- 1 | --- 2 | Language: Cpp 3 | # BasedOnStyle: Google 4 | AccessModifierOffset: -4 5 | AlignAfterOpenBracket: Align 6 | AlignConsecutiveAssignments: false 7 | AlignConsecutiveDeclarations: false 8 | AlignEscapedNewlines: Left 9 | AlignOperands: false 10 | AlignTrailingComments: false 11 | AllowAllParametersOfDeclarationOnNextLine: true 12 | AllowShortBlocksOnASingleLine: false 13 | AllowShortCaseLabelsOnASingleLine: false 14 | AllowShortFunctionsOnASingleLine: Empty 15 | AllowShortIfStatementsOnASingleLine: false 16 | AllowShortLambdasOnASingleLine: Empty 17 | AllowShortLoopsOnASingleLine: false 18 | AlwaysBreakAfterDefinitionReturnType: None 19 | AlwaysBreakAfterReturnType: None 20 | AlwaysBreakBeforeMultilineStrings: false 21 | AlwaysBreakTemplateDeclarations: true 22 | BinPackArguments: true 23 | BinPackParameters: true 24 | BraceWrapping: 25 | AfterClass: false 26 | AfterControlStatement: false 27 | AfterEnum: false 28 | AfterFunction: false 29 | AfterNamespace: false 30 | AfterObjCDeclaration: false 31 | AfterStruct: false 32 | AfterUnion: false 33 | BeforeCatch: false 34 | BeforeElse: false 35 | IndentBraces: false 36 | BreakBeforeBinaryOperators: None 37 | BreakBeforeBraces: Attach 38 | BreakBeforeTernaryOperators: true 39 | BreakConstructorInitializers: BeforeComma 40 | ColumnLimit: 120 41 | CommentPragmas: '^ IWYU pragma:' 42 | CompactNamespaces: true 43 | ConstructorInitializerAllOnOneLineOrOnePerLine: true 44 | ConstructorInitializerIndentWidth: 4 45 | ContinuationIndentWidth: 4 46 | Cpp11BracedListStyle: true 47 | DerivePointerAlignment: false 48 | DisableFormat: false 49 | #EmptyLineBeforeAccessModifier: Always 50 | ExperimentalAutoDetectBinPacking: false 51 | FixNamespaceComments: true 52 | ForEachMacros: [ foreach, Q_FOREACH, BOOST_FOREACH ] 53 | IncludeBlocks: Preserve 54 | IncludeCategories: 55 | - Regex: '^".*\.h"' 56 | Priority: 1 57 | - Regex: '^<.*\.h>' 58 | Priority: 2 59 | - Regex: '^<.*' 60 | Priority: 2 61 | - Regex: '.*' 62 | Priority: 3 63 | #IndentAccessModifiers: false 64 | IndentCaseLabels: true 65 | #IndentExternBlock: NoIndent 66 | #IndentGotoLabels: false 67 | IndentWidth: 4 68 | IndentWrappedFunctionNames: false 69 | KeepEmptyLinesAtTheStartOfBlocks: false 70 | MacroBlockBegin: '' 71 | MacroBlockEnd: '' 72 | MaxEmptyLinesToKeep: 1 73 | NamespaceIndentation: None 74 | PenaltyBreakBeforeFirstCallParameter: 1 75 | PenaltyBreakComment: 300 76 | #PenaltyBreakFirstLessLess: 100 77 | PenaltyBreakString: 1000 78 | PenaltyExcessCharacter: 1000000 79 | PenaltyReturnTypeOnItsOwnLine: 200 80 | PointerAlignment: Left 81 | ReflowComments: true 82 | SortIncludes: false 83 | SpaceAfterCStyleCast: false 84 | #SpaceAfterLogicalNot: false 85 | SpaceAfterTemplateKeyword: true 86 | SpaceBeforeAssignmentOperators: true 87 | #SpaceBeforeCaseColon: false 88 | #SpaceBeforeCtorInitializerColon: true 89 | #SpaceBeforeInheritanceColon: true 90 | SpaceBeforeParens: ControlStatements 91 | #SpaceBeforeRangeBasedForLoopColon: true 92 | #SpaceBeforeSquareBrackets: false 93 | #SpaceInEmptyBlock: false 94 | SpaceInEmptyParentheses: false 95 | SpacesBeforeTrailingComments: 1 96 | SpacesInAngles: false 97 | SpacesInContainerLiterals: false 98 | SpacesInCStyleCastParentheses: false 99 | SpacesInParentheses: false 100 | SpacesInSquareBrackets: false 101 | Standard: Auto 102 | TabWidth: 8 103 | UseTab: Never 104 | ... 105 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/config.yml: -------------------------------------------------------------------------------- 1 | blank_issues_enabled: false -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/issue.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Issue 3 | about: what happened? 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | ## What are the problems?(screenshots or detailed error messages) 11 | 12 | ## What are the types of GPU/CPU you are using? 13 | 14 | ## What's the operating system ppl.llm.serving runs on? 15 | 16 | ## What's the compiler and its version? 17 | 18 | ## Which version(commit id or tag) of ppl.llm.serving is used? 19 | 20 | ## What are the commands used to build ppl.llm.serving? 21 | 22 | ## What are the execution commands? 23 | 24 | ## minimal code snippets for reproducing these problems(if necessary) 25 | 26 | ## models and inputs for reproducing these problems (send them to openppl.ai@hotmail.com if necessary) 27 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | .vscode/ 3 | deps/ 4 | ppl-build/ -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.14) 2 | project(ppl.llm.serving) 3 | 4 | option(PPL_LLM_ENABLE_LLAMA "" ON) 5 | option(PPL_LLM_ENABLE_DEBUG "" OFF) 6 | option(PPL_LLM_SERVING_BUILD_TOOLS "" ON) 7 | option(PPL_LLM_INSTALL "" ON) 8 | option(PPL_LLM_ENABLE_GRPC_SERVING "" ON) 9 | option(PPL_LLM_SERVING_SYNC_DECODE "" OFF) 10 | 11 | option(PPL_LLM_ENABLE_HF_TOKENIZER "" OFF) 12 | option(PPL_LLM_ENABLE_TEST "" OFF) 13 | 14 | # --------------------------------------------------------------------------- # 15 | if(HPCC) 16 | include(hpcc-common) 17 | endif() 18 | 19 | if(PPL_LLM_ENABLE_LLAMA) 20 | set(CMAKE_CXX_STANDARD 17) 21 | endif() 22 | 23 | set(PPLNN_INSTALL ${PPL_LLM_INSTALL}) 24 | 25 | set(CMAKE_INSTALL_LIBDIR "lib") 26 | 27 | include(cmake/deps.cmake) 28 | 29 | # --------------------------------------------------------------------------- # 30 | 31 | if (PPL_LLM_ENABLE_HF_TOKENIZER) 32 | hpcc_populate_dep(tokenizer_cpp) 33 | endif() 34 | 35 | # import grpc first. for protobuf 36 | # set protobuf version before importing ppl.nn 37 | set(PPLNN_DEP_PROTOBUF_VERSION v23.4) 38 | set(protobuf_WITH_ZLIB OFF CACHE BOOL "") 39 | set(protobuf_BUILD_TESTS OFF CACHE BOOL "disable protobuf tests") 40 | 41 | hpcc_populate_dep(grpc) 42 | 43 | # use specified protobuf required by c++17 44 | find_package(Git REQUIRED) 45 | execute_process(COMMAND ${GIT_EXECUTABLE} checkout ${PPLNN_DEP_PROTOBUF_VERSION} 46 | WORKING_DIRECTORY ${grpc_SOURCE_DIR}/third_party/protobuf) 47 | 48 | set(PPLNN_PROTOC_EXECUTABLE ${grpc_BINARY_DIR}/third_party/protobuf/protoc) 49 | 50 | # --------------------------------------------------------------------------- # 51 | 52 | # generate new onnx.pb.* for pplnn 53 | set(__LLM_GENERATED_DIR__ ${CMAKE_CURRENT_BINARY_DIR}/generated) 54 | file(MAKE_DIRECTORY ${__LLM_GENERATED_DIR__}) 55 | 56 | set(__PROTO_DIR__ ${PROJECT_SOURCE_DIR}/src/onnx) 57 | set(__ONNX_GENERATED_FILES__ "${__LLM_GENERATED_DIR__}/onnx.pb.h;${__LLM_GENERATED_DIR__}/onnx.pb.cc") 58 | add_custom_command( 59 | OUTPUT ${__ONNX_GENERATED_FILES__} 60 | COMMAND ${PPLNN_PROTOC_EXECUTABLE} 61 | ARGS --cpp_out ${__LLM_GENERATED_DIR__} -I ${__PROTO_DIR__} 62 | ${__PROTO_DIR__}/onnx.proto 63 | DEPENDS protoc ${__PROTO_DIR__}/onnx.proto) 64 | add_library(pplnn_onnx_generated_static STATIC ${__ONNX_GENERATED_FILES__}) 65 | target_link_libraries(pplnn_onnx_generated_static PUBLIC libprotobuf) 66 | target_include_directories(pplnn_onnx_generated_static PUBLIC ${__LLM_GENERATED_DIR__}) 67 | set(PPLNN_ONNX_GENERATED_LIBS pplnn_onnx_generated_static) 68 | 69 | unset(__ONNX_GENERATED_FILES__) 70 | unset(__PROTO_DIR__) 71 | unset(__LLM_GENERATED_DIR__) 72 | 73 | # pplnn after serving, depends on libprotobuf provided by grpc 74 | hpcc_populate_dep(pplnn) 75 | 76 | # --------------------------------------------------------------------------- # 77 | 78 | # serving after pplnn. depends on pplcommon 79 | if(PPL_LLM_ENABLE_GRPC_SERVING) 80 | include(cmake/grpc_serving.cmake) 81 | endif() 82 | 83 | # --------------------------------------------------------------------------- # 84 | if(PPL_LLM_ENABLE_LLAMA) 85 | include(cmake/llm.cmake) 86 | endif() 87 | 88 | # --------------------------------------------------------------------------- # 89 | 90 | if(PPL_LLM_INSTALL) 91 | include(cmake/install.cmake) 92 | endif() 93 | 94 | # --------------------------------------------------------------------------- # 95 | 96 | if(PPL_LLM_SERVING_BUILD_TOOLS) 97 | include(tools/CMakeLists.txt) 98 | endif() 99 | 100 | if(PPL_LLM_ENABLE_TEST) 101 | add_subdirectory(test) 102 | endif() 103 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2023 OpenPPL 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PPL LLM Serving 2 | 3 | ## Overview 4 | 5 | `ppl.llm.serving` is a part of `PPL.LLM` system. 6 | 7 | ![SYSTEM_OVERVIEW](https://github.com/openppl-public/ppl.nn/blob/master/docs/images/llm-system-overview.png) 8 | 9 | **We recommend users who are new to this project to read the [Overview of system](https://github.com/openppl-public/ppl.nn/blob/master/docs/en/llm-system-overview.md).** 10 | 11 | `ppl.llm.serving` is a serving based on [ppl.nn](https://github.com/openppl-public/ppl.nn) for various Large Language Models(LLMs). This repository contains a server based on gRPC and inference support for [LLaMA](https://github.com/facebookresearch/llama). 12 | 13 | ## Prerequisites 14 | 15 | * Linux running on x86_64 or arm64 CPUs 16 | * GCC >= 9.4.0 17 | * [CMake](https://cmake.org/download/) >= 3.18 18 | * [Git](https://git-scm.com/downloads) >= 2.7.0 19 | * [CUDA Toolkit](https://developer.nvidia.com/cuda-toolkit-archive) >= 11.4. 11.6 recommended. (for CUDA) 20 | * Rust & cargo >= 1.8.0. (for Huggingface Tokenizer) 21 | 22 | ## PPL Server Quick Start 23 | 24 | Here is a brief tutorial, refer to [LLaMA Guide](docs/llama_guide.md) for more details. 25 | 26 | * Installing Prerequisites(on Debian or Ubuntu for example) 27 | 28 | ```bash 29 | apt-get install build-essential cmake git 30 | ``` 31 | 32 | * Cloning Source Code 33 | 34 | ```bash 35 | git clone https://github.com/openppl-public/ppl.llm.serving.git 36 | ``` 37 | 38 | * Building from Source 39 | 40 | ```bash 41 | ./build.sh -DPPLNN_USE_LLM_CUDA=ON -DPPLNN_CUDA_ENABLE_NCCL=ON -DPPLNN_ENABLE_CUDA_JIT=OFF -DPPLNN_CUDA_ARCHITECTURES="'80;86;87'" -DPPLCOMMON_CUDA_ARCHITECTURES="'80;86;87'" -DPPL_LLM_ENABLE_GRPC_SERVING=ON 42 | ``` 43 | 44 | NCCL is required if multiple GPU devices are used. 45 | 46 | We support **Sync Decode** feature (mainly for offline_inference), which means model forward and decode in the same thread. To enable this feature, compile with marco `-DPPL_LLM_SERVING_SYNC_DECODE=ON`. 47 | 48 | * Exporting Models 49 | 50 | Refer to [ppl.pmx](https://github.com/openppl-public/ppl.pmx) for details. 51 | 52 | * Running Server 53 | 54 | ```bash 55 | ./ppl_llm_server \ 56 | --model-dir /data/model \ 57 | --model-param-path /data/model/params.json \ 58 | --tokenizer-path /data/tokenizer.model \ 59 | --tensor-parallel-size 1 \ 60 | --top-p 0.0 \ 61 | --top-k 1 \ 62 | --max-tokens-scale 0.94 \ 63 | --max-input-tokens-per-request 4096 \ 64 | --max-output-tokens-per-request 4096 \ 65 | --max-total-tokens-per-request 8192 \ 66 | --max-running-batch 1024 \ 67 | --max-tokens-per-step 8192 \ 68 | --host 127.0.0.1 \ 69 | --port 23333 70 | ``` 71 | 72 | You are expected to give the correct values before running the server. 73 | 74 | - `model-dir`: path of models exported by [ppl.pmx](https://github.com/openppl-public/ppl.pmx). 75 | - `model-param-path`: params of models. `$model_dir/params.json`. 76 | - `tokenizer-path`: tokenizer files for `sentencepiece`. 77 | 78 | * Running client: send request through gRPC to query the model 79 | 80 | ```bash 81 | ./ppl-build/client_sample 127.0.0.1:23333 82 | ``` 83 | See [tools/client_sample.cc](tools/client_sample.cc) for more details. 84 | 85 | * Benchmarking 86 | 87 | ```bash 88 | ./ppl-build/client_qps_measure --target=127.0.0.1:23333 --tokenizer=/path/to/tokenizer/path --dataset=tools/samples_1024.json --request_rate=inf 89 | ``` 90 | See [tools/client_qps_measure.cc](tools/client_qps_measure.cc) for more details. `--request_rate` is the number of request per second, and value `inf` means send all client request with no interval. 91 | 92 | * Running inference offline: 93 | 94 | ```bash 95 | ./offline_inference \ 96 | --model-dir /data/model \ 97 | --model-param-path /data/model/params.json \ 98 | --tokenizer-path /data/tokenizer.model \ 99 | --tensor-parallel-size 1 \ 100 | --top-p 0.0 \ 101 | --top-k 1 \ 102 | --max-tokens-scale 0.94 \ 103 | --max-input-tokens-per-request 4096 \ 104 | --max-output-tokens-per-request 4096 \ 105 | --max-total-tokens-per-request 8192 \ 106 | --max-running-batch 1024 \ 107 | --max-tokens-per-step 8192 \ 108 | --host 127.0.0.1 \ 109 | --port 23333 110 | ``` 111 | See [tools/offline_inference.cc](tools/offline_inference.cc) for more details. 112 | 113 | ### License 114 | 115 | This project is distributed under the [Apache License, Version 2.0](LICENSE). 116 | -------------------------------------------------------------------------------- /build.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | workdir=`pwd` 4 | 5 | if [ -z "$PPL_BUILD_THREAD_NUM" ]; then 6 | PPL_BUILD_THREAD_NUM=16 7 | echo -e "env 'PPL_BUILD_THREAD_NUM' is not set. use PPL_BUILD_THREAD_NUM=${PPL_BUILD_THREAD_NUM} by default." 8 | fi 9 | 10 | if [ -z "$BUILD_TYPE" ]; then 11 | build_type='Release' 12 | else 13 | build_type="$BUILD_TYPE" 14 | fi 15 | options="-DCMAKE_BUILD_TYPE=${build_type} -DCMAKE_INSTALL_PREFIX=install $*" 16 | 17 | ppl_build_dir="${workdir}/ppl-build" 18 | if [ ! -d "$ppl_build_dir" ]; then 19 | mkdir ${ppl_build_dir} 20 | fi 21 | cd ${ppl_build_dir} 22 | cmd="cmake $options .. && cmake --build . -j ${PPL_BUILD_THREAD_NUM} --config ${build_type}" 23 | echo "cmd -> $cmd" 24 | eval "$cmd" 25 | -------------------------------------------------------------------------------- /cmake/deps.cmake: -------------------------------------------------------------------------------- 1 | if(NOT HPCC_DEPS_DIR) 2 | set(HPCC_DEPS_DIR ${CMAKE_CURRENT_SOURCE_DIR}/deps) 3 | endif() 4 | 5 | include(FetchContent) 6 | 7 | set(FETCHCONTENT_BASE_DIR ${HPCC_DEPS_DIR}) 8 | set(FETCHCONTENT_QUIET OFF) 9 | 10 | if(PPLNN_HOLD_DEPS) 11 | set(FETCHCONTENT_UPDATES_DISCONNECTED ON) 12 | endif() 13 | 14 | # --------------------------------------------------------------------------- # 15 | 16 | if(PPL_LLM_ENABLE_LLAMA AND CMAKE_COMPILER_IS_GNUCC) 17 | if(CMAKE_CXX_COMPILER_VERSION VERSION_LESS 9.0.0) 18 | message(FATAL_ERROR "gcc >= 9.0.0 is required.") 19 | endif() 20 | if(CMAKE_CXX_COMPILER_VERSION VERSION_EQUAL 10.3.0) 21 | message(FATAL_ERROR "gcc 10.3.0 has known bugs. use another version >= 9.0.0.") 22 | endif() 23 | endif() 24 | 25 | # --------------------------------------------------------------------------- # 26 | 27 | find_package(Git QUIET) 28 | if(NOT Git_FOUND) 29 | message(FATAL_ERROR "git is required.") 30 | endif() 31 | 32 | if(NOT PPL_LLM_DEP_HPCC_VERSION) 33 | set(PPL_LLM_DEP_HPCC_VERSION master) 34 | endif() 35 | 36 | if(PPL_LLM_DEP_HPCC_PKG) 37 | FetchContent_Declare(hpcc 38 | URL ${PPL_LLM_DEP_HPCC_PKG} 39 | SOURCE_DIR ${HPCC_DEPS_DIR}/hpcc 40 | BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR}/hpcc-build 41 | SUBBUILD_DIR ${HPCC_DEPS_DIR}/hpcc-subbuild) 42 | else() 43 | if(NOT PPL_LLM_DEP_HPCC_GIT) 44 | set(PPL_LLM_DEP_HPCC_GIT "https://github.com/OpenPPL/hpcc.git") 45 | endif() 46 | FetchContent_Declare(hpcc 47 | GIT_REPOSITORY ${PPL_LLM_DEP_HPCC_GIT} 48 | GIT_TAG ${PPL_LLM_DEP_HPCC_VERSION} 49 | SOURCE_DIR ${HPCC_DEPS_DIR}/hpcc 50 | BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR}/hpcc-build 51 | SUBBUILD_DIR ${HPCC_DEPS_DIR}/hpcc-subbuild) 52 | endif() 53 | 54 | FetchContent_GetProperties(hpcc) 55 | if(NOT hpcc_POPULATED) 56 | FetchContent_Populate(hpcc) 57 | include(${hpcc_SOURCE_DIR}/cmake/hpcc-common.cmake) 58 | endif() 59 | 60 | # ------------------------------------------------------------------------- # 61 | 62 | set(gRPC_BUILD_TESTS OFF CACHE BOOL "") 63 | set(gRPC_BUILD_CSHARP_EXT OFF CACHE BOOL "") 64 | set(gRPC_BUILD_GRPC_CSHARP_PLUGIN OFF CACHE BOOL "") 65 | set(gRPC_BUILD_GRPC_NODE_PLUGIN OFF CACHE BOOL "") 66 | set(gRPC_BUILD_GRPC_OBJECTIVE_C_PLUGIN OFF CACHE BOOL "") 67 | set(gRPC_BUILD_GRPC_PHP_PLUGIN OFF CACHE BOOL "") 68 | set(gRPC_BUILD_GRPC_PYTHON_PLUGIN OFF CACHE BOOL "") 69 | set(gRPC_BUILD_GRPC_RUBY_PLUGIN OFF CACHE BOOL "") 70 | set(ABSL_PROPAGATE_CXX_STD ON CACHE BOOL "") 71 | set(ABSL_ENABLE_INSTALL ON CACHE BOOL "required by protobuf") 72 | 73 | # --------------------------------------------------------------------------- # 74 | 75 | if(PPL_LLM_DEP_GRPC_PKG) 76 | hpcc_declare_pkg_dep(grpc 77 | ${PPL_LLM_DEP_GRPC_PKG}) 78 | else() 79 | if(NOT PPL_LLM_DEP_GRPC_GIT) 80 | set(PPL_LLM_DEP_GRPC_GIT "https://github.com/grpc/grpc.git") 81 | endif() 82 | hpcc_declare_git_dep_depth1(grpc 83 | ${PPL_LLM_DEP_GRPC_GIT} 84 | v1.56.2) 85 | endif() 86 | 87 | # --------------------------------------------------------------------------- # 88 | 89 | set(PPLNN_BUILD_TESTS OFF CACHE BOOL "") 90 | set(PPLNN_BUILD_SAMPLES OFF CACHE BOOL "") 91 | 92 | if(NOT PPL_LLM_DEP_PPLNN_VERSION) 93 | set(PPL_LLM_DEP_PPLNN_VERSION master) 94 | endif() 95 | 96 | if(PPL_LLM_DEP_PPLNN_PKG) 97 | hpcc_declare_pkg_dep(pplnn 98 | ${PPL_LLM_DEP_PPLNN_PKG}) 99 | else() 100 | if(NOT PPL_LLM_DEP_PPLNN_GIT) 101 | set(PPL_LLM_DEP_PPLNN_GIT "https://github.com/OpenPPL/ppl.nn.git") 102 | endif() 103 | hpcc_declare_git_dep(pplnn 104 | ${PPL_LLM_DEP_PPLNN_GIT} 105 | ${PPL_LLM_DEP_PPLNN_VERSION}) 106 | endif() 107 | 108 | # --------------------------------------------------------------------------- # 109 | 110 | if(PPL_LLM_DEP_ABSL_PKG) 111 | hpcc_declare_pkg_dep(absl 112 | ${PPL_LLM_DEP_ABSL_PKG}) 113 | else() 114 | if(NOT PPL_LLM_DEP_ABSL_GIT) 115 | set(PPL_LLM_DEP_ABSL_GIT "https://github.com/abseil/abseil-cpp.git") 116 | endif() 117 | hpcc_declare_git_dep_depth1(absl 118 | ${PPL_LLM_DEP_ABSL_GIT} 119 | lts_2023_01_25) 120 | endif() 121 | 122 | # --------------------------------------------------------------------------- # 123 | 124 | if(PPL_LLM_DEP_SENTENCEPIECE_PKG) 125 | hpcc_declare_pkg_dep(sentencepiece 126 | ${PPL_LLM_DEP_SENTENCEPIECE_PKG}) 127 | else() 128 | if(NOT PPL_LLM_DEP_SENTENCEPIECE_GIT) 129 | set(PPL_LLM_DEP_SENTENCEPIECE_GIT "https://github.com/OpenPPL/sentencepiece.git") 130 | endif() 131 | hpcc_declare_git_dep_depth1(sentencepiece 132 | ${PPL_LLM_DEP_SENTENCEPIECE_GIT} 133 | ppl) 134 | endif() 135 | 136 | # --------------------------------------------------------------------------- # 137 | 138 | set(BUILD_SHARED_LIBS OFF CACHE BOOL "") 139 | set(ENABLE_PUSH OFF CACHE BOOL "") 140 | set(ENABLE_COMPRESSION OFF CACHE BOOL "") 141 | set(ENABLE_TESTING OFF CACHE BOOL "") 142 | set(GENERATE_PKGCONFIG OFF CACHE BOOL "") 143 | set(OVERRIDE_CXX_STANDARD_FLAGS OFF CACHE BOOL "") 144 | 145 | # --------------------------------------------------------------------------- # 146 | 147 | if(PPL_LLM_DEP_PROMETHEUS_PKG) 148 | hpcc_declare_pkg_dep(prometheus 149 | ${PPL_LLM_DEP_PROMETHEUS_PKG}) 150 | else() 151 | if(NOT PPL_LLM_DEP_PROMETHEUS_GIT) 152 | set(PPL_LLM_DEP_PROMETHEUS_GIT "https://github.com/jupp0r/prometheus-cpp.git") 153 | endif() 154 | hpcc_declare_git_dep_depth1(prometheus 155 | ${PPL_LLM_DEP_PROMETHEUS_GIT} 156 | v1.2.4) 157 | endif() 158 | 159 | # --------------------------------------------------------------------------- # 160 | 161 | if(PPL_LLM_DEP_XXHASH_PKG) 162 | hpcc_declare_pkg_dep(xxhash 163 | ${PPL_LLM_DEP_XXHASH_PKG}) 164 | else() 165 | if(NOT PPL_LLM_DEP_XXHASH_GIT) 166 | set(PPL_LLM_DEP_XXHASH_GIT "https://github.com/Cyan4973/xxHash.git") 167 | endif() 168 | hpcc_declare_git_dep_depth1(xxhash 169 | ${PPL_LLM_DEP_XXHASH_GIT} 170 | v0.8.2) 171 | endif() 172 | 173 | # --------------------------------------------------------------------------- # 174 | 175 | if(PPL_LLM_DEP_TOKENIZER_CPP_PKG) 176 | hpcc_declare_pkg_dep(tokenizer_cpp 177 | ${PPL_LLM_DEP_TOKENIZER_CPP_PKG}) 178 | else() 179 | if(NOT PPL_LLM_DEP_TOKENIZER_CPP_GIT) 180 | set(PPL_LLM_DEP_TOKENIZER_CPP_GIT "https://github.com/OpenPPL/tokenizers-cpp.git") 181 | endif() 182 | hpcc_declare_git_dep_depth1(tokenizer_cpp 183 | ${PPL_LLM_DEP_TOKENIZER_CPP_GIT} 184 | llm_v2) 185 | endif() 186 | -------------------------------------------------------------------------------- /cmake/grpc_serving.cmake: -------------------------------------------------------------------------------- 1 | set(__LLM_SERVING_GENERATED_DIR__ "${CMAKE_CURRENT_BINARY_DIR}/generated") 2 | file(MAKE_DIRECTORY ${__LLM_SERVING_GENERATED_DIR__}) 3 | 4 | set(__PROTO_DIR__ ${PROJECT_SOURCE_DIR}/src/serving/grpc/proto) 5 | set(PROTOC_EXECUTABLE "${CMAKE_CURRENT_BINARY_DIR}/grpc-build/third_party/protobuf/protoc") 6 | 7 | # ----- cannot disable zlib tests and examples ----- # 8 | 9 | if(TARGET minigzip) 10 | set_target_properties(minigzip PROPERTIES LINK_LIBRARIES z) 11 | endif() 12 | if(TARGET minigzip64) 13 | set_target_properties(minigzip64 PROPERTIES LINK_LIBRARIES z) 14 | endif() 15 | if(TARGET example) 16 | set_target_properties(example PROPERTIES LINK_LIBRARIES z) 17 | endif() 18 | if(TARGET example64) 19 | set_target_properties(example64 PROPERTIES LINK_LIBRARIES z) 20 | endif() 21 | 22 | # ----- grpc serving pb files ----- # 23 | 24 | set(__LLM_GENERATED_FILES__ "${__LLM_SERVING_GENERATED_DIR__}/llm.pb.cc;${__LLM_SERVING_GENERATED_DIR__}/llm.pb.h;${__LLM_SERVING_GENERATED_DIR__}/llm.grpc.pb.cc;${__LLM_SERVING_GENERATED_DIR__}/llm.grpc.pb.h") 25 | 26 | set(GRPC_CPP_PLUGIN_EXECUTABLE "${CMAKE_CURRENT_BINARY_DIR}/grpc-build/grpc_cpp_plugin") 27 | add_custom_command( 28 | OUTPUT ${__LLM_GENERATED_FILES__} 29 | COMMAND ${PROTOC_EXECUTABLE} 30 | ARGS --grpc_out "${__LLM_SERVING_GENERATED_DIR__}" --cpp_out "${__LLM_SERVING_GENERATED_DIR__}" 31 | -I "${__PROTO_DIR__}" --plugin=protoc-gen-grpc="${GRPC_CPP_PLUGIN_EXECUTABLE}" 32 | "${__PROTO_DIR__}/llm.proto" 33 | DEPENDS protoc grpc_cpp_plugin ${__PROTO_DIR__}/llm.proto) 34 | 35 | add_library(ppl_llm_grpc_proto_static STATIC ${__LLM_GENERATED_FILES__}) 36 | target_link_libraries(ppl_llm_grpc_proto_static PUBLIC libprotobuf grpc) 37 | target_include_directories(ppl_llm_grpc_proto_static PUBLIC 38 | ${HPCC_DEPS_DIR}/grpc/include 39 | ${__LLM_SERVING_GENERATED_DIR__}) 40 | 41 | # ----- # 42 | 43 | file(GLOB __SRC__ src/serving/grpc/*.cc) 44 | add_library(ppl_llm_grpc_serving_static STATIC ${__SRC__}) 45 | target_link_libraries(ppl_llm_grpc_serving_static PUBLIC ppl_llm_grpc_proto_static ppl_llm_static grpc++ pplcommon_static pthread) 46 | target_include_directories(ppl_llm_grpc_serving_static PUBLIC src) 47 | -------------------------------------------------------------------------------- /cmake/install.cmake: -------------------------------------------------------------------------------- 1 | if(NOT PPL_LLM_ENABLE_LLAMA) 2 | return() 3 | endif() 4 | 5 | set(__PPLNN_CMAKE_CONFIG_FILE__ ${CMAKE_CURRENT_BINARY_DIR}/generated/pplllmserving-config.cmake) 6 | configure_file(${CMAKE_CURRENT_SOURCE_DIR}/cmake/pplllmserving-config.cmake.in 7 | ${__PPLNN_CMAKE_CONFIG_FILE__} 8 | @ONLY) 9 | install(FILES ${__PPLNN_CMAKE_CONFIG_FILE__} DESTINATION lib/cmake/ppl) 10 | unset(__PPLNN_CMAKE_CONFIG_FILE__) 11 | 12 | file(GLOB __TMP__ src/common/*.h) 13 | install(FILES ${__TMP__} DESTINATION include/ppl/llm/common) 14 | 15 | file(GLOB __TMP__ src/utils/*.h) 16 | install(FILES ${__TMP__} DESTINATION include/ppl/llm/utils) 17 | 18 | file(GLOB __TMP__ src/models/*.h) 19 | install(FILES ${__TMP__} DESTINATION include/ppl/llm/models) 20 | 21 | if(PPL_LLM_ENABLE_LLAMA) 22 | file(GLOB __TMP__ src/models/llama/*.h) 23 | install(FILES ${__TMP__} DESTINATION include/ppl/llm/models/llama) 24 | endif() 25 | 26 | if(PPL_LLM_ENABLE_GRPC_SERVING) 27 | file(GLOB __TMP__ src/serving/*.h) 28 | install(FILES ${__TMP__} DESTINATION include/ppl/llm/serving) 29 | endif() 30 | 31 | if(PPLNN_USE_LLM_CUDA) 32 | file(GLOB __TMP__ src/backends/cuda/*.h) 33 | install(FILES ${__TMP__} DESTINATION include/ppl/llm/backends/cuda) 34 | endif() 35 | 36 | unset(__TMP__) 37 | -------------------------------------------------------------------------------- /cmake/llm.cmake: -------------------------------------------------------------------------------- 1 | # sentencepiece after serving, libprotobuf and absl wanted 2 | 3 | file(GLOB __PPL_LLM_SRC__ 4 | src/common/*.cc 5 | src/engine/*.cc 6 | src/generator/*.cc 7 | src/utils/*.cc 8 | ) 9 | 10 | if(PPLNN_USE_LLM_CUDA) 11 | if(HPCC_ENABLE_SANITIZE_OPTIONS) 12 | message(FATAL_ERROR "`HPCC_ENABLE_SANITIZE_OPTIONS` can not be used with nccl now.") 13 | endif() 14 | file(GLOB __TMP_SRC__ src/backends/cuda/*.cc) 15 | list(APPEND __PPL_LLM_SRC__ ${__TMP_SRC__}) 16 | unset(__TMP_SRC__) 17 | endif() 18 | 19 | include(cmake/sentencepiece.cmake) 20 | 21 | add_library(ppl_llm_static STATIC ${__PPL_LLM_SRC__}) 22 | target_link_libraries(ppl_llm_static PUBLIC 23 | pplnn_static 24 | ppl_sentencepiece_static) 25 | if (PPL_LLM_ENABLE_HF_TOKENIZER) 26 | target_link_libraries(ppl_llm_static PUBLIC tokenizers_cpp tokenizers_c) 27 | target_compile_definitions(ppl_llm_static PUBLIC PPL_LLM_ENABLE_HF_TOKENIZER) 28 | endif() 29 | 30 | target_compile_options(ppl_llm_static PUBLIC ${HPCC_SANITIZE_COMPILE_OPTIONS}) 31 | target_include_directories(ppl_llm_static PUBLIC ${HPCC_DEPS_DIR}/rapidjson/include) 32 | target_link_options(ppl_llm_static PUBLIC ${HPCC_SANITIZE_LINK_OPTIONS}) 33 | install(TARGETS ppl_llm_static DESTINATION lib) 34 | 35 | if(PPL_LLM_ENABLE_DEBUG) 36 | target_compile_definitions(ppl_llm_static PUBLIC PPL_LLM_ENABLE_DEBUG) 37 | endif() 38 | if(PPLNN_CUDA_ENABLE_NCCL) 39 | target_compile_definitions(ppl_llm_static PUBLIC PPLNN_CUDA_ENABLE_NCCL) 40 | endif() 41 | if(PPLNN_USE_LLM_CUDA) 42 | target_compile_definitions(ppl_llm_static PUBLIC PPLNN_USE_LLM_CUDA) 43 | endif() 44 | if (PPL_LLM_SERVING_SYNC_DECODE) 45 | target_compile_definitions(ppl_llm_static PUBLIC PPL_LLM_SERVING_SYNC_DECODE) 46 | endif() 47 | -------------------------------------------------------------------------------- /cmake/pplllmserving-config.cmake.in: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.10) 2 | 3 | # cmake config without grpc 4 | 5 | if(TARGET "ppl_llm_static") 6 | return() 7 | endif() 8 | 9 | add_library(ppl_llm_static STATIC IMPORTED) 10 | 11 | # --------------------------------------------------------------------------- # 12 | 13 | get_filename_component(__PPL_LLM_PACKAGE_ROOTDIR__ "${CMAKE_CURRENT_LIST_DIR}/../../.." ABSOLUTE) 14 | 15 | # --------------------------------------------------------------------------- # 16 | 17 | # exported definitions 18 | 19 | option(PPL_LLM_ENABLE_LLAMA "" @PPL_LLM_ENABLE_LLAMA@) 20 | 21 | if(NOT TARGET "pplnn_basic_static") 22 | include(${CMAKE_CURRENT_LIST_DIR}/pplnn-config.cmake) 23 | endif() 24 | 25 | set(PPL_LLM_LIBRARIES "ppl_llm_static") 26 | set(PPL_LLM_INCLUDE_DIRS ${__PPL_LLM_PACKAGE_ROOTDIR__}/include) 27 | set(PPL_LLM_LINK_DIRS ${__PPL_LLM_PACKAGE_ROOTDIR__}/lib) 28 | 29 | # --------------------------------------------------------------------------- # 30 | 31 | # sentencepiece 32 | include(${__PPL_LLM_PACKAGE_ROOTDIR__}/lib/cmake/absl/abslConfig.cmake) 33 | 34 | add_library(ppl_sentencepiece_pb_static STATIC IMPORTED) 35 | get_filename_component(__LIB_PATH__ "${__PPL_LLM_PACKAGE_ROOTDIR__}/lib/@HPCC_STATIC_LIB_PREFIX@@PPL_LLM_SENTENCEPIECE_PROTOBUF_LIBS@@HPCC_STATIC_LIB_SUFFIX@" ABSOLUTE) 36 | set_target_properties(ppl_sentencepiece_pb_static PROPERTIES 37 | INTERFACE_LINK_LIBRARIES "protobuf::libprotobuf;absl::strings;absl::flags;absl::flags_parse" 38 | IMPORTED_LOCATION "${__LIB_PATH__}" 39 | IMPORTED_LOCATION_DEBUG "${__LIB_PATH__}" 40 | IMPORTED_LOCATION_RELEASE "${__LIB_PATH__}") 41 | unset(__LIB_PATH__) 42 | 43 | add_library(ppl_sentencepiece_static STATIC IMPORTED) 44 | get_filename_component(__LIB_PATH__ "${__PPL_LLM_PACKAGE_ROOTDIR__}/lib/@HPCC_STATIC_LIB_PREFIX@ppl_sentencepiece_static@HPCC_STATIC_LIB_SUFFIX@" ABSOLUTE) 45 | set_target_properties(ppl_sentencepiece_static PROPERTIES 46 | INTERFACE_LINK_LIBRARIES "ppl_sentencepiece_pb_static" 47 | IMPORTED_LOCATION "${__LIB_PATH__}" 48 | IMPORTED_LOCATION_DEBUG "${__LIB_PATH__}" 49 | IMPORTED_LOCATION_RELEASE "${__LIB_PATH__}") 50 | unset(__LIB_PATH__) 51 | 52 | # --------------------------------------------------------------------------- # 53 | 54 | get_filename_component(__LIB_PATH__ "${__PPL_LLM_PACKAGE_ROOTDIR__}/lib/@HPCC_STATIC_LIB_PREFIX@ppl_llm_static@HPCC_STATIC_LIB_SUFFIX@" ABSOLUTE) 55 | set_target_properties(ppl_llm_static PROPERTIES 56 | INTERFACE_LINK_LIBRARIES "ppl_sentencepiece_static;${PPLNN_LIBRARIES}" 57 | INTERFACE_LINK_DIRECTORIES "${PPLNN_LINK_DIRS}" 58 | INTERFACE_INCLUDE_DIRECTORIES "${PPLNN_INCLUDE_DIRS}" 59 | IMPORTED_LOCATION "${__LIB_PATH__}" 60 | IMPORTED_LOCATION_DEBUG "${__LIB_PATH__}" 61 | IMPORTED_LOCATION_RELEASE "${__LIB_PATH__}") 62 | unset(__LIB_PATH__) 63 | 64 | # --------------------------------------------------------------------------- # 65 | 66 | unset(__PPL_LLM_PACKAGE_ROOTDIR__) 67 | -------------------------------------------------------------------------------- /cmake/sentencepiece.cmake: -------------------------------------------------------------------------------- 1 | FetchContent_GetProperties(sentencepiece) 2 | if(NOT sentencepiece_POPULATED) 3 | FetchContent_Populate(sentencepiece) 4 | endif() 5 | 6 | set(__PPL_LLAMA_GENERATED_DIR__ "${CMAKE_CURRENT_BINARY_DIR}/generated") 7 | file(MAKE_DIRECTORY ${__PPL_LLAMA_GENERATED_DIR__}) 8 | 9 | set(__SENTENCEPIECE_ROOT_DIR__ ${HPCC_DEPS_DIR}/sentencepiece) 10 | 11 | if(NOT PPL_LLM_SENTENCEPIECE_PROTOBUF_LIBS) 12 | set(__SP_GENERATED_FILES__ "${__PPL_LLAMA_GENERATED_DIR__}/sentencepiece.pb.h;${__PPL_LLAMA_GENERATED_DIR__}/sentencepiece.pb.cc") 13 | set(__SPM_GENERATED_FILES__ "${__PPL_LLAMA_GENERATED_DIR__}/sentencepiece_model.pb.h;${__PPL_LLAMA_GENERATED_DIR__}/sentencepiece_model.pb.cc") 14 | 15 | set(__PROTO_DIR__ ${__SENTENCEPIECE_ROOT_DIR__}/src) 16 | add_custom_command( 17 | OUTPUT ${__SP_GENERATED_FILES__} 18 | COMMAND ${PPLNN_PROTOC_EXECUTABLE} 19 | ARGS --cpp_out "${__PPL_LLAMA_GENERATED_DIR__}" -I "${__PROTO_DIR__}" "${__PROTO_DIR__}/sentencepiece.proto" 20 | DEPENDS protoc) 21 | add_custom_command( 22 | OUTPUT ${__SPM_GENERATED_FILES__} 23 | COMMAND ${PPLNN_PROTOC_EXECUTABLE} 24 | ARGS --cpp_out "${__PPL_LLAMA_GENERATED_DIR__}" -I "${__PROTO_DIR__}" "${__PROTO_DIR__}/sentencepiece_model.proto" 25 | DEPENDS protoc) 26 | unset(__PROTO_DIR__) 27 | 28 | add_library(ppl_sentencepiece_pb_static STATIC 29 | ${__SP_GENERATED_FILES__} 30 | ${__SPM_GENERATED_FILES__}) 31 | target_link_libraries(ppl_sentencepiece_pb_static PUBLIC libprotobuf) 32 | 33 | unset(__SPM_GENERATED_FILES__) 34 | unset(__SP_GENERATED_FILES__) 35 | 36 | set(PPL_LLM_SENTENCEPIECE_PROTOBUF_LIBS ppl_sentencepiece_pb_static) 37 | endif() 38 | 39 | configure_file("${__SENTENCEPIECE_ROOT_DIR__}/config.h.in" ${__PPL_LLAMA_GENERATED_DIR__}/config.h) 40 | 41 | set(__PPL_SPM_SRCS__ 42 | ${__SENTENCEPIECE_ROOT_DIR__}/src/bpe_model.h 43 | ${__SENTENCEPIECE_ROOT_DIR__}/src/common.h 44 | ${__SENTENCEPIECE_ROOT_DIR__}/src/normalizer.h 45 | ${__SENTENCEPIECE_ROOT_DIR__}/src/util.h 46 | ${__SENTENCEPIECE_ROOT_DIR__}/src/freelist.h 47 | ${__SENTENCEPIECE_ROOT_DIR__}/src/filesystem.h 48 | ${__SENTENCEPIECE_ROOT_DIR__}/src/init.h 49 | ${__SENTENCEPIECE_ROOT_DIR__}/src/sentencepiece_processor.h 50 | ${__SENTENCEPIECE_ROOT_DIR__}/src/word_model.h 51 | ${__SENTENCEPIECE_ROOT_DIR__}/src/model_factory.h 52 | ${__SENTENCEPIECE_ROOT_DIR__}/src/char_model.h 53 | ${__SENTENCEPIECE_ROOT_DIR__}/src/model_interface.h 54 | ${__SENTENCEPIECE_ROOT_DIR__}/src/testharness.h 55 | ${__SENTENCEPIECE_ROOT_DIR__}/src/unigram_model.h 56 | ${__SENTENCEPIECE_ROOT_DIR__}/src/bpe_model.cc 57 | ${__SENTENCEPIECE_ROOT_DIR__}/src/char_model.cc 58 | ${__SENTENCEPIECE_ROOT_DIR__}/src/error.cc 59 | ${__SENTENCEPIECE_ROOT_DIR__}/src/filesystem.cc 60 | ${__SENTENCEPIECE_ROOT_DIR__}/src/model_factory.cc 61 | ${__SENTENCEPIECE_ROOT_DIR__}/src/model_interface.cc 62 | ${__SENTENCEPIECE_ROOT_DIR__}/src/normalizer.cc 63 | ${__SENTENCEPIECE_ROOT_DIR__}/src/sentencepiece_processor.cc 64 | ${__SENTENCEPIECE_ROOT_DIR__}/src/unigram_model.cc 65 | ${__SENTENCEPIECE_ROOT_DIR__}/src/util.cc 66 | ${__SENTENCEPIECE_ROOT_DIR__}/src/word_model.cc) 67 | 68 | add_library(ppl_sentencepiece_static STATIC ${__PPL_SPM_SRCS__}) 69 | target_link_libraries(ppl_sentencepiece_static PUBLIC 70 | ${PPL_LLM_SENTENCEPIECE_PROTOBUF_LIBS} 71 | libprotobuf absl::strings absl::flags absl::flags_parse) 72 | target_include_directories(ppl_sentencepiece_static PUBLIC 73 | ${__PPL_LLAMA_GENERATED_DIR__} 74 | ${__SENTENCEPIECE_ROOT_DIR__} 75 | ${__SENTENCEPIECE_ROOT_DIR__}/src) 76 | target_compile_definitions(ppl_sentencepiece_static PRIVATE _USE_EXTERNAL_PROTOBUF) 77 | 78 | if(PPL_LLM_INSTALL) 79 | install(TARGETS ${PPL_LLM_SENTENCEPIECE_PROTOBUF_LIBS} DESTINATION lib) 80 | install(TARGETS ppl_sentencepiece_static DESTINATION lib) 81 | endif() 82 | 83 | unset(__SENTENCEPIECE_ROOT_DIR__) 84 | unset(__PPL_LLAMA_GENERATED_DIR__) 85 | -------------------------------------------------------------------------------- /cmake/xxhash.cmake: -------------------------------------------------------------------------------- 1 | hpcc_populate_dep(xxhash) 2 | 3 | set(__XXHASH_SRC__ ${xxhash_SOURCE_DIR}/xxhash.c) 4 | 5 | if(CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64") 6 | list(APPEND __XXHASH_SRC__ ${xxhash_SOURCE_DIR}/xxh_x86dispatch.c) 7 | endif() 8 | 9 | add_library(xxhash_static STATIC ${__XXHASH_SRC__}) 10 | target_compile_definitions(xxhash_static PRIVATE DISPATCH=1) 11 | target_include_directories(xxhash_static PUBLIC ${xxhash_SOURCE_DIR}) 12 | unset(__XXHASH_SRC__) 13 | -------------------------------------------------------------------------------- /docs/llama_guide.md: -------------------------------------------------------------------------------- 1 | ## LLaMA Guide 2 | 3 | 1. Download model 4 | 5 | Download llama model and tokenizer following the documents provided by facebook: 6 | [LLaMA](https://github.com/facebookresearch/llama/tree/llama_v1#llama). 7 | 8 | Here we save downloaded file in folder `/model_data/llama_fb/`. 9 | 10 | 2. Convert model. 11 | 12 | Convert through our pmx, following [guide](https://github.com/openppl-public/ppl.pmx/blob/master/model_zoo/llama/facebook/README.md). Here we use llama_7b for example, exporting model in `/model_data/llama_fb/7B/`. 13 | 14 | ```bash 15 | git clone https://github.com/openppl-public/ppl.pmx.git 16 | cd ppl.pmx/model_zoo/llama/facebook 17 | pip install -r requirements.txt # requirements 18 | MP=1 19 | OMP_NUM_THREADS=${MP} torchrun --nproc_per_node ${MP} \ 20 | Export.py --ckpt_dir /model_data/llama_fb/7B/ \ 21 | --tokenizer_path /model_data/llama_fb/tokenizer.model \ 22 | --export_path /model_data/llama_7b_ppl/ \ 23 | --fused_qkv 1 --fused_kvcache 1 --auto_causal 1 \ 24 | --quantized_cache 1 --dynamic_batching 1 25 | ``` 26 | Differenct model require different MP values, for llama_7b `MP=1`. 27 | | MP | value | 28 | |----------------------|-------| 29 | | LLaMA-7B | 1 | 30 | | LLaMA-13B | 2 | 31 | | LLaMA-30B | 4 | 32 | | LLaMA-65B | 8 | 33 | 34 | Here, we generate `model_slice_0` and `params.json` in `llama_ppl_7B/`. 35 | * Folder `model_slice_0` include tensor parallel slice weight and structure in onnx format of llama_7b model. The number of slice is equal to used GPU numbers. For example, llama_13b has two slice folder `model_slice_0` and `model_slice_1`. 36 | * File `params.json` describe the model llama_7b config, which is differenct with llama_13b and llama_65b. 37 | 38 | 3. Build from source 39 | 40 | ```bash 41 | cd /xx/ppl.llm.serving 42 | ./build.sh -DPPLNN_USE_LLM_CUDA=ON -DPPLNN_CUDA_ENABLE_NCCL=ON -DPPLNN_ENABLE_CUDA_JIT=OFF -DPPLNN_CUDA_ARCHITECTURES="'80;86;87'" -DPPLCOMMON_CUDA_ARCHITECTURES="'80;86;87'" 43 | ``` 44 | 45 | 4. Launch server 46 | 47 | Launch server with configuration file in step 4. 48 | ```bash 49 | ./ppl_llm_server \ 50 | --model-dir /data/model \ 51 | --model-param-path /data/model/params.json \ 52 | --tokenizer-path /data/tokenizer.model \ 53 | --tensor-parallel-size 1 \ 54 | --top-p 0.0 \ 55 | --top-k 1 \ 56 | --max-tokens-scale 0.94 \ 57 | --max-input-tokens-per-request 4096 \ 58 | --max-output-tokens-per-request 4096 \ 59 | --max-total-tokens-per-request 8192 \ 60 | --max-running-batch 1024 \ 61 | --max-tokens-per-step 8192 \ 62 | --host 127.0.0.1 \ 63 | --port 23333 64 | ``` 65 | 66 | where params `model-dir`, `model-param-path` and `tokenizer-path` is from step 1 and step 2, `tensor-parallel-size` is 1 in llama_7b, and would be different in other llama model. 67 | 68 | | tensor-parallel-size | value | 69 | |----------------------|-------| 70 | | LLaMA-7B | 1 | 71 | | LLaMA-13B | 2 | 72 | | LLaMA-30B | 4 | 73 | | LLaMA-65B | 8 | 74 | 75 | 5. Launch client 76 | 77 | Send request through [gRPC](https://github.com/grpc/grpc) to query the model. 78 | 79 | ```bash 80 | ./client_sample 127.0.0.1:23333 81 | ``` 82 | The prompt is writing in source file `tools/client_sample.cc`. If you want to change prompts, you should revise the source file and rebuild it. -------------------------------------------------------------------------------- /samples/integration-cuda/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.10) 2 | project(ppl-llm-integration-cuda) 3 | 4 | set(CMAKE_POLICY_DEFAULT_CMP0077 NEW) 5 | 6 | # cuda env initializations 7 | find_package(CUDA REQUIRED) 8 | if(NOT CMAKE_CUDA_COMPILER) 9 | set(CMAKE_CUDA_COMPILER ${CUDA_TOOLKIT_ROOT_DIR}/bin/nvcc) 10 | endif() 11 | if(NOT CMAKE_CUDA_HOST_COMPILER) 12 | set(CMAKE_CUDA_HOST_COMPILER ${CMAKE_CXX_COMPILER}) 13 | endif() 14 | enable_language(CUDA) 15 | 16 | # import ppl llm serving(without grpc) 17 | get_filename_component(pplllmserving_DIR "${CMAKE_CURRENT_LIST_DIR}/../../ppl-build/install/lib/cmake/ppl" ABSOLUTE) 18 | # optional: disable unused devices to avoid linking extra deps 19 | find_package(pplllmserving REQUIRED) 20 | 21 | # ------------------- # 22 | 23 | # sample target 24 | add_executable(ppl-llm-integration-cuda main.cc) 25 | target_include_directories(ppl-llm-integration-cuda PRIVATE 26 | ${PPL_LLM_INCLUDE_DIRS} 27 | ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) # imported from `enable_language(CUDA)` 28 | target_link_directories(ppl-llm-integration-cuda PRIVATE 29 | ${PPL_LLM_LINK_DIRS} 30 | ${CMAKE_CUDA_HOST_IMPLICIT_LINK_DIRECTORIES}) # imported from `enable_language(CUDA)` 31 | target_link_libraries(ppl-llm-integration-cuda PRIVATE ${PPL_LLM_LIBRARIES}) 32 | -------------------------------------------------------------------------------- /samples/integration-cuda/main.cc: -------------------------------------------------------------------------------- 1 | // Licensed to the Apache Software Foundation (ASF) under one 2 | // or more contributor license agreements. See the NOTICE file 3 | // distributed with this work for additional information 4 | // regarding copyright ownership. The ASF licenses this file 5 | // to you under the Apache License, Version 2.0 (the 6 | // "License"); you may not use this file except in compliance 7 | // with the License. You may obtain a copy of the License at 8 | // 9 | // http://www.apache.org/licenses/LICENSE-2.0 10 | // 11 | // Unless required by applicable law or agreed to in writing, 12 | // software distributed under the License is distributed on an 13 | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | // KIND, either express or implied. See the License for the 15 | // specific language governing permissions and limitations 16 | // under the License. 17 | 18 | #include "ppl/llm/common/request.h" 19 | #include "ppl/llm/common/response.h" 20 | #include "ppl/llm/common/processor.h" 21 | #include "ppl/llm/common/connection.h" 22 | #include "ppl/llm/backends/cuda/resource_manager.h" 23 | #include "ppl/llm/models/factory.h" 24 | #include "ppl/llm/models/resource.h" 25 | #include "ppl/llm/utils/utils.h" 26 | #include "ppl/llm/utils/tokenizer.h" 27 | #include "ppl/llm/utils/config_utils.h" 28 | 29 | #include "ppl/common/log.h" 30 | 31 | int main(int argc, char* argv[]) { 32 | 33 | return 0; 34 | } 35 | -------------------------------------------------------------------------------- /src/backends/cuda/post_processor.h: -------------------------------------------------------------------------------- 1 | // Licensed to the Apache Software Foundation (ASF) under one 2 | // or more contributor license agreements. See the NOTICE file 3 | // distributed with this work for additional information 4 | // regarding copyright ownership. The ASF licenses this file 5 | // to you under the Apache License, Version 2.0 (the 6 | // "License"); you may not use this file except in compliance 7 | // with the License. You may obtain a copy of the License at 8 | // 9 | // http://www.apache.org/licenses/LICENSE-2.0 10 | // 11 | // Unless required by applicable law or agreed to in writing, 12 | // software distributed under the License is distributed on an 13 | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | // KIND, either express or implied. See the License for the 15 | // specific language governing permissions and limitations 16 | // under the License. 17 | 18 | #ifndef __PPL_LLM_CUDA_POST_PROCESSOR_H__ 19 | #define __PPL_LLM_CUDA_POST_PROCESSOR_H__ 20 | 21 | #include "../../common/post_processor.h" 22 | 23 | #include 24 | 25 | namespace ppl { namespace llm { namespace cuda { 26 | 27 | class CudaPostProcessor final : public PostProcessor { 28 | public: 29 | CudaPostProcessor(cudaStream_t stream) : stream_(stream) {} 30 | virtual ~CudaPostProcessor() { 31 | Clear(); 32 | } 33 | 34 | ppl::common::RetCode InitPostProcessorMem(int max_running_batch, int vocab_size, bool enable_penalty) override; 35 | 36 | ppl::common::RetCode SampleTopKTopP(const float* logits_device, const float* temperatures_host, 37 | const int32_t* top_k_host, const float* top_p_host, int32_t batch, 38 | int32_t vocab_size, int32_t batch_stride, int32_t default_top_k, 39 | float default_top_p, bool req_list_changed, int32_t* output_host, 40 | float* logprobs_host, bool enable_penalty) override; 41 | 42 | ppl::common::RetCode ApplyPenalty(const float* temperatures_host, const float* repetition_penalties_host, 43 | const float* presence_penalties_host, const float* frequency_penalties_host, 44 | const int64_t* batch_slots_host, const int64_t* token_inputs, 45 | const int64_t* seqstarts, const int64_t* start_pos, int32_t batch, 46 | int32_t vocab_size, bool req_list_changed, float* logits) override; 47 | 48 | private: 49 | void Clear(); 50 | 51 | private: 52 | cudaStream_t stream_ = 0; 53 | 54 | int32_t* workspace_ = nullptr; 55 | int64_t workspace_size_ = 0; 56 | 57 | float* temperatures_device_ = nullptr; 58 | int32_t* top_k_device_ = nullptr; 59 | float* top_p_device_ = nullptr; 60 | float* rand_device_ = nullptr; 61 | int32_t* output_device_ = nullptr; 62 | float* logprobs_device_ = nullptr; 63 | 64 | uint16_t* penalty_count_map_ = nullptr; 65 | int64_t* batch_slots_device_ = nullptr; 66 | float* repetition_penalties_device_ = nullptr; 67 | float* presence_penalties_device_ = nullptr; 68 | float* frequency_penalties_device_ = nullptr; 69 | }; 70 | 71 | }}}; // namespace ppl::llm::cuda 72 | 73 | #endif -------------------------------------------------------------------------------- /src/backends/cuda/resource_manager.h: -------------------------------------------------------------------------------- 1 | // Licensed to the Apache Software Foundation (ASF) under one 2 | // or more contributor license agreements. See the NOTICE file 3 | // distributed with this work for additional information 4 | // regarding copyright ownership. The ASF licenses this file 5 | // to you under the Apache License, Version 2.0 (the 6 | // "License"); you may not use this file except in compliance 7 | // with the License. You may obtain a copy of the License at 8 | // 9 | // http://www.apache.org/licenses/LICENSE-2.0 10 | // 11 | // Unless required by applicable law or agreed to in writing, 12 | // software distributed under the License is distributed on an 13 | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | // KIND, either express or implied. See the License for the 15 | // specific language governing permissions and limitations 16 | // under the License. 17 | 18 | #ifndef __PPL_LLM_CUDA_RESOURCE_MANAGER_H__ 19 | #define __PPL_LLM_CUDA_RESOURCE_MANAGER_H__ 20 | 21 | #include "../../common/config.h" 22 | #include "../../common/resource.h" 23 | #include "../../common/post_processor.h" 24 | 25 | #include "ppl/common/log.h" 26 | #include "ppl/common/barrier.h" 27 | #include "ppl/common/threadpool.h" 28 | #include "ppl/common/retcode.h" 29 | #include "ppl/nn/engines/llm_cuda/engine_factory.h" 30 | #include "ppl/nn/runtime/runtime.h" 31 | #include 32 | #ifdef PPLNN_CUDA_ENABLE_NCCL 33 | #include "nccl.h" 34 | #endif 35 | 36 | #include 37 | #include 38 | #include 39 | 40 | namespace ppl { namespace llm { namespace cuda { 41 | 42 | struct InferRuntimeParam final { 43 | cudaStream_t stream = 0; 44 | std::unique_ptr engine; 45 | std::unique_ptr input_output_device; 46 | 47 | InferRuntimeParam() {} 48 | InferRuntimeParam(InferRuntimeParam&& rhs) { 49 | if (this != &rhs) { 50 | DoMove(std::move(rhs)); 51 | } 52 | } 53 | ~InferRuntimeParam() { 54 | DoDestroy(); 55 | } 56 | void operator=(InferRuntimeParam&& rhs) { 57 | if (this != &rhs) { 58 | DoDestroy(); 59 | DoMove(std::move(rhs)); 60 | } 61 | } 62 | 63 | // ----- private functions ----- // 64 | 65 | void DoMove(InferRuntimeParam&& rhs) { 66 | stream = rhs.stream; 67 | rhs.stream = 0; 68 | engine = std::move(rhs.engine); 69 | input_output_device = std::move(rhs.input_output_device); 70 | } 71 | 72 | void DoDestroy() { 73 | input_output_device.reset(); 74 | engine.reset(); 75 | if (stream) { 76 | cudaStreamSynchronize(stream); 77 | cudaStreamDestroy(stream); 78 | } 79 | } 80 | 81 | void operator=(InferRuntimeParam&) = delete; 82 | InferRuntimeParam(const InferRuntimeParam&) = delete; 83 | }; 84 | 85 | struct CudaResourceManager final { 86 | ~CudaResourceManager() { 87 | post_processor.reset(); 88 | 89 | for (auto it = items.begin(); it != items.end(); ++it) { 90 | cudaFree(it->kv_cache_mem); 91 | if (it->kv_scale_mem) { 92 | cudaFree(it->kv_scale_mem); 93 | } 94 | delete it->runtime; 95 | delete it->host_device; 96 | } 97 | 98 | runtime_param_list.clear(); 99 | 100 | #ifdef PPLNN_CUDA_ENABLE_NCCL 101 | for (auto it = nccl_comm_list.begin(); it != nccl_comm_list.end(); ++it) { 102 | auto e = ncclCommDestroy(*it); 103 | if (e != ncclSuccess) { 104 | LOG(ERROR) << "NCCL error(code:" << (int)e << ") on " 105 | << "(ncclCommDestroy)"; 106 | } 107 | } 108 | #endif 109 | } 110 | 111 | std::unique_ptr CreateCudaPostProcessor(ppl::nn::Runtime* runtime); 112 | ppl::common::RetCode Init(const ModelConfig&, const ResourceConfig&); 113 | 114 | ppl::common::StaticThreadPool device_worker_pool_; 115 | std::vector runtime_param_list; 116 | std::vector items; 117 | std::unique_ptr post_processor; 118 | uint64_t kv_cache_max_tokens; 119 | 120 | #ifdef PPLNN_CUDA_ENABLE_NCCL 121 | std::vector nccl_comm_list; 122 | #endif 123 | }; 124 | 125 | }}} // namespace ppl::llm::cuda 126 | 127 | #endif 128 | -------------------------------------------------------------------------------- /src/common/config.cc: -------------------------------------------------------------------------------- 1 | // Licensed to the Apache Software Foundation (ASF) under one 2 | // or more contributor license agreements. See the NOTICE file 3 | // distributed with this work for additional information 4 | // regarding copyright ownership. The ASF licenses this file 5 | // to you under the Apache License, Version 2.0 (the 6 | // "License"); you may not use this file except in compliance 7 | // with the License. You may obtain a copy of the License at 8 | // 9 | // http://www.apache.org/licenses/LICENSE-2.0 10 | // 11 | // Unless required by applicable law or agreed to in writing, 12 | // software distributed under the License is distributed on an 13 | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | // KIND, either express or implied. See the License for the 15 | // specific language governing permissions and limitations 16 | // under the License. 17 | 18 | #include "config.h" 19 | 20 | #include "ppl/common/retcode.h" 21 | #include "ppl/common/log.h" 22 | #include "rapidjson/document.h" 23 | #include "rapidjson/istreamwrapper.h" 24 | 25 | #include 26 | #include 27 | #include 28 | 29 | namespace ppl { namespace llm { 30 | 31 | bool ParseModelConfig(const std::string& model_param_path, ModelConfig* model_config) { 32 | std::ifstream ifs(model_param_path); 33 | rapidjson::IStreamWrapper isw(ifs); 34 | rapidjson::Document document; 35 | if (document.ParseStream(isw) == false) { 36 | LOG(ERROR) << "ParseStream failed"; 37 | return false; 38 | } 39 | document.ParseStream(isw); 40 | 41 | auto it = document.FindMember("num_heads"); 42 | if (it == document.MemberEnd()) { 43 | LOG(ERROR) << "find key [num_heads] failed"; 44 | return false; 45 | } 46 | model_config->num_heads = it->value.GetInt(); 47 | LOG(INFO) << "model_config.num_heads: " << model_config->num_heads; 48 | 49 | it = document.FindMember("num_kv_heads"); 50 | if (it == document.MemberEnd()) { 51 | model_config->num_kv_heads = model_config->num_heads; 52 | } else { 53 | model_config->num_kv_heads = it->value.GetInt(); 54 | } 55 | LOG(INFO) << "model_config.num_kv_heads: " << model_config->num_kv_heads; 56 | 57 | it = document.FindMember("num_layers"); 58 | if (it == document.MemberEnd()) { 59 | LOG(ERROR) << "find key [num_layers] failed"; 60 | return false; 61 | } 62 | model_config->num_layers = it->value.GetInt(); 63 | LOG(INFO) << "model_config.num_layers: " << model_config->num_layers; 64 | 65 | it = document.FindMember("hidden_dim"); 66 | if (it == document.MemberEnd()) { 67 | LOG(ERROR) << "find key [hidden_dim] failed"; 68 | return false; 69 | } 70 | model_config->hidden_dim = it->value.GetInt(); 71 | LOG(INFO) << "model_config.hidden_dim: " << model_config->hidden_dim; 72 | 73 | it = document.FindMember("intermediate_dim"); 74 | if (it == document.MemberEnd()) { 75 | LOG(ERROR) << "find key [intermediate_dim] failed"; 76 | return false; 77 | } 78 | model_config->intermediate_dim = it->value.GetInt(); 79 | LOG(INFO) << "model_config.intermediate_dim: " << model_config->intermediate_dim; 80 | 81 | it = document.FindMember("vocab_size"); 82 | if (it == document.MemberEnd()) { 83 | LOG(ERROR) << "find key [vocab_size] failed"; 84 | return false; 85 | } 86 | model_config->vocab_size = it->value.GetInt(); 87 | LOG(INFO) << "model_config.vocab_size: " << model_config->vocab_size; 88 | 89 | it = document.FindMember("cache_quant_bit"); 90 | if (it == document.MemberEnd()) { 91 | LOG(ERROR) << "find key [cache_quant_bit] failed"; 92 | return false; 93 | } 94 | model_config->cache_quant_bit = it->value.GetInt(); 95 | LOG(INFO) << "model_config.cache_quant_bit: " << model_config->cache_quant_bit; 96 | 97 | it = document.FindMember("cache_quant_group"); 98 | if (it == document.MemberEnd()) { 99 | LOG(ERROR) << "find key [cache_quant_group] failed"; 100 | return false; 101 | } 102 | model_config->cache_quant_group = it->value.GetInt(); 103 | LOG(INFO) << "model_config.cache_quant_group: " << model_config->cache_quant_group; 104 | 105 | it = document.FindMember("cache_layout"); 106 | if (it == document.MemberEnd()) { 107 | LOG(ERROR) << "find key [cache_layout] failed"; 108 | return false; 109 | } 110 | model_config->cache_layout = it->value.GetInt(); 111 | LOG(INFO) << "model_config.cache_layout: " << model_config->cache_layout; 112 | 113 | it = document.FindMember("cache_mode"); 114 | if (it == document.MemberEnd()) { 115 | LOG(ERROR) << "find key [cache_mode] failed"; 116 | return false; 117 | } 118 | model_config->cache_mode = it->value.GetInt(); 119 | LOG(INFO) << "model_config.cache_mode: " << model_config->cache_mode; 120 | 121 | if (model_config->cache_mode == 1) { 122 | it = document.FindMember("page_size"); 123 | if (it == document.MemberEnd() && model_config->cache_mode == 1) { 124 | LOG(ERROR) << "find key [page_size] failed"; 125 | return false; 126 | } 127 | model_config->page_size = it->value.GetInt(); 128 | LOG(INFO) << "model_config.page_size: " << model_config->page_size; 129 | } 130 | 131 | it = document.FindMember("dynamic_batching"); 132 | if (it == document.MemberEnd()) { 133 | LOG(ERROR) << "find key [dynamic_batching] failed"; 134 | return false; 135 | } 136 | model_config->dynamic_batching = it->value.GetBool(); 137 | LOG(INFO) << "model_config.dynamic_batching: " << model_config->dynamic_batching; 138 | 139 | it = document.FindMember("auto_causal"); 140 | if (it == document.MemberEnd()) { 141 | LOG(ERROR) << "find key [auto_causal] failed"; 142 | return false; 143 | } 144 | model_config->auto_causal = it->value.GetBool(); 145 | LOG(INFO) << "model_config.auto_causal: " << model_config->auto_causal; 146 | 147 | return true; 148 | } 149 | 150 | }} // namespace ppl::llm -------------------------------------------------------------------------------- /src/common/config.h: -------------------------------------------------------------------------------- 1 | // Licensed to the Apache Software Foundation (ASF) under one 2 | // or more contributor license agreements. See the NOTICE file 3 | // distributed with this work for additional information 4 | // regarding copyright ownership. The ASF licenses this file 5 | // to you under the Apache License, Version 2.0 (the 6 | // "License"); you may not use this file except in compliance 7 | // with the License. You may obtain a copy of the License at 8 | // 9 | // http://www.apache.org/licenses/LICENSE-2.0 10 | // 11 | // Unless required by applicable law or agreed to in writing, 12 | // software distributed under the License is distributed on an 13 | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | // KIND, either express or implied. See the License for the 15 | // specific language governing permissions and limitations 16 | // under the License. 17 | 18 | #ifndef __PPL_LLM_CONFIG_H__ 19 | #define __PPL_LLM_CONFIG_H__ 20 | 21 | #include 22 | #include 23 | #include 24 | 25 | namespace ppl { namespace llm { 26 | 27 | struct ResourceConfig final { 28 | std::string model_type; 29 | std::string model_format; 30 | std::string model_dir; 31 | std::string model_param_path; 32 | int32_t tensor_parallel_size = 0; 33 | float max_tokens_scale = 0.f; 34 | int32_t max_running_batch = 0; 35 | bool enable_penalty = false; 36 | struct EngineConfig { 37 | std::string cublas_layout_hint = "default"; 38 | bool disable_graph_fusion = false; 39 | bool disable_decoding_shm_mha = false; 40 | bool disable_decoding_inf_mha = false; 41 | bool disable_decoding_inf_gqa = false; 42 | int32_t configure_decoding_attn_split_k = 1; 43 | int32_t specify_decoding_attn_tpb = 0; 44 | std::string quant_method; 45 | }; 46 | EngineConfig engine_config; 47 | }; 48 | 49 | struct GeneratorConfig final { 50 | float top_p = 0.0f; 51 | int32_t top_k = 1; 52 | bool enable_penalty = false; 53 | int32_t max_running_batch = 0; 54 | int32_t max_input_tokens_per_request = 0; 55 | int32_t max_output_tokens_per_request = 0; 56 | int32_t max_total_tokens_per_request = 0; 57 | int32_t max_tokens_per_step = 0; 58 | std::set stop_tokens; 59 | std::set special_tokens; 60 | int max_cooldown_request = 2; 61 | bool enable_prefix_cache = false; 62 | int32_t max_prefill_batch = 0; 63 | bool enable_profiling = false; 64 | }; 65 | 66 | struct ModelConfig final { 67 | int32_t hidden_dim = 0; 68 | int32_t intermediate_dim = 0; 69 | int32_t num_layers = 0; 70 | int32_t num_heads = 0; 71 | int32_t num_kv_heads = 0; 72 | int32_t vocab_size = 0; 73 | 74 | float norm_eps = 0.0f; // not used 75 | 76 | int32_t cache_quant_bit = 0; 77 | int32_t cache_quant_group = 0; 78 | 79 | int32_t cache_layout = 0; 80 | int32_t cache_mode = 0; 81 | int32_t page_size = 0; 82 | 83 | bool dynamic_batching = true; 84 | bool auto_causal = true; 85 | }; 86 | 87 | bool ParseModelConfig(const std::string& model_param_path, ModelConfig* model_config); 88 | 89 | }} // namespace ppl::llm 90 | 91 | #endif -------------------------------------------------------------------------------- /src/common/connection.h: -------------------------------------------------------------------------------- 1 | // Licensed to the Apache Software Foundation (ASF) under one 2 | // or more contributor license agreements. See the NOTICE file 3 | // distributed with this work for additional information 4 | // regarding copyright ownership. The ASF licenses this file 5 | // to you under the Apache License, Version 2.0 (the 6 | // "License"); you may not use this file except in compliance 7 | // with the License. You may obtain a copy of the License at 8 | // 9 | // http://www.apache.org/licenses/LICENSE-2.0 10 | // 11 | // Unless required by applicable law or agreed to in writing, 12 | // software distributed under the License is distributed on an 13 | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | // KIND, either express or implied. See the License for the 15 | // specific language governing permissions and limitations 16 | // under the License. 17 | 18 | #ifndef __PPL_LLM_CONNECTION_H__ 19 | #define __PPL_LLM_CONNECTION_H__ 20 | 21 | #include "response.h" 22 | #include "profiler.h" 23 | #include "ppl/common/retcode.h" 24 | #include 25 | 26 | namespace ppl { namespace llm { 27 | 28 | class Connection { 29 | public: 30 | virtual ~Connection() {} 31 | virtual void OnProfiling(const std::shared_ptr&) = 0; 32 | virtual void OnTokenize(uint64_t id, const std::vector&) = 0; 33 | virtual void Send(const std::vector&) = 0; 34 | virtual void NotifyFailure(uint64_t id, ppl::common::RetCode, const std::string& errmsg) = 0; 35 | }; 36 | 37 | }} // namespace ppl::llm 38 | 39 | #endif 40 | -------------------------------------------------------------------------------- /src/common/post_processor.h: -------------------------------------------------------------------------------- 1 | // Licensed to the Apache Software Foundation (ASF) under one 2 | // or more contributor license agreements. See the NOTICE file 3 | // distributed with this work for additional information 4 | // regarding copyright ownership. The ASF licenses this file 5 | // to you under the Apache License, Version 2.0 (the 6 | // "License"); you may not use this file except in compliance 7 | // with the License. You may obtain a copy of the License at 8 | // 9 | // http://www.apache.org/licenses/LICENSE-2.0 10 | // 11 | // Unless required by applicable law or agreed to in writing, 12 | // software distributed under the License is distributed on an 13 | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | // KIND, either express or implied. See the License for the 15 | // specific language governing permissions and limitations 16 | // under the License. 17 | 18 | #ifndef __PPL_LLM_POST_PROCESSOR_H__ 19 | #define __PPL_LLM_POST_PROCESSOR_H__ 20 | 21 | #include "ppl/common/retcode.h" 22 | 23 | namespace ppl { namespace llm { 24 | 25 | class PostProcessor { 26 | public: 27 | virtual ~PostProcessor() {} 28 | 29 | virtual ppl::common::RetCode InitPostProcessorMem(int max_running_batch, int vocab_size, bool enable_penalty) = 0; 30 | 31 | virtual ppl::common::RetCode SampleTopKTopP(const float* logits_device, const float* temperatures_host, 32 | const int32_t* top_k_host, const float* top_p_host, int32_t batch, 33 | int32_t vocab_size, int32_t batch_stride, int32_t default_top_k, 34 | float default_top_p, bool req_list_changed, int32_t* output_host, 35 | float* logprob_host, bool enable_penalty) = 0; 36 | 37 | virtual ppl::common::RetCode ApplyPenalty(const float* temperatures_host, const float* repetition_penalties_host, 38 | const float* presence_penalties_host, 39 | const float* frequency_penalties_host, const int64_t* batch_slots_host, 40 | const int64_t* token_inputs, const int64_t* seqstarts, 41 | const int64_t* start_pos, int32_t batch, int32_t vocab_size, 42 | bool req_list_changed, float* logits) = 0; 43 | }; 44 | 45 | }} // namespace ppl::llm 46 | 47 | #endif 48 | -------------------------------------------------------------------------------- /src/common/profiler.cc: -------------------------------------------------------------------------------- 1 | #include "profiler.h" 2 | #include "stdio.h" 3 | 4 | namespace ppl { namespace llm { 5 | 6 | void PrintProfiler(const WorkerProfiler& worker_profiler) { 7 | float qps = float(worker_profiler.finished_task_cnt) / worker_profiler.step_counter.global.total_cost * 1e6; 8 | float tps = float(worker_profiler.step_counter.global.output_token_cnt) / 9 | worker_profiler.step_counter.global.total_cost * 1e6; 10 | float cache_hit_rate = float(worker_profiler.step_counter.global.cache_hit_count) / 11 | worker_profiler.step_counter.global.input_token_cnt; 12 | 13 | fprintf(stderr, "[PERF] --- step %ld -------------------------------------------------\n", 14 | worker_profiler.step_counter.global.step_cnt); 15 | fprintf(stderr, "[PERF] |- memory usage: (%.2f - %.2f) -> %.2f GiB\n", float(worker_profiler.dev_mem_total) / 1e9, 16 | float(worker_profiler.dev_mem_free) / 1e9, 17 | float(worker_profiler.dev_mem_total - worker_profiler.dev_mem_free) / 1e9); 18 | fprintf(stderr, "[PERF] |- kv cache usage: %.2f %%\n", 19 | (1.0f - (float)worker_profiler.kv_rest_blk / worker_profiler.kv_max_blk) * 100.0); 20 | fprintf(stderr, "[PERF] |- pending task number: %ld\n", worker_profiler.pending_task_size); 21 | fprintf(stderr, "[PERF] |- running batch: %ld, max running batch: %ld\n", worker_profiler.running_task, 22 | worker_profiler.max_running_task); 23 | fprintf(stderr, "[PERF] |- prefill batch: %ld , prefill tokens: %ld\n", worker_profiler.prefill_batch, 24 | worker_profiler.prefill_tokens); 25 | fprintf(stderr, "[PERF] |- prefix cache hit rate: %.2f %%\n", cache_hit_rate * 100); 26 | fprintf(stderr, "[PERF] |- finished query count: %ld, QPS: %.2f\n", worker_profiler.finished_task_cnt, qps); 27 | fprintf(stderr, "[PERF] |- gen token count: %ld, avg gen len: %.2f, TPS: %.2f\n", 28 | worker_profiler.step_counter.global.output_token_cnt, 29 | worker_profiler.finished_task_cnt 30 | ? worker_profiler.step_counter.global.output_token_cnt / float(worker_profiler.finished_task_cnt) 31 | : 0.0f, 32 | tps); 33 | 34 | fprintf(stderr, "[PERF] |- pipeline | cur: %.2f ms, | avg: %.2f ms, | total: %.2f ms\n", 35 | float(worker_profiler.step_counter.current.total_cost) / 1e3, 36 | float(worker_profiler.step_counter.global.total_cost / 1e3) / worker_profiler.step_counter.global.step_cnt, 37 | float(worker_profiler.step_counter.global.total_cost) / 1e3); 38 | fprintf( 39 | stderr, "[PERF] |-- batching | cur: %.2f ms, | avg: %.2f ms, | total: %.2f ms\n", 40 | float(worker_profiler.step_counter.current.prepare_cost) / 1e3, 41 | float(worker_profiler.step_counter.global.prepare_cost) / 1e3 / worker_profiler.step_counter.global.step_cnt, 42 | float(worker_profiler.step_counter.global.prepare_cost) / 1e3); 43 | fprintf( 44 | stderr, "[PERF] |-- set inputs | cur: %.2f ms, | avg: %.2f ms, | total: %.2f ms\n", 45 | float(worker_profiler.step_counter.current.set_input_cost) / 1e3, 46 | float(worker_profiler.step_counter.global.set_input_cost) / 1e3 / worker_profiler.step_counter.global.step_cnt, 47 | float(worker_profiler.step_counter.global.set_input_cost) / 1e3); 48 | fprintf(stderr, "[PERF] |-- model inference | cur: %.2f ms, | avg: %.2f ms, | total: %.2f ms\n", 49 | float(worker_profiler.step_counter.current.model_forward_cost) / 1e3, 50 | float(worker_profiler.step_counter.global.model_forward_cost) / 1e3 / 51 | worker_profiler.step_counter.global.step_cnt, 52 | float(worker_profiler.step_counter.global.model_forward_cost) / 1e3); 53 | fprintf(stderr, "[PERF] |-- choose token | cur: %.2f ms, | avg: %.2f ms, | total: %.2f ms\n", 54 | float(worker_profiler.step_counter.current.choose_token_cost) / 1e3, 55 | float(worker_profiler.step_counter.global.choose_token_cost) / 1e3 / 56 | worker_profiler.step_counter.global.step_cnt, 57 | float(worker_profiler.step_counter.global.choose_token_cost) / 1e3); 58 | fprintf(stderr, "[PERF] |-- post process | cur: %.2f ms, | avg: %.2f ms, | total: %.2f ms\n", 59 | float(worker_profiler.step_counter.current.post_process_cost) / 1e3, 60 | float(worker_profiler.step_counter.global.post_process_cost) / 1e3 / 61 | worker_profiler.step_counter.global.step_cnt, 62 | float(worker_profiler.step_counter.global.post_process_cost) / 1e3); 63 | fprintf( 64 | stderr, "[PERF] |- schedule cost: %.2f %%\n", 65 | float(worker_profiler.step_counter.global.total_cost - worker_profiler.step_counter.global.model_forward_cost - 66 | worker_profiler.step_counter.global.choose_token_cost) / 67 | worker_profiler.step_counter.global.total_cost * 100); 68 | } 69 | 70 | }} // namespace ppl::llm -------------------------------------------------------------------------------- /src/common/profiler.h: -------------------------------------------------------------------------------- 1 | // Licensed to the Apache Software Foundation (ASF) under one 2 | // or more contributor license agreements. See the NOTICE file 3 | // distributed with this work for additional information 4 | // regarding copyright ownership. The ASF licenses this file 5 | // to you under the Apache License, Version 2.0 (the 6 | // "License"); you may not use this file except in compliance 7 | // with the License. You may obtain a copy of the License at 8 | // 9 | // http://www.apache.org/licenses/LICENSE-2.0 10 | // 11 | // Unless required by applicable law or agreed to in writing, 12 | // software distributed under the License is distributed on an 13 | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | // KIND, either express or implied. See the License for the 15 | // specific language governing permissions and limitations 16 | // under the License. 17 | 18 | #ifndef __PPL_LLM_PROFILER_H__ 19 | #define __PPL_LLM_PROFILER_H__ 20 | 21 | #include 22 | #include 23 | #include 24 | 25 | namespace ppl { namespace llm { 26 | 27 | struct ServerCounter final { 28 | // micro seconds 29 | uint64_t request_arrived_cnt = 0; // sum 30 | 31 | uint64_t deliver_task_step_cnt = 0; 32 | uint64_t deliver_task_step_cost = 0; 33 | 34 | uint64_t first_token_cnt = 0; 35 | uint64_t first_token_cost = 0; 36 | 37 | uint64_t generate_token_cnt = 0; // tps=(generated_token_cnt + first_token_cnt) / total_time 38 | uint64_t generate_token_cost = 0; // only including decoding procedure 39 | 40 | uint64_t total_cnt = 0; // finished_task count, used to compute qps=total_cnt / total_time 41 | uint64_t total_cost = 0; 42 | 43 | uint64_t parse_input_cnt = 0; 44 | uint64_t parse_input_cost = 0; 45 | }; 46 | 47 | struct GeneratorReqCounter final { 48 | // micro seconds 49 | uint64_t encode_cnt = 0; 50 | uint64_t encode_cost = 0; 51 | 52 | uint64_t output_tokens_per_req = 0; 53 | 54 | char padding[40]; // avoid false sharing 55 | 56 | uint64_t waiting_cnt = 0; 57 | uint64_t waiting_cost = 0; 58 | }; 59 | 60 | struct WorkerPerStepCounter { 61 | struct { 62 | uint64_t step_cnt = 0; 63 | uint64_t prepare_cost = 0; 64 | uint64_t set_input_cost = 0; 65 | uint64_t model_forward_cost = 0; 66 | uint64_t choose_token_cost = 0; // penalty + sampling 67 | uint64_t post_process_cost = 0; 68 | uint64_t total_cost = 0; 69 | uint64_t input_token_cnt = 0; // per step 70 | uint64_t output_token_cnt = 0; 71 | uint64_t cache_hit_count = 0; 72 | } global, current; 73 | }; 74 | 75 | struct WorkerProfiler { 76 | uint64_t finished_task_cnt = 0; 77 | 78 | uint64_t kv_rest_blk = 0; 79 | uint64_t kv_max_blk = 0; 80 | 81 | uint64_t running_task = 0; 82 | uint64_t prefill_batch = 0; 83 | uint64_t prefill_tokens = 0; 84 | 85 | uint64_t max_running_task = 0; 86 | uint64_t pending_task_size = 0; 87 | 88 | uint64_t dev_mem_total = 0; 89 | uint64_t dev_mem_free = 0; 90 | 91 | WorkerPerStepCounter step_counter; 92 | GeneratorReqCounter req_counter; 93 | }; 94 | 95 | struct ServerProfiler final { 96 | 97 | ServerProfiler& operator=(const ServerProfiler& other) { 98 | if (this == &other) { 99 | return *this; 100 | } 101 | // deep copy 102 | *worker_profiler = *other.worker_profiler; 103 | server_counter = other.server_counter; 104 | qps = other.qps; 105 | tps = other.tps; 106 | global_time_cost = other.global_time_cost; 107 | return *this; 108 | } 109 | 110 | std::shared_ptr worker_profiler; 111 | ServerCounter server_counter; 112 | double qps = 0; // qps=finished_cnt / global_time_cost 113 | double tps = 0; // tps=generate_token_cnt / global_time_cost 114 | uint64_t global_time_cost; 115 | }; 116 | 117 | void PrintProfiler(const WorkerProfiler& worker_profiler); 118 | 119 | }} // namespace ppl::llm 120 | 121 | #endif 122 | -------------------------------------------------------------------------------- /src/common/request.h: -------------------------------------------------------------------------------- 1 | // Licensed to the Apache Software Foundation (ASF) under one 2 | // or more contributor license agreements. See the NOTICE file 3 | // distributed with this work for additional information 4 | // regarding copyright ownership. The ASF licenses this file 5 | // to you under the Apache License, Version 2.0 (the 6 | // "License"); you may not use this file except in compliance 7 | // with the License. You may obtain a copy of the License at 8 | // 9 | // http://www.apache.org/licenses/LICENSE-2.0 10 | // 11 | // Unless required by applicable law or agreed to in writing, 12 | // software distributed under the License is distributed on an 13 | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | // KIND, either express or implied. See the License for the 15 | // specific language governing permissions and limitations 16 | // under the License. 17 | 18 | #ifndef __PPL_LLM_REQUEST_H__ 19 | #define __PPL_LLM_REQUEST_H__ 20 | 21 | #include 22 | #include 23 | #include 24 | #include 25 | #include 26 | 27 | namespace ppl { namespace llm { 28 | 29 | struct Request final { 30 | Request() {} 31 | Request(uint64_t _id, std::string _prompt, float _temperature, uint32_t _generation_length) 32 | : id(_id), prompt(_prompt), temperature(_temperature), generation_length(_generation_length) {} 33 | uint64_t id; 34 | std::string prompt; 35 | float temperature = 1.f; 36 | float top_p = 0.f; 37 | int32_t top_k = 1; 38 | float repetition_penalty = 1.f; 39 | float presence_penalty = 0.f; 40 | float frequency_penalty = 0.f; 41 | int32_t generation_length = 0; 42 | bool early_stopping = true; 43 | bool is_token_in_out = false; 44 | std::shared_ptr> token_ids; 45 | std::shared_ptr> stop_tokens; 46 | }; 47 | 48 | }} // namespace ppl::llm 49 | 50 | #endif 51 | -------------------------------------------------------------------------------- /src/common/resource.h: -------------------------------------------------------------------------------- 1 | // Licensed to the Apache Software Foundation (ASF) under one 2 | // or more contributor license agreements. See the NOTICE file 3 | // distributed with this work for additional information 4 | // regarding copyright ownership. The ASF licenses this file 5 | // to you under the Apache License, Version 2.0 (the 6 | // "License"); you may not use this file except in compliance 7 | // with the License. You may obtain a copy of the License at 8 | // 9 | // http://www.apache.org/licenses/LICENSE-2.0 10 | // 11 | // Unless required by applicable law or agreed to in writing, 12 | // software distributed under the License is distributed on an 13 | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | // KIND, either express or implied. See the License for the 15 | // specific language governing permissions and limitations 16 | // under the License. 17 | 18 | #ifndef __PPL_LLM_RESOURCE_H__ 19 | #define __PPL_LLM_RESOURCE_H__ 20 | 21 | #include "../tokenizer/tokenizer.h" 22 | #include "../common/post_processor.h" 23 | 24 | #include "ppl/common/threadpool.h" 25 | #include "ppl/nn/runtime/runtime.h" 26 | #include "ppl/nn/engines/engine.h" 27 | 28 | #include 29 | #include 30 | 31 | namespace ppl { namespace llm { 32 | 33 | struct ResourceItem final { 34 | void* kv_cache_mem = nullptr; 35 | void* kv_scale_mem = nullptr; 36 | void* penalty_count_map = nullptr; 37 | ppl::nn::Runtime* runtime = nullptr; 38 | ppl::nn::DeviceContext* host_device = nullptr; 39 | ppl::nn::Engine* engine = nullptr; 40 | }; 41 | 42 | struct Resource final { 43 | uint32_t tensor_parallel_size = 0; 44 | uint64_t kv_cache_max_tokens = 0; 45 | std::vector items; 46 | PostProcessor* post_processor = nullptr; 47 | ppl::common::StaticThreadPool* device_worker_pool_ = nullptr; 48 | const Tokenizer* tokenizer = nullptr; 49 | }; 50 | 51 | }} // namespace ppl::llm 52 | 53 | #endif 54 | -------------------------------------------------------------------------------- /src/common/response.h: -------------------------------------------------------------------------------- 1 | // Licensed to the Apache Software Foundation (ASF) under one 2 | // or more contributor license agreements. See the NOTICE file 3 | // distributed with this work for additional information 4 | // regarding copyright ownership. The ASF licenses this file 5 | // to you under the Apache License, Version 2.0 (the 6 | // "License"); you may not use this file except in compliance 7 | // with the License. You may obtain a copy of the License at 8 | // 9 | // http://www.apache.org/licenses/LICENSE-2.0 10 | // 11 | // Unless required by applicable law or agreed to in writing, 12 | // software distributed under the License is distributed on an 13 | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | // KIND, either express or implied. See the License for the 15 | // specific language governing permissions and limitations 16 | // under the License. 17 | 18 | #ifndef __PPL_LLM_RESPONSE_H__ 19 | #define __PPL_LLM_RESPONSE_H__ 20 | 21 | #include 22 | #include 23 | 24 | namespace ppl { namespace llm { 25 | 26 | // ref: https://huggingface.github.io/text-generation-inference/ 27 | enum class FinishFlag { 28 | NOT_FINISHED, 29 | LENGTH, 30 | EOS_TOKEN, 31 | STOP_SEQUENCE // not used yet 32 | }; 33 | 34 | struct Response final { 35 | uint64_t id; 36 | std::string generated; 37 | int token; 38 | FinishFlag finish_flag; 39 | float logprob; 40 | bool is_special; 41 | }; 42 | 43 | }} // namespace ppl::llm 44 | 45 | #endif 46 | -------------------------------------------------------------------------------- /src/engine/llm_engine.cc: -------------------------------------------------------------------------------- 1 | // Licensed to the Apache Software Foundation (ASF) under one 2 | // or more contributor license agreements. See the NOTICE file 3 | // distributed with this work for additional information 4 | // regarding copyright ownership. The ASF licenses this file 5 | // to you under the Apache License, Version 2.0 (the 6 | // "License"); you may not use this file except in compliance 7 | // with the License. You may obtain a copy of the License at 8 | // 9 | // http://www.apache.org/licenses/LICENSE-2.0 10 | // 11 | // Unless required by applicable law or agreed to in writing, 12 | // software distributed under the License is distributed on an 13 | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | // KIND, either express or implied. See the License for the 15 | // specific language governing permissions and limitations 16 | // under the License. 17 | 18 | #include "llm_engine.h" 19 | #include "../common/profiler.h" 20 | #include "../utils/utils.h" 21 | #include "ppl/nn/engines/llm_cuda/options.h" 22 | 23 | using namespace std; 24 | using namespace ppl::common; 25 | using namespace ppl::nn; 26 | 27 | namespace ppl { namespace llm { 28 | 29 | RetCode SetInputTask(uint32_t id, uint32_t cache_mode, bool req_list_changed, const ModelInput& model_input, 30 | EngineInferItem* infer_items) { 31 | infer_items[id].token_ids->FreeBuffer(); 32 | infer_items[id].seq_starts->FreeBuffer(); 33 | infer_items[id].kv_starts->FreeBuffer(); 34 | infer_items[id].start_pos->FreeBuffer(); 35 | infer_items[id].logits->FreeBuffer(); 36 | 37 | RetCode rc; 38 | int32_t bs = model_input.start_pos.size(); 39 | // token ids 40 | infer_items[id].token_ids->GetShape()->Reshape({int64_t(model_input.token_inputs.size())}); 41 | rc = infer_items[id].token_ids->CopyFromHostAsync(model_input.token_inputs.data()); 42 | if (rc != RC_SUCCESS) { 43 | LOG(ERROR) << "set token_ids [" << infer_items[id].token_ids->GetName() << "] failed: " << GetRetCodeStr(rc); 44 | return rc; 45 | } 46 | 47 | // seq_start 48 | infer_items[id].seq_starts->GetShape()->Reshape({int64_t(model_input.seq_starts.size())}); 49 | rc = infer_items[id].seq_starts->CopyFromHostAsync(model_input.seq_starts.data()); 50 | if (rc != RC_SUCCESS) { 51 | LOG(ERROR) << "set seq_starts [" << infer_items[id].seq_starts->GetName() << "] failed: " << GetRetCodeStr(rc); 52 | return rc; 53 | } 54 | 55 | // kv_starts 56 | infer_items[id].kv_starts->GetShape()->Reshape({int64_t(model_input.kv_starts.size())}); 57 | rc = infer_items[id].kv_starts->CopyFromHostAsync(model_input.kv_starts.data()); 58 | if (rc != RC_SUCCESS) { 59 | LOG(ERROR) << "set kv_starts " << infer_items[id].kv_starts->GetName() << " failed: " << GetRetCodeStr(rc); 60 | return rc; 61 | } 62 | 63 | // cache_indices 64 | if (cache_mode == 0) { 65 | infer_items[id].cache_indices->GetShape()->Reshape({int64_t(model_input.cache_indices.size())}); 66 | rc = infer_items[id].cache_indices->CopyFromHostAsync(model_input.cache_indices.data()); 67 | } else if (cache_mode == 1) { 68 | if (req_list_changed) { 69 | infer_items[id].cache_indices->GetShape()->Reshape({int64_t(bs), int64_t(model_input.max_pages)}); 70 | rc = infer_items[id].cache_indices->CopyFromHostAsync(model_input.page_list.data()); 71 | } 72 | } 73 | if (rc != RC_SUCCESS) { 74 | LOG(ERROR) << "set cache_indices [" << infer_items[id].cache_indices->GetName() 75 | << "] failed: " << GetRetCodeStr(rc); 76 | return rc; 77 | } 78 | 79 | // decoding batches 80 | rc = infer_items[id].decoding_batches->CopyFromHostAsync(&model_input.decoding_batches); 81 | if (rc != RC_SUCCESS) { 82 | LOG(ERROR) << "set decoding_batches [" << infer_items[id].decoding_batches->GetName() 83 | << "] failed: " << GetRetCodeStr(rc); 84 | return rc; 85 | } 86 | 87 | // start_pos 88 | infer_items[id].start_pos->GetShape()->Reshape({int64_t(model_input.start_pos.size())}); 89 | rc = infer_items[id].start_pos->CopyFromHostAsync(model_input.start_pos.data()); 90 | if (rc != RC_SUCCESS) { 91 | LOG(ERROR) << "set start_pos [" << infer_items[id].start_pos->GetName() << "] failed: " << GetRetCodeStr(rc); 92 | return rc; 93 | } 94 | 95 | // max_seq_len 96 | rc = infer_items[id].max_seq_len->CopyFromHostAsync(&model_input.max_seq_len); 97 | if (rc != RC_SUCCESS) { 98 | LOG(ERROR) << "set max_seq_len [" << infer_items[id].max_seq_len->GetName() 99 | << "] failed: " << GetRetCodeStr(rc); 100 | return rc; 101 | } 102 | 103 | // max_kv_len 104 | rc = infer_items[id].max_kv_len->CopyFromHostAsync(&model_input.max_kv_len); 105 | if (rc != RC_SUCCESS) { 106 | LOG(ERROR) << "set max_kv_len [" << infer_items[id].max_kv_len->GetName() << "] failed: " << GetRetCodeStr(rc); 107 | return rc; 108 | } 109 | 110 | return rc; 111 | } 112 | 113 | static RetCode RunModelTask(uint32_t id, bool is_prefix_cache_hit, EngineInferItem* infer_items) { 114 | infer_items[id].nn_engine->Configure(ppl::nn::llm::cuda::ENGINE_CONF_CACHE_PREFILL, is_prefix_cache_hit ? 1 : 0); 115 | return infer_items[id].runtime->Run(); 116 | } 117 | 118 | RetCode LLMEngine::Init(WorkerPerStepCounter* step_counter) { 119 | step_counter_ = step_counter; 120 | for (int i = 0; i < tensor_parallel_size_; i++) { 121 | auto item = &infer_items_[i]; 122 | if (model_config_.cache_layout == 0) { 123 | item->kv_cache->GetShape()->Reshape({(int64_t)kv_cache_max_tokens_, model_config_.num_layers, 2, 124 | model_config_.num_kv_heads / tensor_parallel_size_, 125 | model_config_.hidden_dim / model_config_.num_heads}); 126 | if (model_config_.cache_quant_bit > 0) { 127 | item->kv_scale->GetShape()->Reshape( 128 | {(int64_t)kv_cache_max_tokens_, model_config_.num_layers, 2, 129 | model_config_.num_kv_heads / tensor_parallel_size_, 130 | model_config_.hidden_dim / model_config_.num_heads / model_config_.cache_quant_group}); 131 | } 132 | 133 | } else if (model_config_.cache_layout == 1) { 134 | item->kv_cache->GetShape()->Reshape({model_config_.num_layers, (int64_t)kv_cache_max_tokens_, 2, 135 | model_config_.num_kv_heads / tensor_parallel_size_, 136 | model_config_.hidden_dim / model_config_.num_heads}); 137 | if (model_config_.cache_quant_bit > 0) { 138 | item->kv_scale->GetShape()->Reshape( 139 | {model_config_.num_layers, (int64_t)kv_cache_max_tokens_, 2, 140 | model_config_.num_kv_heads / tensor_parallel_size_, 141 | model_config_.hidden_dim / model_config_.num_heads / model_config_.cache_quant_group}); 142 | } 143 | } else if (model_config_.cache_layout == 2) { 144 | item->kv_cache->GetShape()->Reshape({model_config_.num_layers, 2, (int64_t)kv_cache_max_tokens_, 145 | model_config_.num_kv_heads / tensor_parallel_size_, 146 | model_config_.hidden_dim / model_config_.num_heads}); 147 | if (model_config_.cache_quant_bit > 0) { 148 | item->kv_scale->GetShape()->Reshape( 149 | {model_config_.num_layers, 2, (int64_t)kv_cache_max_tokens_, 150 | model_config_.num_kv_heads / tensor_parallel_size_, 151 | model_config_.hidden_dim / model_config_.num_heads / model_config_.cache_quant_group}); 152 | } 153 | } else if (model_config_.cache_layout == 3) { 154 | item->kv_cache->GetShape()->Reshape( 155 | {model_config_.num_layers, 2, model_config_.num_kv_heads / tensor_parallel_size_, 156 | (int64_t)kv_cache_max_tokens_, model_config_.hidden_dim / model_config_.num_heads}); 157 | if (model_config_.cache_quant_bit > 0) { 158 | item->kv_scale->GetShape()->Reshape( 159 | {model_config_.num_layers, 2, model_config_.num_kv_heads / tensor_parallel_size_, 160 | (int64_t)kv_cache_max_tokens_, 161 | model_config_.hidden_dim / model_config_.num_heads / model_config_.cache_quant_group}); 162 | } 163 | } else { 164 | LOG(ERROR) << "impossible status: cache_layout = [" << model_config_.cache_layout << "]"; 165 | return RC_INVALID_VALUE; 166 | } 167 | } 168 | return RC_SUCCESS; 169 | } 170 | 171 | RetCode LLMEngine::Execute(const ModelInput& model_input, bool req_list_changed, bool is_prefix_cache_hit, ModelOutput* model_output, 172 | std::string* error_msg) { 173 | int32_t running_batch = model_input.start_pos.size(); 174 | // set inputs tensor 175 | RetCode rc; 176 | { 177 | utils::TimingGuard __timing__(&step_counter_->current.set_input_cost); 178 | 179 | rc = utils::ParallelExecute(SetInputTask, device_worker_pool_, model_config_.cache_mode, req_list_changed, 180 | model_input, infer_items_.data()); 181 | if (rc != RC_SUCCESS) { 182 | *error_msg = "ParallelExecute(SetInputTask) failed: " + std::string(GetRetCodeStr(rc)); 183 | LOG(ERROR) << *error_msg; 184 | return RC_OTHER_ERROR; 185 | } 186 | } 187 | step_counter_->global.set_input_cost += step_counter_->current.set_input_cost; 188 | // model forward 189 | { 190 | utils::TimingGuard __timing__(&step_counter_->current.model_forward_cost); 191 | rc = utils::ParallelExecute(RunModelTask, device_worker_pool_, is_prefix_cache_hit, infer_items_.data()); 192 | if (rc != RC_SUCCESS) { 193 | *error_msg = "ParallelExecute(RunModelTask) failed: " + std::string(GetRetCodeStr(rc)); 194 | LOG(ERROR) << *error_msg; 195 | return RC_OTHER_ERROR; 196 | } 197 | } 198 | step_counter_->global.model_forward_cost += step_counter_->current.model_forward_cost; 199 | 200 | auto* logits = infer_items_[0].logits; 201 | { 202 | utils::TimingGuard __timing__(&step_counter_->current.choose_token_cost); 203 | // penalty 204 | if (engine_config_.enable_penalty) { 205 | rc = post_processor_->ApplyPenalty( 206 | model_input.temperatures.data(), model_input.repetition_penalty_list.data(), nullptr, nullptr, 207 | model_input.batch_slots.data(), (int64_t*)infer_items_[0].token_ids->GetBufferPtr(), 208 | (int64_t*)infer_items_[0].seq_starts->GetBufferPtr(), 209 | (int64_t*)infer_items_[0].start_pos->GetBufferPtr(), running_batch, model_config_.vocab_size, 210 | req_list_changed, (float*)logits->GetBufferPtr()); 211 | if (rc != RC_SUCCESS) { 212 | *error_msg = "Apply Penalty failed: " + std::string(GetRetCodeStr(rc)); 213 | LOG(ERROR) << *error_msg; 214 | return RC_OTHER_ERROR; 215 | } 216 | } 217 | 218 | // sampling 219 | int32_t default_top_k = model_input.top_k_list.empty() ? engine_config_.top_k : model_input.top_k_list[0]; 220 | rc = post_processor_->SampleTopKTopP( 221 | (float*)logits->GetBufferPtr(), model_input.temperatures.data(), model_input.top_k_list.data(), 222 | model_input.top_p_list.data(), running_batch, model_config_.vocab_size, logits->GetShape()->GetDim(1), 223 | default_top_k, engine_config_.top_p, req_list_changed, model_output->output_token.data(), 224 | model_output->logprobs.data(), engine_config_.enable_penalty); 225 | if (rc != RC_SUCCESS) { 226 | *error_msg = "SampleTopKTopP failed: " + std::string(GetRetCodeStr(rc)); 227 | LOG(ERROR) << *error_msg; 228 | return RC_OTHER_ERROR; 229 | } 230 | } 231 | step_counter_->global.choose_token_cost += step_counter_->current.choose_token_cost; 232 | step_counter_->current.output_token_cnt = running_batch; 233 | step_counter_->global.output_token_cnt += step_counter_->current.output_token_cnt; 234 | 235 | return RC_SUCCESS; 236 | } 237 | 238 | }} // namespace ppl::llm -------------------------------------------------------------------------------- /src/engine/llm_engine.h: -------------------------------------------------------------------------------- 1 | // Licensed to the Apache Software Foundation (ASF) under one 2 | // or more contributor license agreements. See the NOTICE file 3 | // distributed with this work for additional information 4 | // regarding copyright ownership. The ASF licenses this file 5 | // to you under the Apache License, Version 2.0 (the 6 | // "License"); you may not use this file except in compliance 7 | // with the License. You may obtain a copy of the License at 8 | // 9 | // http://www.apache.org/licenses/LICENSE-2.0 10 | // 11 | // Unless required by applicable law or agreed to in writing, 12 | // software distributed under the License is distributed on an 13 | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | // KIND, either express or implied. See the License for the 15 | // specific language governing permissions and limitations 16 | // under the License. 17 | 18 | #ifndef __PPL_LLM_LLM_ENGINE_H__ 19 | #define __PPL_LLM_LLM_ENGINE_H__ 20 | 21 | #include "../common/resource.h" 22 | #include "../common/config.h" 23 | #include "../common/profiler.h" 24 | #include "ppl/nn/models/onnx/runtime_builder_factory.h" 25 | #include "ppl/nn/runtime/tensor.h" 26 | #include "ppl/common/threadpool.h" 27 | #include "ppl/common/typed_mpsc_queue.h" 28 | #include "ppl/common/event_count.h" 29 | #include "ppl/common/page_manager.h" 30 | 31 | #include 32 | #include 33 | #include 34 | #include 35 | #include 36 | #include 37 | 38 | namespace ppl { namespace llm { 39 | 40 | struct ModelInput { 41 | int64_t decoding_batches = 0; 42 | int64_t max_seq_len = 0; 43 | int64_t max_kv_len = 0; 44 | int64_t max_pages = 0; 45 | 46 | std::vector token_inputs; 47 | std::vector seq_starts; 48 | std::vector start_pos; 49 | std::vector cache_indices; 50 | std::vector page_list; 51 | std::vector kv_starts; 52 | std::vector temperatures; 53 | std::vector top_p_list; 54 | std::vector top_k_list; 55 | 56 | std::vector repetition_penalty_list; 57 | std::vector presence_penalty_list; 58 | std::vector frequency_penalty_list; 59 | std::vector batch_slots; 60 | }; 61 | 62 | struct ModelOutput { 63 | std::vector output_token; 64 | std::vector logprobs; 65 | void Clear() { 66 | output_token.clear(); 67 | logprobs.clear(); 68 | } 69 | void Resize(int32_t n) { 70 | output_token.resize(n); 71 | logprobs.resize(n); 72 | } 73 | }; 74 | 75 | struct EngineInferItem { 76 | void* kv_cache_mem = nullptr; 77 | void* kv_scale_mem = nullptr; 78 | ppl::nn::Runtime* runtime = nullptr; 79 | ppl::nn::DeviceContext* host_device = nullptr; 80 | ppl::nn::Engine* nn_engine = nullptr; 81 | 82 | ppl::nn::Tensor* token_ids; 83 | ppl::nn::Tensor* attn_mask; 84 | ppl::nn::Tensor* seq_starts; 85 | ppl::nn::Tensor* kv_starts; 86 | ppl::nn::Tensor* cache_indices; 87 | ppl::nn::Tensor* decoding_batches; 88 | ppl::nn::Tensor* start_pos; 89 | ppl::nn::Tensor* max_seq_len; 90 | ppl::nn::Tensor* max_kv_len; 91 | ppl::nn::Tensor* kv_cache; 92 | ppl::nn::Tensor* kv_scale; 93 | 94 | ppl::nn::Tensor* logits; 95 | }; 96 | 97 | struct EngineConfig { 98 | EngineConfig(bool _enable_penalty, int32_t _top_k, float _top_p) 99 | : enable_penalty(_enable_penalty), top_k(_top_k), top_p(_top_p) {} 100 | 101 | bool enable_penalty; 102 | int32_t top_k; 103 | float top_p; 104 | }; 105 | class LLMEngine final { 106 | public: 107 | LLMEngine(const Resource& resource, const ModelConfig& model_config, bool enable_penalty, int32_t top_k, 108 | float top_p) 109 | : tensor_parallel_size_(resource.tensor_parallel_size) 110 | , device_worker_pool_(resource.device_worker_pool_) 111 | , infer_items_(resource.tensor_parallel_size) 112 | , kv_cache_max_tokens_(resource.kv_cache_max_tokens) 113 | , post_processor_(resource.post_processor) 114 | , model_config_(model_config) 115 | , engine_config_(enable_penalty, top_k, top_p) { 116 | for (int i = 0; i < tensor_parallel_size_; i++) { 117 | auto* item = &infer_items_[i]; 118 | item->kv_cache_mem = resource.items[i].kv_cache_mem; 119 | item->kv_scale_mem = resource.items[i].kv_scale_mem; 120 | item->runtime = resource.items[i].runtime; 121 | item->host_device = resource.items[i].host_device; 122 | item->nn_engine = resource.items[i].engine; 123 | 124 | item->token_ids = item->runtime->GetInputTensor(0); 125 | item->attn_mask = item->runtime->GetInputTensor(1); 126 | item->seq_starts = item->runtime->GetInputTensor(2); 127 | item->kv_starts = item->runtime->GetInputTensor(3); 128 | item->cache_indices = item->runtime->GetInputTensor(4); 129 | item->decoding_batches = item->runtime->GetInputTensor(5); 130 | item->start_pos = item->runtime->GetInputTensor(6); 131 | item->max_seq_len = item->runtime->GetInputTensor(7); 132 | item->max_kv_len = item->runtime->GetInputTensor(8); 133 | item->kv_cache = item->runtime->GetInputTensor(9); 134 | if (model_config_.cache_quant_bit > 0) { 135 | item->kv_scale = item->runtime->GetInputTensor(10); 136 | } 137 | 138 | item->logits = item->runtime->GetOutputTensor(0); 139 | 140 | item->decoding_batches->SetDeviceContext(item->host_device); 141 | item->max_seq_len->SetDeviceContext(item->host_device); 142 | item->max_kv_len->SetDeviceContext(item->host_device); 143 | 144 | item->kv_cache->SetBufferPtr(item->kv_cache_mem); 145 | if (model_config_.cache_quant_bit > 0) { 146 | item->kv_scale->SetBufferPtr(item->kv_scale_mem); 147 | } 148 | } 149 | } 150 | 151 | public: 152 | ppl::common::RetCode Init(WorkerPerStepCounter*); 153 | ppl::common::RetCode Execute(const ModelInput&, bool, bool, ModelOutput*, std::string*); 154 | 155 | private: 156 | int32_t tensor_parallel_size_; 157 | ppl::common::StaticThreadPool* device_worker_pool_; 158 | std::vector infer_items_; 159 | uint64_t kv_cache_max_tokens_; 160 | PostProcessor* post_processor_; 161 | ModelConfig model_config_; 162 | EngineConfig engine_config_; 163 | WorkerPerStepCounter* step_counter_; 164 | }; 165 | 166 | }} // namespace ppl::llm 167 | 168 | #endif -------------------------------------------------------------------------------- /src/generated/onnx/v23.4/llm.grpc.pb.cc: -------------------------------------------------------------------------------- 1 | // Generated by the gRPC C++ plugin. 2 | // If you make any local change, they will be lost. 3 | // source: llm.proto 4 | 5 | #include "llm.pb.h" 6 | #include "llm.grpc.pb.h" 7 | 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include 20 | #include 21 | #include 22 | namespace ppl { 23 | namespace llm { 24 | namespace proto { 25 | 26 | static const char* LLMService_method_names[] = { 27 | "/ppl.llm.proto.LLMService/Generation", 28 | }; 29 | 30 | std::unique_ptr< LLMService::Stub> LLMService::NewStub(const std::shared_ptr< ::grpc::ChannelInterface>& channel, const ::grpc::StubOptions& options) { 31 | (void)options; 32 | std::unique_ptr< LLMService::Stub> stub(new LLMService::Stub(channel, options)); 33 | return stub; 34 | } 35 | 36 | LLMService::Stub::Stub(const std::shared_ptr< ::grpc::ChannelInterface>& channel, const ::grpc::StubOptions& options) 37 | : channel_(channel), rpcmethod_Generation_(LLMService_method_names[0], options.suffix_for_stats(),::grpc::internal::RpcMethod::SERVER_STREAMING, channel) 38 | {} 39 | 40 | ::grpc::ClientReader< ::ppl::llm::proto::Response>* LLMService::Stub::GenerationRaw(::grpc::ClientContext* context, const ::ppl::llm::proto::BatchedRequest& request) { 41 | return ::grpc::internal::ClientReaderFactory< ::ppl::llm::proto::Response>::Create(channel_.get(), rpcmethod_Generation_, context, request); 42 | } 43 | 44 | void LLMService::Stub::async::Generation(::grpc::ClientContext* context, const ::ppl::llm::proto::BatchedRequest* request, ::grpc::ClientReadReactor< ::ppl::llm::proto::Response>* reactor) { 45 | ::grpc::internal::ClientCallbackReaderFactory< ::ppl::llm::proto::Response>::Create(stub_->channel_.get(), stub_->rpcmethod_Generation_, context, request, reactor); 46 | } 47 | 48 | ::grpc::ClientAsyncReader< ::ppl::llm::proto::Response>* LLMService::Stub::AsyncGenerationRaw(::grpc::ClientContext* context, const ::ppl::llm::proto::BatchedRequest& request, ::grpc::CompletionQueue* cq, void* tag) { 49 | return ::grpc::internal::ClientAsyncReaderFactory< ::ppl::llm::proto::Response>::Create(channel_.get(), cq, rpcmethod_Generation_, context, request, true, tag); 50 | } 51 | 52 | ::grpc::ClientAsyncReader< ::ppl::llm::proto::Response>* LLMService::Stub::PrepareAsyncGenerationRaw(::grpc::ClientContext* context, const ::ppl::llm::proto::BatchedRequest& request, ::grpc::CompletionQueue* cq) { 53 | return ::grpc::internal::ClientAsyncReaderFactory< ::ppl::llm::proto::Response>::Create(channel_.get(), cq, rpcmethod_Generation_, context, request, false, nullptr); 54 | } 55 | 56 | LLMService::Service::Service() { 57 | AddMethod(new ::grpc::internal::RpcServiceMethod( 58 | LLMService_method_names[0], 59 | ::grpc::internal::RpcMethod::SERVER_STREAMING, 60 | new ::grpc::internal::ServerStreamingHandler< LLMService::Service, ::ppl::llm::proto::BatchedRequest, ::ppl::llm::proto::Response>( 61 | [](LLMService::Service* service, 62 | ::grpc::ServerContext* ctx, 63 | const ::ppl::llm::proto::BatchedRequest* req, 64 | ::grpc::ServerWriter<::ppl::llm::proto::Response>* writer) { 65 | return service->Generation(ctx, req, writer); 66 | }, this))); 67 | } 68 | 69 | LLMService::Service::~Service() { 70 | } 71 | 72 | ::grpc::Status LLMService::Service::Generation(::grpc::ServerContext* context, const ::ppl::llm::proto::BatchedRequest* request, ::grpc::ServerWriter< ::ppl::llm::proto::Response>* writer) { 73 | (void) context; 74 | (void) request; 75 | (void) writer; 76 | return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, ""); 77 | } 78 | 79 | 80 | } // namespace ppl 81 | } // namespace llm 82 | } // namespace proto 83 | 84 | -------------------------------------------------------------------------------- /src/generated/onnx/v23.4/llm.grpc.pb.h: -------------------------------------------------------------------------------- 1 | // Generated by the gRPC C++ plugin. 2 | // If you make any local change, they will be lost. 3 | // source: llm.proto 4 | #ifndef GRPC_llm_2eproto__INCLUDED 5 | #define GRPC_llm_2eproto__INCLUDED 6 | 7 | #include "llm.pb.h" 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include 20 | #include 21 | #include 22 | #include 23 | #include 24 | #include 25 | #include 26 | #include 27 | 28 | namespace ppl { 29 | namespace llm { 30 | namespace proto { 31 | 32 | class LLMService final { 33 | public: 34 | static constexpr char const* service_full_name() { 35 | return "ppl.llm.proto.LLMService"; 36 | } 37 | class StubInterface { 38 | public: 39 | virtual ~StubInterface() {} 40 | std::unique_ptr< ::grpc::ClientReaderInterface< ::ppl::llm::proto::Response>> Generation(::grpc::ClientContext* context, const ::ppl::llm::proto::BatchedRequest& request) { 41 | return std::unique_ptr< ::grpc::ClientReaderInterface< ::ppl::llm::proto::Response>>(GenerationRaw(context, request)); 42 | } 43 | std::unique_ptr< ::grpc::ClientAsyncReaderInterface< ::ppl::llm::proto::Response>> AsyncGeneration(::grpc::ClientContext* context, const ::ppl::llm::proto::BatchedRequest& request, ::grpc::CompletionQueue* cq, void* tag) { 44 | return std::unique_ptr< ::grpc::ClientAsyncReaderInterface< ::ppl::llm::proto::Response>>(AsyncGenerationRaw(context, request, cq, tag)); 45 | } 46 | std::unique_ptr< ::grpc::ClientAsyncReaderInterface< ::ppl::llm::proto::Response>> PrepareAsyncGeneration(::grpc::ClientContext* context, const ::ppl::llm::proto::BatchedRequest& request, ::grpc::CompletionQueue* cq) { 47 | return std::unique_ptr< ::grpc::ClientAsyncReaderInterface< ::ppl::llm::proto::Response>>(PrepareAsyncGenerationRaw(context, request, cq)); 48 | } 49 | class async_interface { 50 | public: 51 | virtual ~async_interface() {} 52 | virtual void Generation(::grpc::ClientContext* context, const ::ppl::llm::proto::BatchedRequest* request, ::grpc::ClientReadReactor< ::ppl::llm::proto::Response>* reactor) = 0; 53 | }; 54 | typedef class async_interface experimental_async_interface; 55 | virtual class async_interface* async() { return nullptr; } 56 | class async_interface* experimental_async() { return async(); } 57 | private: 58 | virtual ::grpc::ClientReaderInterface< ::ppl::llm::proto::Response>* GenerationRaw(::grpc::ClientContext* context, const ::ppl::llm::proto::BatchedRequest& request) = 0; 59 | virtual ::grpc::ClientAsyncReaderInterface< ::ppl::llm::proto::Response>* AsyncGenerationRaw(::grpc::ClientContext* context, const ::ppl::llm::proto::BatchedRequest& request, ::grpc::CompletionQueue* cq, void* tag) = 0; 60 | virtual ::grpc::ClientAsyncReaderInterface< ::ppl::llm::proto::Response>* PrepareAsyncGenerationRaw(::grpc::ClientContext* context, const ::ppl::llm::proto::BatchedRequest& request, ::grpc::CompletionQueue* cq) = 0; 61 | }; 62 | class Stub final : public StubInterface { 63 | public: 64 | Stub(const std::shared_ptr< ::grpc::ChannelInterface>& channel, const ::grpc::StubOptions& options = ::grpc::StubOptions()); 65 | std::unique_ptr< ::grpc::ClientReader< ::ppl::llm::proto::Response>> Generation(::grpc::ClientContext* context, const ::ppl::llm::proto::BatchedRequest& request) { 66 | return std::unique_ptr< ::grpc::ClientReader< ::ppl::llm::proto::Response>>(GenerationRaw(context, request)); 67 | } 68 | std::unique_ptr< ::grpc::ClientAsyncReader< ::ppl::llm::proto::Response>> AsyncGeneration(::grpc::ClientContext* context, const ::ppl::llm::proto::BatchedRequest& request, ::grpc::CompletionQueue* cq, void* tag) { 69 | return std::unique_ptr< ::grpc::ClientAsyncReader< ::ppl::llm::proto::Response>>(AsyncGenerationRaw(context, request, cq, tag)); 70 | } 71 | std::unique_ptr< ::grpc::ClientAsyncReader< ::ppl::llm::proto::Response>> PrepareAsyncGeneration(::grpc::ClientContext* context, const ::ppl::llm::proto::BatchedRequest& request, ::grpc::CompletionQueue* cq) { 72 | return std::unique_ptr< ::grpc::ClientAsyncReader< ::ppl::llm::proto::Response>>(PrepareAsyncGenerationRaw(context, request, cq)); 73 | } 74 | class async final : 75 | public StubInterface::async_interface { 76 | public: 77 | void Generation(::grpc::ClientContext* context, const ::ppl::llm::proto::BatchedRequest* request, ::grpc::ClientReadReactor< ::ppl::llm::proto::Response>* reactor) override; 78 | private: 79 | friend class Stub; 80 | explicit async(Stub* stub): stub_(stub) { } 81 | Stub* stub() { return stub_; } 82 | Stub* stub_; 83 | }; 84 | class async* async() override { return &async_stub_; } 85 | 86 | private: 87 | std::shared_ptr< ::grpc::ChannelInterface> channel_; 88 | class async async_stub_{this}; 89 | ::grpc::ClientReader< ::ppl::llm::proto::Response>* GenerationRaw(::grpc::ClientContext* context, const ::ppl::llm::proto::BatchedRequest& request) override; 90 | ::grpc::ClientAsyncReader< ::ppl::llm::proto::Response>* AsyncGenerationRaw(::grpc::ClientContext* context, const ::ppl::llm::proto::BatchedRequest& request, ::grpc::CompletionQueue* cq, void* tag) override; 91 | ::grpc::ClientAsyncReader< ::ppl::llm::proto::Response>* PrepareAsyncGenerationRaw(::grpc::ClientContext* context, const ::ppl::llm::proto::BatchedRequest& request, ::grpc::CompletionQueue* cq) override; 92 | const ::grpc::internal::RpcMethod rpcmethod_Generation_; 93 | }; 94 | static std::unique_ptr NewStub(const std::shared_ptr< ::grpc::ChannelInterface>& channel, const ::grpc::StubOptions& options = ::grpc::StubOptions()); 95 | 96 | class Service : public ::grpc::Service { 97 | public: 98 | Service(); 99 | virtual ~Service(); 100 | virtual ::grpc::Status Generation(::grpc::ServerContext* context, const ::ppl::llm::proto::BatchedRequest* request, ::grpc::ServerWriter< ::ppl::llm::proto::Response>* writer); 101 | }; 102 | template 103 | class WithAsyncMethod_Generation : public BaseClass { 104 | private: 105 | void BaseClassMustBeDerivedFromService(const Service* /*service*/) {} 106 | public: 107 | WithAsyncMethod_Generation() { 108 | ::grpc::Service::MarkMethodAsync(0); 109 | } 110 | ~WithAsyncMethod_Generation() override { 111 | BaseClassMustBeDerivedFromService(this); 112 | } 113 | // disable synchronous version of this method 114 | ::grpc::Status Generation(::grpc::ServerContext* /*context*/, const ::ppl::llm::proto::BatchedRequest* /*request*/, ::grpc::ServerWriter< ::ppl::llm::proto::Response>* /*writer*/) override { 115 | abort(); 116 | return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, ""); 117 | } 118 | void RequestGeneration(::grpc::ServerContext* context, ::ppl::llm::proto::BatchedRequest* request, ::grpc::ServerAsyncWriter< ::ppl::llm::proto::Response>* writer, ::grpc::CompletionQueue* new_call_cq, ::grpc::ServerCompletionQueue* notification_cq, void *tag) { 119 | ::grpc::Service::RequestAsyncServerStreaming(0, context, request, writer, new_call_cq, notification_cq, tag); 120 | } 121 | }; 122 | typedef WithAsyncMethod_Generation AsyncService; 123 | template 124 | class WithCallbackMethod_Generation : public BaseClass { 125 | private: 126 | void BaseClassMustBeDerivedFromService(const Service* /*service*/) {} 127 | public: 128 | WithCallbackMethod_Generation() { 129 | ::grpc::Service::MarkMethodCallback(0, 130 | new ::grpc::internal::CallbackServerStreamingHandler< ::ppl::llm::proto::BatchedRequest, ::ppl::llm::proto::Response>( 131 | [this]( 132 | ::grpc::CallbackServerContext* context, const ::ppl::llm::proto::BatchedRequest* request) { return this->Generation(context, request); })); 133 | } 134 | ~WithCallbackMethod_Generation() override { 135 | BaseClassMustBeDerivedFromService(this); 136 | } 137 | // disable synchronous version of this method 138 | ::grpc::Status Generation(::grpc::ServerContext* /*context*/, const ::ppl::llm::proto::BatchedRequest* /*request*/, ::grpc::ServerWriter< ::ppl::llm::proto::Response>* /*writer*/) override { 139 | abort(); 140 | return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, ""); 141 | } 142 | virtual ::grpc::ServerWriteReactor< ::ppl::llm::proto::Response>* Generation( 143 | ::grpc::CallbackServerContext* /*context*/, const ::ppl::llm::proto::BatchedRequest* /*request*/) { return nullptr; } 144 | }; 145 | typedef WithCallbackMethod_Generation CallbackService; 146 | typedef CallbackService ExperimentalCallbackService; 147 | template 148 | class WithGenericMethod_Generation : public BaseClass { 149 | private: 150 | void BaseClassMustBeDerivedFromService(const Service* /*service*/) {} 151 | public: 152 | WithGenericMethod_Generation() { 153 | ::grpc::Service::MarkMethodGeneric(0); 154 | } 155 | ~WithGenericMethod_Generation() override { 156 | BaseClassMustBeDerivedFromService(this); 157 | } 158 | // disable synchronous version of this method 159 | ::grpc::Status Generation(::grpc::ServerContext* /*context*/, const ::ppl::llm::proto::BatchedRequest* /*request*/, ::grpc::ServerWriter< ::ppl::llm::proto::Response>* /*writer*/) override { 160 | abort(); 161 | return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, ""); 162 | } 163 | }; 164 | template 165 | class WithRawMethod_Generation : public BaseClass { 166 | private: 167 | void BaseClassMustBeDerivedFromService(const Service* /*service*/) {} 168 | public: 169 | WithRawMethod_Generation() { 170 | ::grpc::Service::MarkMethodRaw(0); 171 | } 172 | ~WithRawMethod_Generation() override { 173 | BaseClassMustBeDerivedFromService(this); 174 | } 175 | // disable synchronous version of this method 176 | ::grpc::Status Generation(::grpc::ServerContext* /*context*/, const ::ppl::llm::proto::BatchedRequest* /*request*/, ::grpc::ServerWriter< ::ppl::llm::proto::Response>* /*writer*/) override { 177 | abort(); 178 | return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, ""); 179 | } 180 | void RequestGeneration(::grpc::ServerContext* context, ::grpc::ByteBuffer* request, ::grpc::ServerAsyncWriter< ::grpc::ByteBuffer>* writer, ::grpc::CompletionQueue* new_call_cq, ::grpc::ServerCompletionQueue* notification_cq, void *tag) { 181 | ::grpc::Service::RequestAsyncServerStreaming(0, context, request, writer, new_call_cq, notification_cq, tag); 182 | } 183 | }; 184 | template 185 | class WithRawCallbackMethod_Generation : public BaseClass { 186 | private: 187 | void BaseClassMustBeDerivedFromService(const Service* /*service*/) {} 188 | public: 189 | WithRawCallbackMethod_Generation() { 190 | ::grpc::Service::MarkMethodRawCallback(0, 191 | new ::grpc::internal::CallbackServerStreamingHandler< ::grpc::ByteBuffer, ::grpc::ByteBuffer>( 192 | [this]( 193 | ::grpc::CallbackServerContext* context, const::grpc::ByteBuffer* request) { return this->Generation(context, request); })); 194 | } 195 | ~WithRawCallbackMethod_Generation() override { 196 | BaseClassMustBeDerivedFromService(this); 197 | } 198 | // disable synchronous version of this method 199 | ::grpc::Status Generation(::grpc::ServerContext* /*context*/, const ::ppl::llm::proto::BatchedRequest* /*request*/, ::grpc::ServerWriter< ::ppl::llm::proto::Response>* /*writer*/) override { 200 | abort(); 201 | return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, ""); 202 | } 203 | virtual ::grpc::ServerWriteReactor< ::grpc::ByteBuffer>* Generation( 204 | ::grpc::CallbackServerContext* /*context*/, const ::grpc::ByteBuffer* /*request*/) { return nullptr; } 205 | }; 206 | typedef Service StreamedUnaryService; 207 | template 208 | class WithSplitStreamingMethod_Generation : public BaseClass { 209 | private: 210 | void BaseClassMustBeDerivedFromService(const Service* /*service*/) {} 211 | public: 212 | WithSplitStreamingMethod_Generation() { 213 | ::grpc::Service::MarkMethodStreamed(0, 214 | new ::grpc::internal::SplitServerStreamingHandler< 215 | ::ppl::llm::proto::BatchedRequest, ::ppl::llm::proto::Response>( 216 | [this](::grpc::ServerContext* context, 217 | ::grpc::ServerSplitStreamer< 218 | ::ppl::llm::proto::BatchedRequest, ::ppl::llm::proto::Response>* streamer) { 219 | return this->StreamedGeneration(context, 220 | streamer); 221 | })); 222 | } 223 | ~WithSplitStreamingMethod_Generation() override { 224 | BaseClassMustBeDerivedFromService(this); 225 | } 226 | // disable regular version of this method 227 | ::grpc::Status Generation(::grpc::ServerContext* /*context*/, const ::ppl::llm::proto::BatchedRequest* /*request*/, ::grpc::ServerWriter< ::ppl::llm::proto::Response>* /*writer*/) override { 228 | abort(); 229 | return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, ""); 230 | } 231 | // replace default version of method with split streamed 232 | virtual ::grpc::Status StreamedGeneration(::grpc::ServerContext* context, ::grpc::ServerSplitStreamer< ::ppl::llm::proto::BatchedRequest,::ppl::llm::proto::Response>* server_split_streamer) = 0; 233 | }; 234 | typedef WithSplitStreamingMethod_Generation SplitStreamedService; 235 | typedef WithSplitStreamingMethod_Generation StreamedService; 236 | }; 237 | 238 | } // namespace proto 239 | } // namespace llm 240 | } // namespace ppl 241 | 242 | 243 | #endif // GRPC_llm_2eproto__INCLUDED 244 | -------------------------------------------------------------------------------- /src/generated/onnx/v3.1.0/sentencepiece.proto: -------------------------------------------------------------------------------- 1 | // Copyright 2016 Google Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License.! 14 | 15 | syntax = "proto2"; 16 | 17 | // TODO(taku): Needs to use LITE RUNTIME in OSS release. 18 | option optimize_for = LITE_RUNTIME; 19 | 20 | package sentencepiece; 21 | 22 | // SentencePieceText manages a user-facing source sentence, 23 | // postprocessed target sentence, and internal segmentation 24 | // with byte offsets. 25 | message SentencePieceText { 26 | message SentencePiece { 27 | // Internal representation for the decoder. 28 | // - Decoder can use |piece| as a basic token. 29 | // - the piece must be non-empty. 30 | // - A whitespace is replaced with a meta symbol. 31 | // - Concatenation of pieces is not always the same as the |text|. 32 | optional string piece = 1; 33 | 34 | // Vocabulary id. 35 | optional uint32 id = 2; 36 | 37 | // External representation for the client. 38 | // - It is always guaranteed that 39 | // text.substr(begin, end - begin) == surface. 40 | // - Concatenation of surface is always the same as the |text|. 41 | // - |surface| may contain whitespaces. 42 | // - |surface| may be empty if the piece encodes 43 | // a control vocabulary. e.g., , , . 44 | // - When |surface| is empty, always begin == end. (zero-length span). 45 | optional string surface = 3; 46 | 47 | optional uint32 begin = 4; 48 | optional uint32 end = 5; 49 | 50 | // Customized extensions: the range of field numbers 51 | // are open to third-party extensions. 52 | extensions 200 to max; 53 | } 54 | 55 | // User input or postprocessed text. This should be immutable 56 | // since the byte range in SentencePiece is pointing to a span over this 57 | // text. Meta symbols for whitespaces are not included. 58 | optional string text = 1; 59 | 60 | // A sequence of sentence pieces. 61 | repeated SentencePiece pieces = 2; 62 | 63 | // Score (usually log probability) for MultiSentencePieceText. 64 | optional float score = 3; 65 | 66 | // Customized extensions: the range of field numbers 67 | // are open to third-party extensions. 68 | extensions 200 to max; 69 | } 70 | 71 | message NBestSentencePieceText { 72 | repeated SentencePieceText nbests = 1; 73 | } 74 | -------------------------------------------------------------------------------- /src/generator/llm_generator.h: -------------------------------------------------------------------------------- 1 | // Licensed to the Apache Software Foundation (ASF) under one 2 | // or more contributor license agreements. See the NOTICE file 3 | // distributed with this work for additional information 4 | // regarding copyright ownership. The ASF licenses this file 5 | // to you under the Apache License, Version 2.0 (the 6 | // "License"); you may not use this file except in compliance 7 | // with the License. You may obtain a copy of the License at 8 | // 9 | // http://www.apache.org/licenses/LICENSE-2.0 10 | // 11 | // Unless required by applicable law or agreed to in writing, 12 | // software distributed under the License is distributed on an 13 | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | // KIND, either express or implied. See the License for the 15 | // specific language governing permissions and limitations 16 | // under the License. 17 | 18 | #ifndef __PPL_LLM_LLM_GENERATOR_H__ 19 | #define __PPL_LLM_LLM_GENERATOR_H__ 20 | 21 | #include "../engine/llm_engine.h" 22 | #include "../common/config.h" 23 | #include "../common/request.h" 24 | #include "../common/connection.h" 25 | #include "../common/resource.h" 26 | #include "../common/post_processor.h" 27 | #include "../tokenizer/tokenizer.h" 28 | #include "../utils/index_manager.h" 29 | #include "../utils/mpsc_request_scheduler.h" 30 | #include "../utils/prefix_cache_manager.h" 31 | #include "../common/profiler.h" 32 | 33 | #include "ppl/nn/models/onnx/runtime_builder_factory.h" 34 | #include "ppl/nn/runtime/tensor.h" 35 | #include "ppl/common/threadpool.h" 36 | #include "ppl/common/typed_mpsc_queue.h" 37 | #include "ppl/common/event_count.h" 38 | #include "ppl/common/page_manager.h" 39 | 40 | #include 41 | #include 42 | #include 43 | #include 44 | #include 45 | #include 46 | 47 | namespace ppl { namespace llm { 48 | 49 | struct TidGenToken final { 50 | TidGenToken(uint64_t _tid, int _token, float _logprob, FinishFlag _finish_flag, uint64_t _steps, 51 | bool _is_token_in_out, bool _is_special) 52 | : tid(_tid) 53 | , token(_token) 54 | , logprob(_logprob) 55 | , finish_flag(_finish_flag) 56 | , steps(_steps) 57 | , is_token_in_out(_is_token_in_out) 58 | , is_special(_is_special) {} 59 | uint64_t tid; 60 | int token; 61 | float logprob; 62 | FinishFlag finish_flag; 63 | uint64_t steps; 64 | bool is_token_in_out; 65 | bool is_special; 66 | }; 67 | 68 | struct FinishedTaskInfo final { 69 | FinishedTaskInfo(uint64_t fid = UINT64_MAX, uint32_t ftype = UNKNOWN) : id(fid), type(ftype) {} 70 | uint64_t id; 71 | enum { 72 | UNKNOWN, 73 | FROM_WORKER, 74 | FROM_CONN, 75 | }; 76 | uint32_t type; 77 | }; 78 | 79 | struct TidData final { 80 | uint64_t tid; 81 | float temperature; 82 | float top_p; 83 | int32_t top_k; 84 | float repetition_penalty; 85 | float presence_penalty; 86 | float frequency_penalty; 87 | bool early_stopping; 88 | int32_t rest_iters; 89 | bool is_token_in_out = false; 90 | int64_t total_len; 91 | std::shared_ptr> stop_tokens; 92 | 93 | std::shared_ptr> next_tokens; 94 | // int64_t seqlen; 95 | int64_t start_pos; 96 | uint64_t cache_index; 97 | std::vector page_list; 98 | int64_t slot_index; 99 | int32_t steps; 100 | int32_t gen_tokens_cnt = 0; 101 | std::vector hash_list; 102 | int64_t cache_hit_count = 0; 103 | }; 104 | 105 | struct Controller { 106 | bool req_list_changed = true; 107 | ppl::common::TypedMPSCQueue finished_tasks; 108 | std::vector tid_list; 109 | void Reset() { 110 | tid_list.clear(); 111 | req_list_changed = true; 112 | while (true) { 113 | FinishedTaskInfo info; 114 | bool ok = finished_tasks.Pop(&info); 115 | if (!ok) { 116 | break; 117 | } 118 | } 119 | } 120 | }; 121 | 122 | struct LlmRequest final : public ppl::common::MPSCQueue::Node { 123 | std::shared_ptr orig; 124 | std::chrono::time_point enqueue_ts; 125 | }; 126 | 127 | class LLMGenerator final { 128 | public: 129 | LLMGenerator(const Resource& resource, const GeneratorConfig& generator_config, const ModelConfig& model_config, 130 | Connection* conn); 131 | 132 | ~LLMGenerator() { 133 | bool is_active = generate_thread_active_.load(std::memory_order_relaxed); 134 | if (is_active) { 135 | generate_thread_active_.store(false, std::memory_order_release); 136 | req_signal_.NotifyOne(); 137 | pthread_join(generate_thread_, nullptr); 138 | } 139 | } 140 | ppl::common::RetCode Init(); 141 | void Process(const std::shared_ptr&); 142 | 143 | void ClearTask(uint64_t tid) { 144 | controller_.finished_tasks.Push(FinishedTaskInfo(tid, FinishedTaskInfo::FROM_CONN)); 145 | } 146 | 147 | uint32_t GetPendingTaskNum() const { 148 | return sched_.GetPendingSize(); 149 | } 150 | 151 | private: 152 | ppl::common::RetCode CheckParameters() const; 153 | void Generate(); 154 | void DeleteTasks(ModelInput*, ppl::common::TypedMPSCQueue*, std::map*, 155 | uint64_t*, uint64_t*); 156 | void ReleaseResource(); 157 | 158 | private: 159 | static void* GeneratorThreadFunc(void*); 160 | 161 | private: 162 | const Tokenizer* tokenizer_; 163 | GeneratorConfig generator_config_; 164 | ModelConfig model_config_; 165 | Connection* conn_; 166 | uint64_t kv_cache_max_tokens_ = 0; 167 | LLMEngine llm_engine_; 168 | ppl::common::StaticThreadPool decoder_thread_pool_; 169 | 170 | Controller controller_; 171 | utils::IndexManager idx_mgr_; 172 | utils::IndexManager batch_slots_mgr_; 173 | ppl::common::PageManager page_mgr_; 174 | std::shared_ptr worker_profiler_; 175 | 176 | std::atomic generate_thread_active_ = {false}; 177 | pthread_t generate_thread_; 178 | ppl::common::EventCount req_signal_; 179 | utils::MPSCRequestScheduler sched_; 180 | utils::PrefixCacheManager prefix_cache_mgr_; 181 | static constexpr int DECODER_THREAD_NUM = 1; 182 | }; 183 | 184 | }} // namespace ppl::llm 185 | 186 | #endif -------------------------------------------------------------------------------- /src/serving/grpc/grpc_server.h: -------------------------------------------------------------------------------- 1 | // Licensed to the Apache Software Foundation (ASF) under one 2 | // or more contributor license agreements. See the NOTICE file 3 | // distributed with this work for additional information 4 | // regarding copyright ownership. The ASF licenses this file 5 | // to you under the Apache License, Version 2.0 (the 6 | // "License"); you may not use this file except in compliance 7 | // with the License. You may obtain a copy of the License at 8 | // 9 | // http://www.apache.org/licenses/LICENSE-2.0 10 | // 11 | // Unless required by applicable law or agreed to in writing, 12 | // software distributed under the License is distributed on an 13 | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | // KIND, either express or implied. See the License for the 15 | // specific language governing permissions and limitations 16 | // under the License. 17 | 18 | #ifndef __SERVING_GRPC_SERVER_H__ 19 | #define __SERVING_GRPC_SERVER_H__ 20 | 21 | 22 | #include "llm.grpc.pb.h" 23 | #include "grpcpp/grpcpp.h" 24 | #include "../../common/connection.h" 25 | #include "../../generator/llm_generator.h" 26 | #include "ppl/common/retcode.h" 27 | 28 | #include 29 | #include 30 | #include 31 | #include 32 | 33 | namespace ppl { namespace llm { 34 | 35 | struct GRPCEvent final { 36 | GRPCEvent() : writer(&ctx), refcount(0) { 37 | pthread_mutex_init(&send_lock, nullptr); 38 | } 39 | ~GRPCEvent() { 40 | pthread_mutex_destroy(&send_lock); 41 | } 42 | 43 | enum { 44 | NEW, 45 | SENDING, 46 | FINISHED, 47 | } status = NEW; 48 | proto::BatchedRequest pb_req; 49 | grpc::ServerContext ctx; 50 | 51 | pthread_mutex_t send_lock; 52 | int nr_finished_req = 0; 53 | std::list send_queue; 54 | grpc::ServerAsyncWriter writer; 55 | 56 | /* mapped ids of pb_req. used to remove info when the connection is gone */ 57 | uint64_t mapped_id_start = UINT64_MAX; 58 | 59 | /* 60 | an event may be invalid during processing. for example, client is killed before processing is done. 61 | */ 62 | std::atomic refcount; 63 | 64 | GRPCEvent(const GRPCEvent&) = delete; 65 | GRPCEvent(GRPCEvent&&) = delete; 66 | void operator=(const GRPCEvent&) = delete; 67 | void operator=(GRPCEvent&&) = delete; 68 | }; 69 | 70 | struct GRPCReqInfo final { 71 | uint64_t orig_id = 0; 72 | GRPCEvent* event = nullptr; 73 | }; 74 | 75 | class GRPCConnection final : public Connection { 76 | public: 77 | GRPCConnection() { 78 | pthread_mutex_init(&id2info_lock_, nullptr); 79 | } 80 | ~GRPCConnection() { 81 | pthread_mutex_destroy(&id2info_lock_); 82 | } 83 | 84 | bool AddInfo(uint64_t id, const GRPCReqInfo& info); 85 | void FindInfo(uint64_t id, GRPCReqInfo* info); 86 | void RemoveInfo(uint64_t id, GRPCReqInfo* info); 87 | 88 | void OnTokenize(uint64_t, const std::vector&) override {} 89 | void OnProfiling(const std::shared_ptr&) override; 90 | void Send(const std::vector&) override; 91 | void NotifyFailure(uint64_t, ppl::common::RetCode, const std::string& errmsg) override; 92 | 93 | private: 94 | pthread_mutex_t id2info_lock_; 95 | std::map id2info_; 96 | }; 97 | 98 | class GRPCServer final { 99 | public: 100 | GRPCServer(GRPCConnection*, const std::function& on_disconnected_func); 101 | ~GRPCServer(); 102 | ppl::common::RetCode Init(const std::string& addr); 103 | void Loop(LLMGenerator*); 104 | 105 | private: 106 | static void* NewCallThreadFunc(void*); 107 | 108 | public: 109 | struct ThreadArg final { 110 | ppl::llm::proto::LLMService::AsyncService service; 111 | std::function on_disconnected_func = {}; 112 | 113 | /* 114 | https://groups.google.com/g/grpc-io/c/V4NAQ77PMEo 115 | 116 | notification_cq gets the tag back indicating a call has started. 117 | All subsequent operations (reads, writes, etc) on that call report back to new_call_cq. 118 | */ 119 | std::unique_ptr notification_cq; 120 | std::unique_ptr new_call_cq; 121 | 122 | GRPCConnection* conn = nullptr; 123 | }; 124 | 125 | private: 126 | grpc::ServerBuilder builder_; 127 | std::unique_ptr server_; 128 | 129 | uint64_t uuid_seq_ = 0; 130 | bool new_call_thread_created_ = false; 131 | pthread_t new_call_thread_; 132 | ThreadArg arg_; 133 | 134 | private: 135 | GRPCServer(const GRPCServer&) = delete; 136 | void operator=(const GRPCServer&) = delete; 137 | GRPCServer(GRPCServer&&) = delete; 138 | void operator=(GRPCServer&&) = delete; 139 | }; 140 | 141 | }} // namespace ppl::llm 142 | 143 | #endif 144 | -------------------------------------------------------------------------------- /src/serving/grpc/proto/llm.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package ppl.llm.proto; 4 | 5 | service LLMService { 6 | rpc Generation (BatchedRequest) returns (stream BatchedResponse) {} 7 | } 8 | 9 | message Tokens { 10 | /// Token IDs 11 | repeated uint32 ids = 1; 12 | } 13 | 14 | message NextTokenChooserParameters { 15 | /// exponential scaling output probability distribution 16 | float temperature = 1; 17 | /// restricting to the k highest probability elements 18 | uint32 top_k = 2; 19 | /// restricting to top tokens summing to prob_cut_off <= prob_cut_off 20 | float top_p = 3; 21 | /// [not used] restricting to top tokens summing to prob_cut_off <= prob_cut_off, not used 22 | float typical_p = 4; 23 | /// apply sampling on the logits 24 | bool do_sample = 5; 25 | /// [not used] random seed for sampling, not used 26 | uint64 seed = 6; 27 | /// repetition penalty 28 | float repetition_penalty = 7; 29 | /// repetition penalty 30 | float presence_penalty = 8; 31 | /// repetition penalty 32 | float frequency_penalty = 9; 33 | /// [not used] token watermarking using "A Watermark for Large Language Models" 34 | bool watermark = 10; 35 | } 36 | 37 | message StoppingCriteriaParameters { 38 | /// Maximum number of generated tokens 39 | uint32 max_new_tokens = 1; 40 | /// Optional stopping tokens array 41 | Tokens stop_tokens = 2; 42 | /// Ignore end of sequence token 43 | /// used for benchmarking 44 | bool ignore_eos_token = 3; 45 | } 46 | 47 | message Request { 48 | uint64 id = 1; 49 | string prompt = 2; 50 | Tokens tokens = 3; 51 | NextTokenChooserParameters choosing_parameters = 4; 52 | StoppingCriteriaParameters stopping_parameters = 5; 53 | } 54 | 55 | message BatchedRequest { 56 | repeated Request req = 1; 57 | } 58 | 59 | enum Status { 60 | PROCESSING = 0; 61 | FINISHED = 1; 62 | FAILED = 2; 63 | } 64 | 65 | enum FinishReason { 66 | FINISH_REASON_LENGTH = 0; 67 | FINISH_REASON_EOS_TOKEN = 1; 68 | FINISH_REASON_STOP_SEQUENCE = 2; 69 | } 70 | 71 | message Detail { 72 | float logprobs = 1; 73 | bool is_special = 2; 74 | FinishReason finish_reason = 3; 75 | } 76 | 77 | message Response { 78 | Status status = 1; 79 | uint64 id = 2; 80 | string generated = 3; 81 | Tokens tokens = 4; 82 | Detail detail = 5; 83 | } 84 | 85 | message BatchedResponse { 86 | repeated Response rsp = 1; 87 | } -------------------------------------------------------------------------------- /src/tokenizer/models/baichuan/baichuan_tokenizer.h: -------------------------------------------------------------------------------- 1 | // Licensed to the Apache Software Foundation (ASF) under one 2 | // or more contributor license agreements. See the NOTICE file 3 | // distributed with this work for additional information 4 | // regarding copyright ownership. The ASF licenses this file 5 | // to you under the Apache License, Version 2.0 (the 6 | // "License"); you may not use this file except in compliance 7 | // with the License. You may obtain a copy of the License at 8 | // 9 | // http://www.apache.org/licenses/LICENSE-2.0 10 | // 11 | // Unless required by applicable law or agreed to in writing, 12 | // software distributed under the License is distributed on an 13 | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | // KIND, either express or implied. See the License for the 15 | // specific language governing permissions and limitations 16 | // under the License. 17 | 18 | #ifndef __PPL_LLM_BAICHUAN_TOKENIZER_H__ 19 | #define __PPL_LLM_BAICHUAN_TOKENIZER_H__ 20 | 21 | #include "tokenizer/tokenizer.h" 22 | #include "tokenizer/tokenizer_impl.h" 23 | #include "absl/strings/string_view.h" 24 | #include "ppl/nn/common/logger.h" 25 | 26 | namespace ppl { namespace llm { 27 | 28 | class BaiChuanTokenizer final : public Tokenizer { 29 | public: 30 | BaiChuanTokenizer(TokenizerImpl* impl) { 31 | impl_ = std::unique_ptr(impl); 32 | } 33 | ~BaiChuanTokenizer() {} 34 | 35 | void Encode(const char* prompt, uint32_t len, std::vector* token_ids) const override { 36 | impl_->Encode(prompt, len, token_ids); 37 | } 38 | 39 | void Decode(int* token_ids, uint32_t len, std::string* output) const override { 40 | impl_->Decode(token_ids, len, output); 41 | } 42 | 43 | int GetBosId() const override { 44 | return impl_->GetBosId(); 45 | } 46 | 47 | int GetEosId() const override { 48 | return impl_->GetEosId(); 49 | } 50 | 51 | private: 52 | std::unique_ptr impl_; 53 | }; 54 | 55 | }} // namespace ppl::llm 56 | 57 | #endif 58 | -------------------------------------------------------------------------------- /src/tokenizer/models/internlm/internlm_tokenizer.h: -------------------------------------------------------------------------------- 1 | // Licensed to the Apache Software Foundation (ASF) under one 2 | // or more contributor license agreements. See the NOTICE file 3 | // distributed with this work for additional information 4 | // regarding copyright ownership. The ASF licenses this file 5 | // to you under the Apache License, Version 2.0 (the 6 | // "License"); you may not use this file except in compliance 7 | // with the License. You may obtain a copy of the License at 8 | // 9 | // http://www.apache.org/licenses/LICENSE-2.0 10 | // 11 | // Unless required by applicable law or agreed to in writing, 12 | // software distributed under the License is distributed on an 13 | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | // KIND, either express or implied. See the License for the 15 | // specific language governing permissions and limitations 16 | // under the License. 17 | 18 | #ifndef __PPL_LLM_INTERNLM_TOKENIZER_H__ 19 | #define __PPL_LLM_INTERNLM_TOKENIZER_H__ 20 | 21 | #include "tokenizer/tokenizer.h" 22 | #include "tokenizer/tokenizer_impl.h" 23 | #include "absl/strings/string_view.h" 24 | #include "ppl/nn/common/logger.h" 25 | 26 | namespace ppl { namespace llm { 27 | 28 | class InternLMTokenizer final : public Tokenizer { 29 | public: 30 | InternLMTokenizer(TokenizerImpl* impl) { 31 | impl_ = std::unique_ptr(impl); 32 | } 33 | ~InternLMTokenizer() {} 34 | void Encode(const char* prompt, uint32_t len, std::vector* token_ids) const override { 35 | impl_->Encode(prompt, len, token_ids); 36 | token_ids->insert(token_ids->begin(), impl_->GetBosId()); 37 | } 38 | 39 | void Decode(int* token_ids, uint32_t len, std::string* output) const override { 40 | impl_->Decode(token_ids, len, output); 41 | } 42 | 43 | int GetBosId() const override { 44 | return impl_->GetBosId(); 45 | } 46 | 47 | int GetEosId() const override { 48 | return impl_->GetEosId(); 49 | } 50 | 51 | private: 52 | std::unique_ptr impl_; 53 | }; 54 | 55 | }} // namespace ppl::llm 56 | 57 | #endif 58 | -------------------------------------------------------------------------------- /src/tokenizer/models/llama/llama_tokenizer.h: -------------------------------------------------------------------------------- 1 | // Licensed to the Apache Software Foundation (ASF) under one 2 | // or more contributor license agreements. See the NOTICE file 3 | // distributed with this work for additional information 4 | // regarding copyright ownership. The ASF licenses this file 5 | // to you under the Apache License, Version 2.0 (the 6 | // "License"); you may not use this file except in compliance 7 | // with the License. You may obtain a copy of the License at 8 | // 9 | // http://www.apache.org/licenses/LICENSE-2.0 10 | // 11 | // Unless required by applicable law or agreed to in writing, 12 | // software distributed under the License is distributed on an 13 | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | // KIND, either express or implied. See the License for the 15 | // specific language governing permissions and limitations 16 | // under the License. 17 | 18 | #ifndef __PPL_LLM_LLAMA_TOKENIZER_H__ 19 | #define __PPL_LLM_LLAMA_TOKENIZER_H__ 20 | 21 | #include "tokenizer/tokenizer.h" 22 | #include "tokenizer/tokenizer_impl.h" 23 | #include "absl/strings/string_view.h" 24 | #include "ppl/nn/common/logger.h" 25 | 26 | namespace ppl { namespace llm { 27 | 28 | class LlamaTokenizer final : public Tokenizer { 29 | public: 30 | LlamaTokenizer(TokenizerImpl* impl) { 31 | impl_ = std::unique_ptr(impl); 32 | } 33 | ~LlamaTokenizer() {} 34 | 35 | void Encode(const char* prompt, uint32_t len, std::vector* token_ids) const override { 36 | impl_->Encode(prompt, len, token_ids); 37 | token_ids->insert(token_ids->begin(), impl_->GetBosId()); 38 | } 39 | 40 | void Decode(int* token_ids, uint32_t len, std::string* output) const override { 41 | impl_->Decode(token_ids, len, output); 42 | } 43 | 44 | int GetBosId() const override { 45 | return impl_->GetBosId(); 46 | } 47 | 48 | int GetEosId() const override { 49 | return impl_->GetEosId(); 50 | } 51 | 52 | private: 53 | std::unique_ptr impl_; 54 | }; 55 | 56 | }} // namespace ppl::llm 57 | 58 | #endif 59 | -------------------------------------------------------------------------------- /src/tokenizer/models/llama3/llama3_tokenizer.h: -------------------------------------------------------------------------------- 1 | // Licensed to the Apache Software Foundation (ASF) under one 2 | // or more contributor license agreements. See the NOTICE file 3 | // distributed with this work for additional information 4 | // regarding copyright ownership. The ASF licenses this file 5 | // to you under the Apache License, Version 2.0 (the 6 | // "License"); you may not use this file except in compliance 7 | // with the License. You may obtain a copy of the License at 8 | // 9 | // http://www.apache.org/licenses/LICENSE-2.0 10 | // 11 | // Unless required by applicable law or agreed to in writing, 12 | // software distributed under the License is distributed on an 13 | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | // KIND, either express or implied. See the License for the 15 | // specific language governing permissions and limitations 16 | // under the License. 17 | 18 | #ifndef __PPL_LLM_LLAMA3_TOKENIZER_H__ 19 | #define __PPL_LLM_LLAMA3_TOKENIZER_H__ 20 | 21 | #include "tokenizer/tokenizer.h" 22 | #include "tokenizer/tokenizer_impl.h" 23 | #include "absl/strings/string_view.h" 24 | #include "ppl/nn/common/logger.h" 25 | 26 | namespace ppl { namespace llm { 27 | 28 | class Llama3Tokenizer final : public Tokenizer { 29 | public: 30 | Llama3Tokenizer(TokenizerImpl* impl) { 31 | impl_ = std::unique_ptr(impl); 32 | } 33 | ~Llama3Tokenizer() {} 34 | 35 | void Encode(const char* prompt, uint32_t len, std::vector* token_ids) const override { 36 | impl_->Encode(prompt, len, token_ids); 37 | token_ids->insert(token_ids->begin(), impl_->GetBosId()); 38 | } 39 | 40 | void Decode(int* token_ids, uint32_t len, std::string* output) const override { 41 | impl_->Decode(token_ids, len, output); 42 | } 43 | 44 | int GetBosId() const override { 45 | return impl_->GetBosId(); 46 | } 47 | 48 | int GetEosId() const override { 49 | return impl_->GetEosId(); 50 | } 51 | 52 | private: 53 | std::unique_ptr impl_; 54 | }; 55 | 56 | }} // namespace ppl::llm 57 | 58 | #endif 59 | -------------------------------------------------------------------------------- /src/tokenizer/models/llama3/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "bos_token": "<|begin_of_text|>", 3 | "eos_token": "<|end_of_text|>", 4 | "pad_token": "" 5 | } -------------------------------------------------------------------------------- /src/tokenizer/tokenizer.h: -------------------------------------------------------------------------------- 1 | // Licensed to the Apache Software Foundation (ASF) under one 2 | // or more contributor license agreements. See the NOTICE file 3 | // distributed with this work for additional information 4 | // regarding copyright ownership. The ASF licenses this file 5 | // to you under the Apache License, Version 2.0 (the 6 | // "License"); you may not use this file except in compliance 7 | // with the License. You may obtain a copy of the License at 8 | // 9 | // http://www.apache.org/licenses/LICENSE-2.0 10 | // 11 | // Unless required by applicable law or agreed to in writing, 12 | // software distributed under the License is distributed on an 13 | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | // KIND, either express or implied. See the License for the 15 | // specific language governing permissions and limitations 16 | // under the License. 17 | 18 | #ifndef __PPL_LLM_TOKENIZER_H__ 19 | #define __PPL_LLM_TOKENIZER_H__ 20 | 21 | #include "tokenizer_impl.h" 22 | 23 | #include 24 | #include 25 | #include 26 | 27 | namespace ppl { namespace llm { 28 | 29 | class Tokenizer { 30 | public: 31 | virtual ~Tokenizer() {} 32 | virtual void Encode(const char* prompt, uint32_t len, std::vector* token_ids) const = 0; 33 | virtual void Decode(int* token_ids, uint32_t len, std::string* output) const = 0; 34 | virtual int GetBosId() const = 0; 35 | virtual int GetEosId() const = 0; 36 | }; 37 | 38 | }} // namespace ppl::llm 39 | 40 | #endif -------------------------------------------------------------------------------- /src/tokenizer/tokenizer_factory.h: -------------------------------------------------------------------------------- 1 | // Licensed to the Apache Software Foundation (ASF) under one 2 | // or more contributor license agreements. See the NOTICE file 3 | // distributed with this work for additional information 4 | // regarding copyright ownership. The ASF licenses this file 5 | // to you under the Apache License, Version 2.0 (the 6 | // "License"); you may not use this file except in compliance 7 | // with the License. You may obtain a copy of the License at 8 | // 9 | // http://www.apache.org/licenses/LICENSE-2.0 10 | // 11 | // Unless required by applicable law or agreed to in writing, 12 | // software distributed under the License is distributed on an 13 | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | // KIND, either express or implied. See the License for the 15 | // specific language governing permissions and limitations 16 | // under the License. 17 | 18 | #ifndef __PPL_LLM_TOKENIZER_FACTORY_H__ 19 | #define __PPL_LLM_TOKENIZER_FACTORY_H__ 20 | 21 | #include "tokenizer/tokenizer.h" 22 | #ifdef PPL_LLM_ENABLE_HF_TOKENIZER 23 | #include "tokenizer/tokenizer_impl_hf.h" 24 | #endif 25 | #include "tokenizer/tokenizer_impl_sp.h" 26 | #include "models/llama/llama_tokenizer.h" 27 | #include "models/internlm/internlm_tokenizer.h" 28 | #include "models/baichuan/baichuan_tokenizer.h" 29 | #include "models/llama3/llama3_tokenizer.h" 30 | #include "common/resource.h" 31 | #include "common/config.h" 32 | 33 | #include 34 | 35 | namespace ppl { namespace llm { 36 | 37 | class TokenizerFactory final { 38 | public: 39 | static Tokenizer* Create(const std::string& model_type, const std::string& tokenizer_type, 40 | const std::string& tokenizer_path, const std::string& tokenizer_config_path) { 41 | std::unique_ptr tokenizer_impl; 42 | if (tokenizer_type == "sentencepiece") { 43 | tokenizer_impl = std::make_unique(); 44 | auto rc = tokenizer_impl->Init(tokenizer_path, tokenizer_config_path); 45 | if (rc != ppl::common::RC_SUCCESS) { 46 | LOG(ERROR) << "sentencepiece tokenizer init failed"; 47 | return nullptr; 48 | } 49 | #ifdef PPL_LLM_ENABLE_HF_TOKENIZER 50 | } else if (tokenizer_type == "huggingface") { 51 | tokenizer_impl = std::make_unique(); 52 | auto rc = tokenizer_impl->Init(tokenizer_path, tokenizer_config_path); 53 | if (rc != ppl::common::RC_SUCCESS) { 54 | LOG(ERROR) << "huggingface tokenizer init failed"; 55 | return nullptr; 56 | } 57 | #endif 58 | } else { 59 | LOG(ERROR) << "not supported tokenizer: " << tokenizer_type; 60 | return nullptr; 61 | } 62 | 63 | std::unique_ptr tokenizer; 64 | if (model_type == "llama") { 65 | tokenizer = std::make_unique(tokenizer_impl.release()); 66 | } else if (model_type == "internlm") { 67 | tokenizer = std::make_unique(tokenizer_impl.release()); 68 | } else if (model_type == "baichuan") { 69 | tokenizer = std::make_unique(tokenizer_impl.release()); 70 | } else if (model_type == "llama3") { 71 | tokenizer = std::make_unique(tokenizer_impl.release()); 72 | } else { 73 | LOG(ERROR) << "not supported model: " << model_type; 74 | return nullptr; 75 | } 76 | return tokenizer.release(); 77 | } 78 | }; 79 | 80 | }} // namespace ppl::llm 81 | 82 | #endif 83 | -------------------------------------------------------------------------------- /src/tokenizer/tokenizer_impl.h: -------------------------------------------------------------------------------- 1 | // Licensed to the Apache Software Foundation (ASF) under one 2 | // or more contributor license agreements. See the NOTICE file 3 | // distributed with this work for additional information 4 | // regarding copyright ownership. The ASF licenses this file 5 | // to you under the Apache License, Version 2.0 (the 6 | // "License"); you may not use this file except in compliance 7 | // with the License. You may obtain a copy of the License at 8 | // 9 | // http://www.apache.org/licenses/LICENSE-2.0 10 | // 11 | // Unless required by applicable law or agreed to in writing, 12 | // software distributed under the License is distributed on an 13 | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | // KIND, either express or implied. See the License for the 15 | // specific language governing permissions and limitations 16 | // under the License. 17 | 18 | #ifndef __PPL_LLM_TOKENIZER_IMPL_H__ 19 | #define __PPL_LLM_TOKENIZER_IMPL_H__ 20 | 21 | #include "tokenizer.h" 22 | #include "ppl/nn/common/logger.h" 23 | 24 | #include 25 | #include 26 | #include 27 | #include 28 | 29 | namespace ppl { namespace llm { 30 | 31 | class TokenizerImpl { 32 | public: 33 | virtual ~TokenizerImpl() {} 34 | virtual ppl::common::RetCode Init(const std::string& path, const std::string& conifg_path = "") = 0; 35 | virtual void Encode(const char* prompt, uint32_t len, std::vector* token_ids) const = 0; 36 | 37 | virtual void Decode(int* token_ids, uint32_t len, std::string* output) const = 0; 38 | virtual int GetBosId() const = 0; 39 | virtual int GetEosId() const = 0; 40 | }; 41 | 42 | }} // namespace ppl::llm 43 | 44 | #endif -------------------------------------------------------------------------------- /src/tokenizer/tokenizer_impl_hf.h: -------------------------------------------------------------------------------- 1 | // Licensed to the Apache Software Foundation (ASF) under one 2 | // or more contributor license agreements. See the NOTICE file 3 | // distributed with this work for additional information 4 | // regarding copyright ownership. The ASF licenses this file 5 | // to you under the Apache License, Version 2.0 (the 6 | // "License"); you may not use this file except in compliance 7 | // with the License. You may obtain a copy of the License at 8 | // 9 | // http://www.apache.org/licenses/LICENSE-2.0 10 | // 11 | // Unless required by applicable law or agreed to in writing, 12 | // software distributed under the License is distributed on an 13 | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | // KIND, either express or implied. See the License for the 15 | // specific language governing permissions and limitations 16 | // under the License. 17 | 18 | #ifndef __PPL_LLM_TOKENIZER_IMPL_HF_H__ 19 | #define __PPL_LLM_TOKENIZER_IMPL_HF_H__ 20 | 21 | #include "tokenizer_impl.h" 22 | #ifdef PPL_LLM_ENABLE_HF_TOKENIZER 23 | #include "tokenizers_cpp.h" 24 | #endif 25 | #include "ppl/nn/common/logger.h" 26 | #include "rapidjson/document.h" 27 | #include "rapidjson/istreamwrapper.h" 28 | 29 | #include 30 | #include 31 | #include 32 | 33 | namespace ppl { namespace llm { 34 | 35 | static bool LoadBytesFromFile(const std::string& path, std::string* data) { 36 | std::ifstream fs(path, std::ios::in | std::ios::binary); 37 | if (fs.fail()) { 38 | LOG(ERROR) << "Cannot open " << path; 39 | return false; 40 | } 41 | fs.seekg(0, std::ios::end); 42 | size_t size = static_cast(fs.tellg()); 43 | fs.seekg(0, std::ios::beg); 44 | data->resize(size); 45 | fs.read(data->data(), size); 46 | return true; 47 | } 48 | 49 | static bool ParseTokenizerConfig(const std::string& conifg_path, std::string* bos_token, std::string* eos_token) { 50 | std::ifstream ifs(conifg_path); 51 | rapidjson::IStreamWrapper isw(ifs); 52 | rapidjson::Document document; 53 | if (document.ParseStream(isw) == false) { 54 | LOG(ERROR) << "ParseStream failed"; 55 | return false; 56 | } 57 | document.ParseStream(isw); 58 | 59 | auto it = document.FindMember("bos_token"); 60 | if (it == document.MemberEnd()) { 61 | LOG(ERROR) << "find key [bos_token] failed"; 62 | return false; 63 | } 64 | *bos_token = it->value.GetString(); 65 | 66 | it = document.FindMember("eos_token"); 67 | if (it == document.MemberEnd()) { 68 | LOG(ERROR) << "find key [eos_token] failed"; 69 | return false; 70 | } 71 | *eos_token = it->value.GetString(); 72 | 73 | return true; 74 | } 75 | 76 | class TokenizerImplHF final : public TokenizerImpl { 77 | public: 78 | ~TokenizerImplHF() {} 79 | ppl::common::RetCode Init(const std::string& path, const std::string& conifg_path = "") override { 80 | if (conifg_path.empty()) { 81 | LOG(ERROR) << "No config file for HuggingFace Tokenizer"; 82 | return ppl::common::RC_OTHER_ERROR; 83 | } 84 | std::string bos_token, eos_token; 85 | if (!ParseTokenizerConfig(conifg_path, &bos_token, &eos_token)) { 86 | LOG(ERROR) << "ParseTokenizerConfig failed"; 87 | return ppl::common::RC_OTHER_ERROR; 88 | } 89 | 90 | std::string blob; 91 | if (!LoadBytesFromFile(path, &blob)) { 92 | LOG(ERROR) << "LoadBytesFromFile failed"; 93 | return ppl::common::RC_OTHER_ERROR; 94 | } 95 | 96 | hf_processor_ = ::tokenizers::Tokenizer::FromBlobJSON(blob); 97 | if (!hf_processor_) { 98 | LOG(ERROR) << "Init HuggingFace Tokenizer failed"; 99 | return ppl::common::RC_OTHER_ERROR; 100 | } 101 | LOG(INFO) << "VOCAB_SIZE: " << hf_processor_->GetVocabSize(); 102 | LOG(INFO) << "bos_token: " << bos_token; 103 | LOG(INFO) << "eos_token: " << eos_token; 104 | bos_id_ = hf_processor_->TokenToId(bos_token); 105 | if (bos_id_ == -1) { 106 | LOG(ERROR) << "illegal bos token, bos_id_ is -1"; 107 | return ppl::common::RC_OTHER_ERROR; 108 | } 109 | eos_id_ = hf_processor_->TokenToId(eos_token); 110 | if (eos_id_ == -1) { 111 | LOG(ERROR) << "illegal eos token, eos_id_ is -1"; 112 | return ppl::common::RC_OTHER_ERROR; 113 | } 114 | return ppl::common::RC_SUCCESS; 115 | } 116 | 117 | void Encode(const char* prompt, uint32_t len, std::vector* token_ids) const override { 118 | hf_processor_->Encode(prompt, len, token_ids); 119 | } 120 | 121 | void Decode(int* token_ids, uint32_t len, std::string* output) const override { 122 | hf_processor_->Decode(token_ids, len, output); 123 | } 124 | 125 | int GetBosId() const override { 126 | return bos_id_; 127 | } 128 | 129 | int GetEosId() const override { 130 | return eos_id_; 131 | } 132 | 133 | private: 134 | std::unique_ptr hf_processor_; 135 | int bos_id_; 136 | int eos_id_; 137 | }; 138 | 139 | }} // namespace ppl::llm 140 | 141 | #endif -------------------------------------------------------------------------------- /src/tokenizer/tokenizer_impl_sp.h: -------------------------------------------------------------------------------- 1 | // Licensed to the Apache Software Foundation (ASF) under one 2 | // or more contributor license agreements. See the NOTICE file 3 | // distributed with this work for additional information 4 | // regarding copyright ownership. The ASF licenses this file 5 | // to you under the Apache License, Version 2.0 (the 6 | // "License"); you may not use this file except in compliance 7 | // with the License. You may obtain a copy of the License at 8 | // 9 | // http://www.apache.org/licenses/LICENSE-2.0 10 | // 11 | // Unless required by applicable law or agreed to in writing, 12 | // software distributed under the License is distributed on an 13 | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | // KIND, either express or implied. See the License for the 15 | // specific language governing permissions and limitations 16 | // under the License. 17 | 18 | #ifndef __PPL_LLM_TOKENIZER_IMPL_SP_H__ 19 | #define __PPL_LLM_TOKENIZER_IMPL_SP_H__ 20 | 21 | #include "tokenizer_impl.h" 22 | #include "ppl/nn/common/logger.h" 23 | #include "sentencepiece_processor.h" 24 | 25 | #include 26 | #include 27 | #include 28 | 29 | namespace ppl { namespace llm { 30 | 31 | class TokenizerImplSP final : public TokenizerImpl { 32 | public: 33 | ~TokenizerImplSP() {} 34 | ppl::common::RetCode Init(const std::string& path, const std::string& conifg_path = "") override { 35 | sp_processor_ = std::make_unique(); 36 | auto tokenizer_status = sp_processor_->Load(path); 37 | if (!tokenizer_status.ok()) { 38 | LOG(ERROR) << tokenizer_status.ToString(); 39 | return ppl::common::RC_OTHER_ERROR; 40 | } 41 | bos_id_ = sp_processor_->bos_id(); 42 | eos_id_ = sp_processor_->eos_id(); 43 | 44 | LOG(INFO) << "VOCAB_SIZE: " << sp_processor_->GetPieceSize() << "; BOS ID: " << bos_id_ 45 | << "; EOS ID: " << eos_id_ << "; PAD ID: " << sp_processor_->pad_id(); 46 | return ppl::common::RC_SUCCESS; 47 | } 48 | 49 | void Encode(const char* prompt, uint32_t len, std::vector* token_ids) const override { 50 | sp_processor_->Encode(absl::string_view(prompt, len), token_ids); 51 | } 52 | 53 | void Decode(int* token_ids, uint32_t len, std::string* output) const override { 54 | sp_processor_->Decode(token_ids, len, output); 55 | if (len == 1 && sp_processor_->IdToPiece(token_ids[0]).substr(0, 3) == space_symbol_ && !output->empty() && 56 | output->at(0) != ' ') { 57 | output->insert(0, " "); 58 | } 59 | } 60 | 61 | int GetBosId() const override { 62 | return bos_id_; 63 | } 64 | 65 | int GetEosId() const override { 66 | return eos_id_; 67 | } 68 | 69 | private: 70 | std::unique_ptr sp_processor_; 71 | int bos_id_; 72 | int eos_id_; 73 | const std::string space_symbol_ = "\xe2\x96\x81"; 74 | }; 75 | 76 | }} // namespace ppl::llm 77 | 78 | #endif -------------------------------------------------------------------------------- /src/utils/index_manager.h: -------------------------------------------------------------------------------- 1 | // Licensed to the Apache Software Foundation (ASF) under one 2 | // or more contributor license agreements. See the NOTICE file 3 | // distributed with this work for additional information 4 | // regarding copyright ownership. The ASF licenses this file 5 | // to you under the Apache License, Version 2.0 (the 6 | // "License"); you may not use this file except in compliance 7 | // with the License. You may obtain a copy of the License at 8 | // 9 | // http://www.apache.org/licenses/LICENSE-2.0 10 | // 11 | // Unless required by applicable law or agreed to in writing, 12 | // software distributed under the License is distributed on an 13 | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | // KIND, either express or implied. See the License for the 15 | // specific language governing permissions and limitations 16 | // under the License. 17 | 18 | #ifndef __PPL_LLM_INDEX_MANAGER_H__ 19 | #define __PPL_LLM_INDEX_MANAGER_H__ 20 | 21 | #include "ppl/common/compact_addr_manager.h" 22 | 23 | namespace ppl { namespace llm { namespace utils { 24 | 25 | class IndexManager final { 26 | private: 27 | class IndexAllocator final : public ppl::common::CompactAddrManager::VMAllocator { 28 | public: 29 | void Init(uint64_t max_index) { 30 | max_ = max_index; 31 | } 32 | uintptr_t GetReservedBase() const override { 33 | return 0; 34 | } 35 | uint64_t GetAllocatedSize() const override { 36 | return used_; 37 | } 38 | uint64_t Extend(uint64_t needed) override { 39 | if (needed + used_ > max_) { 40 | return 0; 41 | } 42 | 43 | used_ += needed; 44 | return needed; 45 | } 46 | 47 | private: 48 | uint64_t max_ = 0; 49 | uint64_t used_ = 0; 50 | }; 51 | 52 | public: 53 | IndexManager() : mgr_(&vmr_) {} 54 | void Init(uint64_t max_index) { 55 | nr_avail_blk_ = max_index; 56 | vmr_.Init(max_index); 57 | } 58 | int64_t GetAvailableBlockNum() const { 59 | return nr_avail_blk_; 60 | } 61 | int64_t Alloc(uint64_t nr) { 62 | auto ret = mgr_.Alloc(nr); 63 | if (ret == UINTPTR_MAX) { 64 | return INT64_MAX; 65 | } 66 | nr_avail_blk_ -= nr; 67 | return (int64_t)ret; 68 | } 69 | void Free(uint64_t start, uint64_t nr) { 70 | mgr_.Free(start, nr); 71 | nr_avail_blk_ += nr; 72 | } 73 | 74 | private: 75 | uint64_t nr_avail_blk_; 76 | IndexAllocator vmr_; 77 | ppl::common::CompactAddrManager mgr_; 78 | }; 79 | 80 | }}} // namespace ppl::llm::utils 81 | 82 | #endif 83 | -------------------------------------------------------------------------------- /src/utils/mpsc_request_scheduler.h: -------------------------------------------------------------------------------- 1 | // Licensed to the Apache Software Foundation (ASF) under one 2 | // or more contributor license agreements. See the NOTICE file 3 | // distributed with this work for additional information 4 | // regarding copyright ownership. The ASF licenses this file 5 | // to you under the Apache License, Version 2.0 (the 6 | // "License"); you may not use this file except in compliance 7 | // with the License. You may obtain a copy of the License at 8 | // 9 | // http://www.apache.org/licenses/LICENSE-2.0 10 | // 11 | // Unless required by applicable law or agreed to in writing, 12 | // software distributed under the License is distributed on an 13 | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | // KIND, either express or implied. See the License for the 15 | // specific language governing permissions and limitations 16 | // under the License. 17 | 18 | #ifndef __PPL_LLM_MPSC_REQUEST_SCHEDULER_H__ 19 | #define __PPL_LLM_MPSC_REQUEST_SCHEDULER_H__ 20 | 21 | #include "ppl/common/event_count.h" 22 | #include "ppl/common/mpsc_queue.h" 23 | #include 24 | 25 | /** multi-producer-single-consumer request scheduler */ 26 | 27 | namespace ppl { namespace llm { namespace utils { 28 | 29 | template 30 | class MPSCRequestScheduler final { 31 | public: 32 | MPSCRequestScheduler() { 33 | static_assert(std::is_base_of::value, 34 | "template parameter is not a derived class of ppl::common::MPSCQueue::Node"); 35 | } 36 | 37 | ~MPSCRequestScheduler() { 38 | delete stashed_req_; // it's ok to delete nullptr 39 | 40 | while (true) { 41 | bool is_empty; 42 | auto node = queue_.Pop(&is_empty); 43 | if (!node) { 44 | return; 45 | } 46 | auto req = static_cast(node); 47 | delete req; 48 | } 49 | } 50 | 51 | /** returns true if the queue MAY be empty before `req` is pushed */ 52 | bool PushRequest(ReqType* req) { 53 | queue_.Push(req); 54 | uint32_t prev = size_.fetch_add(1, std::memory_order_acq_rel); 55 | return (prev == 0); 56 | } 57 | 58 | ReqType* TryPopRequest(const std::function& check_req_func) { 59 | if (stashed_req_) { 60 | if (!check_req_func(*stashed_req_)) { 61 | return nullptr; 62 | } 63 | 64 | auto req = stashed_req_; 65 | stashed_req_ = nullptr; 66 | size_.fetch_sub(1, std::memory_order_acq_rel); 67 | return req; 68 | } 69 | 70 | bool is_empty = true; 71 | ppl::common::MPSCQueue::Node* node; 72 | do { 73 | node = queue_.Pop(&is_empty); 74 | } while (!node && !is_empty); 75 | 76 | if (is_empty) { 77 | return nullptr; 78 | } 79 | 80 | auto req = static_cast(node); 81 | if (check_req_func(*req)) { 82 | size_.fetch_sub(1, std::memory_order_acq_rel); 83 | return req; 84 | } 85 | 86 | stashed_req_ = req; 87 | return nullptr; 88 | } 89 | 90 | // approximate size 91 | uint32_t GetPendingSize() const { 92 | return size_.load(std::memory_order_relaxed); 93 | } 94 | 95 | private: 96 | ppl::common::MPSCQueue queue_; 97 | std::atomic size_ = {0}; 98 | // stashed request that will be popped if `check_req_func` returns true 99 | ReqType* stashed_req_ = nullptr; 100 | 101 | private: 102 | MPSCRequestScheduler(const MPSCRequestScheduler&) = delete; 103 | void operator=(const MPSCRequestScheduler&) = delete; 104 | MPSCRequestScheduler(MPSCRequestScheduler&&) = delete; 105 | void operator=(MPSCRequestScheduler&&) = delete; 106 | }; 107 | 108 | }}} // namespace ppl::llm::utils 109 | 110 | #endif 111 | -------------------------------------------------------------------------------- /src/utils/prefix_cache_manager.h: -------------------------------------------------------------------------------- 1 | #ifndef __PPL_LLM_PREFIX_CACHE_MANAGER_H__ 2 | #define __PPL_LLM_PREFIX_CACHE_MANAGER_H__ 3 | 4 | #include "ppl/common/log.h" 5 | 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | 12 | namespace ppl { namespace llm { namespace utils { 13 | 14 | class LRUCache { 15 | public: 16 | struct Node { 17 | Node(uint64_t _hash_val, int64_t _page_id) 18 | : hash_val(_hash_val), page_id(_page_id), prev(nullptr), next(nullptr) {} 19 | 20 | uint64_t hash_val; 21 | int64_t page_id; 22 | Node* prev; 23 | Node* next; 24 | }; 25 | 26 | LRUCache() { 27 | head_ = new Node(-1, -1); 28 | tail_ = new Node(-1, -1); 29 | head_->next = tail_; 30 | tail_->prev = head_; 31 | } 32 | 33 | ~LRUCache() { 34 | while (head_) { 35 | Node* tmp = head_; 36 | head_ = head_->next; 37 | delete tmp; 38 | } 39 | } 40 | 41 | bool Find(uint64_t hash_val) const { 42 | return cache_.find(hash_val) != cache_.end(); 43 | } 44 | 45 | void Insert(uint64_t hash_val, int64_t page_id) { 46 | if (cache_.find(hash_val) != cache_.end()) { // almost impossible to reach 47 | LOG(WARNING) << "hash_val [" << hash_val << "] page_id [" << page_id << "] already exists in cache_"; 48 | return; 49 | } 50 | Node* node = new Node(hash_val, page_id); 51 | cache_.insert({hash_val, node}); 52 | AddToHead(node); 53 | } 54 | 55 | void EvictNode(uint64_t hash_val) { 56 | Node* node = cache_[hash_val]; 57 | int success = cache_.erase(hash_val); 58 | if (success == 0) { 59 | LOG(WARNING) << "erase unexist hash [" << node->hash_val << "]"; 60 | } 61 | DeleteNode(node); 62 | } 63 | 64 | void EvictList(int64_t nums, std::vector* page_list, std::vector* hash_list) { 65 | int64_t erase_nums = (uint64_t)nums < cache_.size() ? nums : cache_.size(); 66 | for (int i = 0; i < erase_nums; ++i) { 67 | Node* node = tail_->prev; 68 | page_list->push_back(node->page_id); 69 | hash_list->push_back(node->hash_val); 70 | int success = cache_.erase(node->hash_val); 71 | if (success == 0) { 72 | LOG(WARNING) << "erase unexist hash [" << node->hash_val << "]"; 73 | } 74 | DeleteNode(node); 75 | } 76 | } 77 | 78 | int32_t Size() const { 79 | return cache_.size(); 80 | } 81 | 82 | void Reset() { 83 | while (head_->next != tail_) { 84 | Node* node = head_->next; 85 | head_->next = node->next; 86 | delete node; 87 | } 88 | cache_.clear(); 89 | } 90 | 91 | private: 92 | void AddToHead(Node* node) { 93 | node->next = head_->next; 94 | node->prev = head_; 95 | head_->next->prev = node; 96 | head_->next = node; 97 | } 98 | 99 | void DeleteNode(Node* node) { 100 | node->prev->next = node->next; 101 | node->next->prev = node->prev; 102 | delete node; 103 | } 104 | 105 | private: 106 | Node* head_; 107 | Node* tail_; 108 | std::unordered_map cache_; 109 | }; 110 | 111 | class PrefixCacheManager { 112 | public: 113 | struct PrefixItem { 114 | PrefixItem(uint64_t _hash_val, int64_t _page_id) : hash_val(_hash_val), page_id(_page_id) {} 115 | uint64_t hash_val; 116 | int64_t page_id; 117 | int32_t ref_count = 1; 118 | }; 119 | 120 | PrefixCacheManager() {} 121 | 122 | int64_t Find(uint64_t hash_val) const { 123 | auto iter = prefix_map_.find(hash_val); 124 | if (iter == prefix_map_.end()) { 125 | return -1; 126 | } 127 | return iter->second.page_id; 128 | } 129 | 130 | void Insert(uint64_t hash_val, int64_t page_id) { 131 | prefix_map_.insert({hash_val, PrefixItem(hash_val, page_id)}); 132 | } 133 | 134 | void IncRefCount(const uint64_t* hash_list, int64_t nums) { 135 | for (int i = 0; i < nums; ++i) { 136 | uint64_t hash_val = hash_list[i]; 137 | auto iter = prefix_map_.find(hash_val); 138 | if (iter == prefix_map_.end()) { 139 | LOG(WARNING) << "hash [" << hash_val << "] not found in prefix map"; 140 | break; 141 | } 142 | 143 | iter->second.ref_count++; 144 | if (lru_cache_.Find(hash_val)) { 145 | lru_cache_.EvictNode(hash_val); 146 | } 147 | } 148 | } 149 | 150 | void DecRefCount(const uint64_t* hash_list, int64_t nums) { 151 | for (int i = 0; i < nums; ++i) { 152 | auto iter = prefix_map_.find(hash_list[i]); 153 | if (iter == prefix_map_.end()) { 154 | LOG(WARNING) << "hash [" << hash_list[i] << "] not found in prefix map"; 155 | break; 156 | } 157 | iter->second.ref_count--; 158 | 159 | if (iter->second.ref_count == 0) { 160 | lru_cache_.Insert(iter->second.hash_val, iter->second.page_id); 161 | } 162 | } 163 | } 164 | 165 | void Evict(int64_t nums, std::vector* page_list) { 166 | std::vector hash_list; 167 | lru_cache_.EvictList(nums, page_list, &hash_list); 168 | for (size_t i = 0; i < hash_list.size(); ++i) { 169 | uint64_t hash_val = hash_list[i]; 170 | prefix_map_.erase(hash_val); 171 | } 172 | } 173 | 174 | int32_t Size() const { 175 | return prefix_map_.size(); 176 | } 177 | 178 | void Reset() { 179 | prefix_map_.clear(); 180 | lru_cache_.Reset(); 181 | } 182 | 183 | private: 184 | std::unordered_map prefix_map_; 185 | LRUCache lru_cache_; 186 | }; 187 | 188 | }}} // namespace ppl::llm::utils 189 | 190 | #endif -------------------------------------------------------------------------------- /src/utils/utils.cc: -------------------------------------------------------------------------------- 1 | // Licensed to the Apache Software Foundation (ASF) under one 2 | // or more contributor license agreements. See the NOTICE file 3 | // distributed with this work for additional information 4 | // regarding copyright ownership. The ASF licenses this file 5 | // to you under the Apache License, Version 2.0 (the 6 | // "License"); you may not use this file except in compliance 7 | // with the License. You may obtain a copy of the License at 8 | // 9 | // http://www.apache.org/licenses/LICENSE-2.0 10 | // 11 | // Unless required by applicable law or agreed to in writing, 12 | // software distributed under the License is distributed on an 13 | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | // KIND, either express or implied. See the License for the 15 | // specific language governing permissions and limitations 16 | // under the License. 17 | 18 | #include "utils.h" 19 | #include "ppl/nn/runtime/tensor.h" 20 | #include "ppl/common/log.h" 21 | #include "ppl/common/types.h" 22 | #include 23 | 24 | using namespace ppl::common; 25 | using namespace ppl::nn; 26 | using namespace std; 27 | 28 | namespace ppl { namespace llm { namespace utils { 29 | 30 | static const char* MemMem(const char* haystack, unsigned int haystack_len, const char* needle, 31 | unsigned int needle_len) { 32 | if (!haystack || haystack_len == 0 || !needle || needle_len == 0) { 33 | return nullptr; 34 | } 35 | 36 | for (auto h = haystack; haystack_len >= needle_len; ++h, --haystack_len) { 37 | if (memcmp(h, needle, needle_len) == 0) { 38 | return h; 39 | } 40 | } 41 | return nullptr; 42 | } 43 | static void SplitString(const char* str, unsigned int len, const char* delim, unsigned int delim_len, 44 | const std::function& f) { 45 | const char* end = str + len; 46 | 47 | while (str < end) { 48 | auto cursor = MemMem(str, len, delim, delim_len); 49 | if (!cursor) { 50 | f(str, end - str); 51 | return; 52 | } 53 | 54 | if (!f(str, cursor - str)) { 55 | return; 56 | } 57 | 58 | cursor += delim_len; 59 | str = cursor; 60 | len = end - cursor; 61 | } 62 | 63 | f("", 0); // the last empty field 64 | } 65 | 66 | void ParseTokens(const std::string& stop_tokens_str, std::set* stop_tokens) { 67 | SplitString(stop_tokens_str.data(), stop_tokens_str.size(), ",", 1, 68 | [stop_tokens](const char* s, unsigned int l) -> bool { 69 | if (l > 0) { 70 | stop_tokens->insert(std::atoi(s)); 71 | } 72 | return true; 73 | }); 74 | return; 75 | } 76 | 77 | // ref: https://stackoverflow.com/questions/20511347/a-good-hash-function-for-a-vector 78 | uint64_t HashStd(uint64_t prev, const int32_t* vec, int32_t len) { 79 | uint64_t ret = len; 80 | ret ^= std::hash()(prev); 81 | for (int i=0; i()((uint64_t)vec[i]); 83 | } 84 | return ret; 85 | } 86 | 87 | uint64_t HashCombine(uint64_t prev, const int32_t* vec, int32_t len) { 88 | uint64_t seed = len; 89 | seed ^= prev + 0x9e3779b9 + (seed << 6) + (seed >> 2); 90 | for (int i=0; i> 2); 92 | } 93 | return seed; 94 | } 95 | 96 | static const pair g_str2datatype[] = { 97 | {"fp64", DATATYPE_FLOAT64}, {"fp32", DATATYPE_FLOAT32}, {"fp16", DATATYPE_FLOAT16}, {"int32", DATATYPE_INT32}, 98 | {"int64", DATATYPE_INT64}, {"int8", DATATYPE_INT8}, {"bool", DATATYPE_BOOL}, {"", DATATYPE_UNKNOWN}, 99 | }; 100 | 101 | static const char* FindDataTypeStr(datatype_t dt) { 102 | for (int i = 0; !g_str2datatype[i].first.empty(); ++i) { 103 | if (g_str2datatype[i].second == dt) { 104 | return g_str2datatype[i].first.c_str(); 105 | } 106 | } 107 | return nullptr; 108 | } 109 | 110 | static string GetDimsStr(const Tensor* tensor) { 111 | auto shape = tensor->GetShape(); 112 | if (shape->GetRealDimCount() == 0) { 113 | return string(); 114 | } 115 | 116 | string res = ToString(shape->GetDim(0)); 117 | for (uint32_t i = 1; i < shape->GetDimCount(); ++i) { 118 | res += "_" + ToString(shape->GetDim(i)); 119 | } 120 | 121 | return res; 122 | } 123 | 124 | bool SaveInputsOneByOne(const ppl::nn::Runtime* runtime, const std::string& save_dir, const std::string& tag = "") { 125 | for (uint32_t c = 0; c < runtime->GetInputCount(); ++c) { 126 | auto t = runtime->GetInputTensor(c); 127 | auto shape = t->GetShape(); 128 | 129 | auto bytes = shape->CalcBytesIncludingPadding(); 130 | vector buffer(bytes); 131 | 132 | ppl::nn::TensorShape src_desc = *t->GetShape(); 133 | src_desc.SetDataFormat(DATAFORMAT_NDARRAY); 134 | auto status = t->ConvertToHost(buffer.data(), src_desc); 135 | if (status != RC_SUCCESS) { 136 | LOG(ERROR) << "convert data failed: " << GetRetCodeStr(status); 137 | return false; 138 | } 139 | 140 | const char* data_type_str = FindDataTypeStr(shape->GetDataType()); 141 | if (!data_type_str) { 142 | LOG(ERROR) << "unsupported data type[" << GetDataTypeStr(shape->GetDataType()) << "]"; 143 | return false; 144 | } 145 | 146 | char name_prefix[32]; 147 | if (tag.empty()) 148 | sprintf(name_prefix, "pplnn_input_%05u_", c); 149 | else 150 | sprintf(name_prefix, "pplnn_input_%s_%05u_", tag.c_str(), c); 151 | const string in_file_name = save_dir + "/" + string(name_prefix) + t->GetName() + "-" + 152 | GetDimsStr(t) + "-" + string(data_type_str) + ".dat"; 153 | ofstream ofs(in_file_name, ios_base::out | ios_base::binary | ios_base::trunc); 154 | if (!ofs.is_open()) { 155 | LOG(ERROR) << "save input file[" << in_file_name << "] failed."; 156 | return false; 157 | } 158 | 159 | ofs.write(buffer.data(), bytes); 160 | } 161 | 162 | return true; 163 | } 164 | 165 | }}} // namespace ppl::llm::utils -------------------------------------------------------------------------------- /src/utils/utils.h: -------------------------------------------------------------------------------- 1 | // Licensed to the Apache Software Foundation (ASF) under one 2 | // or more contributor license agreements. See the NOTICE file 3 | // distributed with this work for additional information 4 | // regarding copyright ownership. The ASF licenses this file 5 | // to you under the Apache License, Version 2.0 (the 6 | // "License"); you may not use this file except in compliance 7 | // with the License. You may obtain a copy of the License at 8 | // 9 | // http://www.apache.org/licenses/LICENSE-2.0 10 | // 11 | // Unless required by applicable law or agreed to in writing, 12 | // software distributed under the License is distributed on an 13 | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | // KIND, either express or implied. See the License for the 15 | // specific language governing permissions and limitations 16 | // under the License. 17 | 18 | #ifndef __PPL_LLM_UTILS_H__ 19 | #define __PPL_LLM_UTILS_H__ 20 | 21 | #include "ppl/common/threadpool.h" 22 | #include "ppl/common/log.h" 23 | #include "ppl/nn/runtime/runtime.h" 24 | 25 | #include 26 | #include 27 | #include 28 | #include 29 | #include 30 | 31 | namespace ppl { namespace llm { namespace utils { 32 | 33 | void ParseTokens(const std::string& stop_tokens_str, std::set* stop_tokens); 34 | 35 | inline void DummyTaskDeleter(ppl::common::ThreadTask*) {} 36 | 37 | template 38 | ppl::common::RetCode ParallelExecute(F&& func, ppl::common::StaticThreadPool* pool, TaskArgType&&... rest_args) { 39 | auto n = pool->GetNumThreads(); 40 | ppl::common::RetCode thr_rc[n]; 41 | 42 | pool->Run([&](uint32_t nthr, uint32_t ithr) { 43 | thr_rc[ithr] = func(ithr, std::forward(rest_args)...); 44 | }); 45 | for (uint32_t i = 0; i < n; ++i) { 46 | if (thr_rc[i] != ppl::common::RC_SUCCESS) 47 | LOG(ERROR) << "ParallelExecute task[" << i << "] failed"; 48 | return thr_rc[i]; 49 | } 50 | 51 | return ppl::common::RC_SUCCESS; 52 | } 53 | 54 | class TimingGuard final { 55 | public: 56 | TimingGuard(uint64_t* res) { 57 | diff_microsec_ = res; 58 | begin_ = std::chrono::high_resolution_clock::now(); 59 | } 60 | ~TimingGuard() { 61 | auto end = std::chrono::high_resolution_clock::now(); 62 | *diff_microsec_ = uint64_t(std::chrono::duration_cast(end - begin_).count()); 63 | } 64 | 65 | private: 66 | uint64_t* diff_microsec_; 67 | std::chrono::time_point begin_; 68 | }; 69 | 70 | uint64_t HashStd(uint64_t prev, const int32_t* vec, int32_t len); 71 | 72 | uint64_t HashCombine(uint64_t prev, const int32_t* vec, int32_t len); 73 | 74 | bool SaveInputsOneByOne(const ppl::nn::Runtime* runtime, const std::string& save_dir, const std::string& tag); 75 | 76 | }}} // namespace ppl::llm::utils 77 | 78 | #endif 79 | -------------------------------------------------------------------------------- /test/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | if(PPLNN_USE_LLM_CUDA) 2 | add_executable(test_prefix_cache_mgr test_prefix_cache_mgr.cc) 3 | target_include_directories(test_prefix_cache_mgr PRIVATE ../src) 4 | target_link_libraries(test_prefix_cache_mgr PRIVATE ppl_llm_static) 5 | endif() 6 | -------------------------------------------------------------------------------- /test/test_prefix_cache_mgr.cc: -------------------------------------------------------------------------------- 1 | #include "utils/prefix_cache_manager.h" 2 | #include "utils/utils.h" 3 | #include "ppl/common/log.h" 4 | #include "ppl/common/cuda/cuda_env.h" 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | using namespace ppl::llm; 12 | using namespace ppl::llm::utils; 13 | using namespace ppl::common; 14 | using namespace std; 15 | 16 | template 17 | static void PrintVector(vector vec, const std::string& prefix = "") { 18 | stringstream ss; 19 | for (auto& ele : vec) { 20 | ss << ele << ", "; 21 | } 22 | std::cout << prefix << ": " << ss.str() << std::endl; 23 | } 24 | 25 | void test_hash() { 26 | uint64_t prev = 0; 27 | int32_t vec[] = {1, 2, 3, 4, 5}; 28 | int32_t len = 5; 29 | uint64_t hash_val = HashCombine(prev, vec, len); 30 | std::cout << "hash_val: " << hash_val << std::endl; 31 | } 32 | 33 | void test_prefix_mgr() { 34 | PrefixCacheManager prefix_mgr; 35 | std::vector hash_list = {0, 1, 2, 3}; 36 | std::vector page_list = {11, 12, 13, 14}; 37 | int nums = hash_list.size(); 38 | 39 | for (size_t i = 0; i < hash_list.size(); ++i) { 40 | prefix_mgr.Insert(hash_list[i], page_list[i]); 41 | } 42 | 43 | std::vector hash_list2 = {5, 6, 7, 8}; 44 | std::vector page_list2 = {15, 16, 17, 18}; 45 | 46 | for (size_t i = 0; i < hash_list.size(); ++i) { 47 | prefix_mgr.Insert(hash_list2[i], page_list2[i]); 48 | } 49 | 50 | prefix_mgr.DecRefCount(hash_list.data(), hash_list.size()); 51 | std::cout << prefix_mgr.Size() << std::endl; 52 | 53 | prefix_mgr.DecRefCount(hash_list2.data(), hash_list2.size()); 54 | 55 | std::vector evicted_page_list; 56 | prefix_mgr.Evict(nums, &evicted_page_list); 57 | std::cout << prefix_mgr.Size() << std::endl; 58 | evicted_page_list.clear(); 59 | prefix_mgr.Evict(nums, &evicted_page_list); 60 | } 61 | 62 | int main(int argc, char const* argv[]) { 63 | test_hash(); 64 | test_prefix_mgr(); 65 | return 0; 66 | } 67 | -------------------------------------------------------------------------------- /tools/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | if(PPLNN_USE_LLM_CUDA) 2 | add_executable(offline_inference tools/offline_inference.cc tools/simple_flags.cc) 3 | target_include_directories(offline_inference PRIVATE src) 4 | target_link_libraries(offline_inference PRIVATE ppl_llm_static ${NCCL_LIBRARIES}) 5 | install(TARGETS offline_inference DESTINATION bin) 6 | 7 | add_executable(benchmark_prefix_cache_offline tools/benchmark_prefix_cache_offline.cc tools/simple_flags.cc) 8 | target_include_directories(benchmark_prefix_cache_offline PRIVATE src) 9 | target_link_libraries(benchmark_prefix_cache_offline PRIVATE ppl_llm_static ${NCCL_LIBRARIES}) 10 | endif() 11 | 12 | if(PPLNN_USE_LLM_CUDA) 13 | if(PPL_LLM_ENABLE_GRPC_SERVING) 14 | add_executable(ppl_llm_server tools/llm_server.cc tools/simple_flags.cc) 15 | target_link_libraries(ppl_llm_server PRIVATE 16 | ppl_llm_static 17 | ppl_llm_grpc_serving_static 18 | ${NCCL_LIBRARIES}) 19 | target_include_directories(ppl_llm_server PRIVATE 20 | ${HPCC_DEPS_DIR}/rapidjson/include 21 | ${NCCL_INCLUDE_DIRS}) 22 | 23 | add_executable(client_sample tools/client_sample.cc) 24 | target_link_libraries(client_sample ppl_llm_grpc_proto_static grpc++) 25 | 26 | add_executable(client_pressure tools/client_pressure.cc) 27 | target_link_libraries(client_pressure ppl_llm_grpc_proto_static grpc++) 28 | 29 | add_executable(client_qps_measure tools/client_qps_measure.cc tools/simple_flags.cc) 30 | target_link_libraries(client_qps_measure pplnn_static ppl_llm_grpc_proto_static grpc++ ppl_sentencepiece_static) 31 | target_include_directories(client_qps_measure PUBLIC ${HPCC_DEPS_DIR}/rapidjson/include) 32 | 33 | add_executable(client_sample_token_in_out tools/client_sample_token_in_out.cc) 34 | target_link_libraries(client_sample_token_in_out ppl_llm_grpc_proto_static grpc++ ppl_sentencepiece_static) 35 | 36 | add_executable(client_qps_measure_token_in_out tools/client_qps_measure_token_in_out.cc tools/simple_flags.cc) 37 | target_link_libraries(client_qps_measure_token_in_out pplnn_static ppl_llm_grpc_proto_static grpc++) 38 | target_include_directories(client_qps_measure_token_in_out PUBLIC ${HPCC_DEPS_DIR}/rapidjson/include) 39 | endif() 40 | endif() 41 | -------------------------------------------------------------------------------- /tools/backtrace.h: -------------------------------------------------------------------------------- 1 | #ifndef __PPL_LLM_SERVING_BACKTRACE_H__ 2 | #define __PPL_LLM_SERVING_BACKTRACE_H__ 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | namespace ppl { namespace llm { 12 | 13 | class BackTrace final { 14 | public: 15 | static std::string Get() { 16 | const uint32_t MAX_STACK_SIZE = 128; 17 | void* stk[MAX_STACK_SIZE]; 18 | size_t stk_size; 19 | 20 | stk_size = backtrace(stk, MAX_STACK_SIZE); 21 | if (stk_size == 0) { 22 | return std::string(); 23 | } 24 | 25 | char** strings = backtrace_symbols(stk, stk_size); 26 | 27 | std::string ret; 28 | for (size_t i = 0; i < stk_size; ++i) { 29 | auto addr = ExtractAddr(strings[i]); 30 | if (addr.empty()) { 31 | continue; 32 | } 33 | auto cmd = "addr2line -fp -e " + GetBinPath() + " " + addr; 34 | auto line = ExecuteCmd(cmd); 35 | if (line.empty()) { 36 | continue; 37 | } 38 | 39 | std::string func, pos; 40 | ExtractFuncAndPos(line, &func, &pos); 41 | char tmp[64] = {'\0'}; 42 | sprintf(tmp, "#%ld\t", i); 43 | ret += tmp + func + " at " + pos + "\n"; 44 | } 45 | ret += "\n"; 46 | ::free(strings); 47 | 48 | return ret; 49 | } 50 | 51 | private: 52 | static std::string ExtractAddr(const char* content) { 53 | char tmp[128] = {'\0'}, tmp2[128] = {'\0'}; 54 | char binname[128] = {'\0'}; 55 | uint64_t addr; 56 | sscanf(content, "%s%s", binname, tmp); 57 | sscanf(tmp, "%3c%lx", tmp2, &addr); // tmp -> [0x1234567] 58 | int len = sprintf(tmp, "0x%lx", addr); 59 | return std::string(tmp, len); 60 | } 61 | 62 | static std::string ExecuteCmd(const std::string& cmd) { 63 | auto fp = popen(cmd.c_str(), "r"); 64 | if (!fp) { 65 | std::cerr << "exec cmd [" << cmd << "] failed." << std::endl; 66 | return std::string(); 67 | } 68 | 69 | char buf[1024] = {'\0'}; 70 | auto unused = fgets(buf, 1024, fp); 71 | (void)unused; 72 | int len = strlen(buf); 73 | fclose(fp); 74 | return std::string(buf, len - 1); // remove trailing '\n' 75 | } 76 | 77 | static std::string GetBinPath() { 78 | char tmp[256] = {'\0'}; 79 | sprintf(tmp, "%d", getpid()); 80 | auto exe_info = "/proc/" + std::string(tmp) + "/exe"; 81 | auto len = readlink(exe_info.c_str(), tmp, 256); 82 | if (len <= 0) { 83 | std::cerr << "GetBinPath() failed." << std::endl; 84 | return std::string(); 85 | } 86 | return std::string(tmp, len); 87 | } 88 | 89 | static std::string Demangle(const std::string& name) { 90 | std::string res; 91 | size_t size = 0; 92 | int status = 0; 93 | auto ret = abi::__cxa_demangle(name.c_str(), nullptr, &size, &status); 94 | if (ret) { 95 | res.assign(ret); 96 | free(ret); 97 | } else { 98 | /* 99 | status: 100 | 0: The demangling operation succeeded. 101 | -1: A memory allocation failiure occurred. 102 | -2: mangled_name is not a valid name under the C++ ABI mangling rules. 103 | -3: One of the arguments is invalid. 104 | */ 105 | //std::cerr << "demangle failed, status = " << status << ", name -> " << name << std::endl; 106 | res = name; 107 | } 108 | 109 | return res; 110 | } 111 | 112 | static void ExtractFuncAndPos(const std::string& line, std::string* func, 113 | std::string* pos) { 114 | char tmp[64] = {'\0'}; 115 | char func_name[512] = {'\0'}; 116 | char pos_buf[1024] = {'\0'}; 117 | sscanf(line.c_str(), "%s%s%s", func_name, tmp, pos_buf); 118 | func->assign(Demangle(func_name)); 119 | pos->assign(pos_buf); 120 | } 121 | }; 122 | 123 | }} 124 | 125 | #endif 126 | -------------------------------------------------------------------------------- /tools/client_sample.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | #include "absl/flags/flag.h" 7 | #include "absl/flags/parse.h" 8 | #include "llm.grpc.pb.h" 9 | 10 | #include 11 | #include 12 | 13 | using namespace grpc; 14 | using grpc::Channel; 15 | using grpc::ClientContext; 16 | using grpc::Status; 17 | using namespace std::chrono; 18 | using namespace ppl::llm; 19 | 20 | ABSL_FLAG(std::string, target, "localhost:50052", "Server address"); 21 | 22 | class GenerationClient { 23 | public: 24 | GenerationClient(std::shared_ptr channel) : stub_(proto::LLMService::NewStub(channel)) {} 25 | 26 | int Generation(const std::vector& prompts) { 27 | // Data we are sending to the server. 28 | ClientContext context; 29 | proto::BatchedRequest req_list; 30 | std::unordered_map rsp_stream_store; 31 | for (size_t i = 0; i < prompts.size(); i++) { 32 | // request 33 | auto* req = req_list.add_req(); 34 | req->set_id(i); 35 | req->set_prompt(prompts[i]); 36 | auto* choosing_parameter = req->mutable_choosing_parameters(); 37 | choosing_parameter->set_do_sample(false); 38 | choosing_parameter->set_temperature(1.f); 39 | choosing_parameter->set_repetition_penalty(1.f); 40 | choosing_parameter->set_presence_penalty(0.f); 41 | choosing_parameter->set_frequency_penalty(0.f); 42 | 43 | auto* stopping_parameters = req->mutable_stopping_parameters(); 44 | stopping_parameters->set_max_new_tokens(16); 45 | stopping_parameters->set_ignore_eos_token(false); 46 | rsp_stream_store[i] = ""; 47 | } 48 | // response 49 | proto::BatchedResponse batched_rsp; 50 | std::unique_ptr> reader(stub_->Generation(&context, req_list)); 51 | 52 | // stream chat 53 | auto start = system_clock::now(); 54 | auto first_fill_time = system_clock::now(); 55 | bool is_first_fill = true; 56 | 57 | while (reader->Read(&batched_rsp)) { 58 | if (is_first_fill) { 59 | first_fill_time = system_clock::now(); 60 | is_first_fill = false; 61 | } 62 | for (const auto& rsp : batched_rsp.rsp()) { 63 | int tid = rsp.id(); 64 | std::string rsp_stream = rsp.generated(); 65 | rsp_stream_store[tid] += rsp_stream; 66 | } 67 | } 68 | auto end = system_clock::now(); 69 | 70 | std::cout << "------------------------------" << std::endl; 71 | std::cout << "--------- Answer -------------" << std::endl; 72 | std::cout << "------------------------------" << std::endl; 73 | 74 | for (auto rsp : rsp_stream_store) { 75 | std::cout << rsp.second << std::endl; 76 | std::cout << "--------------------" << std::endl; 77 | } 78 | 79 | auto first_till_duration = duration_cast(first_fill_time - start); 80 | auto duration = duration_cast(end - start); 81 | 82 | std::cout << "first fill: " << first_till_duration.count() << " ms" << std::endl; 83 | 84 | std::cout << "total: " << duration.count() << " ms" << std::endl; 85 | 86 | Status status = reader->Finish(); 87 | if (status.ok()) { 88 | std::cout << "Generation rpc succeeded." << std::endl; 89 | } else { 90 | std::cerr << "Generation rpc failed." << std::endl; 91 | return -1; 92 | } 93 | return 0; 94 | } 95 | 96 | private: 97 | std::unique_ptr stub_; 98 | }; 99 | 100 | int main(int argc, char** argv) { 101 | if (argc < 2) { 102 | std::cerr << "usage: " << argv[0] << " host:port" << std::endl; 103 | return -1; 104 | } 105 | 106 | const std::string target_str = argv[1]; 107 | 108 | GenerationClient generator(grpc::CreateChannel(target_str, grpc::InsecureChannelCredentials())); 109 | 110 | const std::string prompt = "Building a website can be done in 10 simple steps:\n"; 111 | const std::vector prompts(3, prompt); 112 | 113 | std::cout << "------------------------------" << std::endl; 114 | std::cout << "--------- Question -------------" << std::endl; 115 | std::cout << "------------------------------" << std::endl; 116 | 117 | for (auto& str : prompts) { 118 | std::cout << str << std::endl; 119 | } 120 | 121 | generator.Generation(prompts); 122 | return 0; 123 | } 124 | -------------------------------------------------------------------------------- /tools/client_sample_token_in_out.cc: -------------------------------------------------------------------------------- 1 | #include "llm.grpc.pb.h" 2 | 3 | #include "absl/flags/flag.h" 4 | #include "absl/flags/parse.h" 5 | #include "grpc++/grpc++.h" 6 | 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | 13 | using namespace grpc; 14 | using grpc::Channel; 15 | using grpc::ClientContext; 16 | using grpc::Status; 17 | using namespace std::chrono; 18 | using namespace ppl::llm; 19 | 20 | class GenerationClient { 21 | public: 22 | GenerationClient(std::shared_ptr channel) : stub_(proto::LLMService::NewStub(channel)) {} 23 | 24 | int Generation(const std::vector>& batch_prompt_token_ids) { 25 | // Data we are sending to the server. 26 | ClientContext context; 27 | proto::BatchedRequest req_list; 28 | std::unordered_map> rsp_stream_store; 29 | for (size_t i = 0; i < batch_prompt_token_ids.size(); i++) { 30 | const auto& prompt_token_ids = batch_prompt_token_ids[i]; 31 | // request 32 | auto req = req_list.add_req(); 33 | req->set_id(i); 34 | auto* pb_tokens = req->mutable_tokens(); 35 | for (auto token : prompt_token_ids) { 36 | pb_tokens->add_ids(token); 37 | } 38 | auto* choosing_parameter = req->mutable_choosing_parameters(); 39 | choosing_parameter->set_do_sample(false); 40 | choosing_parameter->set_temperature(1.f); 41 | choosing_parameter->set_repetition_penalty(1.f); 42 | choosing_parameter->set_presence_penalty(0.f); 43 | choosing_parameter->set_frequency_penalty(0.f); 44 | 45 | auto* stopping_parameters = req->mutable_stopping_parameters(); 46 | stopping_parameters->set_max_new_tokens(64); 47 | stopping_parameters->set_ignore_eos_token(false); 48 | rsp_stream_store[i] = {}; 49 | } 50 | // response 51 | proto::BatchedResponse batched_rsp; 52 | std::unique_ptr> reader(stub_->Generation(&context, req_list)); 53 | 54 | // stream chat 55 | auto start = system_clock::now(); 56 | auto first_fill_time = system_clock::now(); 57 | bool is_first_fill = true; 58 | 59 | while (reader->Read(&batched_rsp)) { 60 | if (is_first_fill) { 61 | first_fill_time = system_clock::now(); 62 | is_first_fill = false; 63 | } 64 | 65 | for (const auto& rsp : batched_rsp.rsp()) { 66 | int tid = rsp.id(); 67 | int token = rsp.tokens().ids().at(0); 68 | rsp_stream_store[tid].push_back(token); 69 | } 70 | } 71 | auto end = system_clock::now(); 72 | 73 | std::cout << "------------------------------" << std::endl; 74 | std::cout << "--------- Answer -------------" << std::endl; 75 | std::cout << "------------------------------" << std::endl; 76 | 77 | for (const auto rsp : rsp_stream_store) { 78 | for (const auto token : rsp.second) { 79 | std::cout << token << ", "; 80 | } 81 | std::cout << std::endl; 82 | std::cout << "--------------------" << std::endl; 83 | } 84 | 85 | auto first_till_duration = duration_cast(first_fill_time - start); 86 | auto duration = duration_cast(end - start); 87 | 88 | std::cout << "first fill: " << first_till_duration.count() << " ms" << std::endl; 89 | 90 | std::cout << "total: " << duration.count() << " ms" << std::endl; 91 | 92 | Status status = reader->Finish(); 93 | if (status.ok()) { 94 | std::cout << "Generation rpc succeeded." << std::endl; 95 | } else { 96 | std::cerr << "Generation rpc failed." << std::endl; 97 | return -1; 98 | } 99 | return 0; 100 | } 101 | 102 | private: 103 | std::unique_ptr stub_; 104 | }; 105 | 106 | int main(int argc, char** argv) { 107 | if (argc < 2) { 108 | std::cerr << "usage: " << argv[0] << " host:port" << std::endl; 109 | return -1; 110 | } 111 | const std::string target_str = argv[1]; 112 | 113 | GenerationClient generator(grpc::CreateChannel(target_str, grpc::InsecureChannelCredentials())); 114 | 115 | std::vector token_ids = {0, 1, 2, 3, 4, 5, 6, 7}; 116 | const std::vector> prompt_token_ids = {3, token_ids}; 117 | 118 | std::cout << "------------------------------" << std::endl; 119 | std::cout << "--------- Question -------------" << std::endl; 120 | std::cout << "------------------------------" << std::endl; 121 | 122 | for (auto& token_ids : prompt_token_ids) { 123 | for (int token : token_ids) { 124 | std::cout << token << ", "; 125 | } 126 | std::cout << std::endl; 127 | } 128 | 129 | generator.Generation(prompt_token_ids); 130 | return 0; 131 | } 132 | -------------------------------------------------------------------------------- /tools/simple_flags.h: -------------------------------------------------------------------------------- 1 | // Licensed to the Apache Software Foundation (ASF) under one 2 | // or more contributor license agreements. See the NOTICE file 3 | // distributed with this work for additional information 4 | // regarding copyright ownership. The ASF licenses this file 5 | // to you under the Apache License, Version 2.0 (the 6 | // "License"); you may not use this file except in compliance 7 | // with the License. You may obtain a copy of the License at 8 | // 9 | // http://www.apache.org/licenses/LICENSE-2.0 10 | // 11 | // Unless required by applicable law or agreed to in writing, 12 | // software distributed under the License is distributed on an 13 | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | // KIND, either express or implied. See the License for the 15 | // specific language governing permissions and limitations 16 | // under the License. 17 | 18 | #ifndef SIMPLE_FLAGS_H 19 | #define SIMPLE_FLAGS_H 20 | #include 21 | #include 22 | #include 23 | 24 | #define BEGIN_FLAGS_NAMESPACES namespace simple_flags { 25 | #define END_FLAGS_NAMESPACES } 26 | 27 | BEGIN_FLAGS_NAMESPACES 28 | typedef std::string flag_string; 29 | typedef bool flag_bool; 30 | typedef float flag_float; 31 | typedef double flag_double; 32 | typedef int16_t flag_int16; 33 | typedef uint16_t flag_uint16; 34 | typedef int32_t flag_int32; 35 | typedef uint32_t flag_uint32; 36 | typedef int64_t flag_int64; 37 | typedef uint64_t flag_uint64; 38 | typedef std::vector flag_stringlist; 39 | typedef std::vector flag_boollist; 40 | typedef std::vector flag_floatlist; 41 | typedef std::vector flag_doublelist; 42 | typedef std::vector flag_int16list; 43 | typedef std::vector flag_int32list; 44 | typedef std::vector flag_int64list; 45 | typedef std::vector flag_uint16list; 46 | typedef std::vector flag_uint32list; 47 | typedef std::vector flag_uint64list; 48 | 49 | int parse_args(int argc, char** argv); 50 | void print_args_info(); 51 | 52 | template 53 | void registerFlag(const flag_string& opt, T* optPtr, const char* comment); 54 | 55 | const flag_stringlist& get_unknown_flags(); 56 | 57 | END_FLAGS_NAMESPACES 58 | 59 | /** 60 | * \brief Flag_To_Str 61 | * Convert the given literal expression to literal string. 62 | * \param x literal expression to be converted. 63 | */ 64 | #define Flag_To_Str(x) #x 65 | 66 | /** 67 | * \defgroup Single-parameter flag declaration macros 68 | * Declare a option which accept only one parameter, so you can access it by using Flag_* vairable. 69 | * @{ 70 | */ 71 | #define Declare_bool(opt) extern simple_flags::flag_bool Flag_##opt 72 | #define Declare_float(opt) extern simple_flags::flag_float Flag_##opt 73 | #define Declare_double(opt) extern simple_flags::flag_double Flag_##opt 74 | #define Declare_int32(opt) extern simple_flags::flag_int32 Flag_##opt 75 | #define Declare_uint32(opt) extern simple_flags::flag_uint32 Flag_##opt 76 | #define Declare_int64(opt) extern simple_flags::flag_int64 Flag_##opt 77 | #define Declare_uint64(opt) extern simple_flags::flag_uint64 Flag_##opt 78 | #define Declare_string(opt) extern simple_flags::flag_string Flag_##opt 79 | 80 | /** 81 | * @} //! Common declare macros. 82 | */ 83 | 84 | /** 85 | * \defgroup Muti-parameter flag declaration macros 86 | * Declare a option which accept parameters list, so you can access it by using Flag_* variable 87 | */ 88 | #define Declare_stringlist(opt) extern simple_flags::flag_stringlist Flag_##opt 89 | #define Declare_boollist(opt) extern simple_flags::flag_boollist Flag_##opt 90 | #define Declare_floatlist(opt) extern simple_flags::flag_floatlist Flag_##opt 91 | #define Declare_doublelist(opt) extern simple_flags::flag_doublelist Flag_##opt 92 | #define Declare_int32list(opt) extern simple_flags::flag_int32list Flag_##opt 93 | #define Declare_int64list(opt) extern simple_flags::flag_int64list Flag_##opt 94 | #define Declare_uint32list(opt) extern simple_flags::flag_uint32list Flag_##opt 95 | #define Declare_uint64list(opt) extern simple_flags::flag_uint64list Flag_##opt 96 | 97 | /**@}*/ 98 | 99 | /** 100 | * \defgroup Self-defined single-parameter flag declaration macros 101 | * Declare a self-defined option which accept only one parameter, so you can access it by using Flag_* variable 102 | * @{ 103 | */ 104 | #define Declare_bool_opt(flag) extern simple_flags::flag_bool flag 105 | #define Declare_float_opt(flag) extern simple_flags::flag_float flag 106 | #define Declare_double_opt(flag) extern simple_flags::flag_double flag 107 | #define Declare_int32_opt(flag) extern simple_flags::flag_int32 flag 108 | #define Declare_uint32_opt(flag) extern simple_flags::flag_uint32 flag 109 | #define Declare_int64_opt(flag) extern simple_flags::flag_int64 flag 110 | #define Declare_uint64_opt(flag) extern simple_flags::flag_uint64 flag 111 | #define Declare_string_opt(flag) extern simple_flags::flag_string flag 112 | 113 | /**@}*/ 114 | 115 | /** 116 | * \defgroup Self-defined multi-parameter flag declaration macros 117 | * Declare a self-define option which accept parameter list, so you can access it by using Flag_* variable. 118 | * @{ 119 | */ 120 | #define Declare_stringlist_opt(flag) extern simple_flags::flag_stringlist flag 121 | #define Declare_boollist_opt(flag) extern simple_flags::flag_boollist flag 122 | #define Declare_floatlist_opt(flag) extern simple_flags::flag_floatlist flag 123 | #define Declare_doublelist_opt(flag) extern simple_flags::flag_doublelist flag 124 | #define Declare_int32list_opt(flag) extern simple_flags::flag_int32list flag 125 | #define Declare_int64list_opt(flag) extern simple_flags::flag_int64list flag 126 | #define Declare_uint32list_opt(flag) extern simple_flags::flag_uint32list flag 127 | #define Declare_uint64list_opt(flag) extern simple_flags::flag_uint64list flag 128 | 129 | /**@}*/ 130 | 131 | /** 132 | * \defgroup Define definitions 133 | */ 134 | #define Define_Implementer(type, opt, def, comment) \ 135 | simple_flags::type Flag_##opt = def; \ 136 | BEGIN_FLAGS_NAMESPACES \ 137 | class type##_Flag_Register_##opt { \ 138 | public: \ 139 | type##_Flag_Register_##opt() { \ 140 | registerFlag("-" Flag_To_Str(opt), &Flag_##opt, comment); \ 141 | } \ 142 | }; \ 143 | static type##_Flag_Register_##opt s_flag_##opt##_object; \ 144 | END_FLAGS_NAMESPACES 145 | 146 | #define Define_ImplementerOpt(type, opt, flag, def, comment) \ 147 | simple_flags::type flag = def; \ 148 | BEGIN_FLAGS_NAMESPACES \ 149 | class type##_Flag_Register_##flag { \ 150 | public: \ 151 | type##_Flag_Register_##flag() { \ 152 | registerFlag(opt, &flag, comment); \ 153 | } \ 154 | }; \ 155 | static type##_Flag_Register_##flag type##_Flag_Register_##flag##_object; \ 156 | END_FLAGS_NAMESPACES 157 | 158 | #define Define_Implementer_list(type, opt, comment) \ 159 | simple_flags::type Flag_##opt; \ 160 | BEGIN_FLAGS_NAMESPACES \ 161 | class type##_Flag_Register_##opt { \ 162 | public: \ 163 | type##_Flag_Register_##opt() { \ 164 | registerFlag("-" Flag_To_Str(opt), &Flag_##opt, comment); \ 165 | } \ 166 | }; \ 167 | static type##_Flag_Register_##opt type##_Flag_Register_##opt##_object; \ 168 | END_FLAGS_NAMESPACES 169 | 170 | #define Define_Implementer_listOpt(type, opt, flag, comment) \ 171 | simple_flags::type flag; \ 172 | BEGIN_FLAGS_NAMESPACES \ 173 | class type##_Flag_Register_##flag { \ 174 | public: \ 175 | type##_Flag_Register_##flag() { \ 176 | registerFlag(opt, &flag, comment); \ 177 | } \ 178 | }; \ 179 | static type##_Flag_Register_##flag type##_Flag_Register_##flag##_object; \ 180 | END_FLAGS_NAMESPACES 181 | 182 | /**@}*/ 183 | 184 | /** 185 | * \defgroup Single-parameter option define macros 186 | * Define a option which accept only one option, so it will be parsed. 187 | * @{ 188 | */ 189 | #define Define_bool(opt, def, comment) Define_Implementer(flag_bool, opt, def, comment) 190 | #define Define_float(opt, def, comment) Define_Implementer(flag_float, opt, def, comment) 191 | #define Define_double(opt, def, comment) Define_Implementer(flag_double, opt, def, comment) 192 | #define Define_int32(opt, def, comment) Define_Implementer(flag_int32, opt, def, comment) 193 | #define Define_uint32(opt, def, comment) Define_Implementer(flag_uint32, opt, def, comment) 194 | #define Define_int64(opt, def, comment) Define_Implementer(flag_int64, opt, def, comment) 195 | #define Define_uint64(opt, def, comment) Define_Implementer(flag_uint64, opt, def, comment) 196 | #define Define_string(opt, def, comment) Define_Implementer(flag_string, opt, def, comment) 197 | 198 | /**@}*/ 199 | 200 | /** 201 | * \defgroup Self-define single-parameter option define macros 202 | * Define self-defined optiong, so it will be parsed. 203 | * @{ 204 | */ 205 | #define Define_bool_opt(opt, flag, def, comment) Define_ImplementerOpt(flag_bool, opt, flag, def, comment) 206 | #define Define_float_opt(opt, flag, def, comment) Define_ImplementerOpt(flag_float, opt, flag, def, comment) 207 | #define Define_double_opt(opt, flag, def, comment) Define_ImplementerOpt(flag_double, opt, flag, def, comment) 208 | #define Define_int32_opt(opt, flag, def, comment) Define_ImplementerOpt(flag_int32, opt, flag, def, comment) 209 | #define Define_uint32_opt(opt, flag, def, comment) Define_ImplementerOpt(flag_uint32, opt, flag, def, comment) 210 | #define Define_int64_opt(opt, flag, def, comment) Define_ImplementerOpt(flag_int64, opt, flag, def, comment) 211 | #define Define_uint64_opt(opt, flag, def, comment) Define_ImplementerOpt(flag_uint64, opt, flag, def, comment) 212 | #define Define_string_opt(opt, flag, def, comment) Define_ImplementerOpt(flag_string, opt, flag, def, comment) 213 | 214 | /**@}*/ 215 | 216 | /** 217 | * \defgroup Multi-parameter option define macros 218 | * Define a option which accept parameter list, so it will be parsed. 219 | * @{ 220 | */ 221 | #define Define_stringlist(opt, comment) Define_Implementer_list(flag_stringlist, opt, comment) 222 | #define Define_boollist(opt, comment) Define_Implementer_list(flag_boollist, opt, comment) 223 | #define Define_floatlist(opt, comment) Define_Implementer_list(flag_floatlist, opt, comment) 224 | #define Define_doublelist(opt, comment) Define_Implementer_list(flag_doublelist, opt, comment) 225 | #define Define_int32list(opt, comment) Define_Implementer_list(flag_int32list, opt, comment) 226 | #define Define_uint32list(opt, comment) Define_Implementer_list(flag_uint32list, opt, comment) 227 | #define Define_int64list(opt, comment) Define_Implementer_list(flag_int64list, opt, comment) 228 | #define Define_uint64list(opt, comment) Define_Implementer_list(flag_uint64list, opt, comment) 229 | 230 | /**@}*/ 231 | 232 | /** 233 | * \defgroup Self-define multi-parameter option define macros 234 | * Define a option which accept parameter list, so it will be parsed. 235 | */ 236 | #define Define_stringlist_opt(opt, flag, comment) Define_Implementer_listOpt(flag_stringlist, opt, flag, comment) 237 | #define Define_int32list_opt(opt, flag, comment) Define_Implementer_listOpt(flag_int32list, opt, flag, comment) 238 | #define Define_uint32list_opt(opt, flag, comment) Define_Implementer_listOpt(flag_uint32list, opt, flag, comment) 239 | #define Define_int64list_opt(opt, flag, comment) Define_Implementer_listOpt(flag_int64list, opt, flag, comment) 240 | #define Define_uint64list_opt(opt, flag, comment) Define_Implementer_listOpt(flag_uint64list, opt, flag, comment) 241 | #define Define_floatlist_opt(opt, flag, comment) Define_Implementer_listOpt(flag_floatlist, opt, flag, comment) 242 | #define Define_doublelist_opt(opt, flag, comment) Define_Implementer_listOpt(flag_doublelist, opt, flag, comment) 243 | 244 | /**@}*/ 245 | 246 | Declare_bool(help); 247 | Declare_stringlist(unknown_trash); 248 | #endif // SIMPLE_FLAGS_H 249 | --------------------------------------------------------------------------------