├── examples ├── model_repo │ ├── facebook_m2m100_1.2B │ │ ├── 1 │ │ │ └── .gitkeep │ │ └── config.pbtxt │ └── Helsinki-NLP_opus-mt-en-de │ │ ├── 1 │ │ ├── .gitkeep │ │ └── README.md │ │ └── config.pbtxt └── client.py ├── .gitignore ├── LICENSE.txt ├── src ├── libtriton_ctranslate2.ldscript └── ctranslate2.cc ├── cmake └── CTranslate2BackendConfig.cmake.in ├── README.md └── CMakeLists.txt /examples/model_repo/facebook_m2m100_1.2B/1/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /examples/model_repo/Helsinki-NLP_opus-mt-en-de/1/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | build-* 2 | build 3 | .cache 4 | example_repo/Helsinki-* 5 | -------------------------------------------------------------------------------- /examples/model_repo/Helsinki-NLP_opus-mt-en-de/1/README.md: -------------------------------------------------------------------------------- 1 | Copy the converted model in here. -------------------------------------------------------------------------------- /examples/model_repo/Helsinki-NLP_opus-mt-en-de/config.pbtxt: -------------------------------------------------------------------------------- 1 | backend: "ctranslate2" 2 | name: "Helsinki-NLP_opus-mt-en-de" 3 | max_batch_size: 128 4 | input [ 5 | { 6 | name: "INPUT_IDS" 7 | data_type: TYPE_INT32 8 | dims: [ -1 ] 9 | allow_ragged_batch: true 10 | } 11 | ] 12 | output [ 13 | { 14 | name: "OUTPUT_IDS" 15 | data_type: TYPE_INT32 16 | dims: [ -1 ] 17 | } 18 | ] 19 | parameters [ 20 | { 21 | key: "compute_type" 22 | value { 23 | string_value: "float16" 24 | } 25 | }, 26 | { 27 | key: "max_decoding_length_multiple" 28 | value { 29 | string_value: "2" 30 | } 31 | }, 32 | { 33 | key: "beam_size" 34 | value { 35 | string_value: "4" 36 | } 37 | } 38 | ] 39 | 40 | instance_group [{ kind: KIND_GPU, count: 1 }] 41 | dynamic_batching { 42 | max_queue_delay_microseconds: 5000 43 | } 44 | -------------------------------------------------------------------------------- /examples/model_repo/facebook_m2m100_1.2B/config.pbtxt: -------------------------------------------------------------------------------- 1 | backend: "ctranslate2" 2 | name: "facebook_m2m100_1.2B" 3 | max_batch_size: 64 4 | input [ 5 | { 6 | name: "INPUT_IDS" 7 | data_type: TYPE_INT32 8 | dims: [ -1 ] 9 | allow_ragged_batch: true 10 | } 11 | ] 12 | input [ 13 | { 14 | name: "TARGET_PREFIX" 15 | data_type: TYPE_INT32 16 | dims: [ -1 ] 17 | allow_ragged_batch: true 18 | } 19 | ] 20 | output [ 21 | { 22 | name: "OUTPUT_IDS" 23 | data_type: TYPE_INT32 24 | dims: [ -1 ] 25 | } 26 | ] 27 | parameters [ 28 | { 29 | key: "compute_type" 30 | value { 31 | string_value: "float16" 32 | } 33 | }, 34 | { 35 | key: "max_decoding_length_multiple" 36 | value { 37 | string_value: "2" 38 | } 39 | }, 40 | { 41 | key: "beam_size" 42 | value { 43 | string_value: "5" 44 | } 45 | } 46 | ] 47 | 48 | instance_group [{ kind: KIND_GPU, count: 1 }] 49 | dynamic_batching { 50 | max_queue_delay_microseconds: 5000 51 | } 52 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023, Cantab Research Ltd. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /examples/client.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import tritonclient.grpc.aio 3 | from tritonclient.utils import np_to_triton_dtype 4 | from grpc import ChannelConnectivity 5 | from transformers import AutoTokenizer 6 | import logging 7 | import numpy as np 8 | import sys 9 | 10 | async def main(): 11 | MODEL_NAME = "opus-mt-en-de" 12 | URL = "127.0.0.1:8001" 13 | client = tritonclient.grpc.aio.InferenceServerClient(URL) 14 | 15 | en_text = sys.stdin.readline() 16 | tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/" + MODEL_NAME) 17 | 18 | input_ids = tokenizer(en_text, return_attention_mask=False, return_tensors="np").input_ids.astype(np.int32) 19 | logging.info(f"Tokenised input: {input_ids}") 20 | 21 | if client._channel.get_state() == ChannelConnectivity.SHUTDOWN: 22 | return 23 | 24 | inputs = [ 25 | tritonclient.grpc.aio.InferInput("INPUT_IDS", input_ids.shape, np_to_triton_dtype(input_ids.dtype)), 26 | ] 27 | inputs[0].set_data_from_numpy(input_ids) 28 | outputs = [tritonclient.grpc.aio.InferRequestedOutput("OUTPUT_IDS")] 29 | 30 | res = await client.infer(model_name=MODEL_NAME, inputs=inputs, outputs=outputs) 31 | out_tokens = res.as_numpy("OUTPUT_IDS") 32 | logging.info(f"Returned tokens: {out_tokens}") 33 | translated_text = tokenizer.batch_decode(out_tokens) 34 | print(translated_text) 35 | 36 | if __name__ == "__main__": 37 | logging.basicConfig(level=logging.INFO) 38 | asyncio.run(main()) 39 | -------------------------------------------------------------------------------- /src/libtriton_ctranslate2.ldscript: -------------------------------------------------------------------------------- 1 | # Copyright 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # Redistribution and use in source and binary forms, with or without 4 | # modification, are permitted provided that the following conditions 5 | # are met: 6 | # * Redistributions of source code must retain the above copyright 7 | # notice, this list of conditions and the following disclaimer. 8 | # * Redistributions in binary form must reproduce the above copyright 9 | # notice, this list of conditions and the following disclaimer in the 10 | # documentation and/or other materials provided with the distribution. 11 | # * Neither the name of NVIDIA CORPORATION nor the names of its 12 | # contributors may be used to endorse or promote products derived 13 | # from this software without specific prior written permission. 14 | # 15 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY 16 | # EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 17 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 18 | # PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR 19 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 20 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 21 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 22 | # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY 23 | # OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 24 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | { 27 | global: 28 | TRITONBACKEND_*; 29 | local: *; 30 | }; 31 | -------------------------------------------------------------------------------- /cmake/CTranslate2BackendConfig.cmake.in: -------------------------------------------------------------------------------- 1 | # Copyright 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # Redistribution and use in source and binary forms, with or without 4 | # modification, are permitted provided that the following conditions 5 | # are met: 6 | # * Redistributions of source code must retain the above copyright 7 | # notice, this list of conditions and the following disclaimer. 8 | # * Redistributions in binary form must reproduce the above copyright 9 | # notice, this list of conditions and the following disclaimer in the 10 | # documentation and/or other materials provided with the distribution. 11 | # * Neither the name of NVIDIA CORPORATION nor the names of its 12 | # contributors may be used to endorse or promote products derived 13 | # from this software without specific prior written permission. 14 | # 15 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY 16 | # EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 17 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 18 | # PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR 19 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 20 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 21 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 22 | # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY 23 | # OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 24 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | 27 | include(CMakeFindDependencyMacro) 28 | 29 | get_filename_component( 30 | CTRANSLATE2BACKEND_CMAKE_DIR "${CMAKE_CURRENT_LIST_FILE}" PATH 31 | ) 32 | 33 | list(APPEND CMAKE_MODULE_PATH ${CTRANSLATE2BACKEND_CMAKE_DIR}) 34 | 35 | if(NOT TARGET CTranslate2Backend::ctranslate2-backend) 36 | include("${CTRANSLATE2BACKEND_CMAKE_DIR}/CTranslate2BackendTargets.cmake") 37 | endif() 38 | 39 | set(CTRANSLATE2BACKEND_LIBRARIES CTranslate2Backend::ctranslate2-backend) 40 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CTranslate2 Backend for Triton Inference Server 2 | 3 | This is a [backend](https://github.com/triton-inference-server/backend) based on [CTranslate2](https://github.com/OpenNMT/CTranslate2) for NVIDIA's [Triton Inference Server](https://developer.nvidia.com/nvidia-triton-inference-server), which can be used to deploy translation and language models supported by CTranslate2 on Triton with both CPU and GPU capabilities. 4 | 5 | It supports ragged and dynamic batching and setting of (a subset of) CTranslate decoding parameters in the model config. 6 | 7 | ## Building 8 | 9 | Make sure to have [cmake](https://cmake.org) installed on your system. 10 | 11 | 1. Build and install CTranslate2: [https://opennmt.net/CTranslate2/installation.html#compile-the-c-library](https://opennmt.net/CTranslate2/installation.html#compile-the-c-library) 12 | 2. Build the backend 13 | ```bash 14 | mkdir build && cd build 15 | export BACKEND_INSTALL_DIR=$(pwd)/install 16 | cmake .. -DCMAKE_BUILD_TYPE=Release -DTRITON_ENABLE_GPU=1 -DCMAKE_INSTALL_PREFIX=$BACKEND_INSTALL_DIR 17 | make install 18 | ``` 19 | 20 | This builds the backend into `$BACKEND_INSTALL_DIR/backends/ctranslate2`. 21 | 22 | ## Setting up the backend 23 | 24 | First install the pip package to convert models: `pip install ctranslate2`. Then create a model repository, which consists of a configuration (config.pbtxt) and the converted model. 25 | 26 | For example for the [Helsinki-NLP/opus-mt-en-de](https://huggingface.co/Helsinki-NLP/opus-mt-en-de) HuggingFace transformer model, create a new directory e.g. `mkdir $MODEL_DIR/opus-mt-en-de`. 27 | The model needs to be moved into a directory called `model` that is nested in a folder specifying a numerical version of the model: 28 | 29 | ```bash 30 | ct2-transformers-converter --model Helsinki-NLP/opus-mt-en-de --output_dir 1/model 31 | ``` 32 | 33 | The minimum configuration for the model and backend is the following, you can see an example configs in [examples/model_repo](examples/model_repo): 34 | 35 | ```protobuf 36 | backend: "ctranslate2" # must be ctranslate2 37 | name: "opus-mt-en-de" # must be the same as the model name 38 | max_batch_size: 128 # can be optimised based on available GPU memory 39 | input [ 40 | { 41 | name: "INPUT_IDS" 42 | data_type: TYPE_INT32 43 | dims: [ -1 ] 44 | allow_ragged_batch: true # needed for dynamic batching 45 | } 46 | ] 47 | output [ 48 | { 49 | name: "OUTPUT_IDS" 50 | data_type: TYPE_INT32 51 | dims: [ -1 ] 52 | } 53 | ] 54 | 55 | instance_group [{ kind: KIND_GPU, count: 1 }] # use KIND_CPU for CPU inference 56 | dynamic_batching { 57 | max_queue_delay_microseconds: 5000 # can be tuned based on latency requirements 58 | } 59 | ``` 60 | 61 | Start the tritonserver with the ctranslate backend and model repository: 62 | ```bash 63 | tritonserver --backend-directory $BACKEND_INSTALL_DIR/backends --model-repository $MODEL_DIR 64 | ``` 65 | 66 | The backend is set up to use ragged, [dynamic batching](https://github.com/triton-inference-server/server/blob/main/docs/user_guide/model_configuration.md#dynamic-batcher). 67 | This means you should send each input in it's own request and Triton will take care of batching, using the `dynamic_batching` and `max_batch_size` configuration to build appropriate batches, 68 | for best performance on GPUs. 69 | 70 | ### Providing a target prefix 71 | 72 | There are models that require a special prefix token for the decoder. For example the M2M models need a token that specifies the target language, or sometimes it might be useful to start the translation with a specific prefix. An example config can be found in the [facebook M2M config.pbtxt](examples/model_repo/facebook_m2m100_1.2B/config.pbtxt). 73 | 74 | 75 | ## Sending requests 76 | 77 | The backend expects token IDs as input and output, which need to be INT32 or bigger data-types. This might include beginning-of-sentence and end-of-sentence token, depending on the model. 78 | No padding tokens need to be added though, the backend is taking care of padding and batching. 79 | 80 | You can use the offical Triton clients to make requests, both HTTP and gRPC protocols are supported. We provide an example for a Python client [here](examples/client.py). 81 | 82 | You can try the working translation by running 83 | ```bash 84 | echo "How is your day going?" | python3 examples/client.py 85 | ``` 86 | 87 | In case you want to use a special prefix for decoding, the request also needs to have the input `TARGET_PREFIX` set, which could look like this in Python: 88 | 89 | ```python 90 | inputs.append( 91 | tritonclient.grpc.InferInput("TARGET_PREFIX", prefix_ids.shape, np_to_triton_dtype(prefix_ids.dtype)) 92 | ) 93 | inputs[1].set_data_from_numpy(prefix_ids) 94 | ``` 95 | 96 | ## Configuration 97 | 98 | The backend exposes a few parameters for customisation, currently that is a subset of decoding options in the CTranslate2 C++ interface and some special ones that are useful for limiting inference compute requirements: 99 | 100 | * `compute_type` overrides the numerical precision used for the majority of the Transformer computation, relates to [quantization types](https://opennmt.net/CTranslate2/quantization.html#quantization) in model conversion. 101 | * `max_decoding_length_multiple` can be used to limit the number of output tokens as a multiple of input tokens. E.g if the longest input sequence is 10 tokens and `max_decoding_length_multiple: "2"` decoding is limitted to 20 tokens. 102 | 103 | Decoding parameters passed through to CTranslate, more to be added: 104 | 105 | * `beam_size` 106 | * `repetition_penalty` 107 | 108 | The parameters can be set like this in the config.pbtxt: 109 | 110 | ```protobuf 111 | parameters [ 112 | { 113 | key: "compute_type" 114 | value { 115 | string_value: "float16" # optional, can be used to force a compute type 116 | } 117 | }, 118 | { 119 | key: "max_decoding_length_multiple" 120 | value { 121 | string_value: "2" 122 | } 123 | }, 124 | { 125 | key: "beam_size" 126 | value { 127 | string_value: "4" 128 | } 129 | }, 130 | { 131 | key: "repetition_penalty" 132 | value { 133 | string_value: "1.5" 134 | } 135 | } 136 | ] 137 | ``` -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # Copyright 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # Redistribution and use in source and binary forms, with or without 4 | # modification, are permitted provided that the following conditions 5 | # are met: 6 | # * Redistributions of source code must retain the above copyright 7 | # notice, this list of conditions and the following disclaimer. 8 | # * Redistributions in binary form must reproduce the above copyright 9 | # notice, this list of conditions and the following disclaimer in the 10 | # documentation and/or other materials provided with the distribution. 11 | # * Neither the name of NVIDIA CORPORATION nor the names of its 12 | # contributors may be used to endorse or promote products derived 13 | # from this software without specific prior written permission. 14 | # 15 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY 16 | # EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 17 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 18 | # PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR 19 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 20 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 21 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 22 | # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY 23 | # OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 24 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | 27 | cmake_minimum_required(VERSION 3.17) 28 | 29 | project(ctranslate2backend LANGUAGES C CXX) 30 | 31 | # By default RPATH will be used in the build tree, but cleared when installing the targets, leaving an empty RPATH 32 | # Thus we set the RPATH explicit for the installed targets 33 | # add the automatically determined parts of the RPATH 34 | # which point to directories outside the build tree to the install RPATH 35 | set(CMAKE_SKIP_BUILD_RPATH FALSE) 36 | set(CMAKE_BUILD_WITH_INSTALL_RPATH TRUE) 37 | set(CMAKE_INSTALL_RPATH "\${ORIGIN}") 38 | 39 | # 40 | # Options 41 | # 42 | # Must include options required for this project as well as any 43 | # projects included in this one by FetchContent. 44 | # 45 | option(TRITON_ENABLE_GPU "Enable GPU support in backend" OFF) 46 | option(TRITON_ENABLE_STATS "Include statistics collections in backend" ON) 47 | 48 | set(TRITON_COMMON_REPO_TAG "main" CACHE STRING "Tag for triton-inference-server/common repo") 49 | set(TRITON_CORE_REPO_TAG "main" CACHE STRING "Tag for triton-inference-server/core repo") 50 | set(TRITON_BACKEND_REPO_TAG "main" CACHE STRING "Tag for triton-inference-server/backend repo") 51 | 52 | if(NOT CMAKE_BUILD_TYPE) 53 | set(CMAKE_BUILD_TYPE Release) 54 | endif() 55 | 56 | # 57 | # Dependencies 58 | # 59 | # FetchContent requires us to include the transitive closure of all 60 | # repos that we depend on so that we can override the tags. 61 | # 62 | include(FetchContent) 63 | 64 | FetchContent_Declare( 65 | repo-common 66 | GIT_REPOSITORY https://github.com/triton-inference-server/common.git 67 | GIT_TAG ${TRITON_COMMON_REPO_TAG} 68 | GIT_SHALLOW ON 69 | ) 70 | FetchContent_Declare( 71 | repo-core 72 | GIT_REPOSITORY https://github.com/triton-inference-server/core.git 73 | GIT_TAG ${TRITON_CORE_REPO_TAG} 74 | GIT_SHALLOW ON 75 | ) 76 | FetchContent_Declare( 77 | repo-backend 78 | GIT_REPOSITORY https://github.com/triton-inference-server/backend.git 79 | GIT_TAG ${TRITON_BACKEND_REPO_TAG} 80 | GIT_SHALLOW ON 81 | ) 82 | FetchContent_MakeAvailable(repo-common repo-core repo-backend) 83 | 84 | find_package(ctranslate2) 85 | 86 | # 87 | # The backend must be built into a shared library. Use an ldscript to 88 | # hide all symbols except for the TRITONBACKEND API. 89 | # 90 | configure_file(src/libtriton_ctranslate2.ldscript libtriton_ctranslate2.ldscript COPYONLY) 91 | 92 | add_library( 93 | ctranslate2-backend SHARED 94 | src/ctranslate2.cc 95 | ) 96 | 97 | add_library( 98 | CTranslate2Backend::ctranslate2-backend ALIAS ctranslate2-backend 99 | ) 100 | 101 | target_include_directories( 102 | ctranslate2-backend 103 | PRIVATE 104 | ${CMAKE_CURRENT_SOURCE_DIR}/src 105 | ) 106 | 107 | target_compile_features(ctranslate2-backend PRIVATE cxx_std_17) 108 | target_compile_options( 109 | ctranslate2-backend PRIVATE 110 | $<$,$,$>: 111 | -Wall -Wextra -Wno-unused-parameter -Wno-type-limits -Werror> 112 | $<$:/Wall /D_WIN32_WINNT=0x0A00 /EHsc> 113 | ) 114 | 115 | target_link_libraries( 116 | ctranslate2-backend 117 | PRIVATE 118 | triton-core-serverapi # from repo-core 119 | triton-core-backendapi # from repo-core 120 | triton-core-serverstub # from repo-core 121 | triton-backend-utils # from repo-backend 122 | CTranslate2::ctranslate2 123 | ) 124 | 125 | if(WIN32) 126 | set_target_properties( 127 | ctranslate2-backend PROPERTIES 128 | POSITION_INDEPENDENT_CODE ON 129 | OUTPUT_NAME triton_ctranslate2 130 | ) 131 | else() 132 | set_target_properties( 133 | ctranslate2-backend PROPERTIES 134 | POSITION_INDEPENDENT_CODE ON 135 | OUTPUT_NAME triton_ctranslate2 136 | LINK_DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/libtriton_ctranslate2.ldscript 137 | LINK_FLAGS "-Wl,--version-script libtriton_ctranslate2.ldscript" 138 | ) 139 | endif() 140 | 141 | # 142 | # Install 143 | # 144 | include(GNUInstallDirs) 145 | set(INSTALL_CONFIGDIR ${CMAKE_INSTALL_LIBDIR}/cmake/CTranslate2Backend) 146 | 147 | 148 | install( 149 | TARGETS 150 | ctranslate2-backend 151 | EXPORT 152 | ctranslate2-backend-targets 153 | LIBRARY DESTINATION ${CMAKE_INSTALL_PREFIX}/backends/ctranslate2 154 | RUNTIME DESTINATION ${CMAKE_INSTALL_PREFIX}/backends/ctranslate2 155 | ) 156 | 157 | install(IMPORTED_RUNTIME_ARTIFACTS CTranslate2::ctranslate2 DESTINATION ${CMAKE_INSTALL_PREFIX}/backends/ctranslate2) 158 | 159 | 160 | install( 161 | EXPORT 162 | ctranslate2-backend-targets 163 | FILE 164 | CTranslate2BackendTargets.cmake 165 | NAMESPACE 166 | CTranslate2:: 167 | DESTINATION 168 | ${INSTALL_CONFIGDIR} 169 | ) 170 | 171 | include(CMakePackageConfigHelpers) 172 | configure_package_config_file( 173 | ${CMAKE_CURRENT_LIST_DIR}/cmake/CTranslate2BackendConfig.cmake.in 174 | ${CMAKE_CURRENT_BINARY_DIR}/CTranslate2BackendConfig.cmake 175 | INSTALL_DESTINATION ${INSTALL_CONFIGDIR} 176 | ) 177 | 178 | install( 179 | FILES 180 | ${CMAKE_CURRENT_BINARY_DIR}/CTranslate2BackendConfig.cmake 181 | DESTINATION ${INSTALL_CONFIGDIR} 182 | ) 183 | 184 | # 185 | # Export from build tree 186 | # 187 | export( 188 | EXPORT ctranslate2-backend-targets 189 | FILE ${CMAKE_CURRENT_BINARY_DIR}/CTranslate2BackendTargets.cmake 190 | NAMESPACE CTranslate2Backend:: 191 | ) 192 | 193 | export(PACKAGE CTranslate2Backend) 194 | -------------------------------------------------------------------------------- /src/ctranslate2.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // Redistribution and use in source and binary forms, with or without 4 | // modification, are permitted provided that the following conditions 5 | // are met: 6 | // * Redistributions of source code must retain the above copyright 7 | // notice, this list of conditions and the following disclaimer. 8 | // * Redistributions in binary form must reproduce the above copyright 9 | // notice, this list of conditions and the following disclaimer in the 10 | // documentation and/or other materials provided with the distribution. 11 | // * Neither the name of NVIDIA CORPORATION nor the names of its 12 | // contributors may be used to endorse or promote products derived 13 | // from this software without specific prior written permission. 14 | // 15 | // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY 16 | // EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 17 | // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 18 | // PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR 19 | // CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 20 | // EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 21 | // PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 22 | // PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY 23 | // OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 24 | // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | 27 | #include 28 | #include 29 | #include 30 | #include 31 | #include 32 | #include 33 | #include 34 | #include 35 | 36 | #include "triton/backend/backend_common.h" 37 | #include "triton/backend/backend_input_collector.h" 38 | #include "triton/backend/backend_model.h" 39 | #include "triton/backend/backend_model_instance.h" 40 | #include "triton/backend/backend_output_responder.h" 41 | #include "triton/common/triton_json.h" 42 | #include "triton/core/tritonbackend.h" 43 | 44 | #include "ctranslate2/models/model.h" 45 | #include "ctranslate2/models/sequence_to_sequence.h" 46 | #include "triton/core/tritonserver.h" 47 | 48 | namespace triton { 49 | namespace backend { 50 | namespace ctranslate2 { 51 | 52 | TRITONSERVER_Error * 53 | ReadParameter(const triton::common::TritonJson::Value ¶ms, 54 | const std::string &key, std::string *param) { 55 | triton::common::TritonJson::Value value; 56 | RETURN_ERROR_IF_FALSE( 57 | const_cast(params).Find(key.c_str(), 58 | &value), 59 | TRITONSERVER_ERROR_INVALID_ARG, 60 | std::string("model configuration is missing the parameter ") + key); 61 | RETURN_IF_ERROR(value.MemberAsString("string_value", param)); 62 | return nullptr; // success 63 | } 64 | 65 | TRITONSERVER_Error * 66 | ReadParameter(const triton::common::TritonJson::Value ¶ms, 67 | const std::string &key, int *param) { 68 | std::string tmp; 69 | RETURN_IF_ERROR(ReadParameter(params, key, &tmp)); 70 | *param = std::stoi(tmp); 71 | return nullptr; // success 72 | } 73 | 74 | TRITONSERVER_Error * 75 | ReadParameter(const triton::common::TritonJson::Value ¶ms, 76 | const std::string &key, size_t *param) { 77 | std::string tmp; 78 | RETURN_IF_ERROR(ReadParameter(params, key, &tmp)); 79 | *param = static_cast(std::stoi(tmp)); 80 | return nullptr; // success 81 | } 82 | 83 | TRITONSERVER_Error * 84 | ReadParameter(const triton::common::TritonJson::Value ¶ms, 85 | const std::string &key, float *param) { 86 | std::string tmp; 87 | RETURN_IF_ERROR(ReadParameter(params, key, &tmp)); 88 | *param = std::stof(tmp); 89 | return nullptr; // success 90 | } 91 | 92 | class ModelState : public BackendModel { 93 | public: 94 | static TRITONSERVER_Error *Create(TRITONBACKEND_Model *triton_model, 95 | ModelState **state); 96 | virtual ~ModelState() = default; 97 | 98 | ModelState(TRITONBACKEND_Model *triton_model) 99 | : BackendModel(triton_model, false) { 100 | THROW_IF_BACKEND_MODEL_ERROR(ValidateModel()); 101 | THROW_IF_BACKEND_MODEL_ERROR(ValidateModelConfig()); 102 | } 103 | 104 | TRITONSERVER_Error *ValidateModelConfig() { 105 | // If verbose logging is enabled, dump the model's configuration as 106 | // JSON into the console output. 107 | if (TRITONSERVER_LogIsEnabled(TRITONSERVER_LOG_VERBOSE)) { 108 | common::TritonJson::WriteBuffer buffer; 109 | RETURN_IF_ERROR(ModelConfig().PrettyWrite(&buffer)); 110 | LOG_MESSAGE( 111 | TRITONSERVER_LOG_VERBOSE, 112 | (std::string("model configuration:\n") + buffer.Contents()).c_str()); 113 | } 114 | 115 | // ModelConfig is the model configuration as a TritonJson 116 | // object. Use the TritonJson utilities to parse the JSON and 117 | // determine if the configuration is supported by this backend. 118 | common::TritonJson::Value inputs, outputs; 119 | RETURN_IF_ERROR(ModelConfig().MemberAsArray("input", &inputs)); 120 | RETURN_IF_ERROR(ModelConfig().MemberAsArray("output", &outputs)); 121 | 122 | // The model must have exactly 1 input and 1 output. 123 | RETURN_ERROR_IF_FALSE( 124 | inputs.ArraySize() == 1 || inputs.ArraySize() == 2, 125 | TRITONSERVER_ERROR_INVALID_ARG, 126 | std::string("model configuration must have 1 or 2 inputs")); 127 | RETURN_ERROR_IF_FALSE( 128 | outputs.ArraySize() == 1, TRITONSERVER_ERROR_INVALID_ARG, 129 | std::string("model configuration must have 1 output")); 130 | 131 | common::TritonJson::Value input, output; 132 | RETURN_IF_ERROR(inputs.IndexAsObject(0, &input)); 133 | RETURN_IF_ERROR(outputs.IndexAsObject(0, &output)); 134 | 135 | // Record the input and output name in the model state. 136 | const char *input_name; 137 | size_t input_name_len; 138 | RETURN_IF_ERROR(input.MemberAsString("name", &input_name, &input_name_len)); 139 | input_name_ = std::string(input_name); 140 | 141 | if (inputs.ArraySize() == 2) { 142 | RETURN_IF_ERROR(inputs.IndexAsObject(1, &input)); 143 | const char *target_prefix_input_name; 144 | size_t target_prefix_input_name_len; 145 | RETURN_IF_ERROR(input.MemberAsString("name", &target_prefix_input_name, 146 | &target_prefix_input_name_len)); 147 | target_prefix_input_name_ = std::string(target_prefix_input_name); 148 | } 149 | 150 | const char *output_name; 151 | size_t output_name_len; 152 | RETURN_IF_ERROR( 153 | output.MemberAsString("name", &output_name, &output_name_len)); 154 | output_name_ = std::string(output_name); 155 | 156 | std::string io_dtype; 157 | RETURN_IF_ERROR(output.MemberAsString("data_type", &io_dtype)); 158 | output_type_ = 159 | triton::backend::ModelConfigDataTypeToTritonServerDataType(io_dtype); 160 | 161 | triton::common::TritonJson::Value params; 162 | bool has_params = ModelConfig().Find("parameters", ¶ms); 163 | if (has_params) { 164 | if (params.Find("compute_type")) { 165 | std::string compute_type_str; 166 | RETURN_IF_ERROR( 167 | ReadParameter(params, "compute_type", &compute_type_str)); 168 | compute_type_ = ::ctranslate2::str_to_compute_type(compute_type_str); 169 | 170 | LOG_MESSAGE(TRITONSERVER_LOG_INFO, 171 | (std::string("Running inference in compute type: ") + 172 | compute_type_str) 173 | .c_str()); 174 | } 175 | if (params.Find("max_decoding_length_multiple")) { 176 | size_t max_decode_length_multiple; 177 | RETURN_IF_ERROR(ReadParameter(params, "max_decoding_length_multiple", 178 | &max_decode_length_multiple)); 179 | max_decode_length_multiple_ = max_decode_length_multiple; 180 | } 181 | if (params.Find("beam_size")) { 182 | RETURN_IF_ERROR(ReadParameter( 183 | params, "beam_size", &(default_translation_options_.beam_size))); 184 | } 185 | if (params.Find("repetition_penalty")) { 186 | RETURN_IF_ERROR(ReadParameter(params, "repetition_penalty", 187 | &(default_translation_options_.repetition_penalty))); 188 | } 189 | } 190 | return nullptr; 191 | } 192 | 193 | const std::string &InputTensorName() const { return input_name_; } 194 | const std::optional &TargetPrefixInputName() const { 195 | return target_prefix_input_name_; 196 | } 197 | const std::string &OutputTensorName() const { return output_name_; } 198 | TRITONSERVER_DataType OutputDataType() const { return output_type_; } 199 | ::ctranslate2::TranslationOptions DefaultTranslationOptions() const { 200 | return default_translation_options_; 201 | } 202 | const std::optional &MaxDecodeLengthMultiple() const { 203 | return max_decode_length_multiple_; 204 | } 205 | 206 | TRITONSERVER_Error * 207 | LoadModel(const ::ctranslate2::Device device, std::int32_t device_index, 208 | std::shared_ptr *ct_model) { 209 | std::shared_ptr model; 210 | std::pair<::ctranslate2::Device, std::int32_t> device_pair = 211 | std::make_pair(device, device_index); 212 | auto mit = models_.find(device_pair); 213 | if (mit != models_.end()) { 214 | model = mit->second; 215 | } else { 216 | if (!models_.empty()) { 217 | model = models_.begin()->second->copy_to(device, device_index); 218 | } else { 219 | model = ::ctranslate2::models::Model::load(*model_reader_, device, 220 | device_index, compute_type_); 221 | } 222 | models_.emplace(device_pair, model); 223 | } 224 | *ct_model = model; 225 | 226 | return nullptr; 227 | } 228 | 229 | private: 230 | // TRITONBACKEND_Model *triton_model_; 231 | triton::common::TritonJson::Value model_config_; 232 | std::string input_name_; 233 | std::optional target_prefix_input_name_; 234 | std::string output_name_; 235 | TRITONSERVER_DataType output_type_; 236 | ::ctranslate2::ComputeType compute_type_ = 237 | ::ctranslate2::ComputeType::DEFAULT; 238 | ::ctranslate2::TranslationOptions default_translation_options_; 239 | std::optional max_decode_length_multiple_; 240 | std::string model_path_; 241 | std::shared_ptr<::ctranslate2::models::ModelReader> model_reader_; 242 | std::map, 243 | std::shared_ptr> 244 | models_; 245 | 246 | TRITONSERVER_Error *ValidateModel() { 247 | std::string artifact_filename; 248 | THROW_IF_BACKEND_MODEL_ERROR(ModelConfig().MemberAsString( 249 | "default_model_filename", &artifact_filename)); 250 | // if default_model_filename not set default to "model" 251 | if (artifact_filename.empty()) { 252 | artifact_filename = "model"; 253 | } 254 | 255 | model_path_ = JoinPath( 256 | {RepositoryPath(), std::to_string(Version()), artifact_filename}); 257 | model_reader_ = 258 | std::make_shared<::ctranslate2::models::ModelFileReader>(model_path_); 259 | auto contains_model = ::ctranslate2::models::contains_model(model_path_); 260 | 261 | RETURN_ERROR_IF_FALSE(contains_model, TRITONSERVER_ERROR_UNAVAILABLE, 262 | std::string("unable to find '") + model_path_ + 263 | "' for model instance '" + Name() + "'"); 264 | 265 | return nullptr; 266 | } 267 | }; 268 | 269 | TRITONSERVER_Error *ModelState::Create(TRITONBACKEND_Model *triton_model, 270 | ModelState **state) { 271 | 272 | try { 273 | *state = new ModelState(triton_model); 274 | } catch (const BackendModelException &ex) { 275 | RETURN_ERROR_IF_TRUE( 276 | ex.err_ == nullptr, TRITONSERVER_ERROR_INTERNAL, 277 | std::string("unexpected nullptr in BackendModelException")); 278 | RETURN_IF_ERROR(ex.err_); 279 | } 280 | 281 | return nullptr; // success 282 | } 283 | 284 | extern "C" { 285 | 286 | // Triton calls TRITONBACKEND_ModelInitialize when a model is loaded 287 | // to allow the backend to create any state associated with the model, 288 | // and to also examine the model configuration to determine if the 289 | // configuration is suitable for the backend. Any errors reported by 290 | // this function will prevent the model from loading. 291 | // 292 | TRITONSERVER_Error *TRITONBACKEND_ModelInitialize(TRITONBACKEND_Model *model) { 293 | const char *cname; 294 | RETURN_IF_ERROR(TRITONBACKEND_ModelName(model, &cname)); 295 | std::string name(cname); 296 | 297 | uint64_t version; 298 | RETURN_IF_ERROR(TRITONBACKEND_ModelVersion(model, &version)); 299 | 300 | LOG_MESSAGE(TRITONSERVER_LOG_INFO, 301 | (std::string("TRITONBACKEND_ModelInitialize: ") + name + 302 | " (version " + std::to_string(version) + ")") 303 | .c_str()); 304 | 305 | // Create a ModelState object and associate it with the 306 | // TRITONBACKEND_Model. If anything goes wrong with initialization 307 | // of the model state then an error is returned and Triton will fail 308 | // to load the model. 309 | ModelState *model_state; 310 | RETURN_IF_ERROR(ModelState::Create(model, &model_state)); 311 | RETURN_IF_ERROR(TRITONBACKEND_ModelSetState( 312 | model, reinterpret_cast(model_state))); 313 | 314 | return nullptr; // success 315 | } 316 | 317 | // Triton calls TRITONBACKEND_ModelFinalize when a model is no longer 318 | // needed. The backend should cleanup any state associated with the 319 | // model. This function will not be called until all model instances 320 | // of the model have been finalized. 321 | // 322 | TRITONSERVER_Error *TRITONBACKEND_ModelFinalize(TRITONBACKEND_Model *model) { 323 | void *vstate; 324 | RETURN_IF_ERROR(TRITONBACKEND_ModelState(model, &vstate)); 325 | ModelState *model_state = reinterpret_cast(vstate); 326 | delete model_state; 327 | 328 | return nullptr; // success 329 | } 330 | 331 | } // extern "C" 332 | 333 | template 334 | TRITONSERVER_Error * 335 | ToIdVectorTyped(const char *buffer, const size_t element_count, 336 | std::vector *ids, const size_t start_idx = 0) { 337 | const T *vals = reinterpret_cast(buffer); 338 | *ids = 339 | std::vector(vals + start_idx, vals + start_idx + element_count); 340 | return nullptr; 341 | } 342 | 343 | TRITONSERVER_Error *ToIdVector(const char *buffer, 344 | TRITONSERVER_DataType datatype, 345 | std::vector *ids, const size_t start_idx, 346 | const size_t element_cnt) { 347 | 348 | switch (datatype) { 349 | case TRITONSERVER_TYPE_UINT8: 350 | return ToIdVectorTyped(buffer, element_cnt, ids, start_idx); 351 | case TRITONSERVER_TYPE_UINT16: 352 | return ToIdVectorTyped(buffer, element_cnt, ids, start_idx); 353 | case TRITONSERVER_TYPE_UINT32: 354 | return ToIdVectorTyped(buffer, element_cnt, ids, start_idx); 355 | case TRITONSERVER_TYPE_UINT64: 356 | return ToIdVectorTyped(buffer, element_cnt, ids, start_idx); 357 | case TRITONSERVER_TYPE_INT8: 358 | return ToIdVectorTyped(buffer, element_cnt, ids, start_idx); 359 | case TRITONSERVER_TYPE_INT16: 360 | return ToIdVectorTyped(buffer, element_cnt, ids, start_idx); 361 | case TRITONSERVER_TYPE_INT32: 362 | return ToIdVectorTyped(buffer, element_cnt, ids, start_idx); 363 | case TRITONSERVER_TYPE_INT64: 364 | return ToIdVectorTyped(buffer, element_cnt, ids, start_idx); 365 | 366 | case TRITONSERVER_TYPE_FP32: 367 | return ToIdVectorTyped(buffer, element_cnt, ids, start_idx); 368 | case TRITONSERVER_TYPE_FP64: 369 | return ToIdVectorTyped(buffer, element_cnt, ids, start_idx); 370 | default: 371 | return TRITONSERVER_ErrorNew( 372 | TRITONSERVER_ERROR_INVALID_ARG, 373 | std::string(std::string("class result not available for output due to " 374 | "unsupported type '") + 375 | std::string(TRITONSERVER_DataTypeString(datatype)) + "'") 376 | .c_str()); 377 | } 378 | } 379 | 380 | template 381 | void ConvertToRawPointer(const std::vector &out_tokens, 382 | void *out_buffer) { 383 | T *buffer = static_cast(out_buffer); 384 | for (auto &token : out_tokens) { 385 | auto idx = &token - &out_tokens[0]; 386 | buffer[idx] = static_cast(token); 387 | } 388 | } 389 | 390 | size_t TritonTypeSize(TRITONSERVER_DataType datatype) { 391 | switch (datatype) { 392 | case TRITONSERVER_TYPE_UINT8: 393 | return sizeof(std::uint8_t); 394 | case TRITONSERVER_TYPE_UINT16: 395 | return sizeof(std::uint16_t); 396 | case TRITONSERVER_TYPE_UINT32: 397 | return sizeof(std::uint32_t); 398 | case TRITONSERVER_TYPE_UINT64: 399 | return sizeof(std::uint64_t); 400 | case TRITONSERVER_TYPE_INT8: 401 | return sizeof(std::int8_t); 402 | case TRITONSERVER_TYPE_INT16: 403 | return sizeof(std::int16_t); 404 | case TRITONSERVER_TYPE_INT32: 405 | return sizeof(std::int32_t); 406 | case TRITONSERVER_TYPE_INT64: 407 | return sizeof(std::int64_t); 408 | break; 409 | 410 | case TRITONSERVER_TYPE_FP32: 411 | return sizeof(std::float_t); 412 | case TRITONSERVER_TYPE_FP64: 413 | return sizeof(std::double_t); 414 | default: 415 | throw std::invalid_argument(std::string("Can't determine type size for ") + 416 | std::to_string(TRITONSERVER_TYPE_FP64)); 417 | } 418 | } 419 | 420 | TRITONSERVER_Error *ToOutBuffer(const std::vector &out_tokens, 421 | TRITONSERVER_DataType datatype, 422 | void *out_buffer) { 423 | switch (datatype) { 424 | case TRITONSERVER_TYPE_UINT8: 425 | ConvertToRawPointer(out_tokens, out_buffer); 426 | break; 427 | case TRITONSERVER_TYPE_UINT16: 428 | ConvertToRawPointer(out_tokens, out_buffer); 429 | break; 430 | case TRITONSERVER_TYPE_UINT32: 431 | ConvertToRawPointer(out_tokens, out_buffer); 432 | break; 433 | case TRITONSERVER_TYPE_UINT64: 434 | ConvertToRawPointer(out_tokens, out_buffer); 435 | break; 436 | case TRITONSERVER_TYPE_INT8: 437 | ConvertToRawPointer(out_tokens, out_buffer); 438 | break; 439 | case TRITONSERVER_TYPE_INT16: 440 | ConvertToRawPointer(out_tokens, out_buffer); 441 | break; 442 | case TRITONSERVER_TYPE_INT32: 443 | ConvertToRawPointer(out_tokens, out_buffer); 444 | break; 445 | case TRITONSERVER_TYPE_INT64: 446 | ConvertToRawPointer(out_tokens, out_buffer); 447 | break; 448 | 449 | case TRITONSERVER_TYPE_FP32: 450 | ConvertToRawPointer(out_tokens, out_buffer); 451 | break; 452 | case TRITONSERVER_TYPE_FP64: 453 | ConvertToRawPointer(out_tokens, out_buffer); 454 | break; 455 | default: 456 | return TRITONSERVER_ErrorNew( 457 | TRITONSERVER_ERROR_INVALID_ARG, 458 | std::string(std::string("class result not available for output due to " 459 | "unsupported type '") + 460 | std::string(TRITONSERVER_DataTypeString(datatype)) + "'") 461 | .c_str()); 462 | } 463 | 464 | return nullptr; 465 | } 466 | 467 | std::string 468 | TranslationOptionsToString(const ::ctranslate2::TranslationOptions &options) { 469 | std::stringstream ss; 470 | 471 | ss << "TranslationOptions(" 472 | << "beam_size=" << options.beam_size << ", " 473 | << "patience=" << options.patience << ", " 474 | << "length_penalty=" << options.length_penalty << ", " 475 | << "coverage_penalty=" << options.coverage_penalty << ", " 476 | << "repetition_penalty=" << options.repetition_penalty << ", " 477 | << "no_repeat_ngram_size=" << options.no_repeat_ngram_size << ", " 478 | << "disable_unk=" << options.disable_unk << ", " 479 | << "size(suppress_sequences)=" << options.suppress_sequences.size() << ", " 480 | << "prefix_bias_beta=" << options.prefix_bias_beta << ", "; 481 | 482 | if (std::holds_alternative(options.end_token)) { 483 | ss << "end_token=\"" << std::get(options.end_token) << "\", "; 484 | } else if (std::holds_alternative>( 485 | options.end_token)) { 486 | for (auto &end_token : 487 | std::get>(options.end_token)) { 488 | ss << "end_token[]=" << end_token << " "; 489 | } 490 | ss << ","; 491 | } else if (std::holds_alternative>(options.end_token)) { 492 | for (auto &end_token : 493 | std::get>(options.end_token)) { 494 | ss << "end_token[]=" << end_token << " "; 495 | } 496 | ss << ","; 497 | } 498 | 499 | ss << "max_input_length=" << options.max_input_length << ", " 500 | << "max_decoding_length=" << options.max_decoding_length << ", " 501 | << "min_decoding_length=" << options.min_decoding_length << ", " 502 | << "sampling_topk=" << options.sampling_topk << ", " 503 | << "sampling_temperature=" << options.sampling_temperature << ", " 504 | << "use_vmap=" << options.use_vmap << ", " 505 | << "num_hypotheses=" << options.num_hypotheses << ", " 506 | << "return_scores=" << options.return_scores << ", " 507 | << "return_attention=" << options.return_attention << ", " 508 | << "return_alternatives=" << options.return_alternatives << ", " 509 | << "min_alternative_expansion_prob=" 510 | << options.min_alternative_expansion_prob << ", " 511 | << "replace_unknowns=" << options.replace_unknowns << ")"; 512 | return ss.str(); 513 | } 514 | 515 | TRITONSERVER_Error *InputBufferToRaggedTokens( 516 | size_t total_batch_size, TRITONBACKEND_Request **requests, 517 | const uint32_t request_count, 518 | std::vector *responses, 519 | BackendInputCollector *collector, 520 | std::vector> *ragged_tokens, 521 | size_t *max_sequence_length, const std::string &input_name, 522 | bool is_ragged_input = true, bool supports_batching = true) { 523 | std::vector> tokens; 524 | tokens.reserve(request_count); 525 | 526 | const char *input_buffer; 527 | size_t batchn_byte_size; 528 | TRITONSERVER_MemoryType memory_type; 529 | int64_t memory_type_id; 530 | 531 | // TODO support data straight from GPU 532 | std::vector> alloc_preference = { 533 | {TRITONSERVER_MEMORY_CPU_PINNED, 0}, {TRITONSERVER_MEMORY_CPU, 0}}; 534 | 535 | RETURN_IF_ERROR(collector->ProcessTensor( 536 | input_name.c_str(), nullptr, 0, alloc_preference, &input_buffer, 537 | &batchn_byte_size, &memory_type, &memory_type_id)); 538 | 539 | // bool is_ragged = 540 | // 541 | size_t max_seq_length = 0; 542 | if (is_ragged_input) { 543 | int64_t total_elements = 0; 544 | for (size_t request_idx = 0; request_idx < request_count; request_idx++) { 545 | TRITONBACKEND_Input *input; 546 | RESPOND_AND_SET_NULL_IF_ERROR( 547 | &((*responses)[request_idx]), 548 | TRITONBACKEND_RequestInput(requests[request_idx], input_name.c_str(), 549 | &input)); 550 | 551 | TRITONSERVER_DataType input_dt; 552 | const int64_t *input_shape; 553 | uint32_t input_dims_count; 554 | RETURN_IF_ERROR( 555 | TRITONBACKEND_InputProperties(input, nullptr, &input_dt, &input_shape, 556 | &input_dims_count, nullptr, nullptr)); 557 | 558 | auto element_count = GetElementCount(input_shape, input_dims_count); 559 | LOG_MESSAGE(TRITONSERVER_LOG_VERBOSE, 560 | (std::string("Element count for request ") + 561 | std::to_string(request_idx) + std::string(": ") + 562 | std::to_string(element_count)) 563 | .c_str()); 564 | max_seq_length = 565 | std::max(max_seq_length, static_cast(element_count)); 566 | 567 | std::vector ids; 568 | ToIdVector(input_buffer, input_dt, &ids, total_elements, element_count); 569 | total_elements += element_count; 570 | tokens.emplace_back(ids); 571 | } 572 | } else { 573 | // input type is the same for all 574 | TRITONBACKEND_Input *input; 575 | RETURN_IF_ERROR( 576 | TRITONBACKEND_RequestInput(requests[0], input_name.c_str(), &input)); 577 | 578 | TRITONSERVER_DataType input_dt; 579 | const int64_t *input_shape; 580 | uint32_t input_dims_count; 581 | RETURN_IF_ERROR( 582 | TRITONBACKEND_InputProperties(input, nullptr, &input_dt, &input_shape, 583 | &input_dims_count, nullptr, nullptr)); 584 | 585 | if (input_dims_count > 2) { 586 | return TRITONSERVER_ErrorNew( 587 | TRITONSERVER_ERROR_INVALID_ARG, 588 | std::string("Inputs with more than two dimensions unsupported") 589 | .c_str()); 590 | } 591 | 592 | std::vector batchn_shape = 593 | std::vector(input_shape, input_shape + input_dims_count); 594 | if (supports_batching) { 595 | batchn_shape[0] = total_batch_size; 596 | } 597 | 598 | for (size_t vector_idx = 0; vector_idx < total_batch_size; vector_idx++) { 599 | std::vector ids; 600 | ToIdVector(input_buffer, input_dt, &ids, vector_idx * batchn_shape[1], 601 | (vector_idx + 1) * batchn_shape[1]); 602 | tokens.emplace_back(ids); 603 | } 604 | max_seq_length = static_cast(batchn_shape[1]); 605 | } 606 | 607 | *ragged_tokens = tokens; 608 | *max_sequence_length = max_seq_length; 609 | 610 | return nullptr; 611 | } 612 | ///////////// 613 | 614 | // 615 | // ModelInstanceState 616 | // 617 | // State associated with a model instance. An object of this class is 618 | // created and associated with each 619 | // TRITONBACKEND_ModelInstance. ModelInstanceState is derived from 620 | // BackendModelInstance class provided in the backend utilities that 621 | // provides many common functions. 622 | // 623 | class ModelInstanceState : public BackendModelInstance { 624 | public: 625 | static TRITONSERVER_Error * 626 | Create(ModelState *model_state, 627 | TRITONBACKEND_ModelInstance *triton_model_instance, 628 | ModelInstanceState **state); 629 | virtual ~ModelInstanceState() = default; 630 | 631 | // Get the state of the model that corresponds to this instance. 632 | ModelState *StateForModel() const { return model_state_; } 633 | 634 | ModelInstanceState(ModelState *model_state, 635 | TRITONBACKEND_ModelInstance *triton_model_instance) 636 | : BackendModelInstance(model_state, triton_model_instance), 637 | model_state_(model_state) { 638 | if (Kind() == TRITONSERVER_INSTANCEGROUPKIND_GPU) { 639 | #ifdef TRITON_ENABLE_GPU 640 | device_ = ::ctranslate2::Device::CUDA; 641 | #else 642 | throw triton::backend::BackendModelInstanceException( 643 | TRITONSERVER_ErrorNew(TRITONSERVER_ERROR_NOT_FOUND, 644 | "Backend not built with GPU support")); 645 | #endif 646 | } else { 647 | device_ = ::ctranslate2::Device::CPU; 648 | } 649 | model_state->LoadModel(device_, DeviceId(), &model_); 650 | supports_batching_ = model_state_->MaxBatchSize() > 0; 651 | } 652 | 653 | TRITONSERVER_Error * 654 | CreateInput(size_t total_batch_size, TRITONBACKEND_Request **requests, 655 | const uint32_t request_count, 656 | std::vector *responses, 657 | BackendInputCollector *collector, 658 | const ::ctranslate2::Vocabulary &source_vocab, 659 | const ::ctranslate2::Vocabulary &target_vocab, 660 | std::vector> *input_tokens, 661 | std::vector> *input_target_prefix, 662 | size_t *max_sequence_length) { 663 | 664 | std::vector> input_token_ids; 665 | RETURN_IF_ERROR(InputBufferToRaggedTokens( 666 | total_batch_size, requests, request_count, responses, collector, 667 | &input_token_ids, max_sequence_length, 668 | StateForModel()->InputTensorName(), 669 | StateForModel()->IsInputRagged(StateForModel()->InputTensorName()), 670 | supports_batching_)); 671 | *input_tokens = source_vocab.to_tokens(input_token_ids); 672 | if (StateForModel()->TargetPrefixInputName()) { 673 | std::vector> target_prefix_token_ids; 674 | size_t discard_seq_length; 675 | RETURN_IF_ERROR(InputBufferToRaggedTokens( 676 | total_batch_size, requests, request_count, responses, collector, 677 | &target_prefix_token_ids, &discard_seq_length, 678 | *(StateForModel()->TargetPrefixInputName()), 679 | StateForModel()->IsInputRagged( 680 | *(StateForModel()->TargetPrefixInputName())), 681 | supports_batching_)); 682 | *input_target_prefix = target_vocab.to_tokens(target_prefix_token_ids); 683 | } 684 | return nullptr; 685 | } 686 | 687 | void ProcessRequests(TRITONBACKEND_Request **requests, 688 | const uint32_t request_count) { 689 | 690 | uint64_t exec_start_ns = 0; 691 | SET_TIMESTAMP(exec_start_ns); 692 | 693 | std::vector responses; 694 | responses.reserve(request_count); 695 | bool all_response_failed = false; 696 | 697 | for (size_t i = 0; i < request_count; i++) { 698 | TRITONBACKEND_Response *response; 699 | auto err = TRITONBACKEND_ResponseNew(&response, requests[i]); 700 | if (err == nullptr) { 701 | responses.emplace_back(response); 702 | } else { 703 | responses.emplace_back(nullptr); 704 | LOG_MESSAGE(TRITONSERVER_LOG_ERROR, "Fail to create response"); 705 | TRITONSERVER_ErrorDelete(err); 706 | } 707 | } 708 | 709 | const int max_batch_size = model_state_->MaxBatchSize(); 710 | 711 | size_t total_batch_size = 0; 712 | for (size_t i = 0; i < request_count; i++) { 713 | if (max_batch_size > 0) { 714 | // Retrieve the batch size from one of the inputs, if the model 715 | // supports batching, the first dimension size is batch size. 716 | TRITONBACKEND_Input *input; 717 | TRITONSERVER_Error *err = TRITONBACKEND_RequestInput( 718 | requests[i], StateForModel()->InputTensorName().c_str(), &input); 719 | if (err == nullptr) { 720 | const int64_t *shape; 721 | err = TRITONBACKEND_InputProperties(input, nullptr, nullptr, &shape, 722 | nullptr, nullptr, nullptr); 723 | total_batch_size += shape[0]; 724 | } 725 | if (err != nullptr) { 726 | RESPOND_ALL_AND_SET_TRUE_IF_ERROR(responses, request_count, 727 | all_response_failed, err); 728 | } 729 | } else { 730 | total_batch_size += 1; 731 | } 732 | } 733 | 734 | // If there are no valid payloads then no need to run the inference. 735 | if (total_batch_size == 0) { 736 | return; 737 | } 738 | 739 | // Make sure the maximum batch size is not exceeded. The 740 | // total_batch_size must be 1 for models that don't support batching 741 | // (i.e. max_batch_size == 0). If max_batch_size is exceeded then 742 | // scheduler has done something badly wrong so fail and release all 743 | // requests. 744 | if (!all_response_failed) { 745 | if ((total_batch_size != 1) && 746 | (total_batch_size > (size_t)max_batch_size)) { 747 | RESPOND_ALL_AND_SET_TRUE_IF_ERROR( 748 | responses, request_count, all_response_failed, 749 | TRITONSERVER_ErrorNew( 750 | TRITONSERVER_ERROR_INTERNAL, 751 | std::string("batch size " + std::to_string(total_batch_size) + 752 | " for '" + Name() + "', max allowed is " + 753 | std::to_string(max_batch_size)) 754 | .c_str())); 755 | } 756 | } 757 | 758 | const ::ctranslate2::models::SequenceToSequenceModel *seq2seq_model = 759 | dynamic_cast( 760 | model_.get()); 761 | const auto source_vocab = seq2seq_model->get_source_vocabulary(); 762 | const auto target_vocab = seq2seq_model->get_target_vocabulary(); 763 | 764 | auto collector = std::make_unique( 765 | requests, request_count, &responses, 766 | model_state_->TritonMemoryManager(), 767 | model_state_->EnablePinnedInput() /* pinned_enabled */, 768 | nullptr /* stream*/); 769 | 770 | std::vector> inputs; 771 | std::vector> target_prefix; 772 | size_t max_input_seq_length; 773 | RESPOND_ALL_AND_SET_NULL_IF_ERROR( 774 | responses, request_count, 775 | CreateInput(total_batch_size, requests, request_count, &responses, 776 | collector.get(), source_vocab, target_vocab, &inputs, 777 | &target_prefix, &max_input_seq_length)); 778 | 779 | std::unique_ptr<::ctranslate2::models::SequenceToSequenceReplica> 780 | seq2seq_replica = model_->as_sequence_to_sequence(); 781 | 782 | // Finalize the collector. If 'true' is returned, 'input_buffer' 783 | // will not be valid until the backend synchronizes the CUDA 784 | // stream or event that was used when creating the collector. For 785 | // this backend, GPU is not supported and so no CUDA sync should 786 | // be needed; so if 'true' is returned simply log an error. 787 | const bool need_cuda_input_sync = collector->Finalize(); 788 | if (need_cuda_input_sync) { 789 | LOG_MESSAGE(TRITONSERVER_LOG_ERROR, 790 | "backend: unexpected CUDA sync required by collector"); 791 | } 792 | 793 | uint64_t compute_start_ns = 0; 794 | SET_TIMESTAMP(compute_start_ns); 795 | ::ctranslate2::TranslationOptions options = 796 | StateForModel()->DefaultTranslationOptions(); 797 | auto max_decode_length_multiple = 798 | StateForModel()->MaxDecodeLengthMultiple(); 799 | if (max_decode_length_multiple) { 800 | options.max_decoding_length = 801 | *max_decode_length_multiple * max_input_seq_length; 802 | } 803 | LOG_MESSAGE(TRITONSERVER_LOG_VERBOSE, 804 | TranslationOptionsToString(options).c_str()); 805 | std::vector<::ctranslate2::TranslationResult> translation_results = 806 | seq2seq_replica->translate(inputs, target_prefix, options); 807 | 808 | uint64_t compute_end_ns = 0; 809 | SET_TIMESTAMP(compute_end_ns); 810 | 811 | // This backend supports models that batch along the first dimension 812 | // and those that don't batch. For non-batch models the output shape 813 | // will be [ 4 ]. For batch models the output shape will be [ -1, 4 814 | // ] and the backend "responder" utility below will set the 815 | // appropriate batch dimension value for each response. 816 | std::vector output_batch_shape; 817 | bool supports_first_dim_batching; 818 | RESPOND_ALL_AND_SET_NULL_IF_ERROR(responses, request_count, 819 | StateForModel()->SupportsFirstDimBatching( 820 | &supports_first_dim_batching)); 821 | size_t idx = 0; 822 | for (auto &translation : translation_results) { 823 | 824 | std::vector out_tokens = translation.output(); 825 | // only output best hypotheses 826 | std::vector out_ids = target_vocab.to_ids({out_tokens})[0]; 827 | 828 | TRITONBACKEND_Output *response_output; 829 | std::vector out_shape = {(std::int64_t)out_ids.size()}; 830 | if (supports_first_dim_batching) { 831 | out_shape.insert(out_shape.begin(), -1); 832 | } 833 | 834 | RESPOND_AND_SET_NULL_IF_ERROR( 835 | &responses[idx], TRITONBACKEND_ResponseOutput( 836 | responses[idx], &response_output, 837 | StateForModel()->OutputTensorName().c_str(), 838 | StateForModel()->OutputDataType(), 839 | out_shape.data(), out_shape.size())); 840 | if (responses[idx] != nullptr) { 841 | void *out_buffer; 842 | size_t out_buffer_size = 843 | TritonTypeSize(StateForModel()->OutputDataType()) * out_ids.size(); 844 | TRITONSERVER_MemoryType actual_memory_type = TRITONSERVER_MEMORY_CPU; 845 | int64_t actual_memory_type_id = 0; 846 | RESPOND_AND_SET_NULL_IF_ERROR( 847 | &responses[idx], TRITONBACKEND_OutputBuffer( 848 | response_output, &out_buffer, out_buffer_size, 849 | &actual_memory_type, &actual_memory_type_id)); 850 | ToOutBuffer(out_ids, StateForModel()->OutputDataType(), out_buffer); 851 | } 852 | idx += 1; 853 | } 854 | // Send all the responses that haven't already been sent because of 855 | // an earlier error. 856 | for (auto &response : responses) { 857 | if (response != nullptr) { 858 | LOG_IF_ERROR( 859 | TRITONBACKEND_ResponseSend( 860 | response, TRITONSERVER_RESPONSE_COMPLETE_FINAL, nullptr), 861 | "failed to send response"); 862 | } 863 | } 864 | 865 | // Done with the request objects so release them. 866 | for (uint32_t r = 0; r < request_count; ++r) { 867 | auto &request = requests[r]; 868 | LOG_IF_ERROR(TRITONBACKEND_RequestRelease( 869 | request, TRITONSERVER_REQUEST_RELEASE_ALL), 870 | "failed releasing request"); 871 | } 872 | 873 | uint64_t exec_end_ns = 0; 874 | SET_TIMESTAMP(exec_end_ns); 875 | 876 | if (!all_response_failed) { 877 | #ifdef TRITON_ENABLE_STATS 878 | // Report batch statistics. 879 | LOG_IF_ERROR(TRITONBACKEND_ModelInstanceReportBatchStatistics( 880 | TritonModelInstance(), total_batch_size, exec_start_ns, 881 | compute_start_ns, compute_end_ns, exec_end_ns), 882 | "failed reporting batch request statistics"); 883 | #endif // TRITON_ENABLE_STATS 884 | } 885 | } 886 | 887 | private: 888 | ModelState *model_state_; 889 | ::ctranslate2::Device device_; 890 | std::shared_ptr model_; 891 | bool supports_batching_; 892 | }; 893 | 894 | TRITONSERVER_Error * 895 | ModelInstanceState::Create(ModelState *model_state, 896 | TRITONBACKEND_ModelInstance *triton_model_instance, 897 | ModelInstanceState **state) { 898 | try { 899 | *state = new ModelInstanceState(model_state, triton_model_instance); 900 | } catch (const BackendModelInstanceException &ex) { 901 | RETURN_ERROR_IF_TRUE( 902 | ex.err_ == nullptr, TRITONSERVER_ERROR_INTERNAL, 903 | std::string("unexpected nullptr in BackendModelInstanceException")); 904 | RETURN_IF_ERROR(ex.err_); 905 | } 906 | 907 | return nullptr; // success 908 | } 909 | 910 | extern "C" { 911 | 912 | // Triton calls TRITONBACKEND_ModelInstanceInitialize when a model 913 | // instance is created to allow the backend to initialize any state 914 | // associated with the instance. 915 | // 916 | TRITONSERVER_Error * 917 | TRITONBACKEND_ModelInstanceInitialize(TRITONBACKEND_ModelInstance *instance) { 918 | // Get the model state associated with this instance's model. 919 | TRITONBACKEND_Model *model; 920 | RETURN_IF_ERROR(TRITONBACKEND_ModelInstanceModel(instance, &model)); 921 | 922 | void *vmodelstate; 923 | RETURN_IF_ERROR(TRITONBACKEND_ModelState(model, &vmodelstate)); 924 | ModelState *model_state = reinterpret_cast(vmodelstate); 925 | 926 | // Create a ModelInstanceState object and associate it with the 927 | // TRITONBACKEND_ModelInstance. 928 | ModelInstanceState *instance_state; 929 | RETURN_IF_ERROR( 930 | ModelInstanceState::Create(model_state, instance, &instance_state)); 931 | RETURN_IF_ERROR(TRITONBACKEND_ModelInstanceSetState( 932 | instance, reinterpret_cast(instance_state))); 933 | 934 | return nullptr; // success 935 | } 936 | 937 | // Triton calls TRITONBACKEND_ModelInstanceFinalize when a model 938 | // instance is no longer needed. The backend should cleanup any state 939 | // associated with the model instance. 940 | // 941 | TRITONSERVER_Error * 942 | TRITONBACKEND_ModelInstanceFinalize(TRITONBACKEND_ModelInstance *instance) { 943 | void *vstate; 944 | RETURN_IF_ERROR(TRITONBACKEND_ModelInstanceState(instance, &vstate)); 945 | ModelInstanceState *instance_state = 946 | reinterpret_cast(vstate); 947 | delete instance_state; 948 | 949 | return nullptr; // success 950 | } 951 | 952 | } // extern "C" 953 | 954 | ///////////// 955 | 956 | extern "C" { 957 | 958 | // When Triton calls TRITONBACKEND_ModelInstanceExecute it is required 959 | // that a backend create a response for each request in the batch. A 960 | // response may be the output tensors required for that request or may 961 | // be an error that is returned in the response. 962 | // 963 | TRITONSERVER_Error * 964 | TRITONBACKEND_ModelInstanceExecute(TRITONBACKEND_ModelInstance *instance, 965 | TRITONBACKEND_Request **requests, 966 | const uint32_t request_count) { 967 | 968 | // Triton will not call this function simultaneously for the same 969 | // 'instance'. But since this backend could be used by multiple 970 | // instances from multiple models the implementation needs to handle 971 | // multiple calls to this function at the same time (with different 972 | // 'instance' objects). Best practice for a high-performance 973 | // implementation is to avoid introducing mutex/lock and instead use 974 | // only function-local and model-instance-specific state. 975 | ModelInstanceState *instance_state; 976 | RETURN_IF_ERROR(TRITONBACKEND_ModelInstanceState( 977 | instance, reinterpret_cast(&instance_state))); 978 | ModelState *model_state = instance_state->StateForModel(); 979 | 980 | // 'responses' is initialized as a parallel array to 'requests', 981 | // with one TRITONBACKEND_Response object for each 982 | // TRITONBACKEND_Request object. If something goes wrong while 983 | // creating these response objects, the backend simply returns an 984 | // error from TRITONBACKEND_ModelInstanceExecute, indicating to 985 | // Triton that this backend did not create or send any responses and 986 | // so it is up to Triton to create and send an appropriate error 987 | // response for each request. RETURN_IF_ERROR is one of several 988 | // useful macros for error handling that can be found in 989 | // backend_common.h. 990 | 991 | LOG_MESSAGE(TRITONSERVER_LOG_INFO, 992 | (std::string("model ") + model_state->Name() + ", instance " + 993 | instance_state->Name() + ", executing " + 994 | std::to_string(request_count) + " requests") 995 | .c_str()); 996 | 997 | instance_state->ProcessRequests(requests, request_count); 998 | 999 | return nullptr; // success 1000 | } 1001 | 1002 | } // extern "C" 1003 | 1004 | } // namespace ctranslate2 1005 | } // namespace backend 1006 | } // namespace triton 1007 | --------------------------------------------------------------------------------