├── .gitignore ├── .clang-format ├── LICENSE ├── src ├── libtriton_pytorch.ldscript ├── naming_convention.hh ├── libtorch.hh ├── libtorch_utils.h ├── string_utils.hh ├── model_state.hh ├── libtorch_utils.cc ├── model_instance_state.hh ├── libtorch.cc ├── string_utils.cc ├── model.py ├── model_state.cc └── model_instance_state.cc ├── .github └── workflows │ └── pre-commit.yml ├── cmake └── TritonPyTorchBackendConfig.cmake.in ├── pyproject.toml ├── tools └── gen_pb_exec_env.sh ├── .pre-commit-config.yaml ├── README.md └── CMakeLists.txt /.gitignore: -------------------------------------------------------------------------------- 1 | /build 2 | /.vscode 3 | *.so 4 | -------------------------------------------------------------------------------- /.clang-format: -------------------------------------------------------------------------------- 1 | --- 2 | BasedOnStyle: Google 3 | 4 | IndentWidth: 2 5 | ColumnLimit: 80 6 | ContinuationIndentWidth: 4 7 | UseTab: Never 8 | MaxEmptyLinesToKeep: 2 9 | 10 | SortIncludes: true 11 | CompactNamespaces: true 12 | ReflowComments: true 13 | 14 | DerivePointerAlignment: false 15 | PointerAlignment: Left 16 | 17 | AllowShortIfStatementsOnASingleLine: false 18 | AllowShortBlocksOnASingleLine: false 19 | AllowShortFunctionsOnASingleLine: Inline 20 | 21 | AlwaysBreakAfterReturnType: TopLevelDefinitions 22 | AlignAfterOpenBracket: AlwaysBreak 23 | BreakBeforeBraces: Custom 24 | BraceWrapping: 25 | AfterClass: false 26 | AfterControlStatement: false 27 | AfterEnum: false 28 | AfterFunction: true 29 | AfterNamespace: false 30 | AfterStruct: false 31 | AfterUnion: false 32 | BeforeCatch: true 33 | 34 | BinPackArguments: true 35 | BinPackParameters: true 36 | ConstructorInitializerAllOnOneLineOrOnePerLine: false 37 | 38 | IndentCaseLabels: true 39 | 40 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 2 | 3 | Redistribution and use in source and binary forms, with or without 4 | modification, are permitted provided that the following conditions 5 | are met: 6 | * Redistributions of source code must retain the above copyright 7 | notice, this list of conditions and the following disclaimer. 8 | * Redistributions in binary form must reproduce the above copyright 9 | notice, this list of conditions and the following disclaimer in the 10 | documentation and/or other materials provided with the distribution. 11 | * Neither the name of NVIDIA CORPORATION nor the names of its 12 | contributors may be used to endorse or promote products derived 13 | from this software without specific prior written permission. 14 | 15 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY 16 | EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 17 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 18 | PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR 19 | CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 20 | EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 21 | PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 22 | PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY 23 | OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 24 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | -------------------------------------------------------------------------------- /src/libtriton_pytorch.ldscript: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Redistribution and use in source and binary forms, with or without 4 | # modification, are permitted provided that the following conditions 5 | # are met: 6 | # * Redistributions of source code must retain the above copyright 7 | # notice, this list of conditions and the following disclaimer. 8 | # * Redistributions in binary form must reproduce the above copyright 9 | # notice, this list of conditions and the following disclaimer in the 10 | # documentation and/or other materials provided with the distribution. 11 | # * Neither the name of NVIDIA CORPORATION nor the names of its 12 | # contributors may be used to endorse or promote products derived 13 | # from this software without specific prior written permission. 14 | # 15 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY 16 | # EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 17 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 18 | # PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR 19 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 20 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 21 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 22 | # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY 23 | # OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 24 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | { 27 | global: 28 | TRITONBACKEND_*; 29 | local: *; 30 | }; 31 | -------------------------------------------------------------------------------- /.github/workflows/pre-commit.yml: -------------------------------------------------------------------------------- 1 | # Copyright 2023-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # Redistribution and use in source and binary forms, with or without 4 | # modification, are permitted provided that the following conditions 5 | # are met: 6 | # * Redistributions of source code must retain the above copyright 7 | # notice, this list of conditions and the following disclaimer. 8 | # * Redistributions in binary form must reproduce the above copyright 9 | # notice, this list of conditions and the following disclaimer in the 10 | # documentation and/or other materials provided with the distribution. 11 | # * Neither the name of NVIDIA CORPORATION nor the names of its 12 | # contributors may be used to endorse or promote products derived 13 | # from this software without specific prior written permission. 14 | # 15 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY 16 | # EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 17 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 18 | # PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR 19 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 20 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 21 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 22 | # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY 23 | # OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 24 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | 27 | name: pre-commit 28 | 29 | on: 30 | pull_request: 31 | 32 | jobs: 33 | pre-commit: 34 | runs-on: ubuntu-latest 35 | steps: 36 | - uses: actions/checkout@v5.0.0 37 | - uses: actions/setup-python@v6.0.0 38 | - uses: pre-commit/action@v3.0.1 39 | -------------------------------------------------------------------------------- /src/naming_convention.hh: -------------------------------------------------------------------------------- 1 | // Copyright 2019-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // Redistribution and use in source and binary forms, with or without 4 | // modification, are permitted provided that the following conditions 5 | // are met: 6 | // * Redistributions of source code must retain the above copyright 7 | // notice, this list of conditions and the following disclaimer. 8 | // * Redistributions in binary form must reproduce the above copyright 9 | // notice, this list of conditions and the following disclaimer in the 10 | // documentation and/or other materials provided with the distribution. 11 | // * Neither the name of NVIDIA CORPORATION nor the names of its 12 | // contributors may be used to endorse or promote products derived 13 | // from this software without specific prior written permission. 14 | // 15 | // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY 16 | // EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 17 | // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 18 | // PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR 19 | // CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 20 | // EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 21 | // PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 22 | // PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY 23 | // OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 24 | // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | 27 | #pragma once 28 | 29 | 30 | namespace triton::backend::pytorch { 31 | 32 | // The naming convention followed for inputs/outputs in the model configuration. 33 | // Outputs don't support FORWARD_ARGUMENT. 34 | enum class NamingConvention { 35 | NAMED_INDEX, 36 | FORWARD_ARGUMENT, 37 | STRICT_CONFIG_ORDERING 38 | }; 39 | 40 | } // namespace triton::backend::pytorch 41 | -------------------------------------------------------------------------------- /cmake/TritonPyTorchBackendConfig.cmake.in: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Redistribution and use in source and binary forms, with or without 4 | # modification, are permitted provided that the following conditions 5 | # are met: 6 | # * Redistributions of source code must retain the above copyright 7 | # notice, this list of conditions and the following disclaimer. 8 | # * Redistributions in binary form must reproduce the above copyright 9 | # notice, this list of conditions and the following disclaimer in the 10 | # documentation and/or other materials provided with the distribution. 11 | # * Neither the name of NVIDIA CORPORATION nor the names of its 12 | # contributors may be used to endorse or promote products derived 13 | # from this software without specific prior written permission. 14 | # 15 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY 16 | # EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 17 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 18 | # PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR 19 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 20 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 21 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 22 | # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY 23 | # OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 24 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | 27 | include(CMakeFindDependencyMacro) 28 | 29 | get_filename_component( 30 | TRITONPYTORCHBACKEND_CMAKE_DIR "${CMAKE_CURRENT_LIST_FILE}" PATH 31 | ) 32 | 33 | list(APPEND CMAKE_MODULE_PATH ${TRITONPYTORCHBACKEND_CMAKE_DIR}) 34 | 35 | if(NOT TARGET TritonPyTorchBackend::triton-pytorch-backend) 36 | include("${TRITONPYTORCHBACKEND_CMAKE_DIR}/TritonPyTorchBackendTargets.cmake") 37 | endif() 38 | 39 | set(TRITONPYTORCHBACKEND_LIBRARIES TritonPyTorchBackend::triton-pytorch-backend) 40 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | # Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # Redistribution and use in source and binary forms, with or without 4 | # modification, are permitted provided that the following conditions 5 | # are met: 6 | # * Redistributions of source code must retain the above copyright 7 | # notice, this list of conditions and the following disclaimer. 8 | # * Redistributions in binary form must reproduce the above copyright 9 | # notice, this list of conditions and the following disclaimer in the 10 | # documentation and/or other materials provided with the distribution. 11 | # * Neither the name of NVIDIA CORPORATION nor the names of its 12 | # contributors may be used to endorse or promote products derived 13 | # from this software without specific prior written permission. 14 | # 15 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY 16 | # EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 17 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 18 | # PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR 19 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 20 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 21 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 22 | # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY 23 | # OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 24 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | 27 | [tool.codespell] 28 | # note: pre-commit passes explicit lists of files here, which this skip file list doesn't override - 29 | # this is only to allow you to run codespell interactively 30 | skip = "./.git,./.github" 31 | # ignore short words, and typename parameters like OffsetT 32 | ignore-regex = "\\b(.{1,4}|[A-Z]\\w*T)\\b" 33 | # use the 'clear' dictionary for unambiguous spelling mistakes 34 | builtin = "clear" 35 | # disable warnings about binary files and wrong encoding 36 | quiet-level = 3 37 | 38 | [tool.isort] 39 | profile = "black" 40 | use_parentheses = true 41 | multi_line_output = 3 42 | include_trailing_comma = true 43 | force_grid_wrap = 0 44 | ensure_newline_before_comments = true 45 | line_length = 88 46 | balanced_wrapping = true 47 | indent = " " 48 | skip = ["build"] 49 | 50 | -------------------------------------------------------------------------------- /src/libtorch.hh: -------------------------------------------------------------------------------- 1 | // Copyright 2019-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // Redistribution and use in source and binary forms, with or without 4 | // modification, are permitted provided that the following conditions 5 | // are met: 6 | // * Redistributions of source code must retain the above copyright 7 | // notice, this list of conditions and the following disclaimer. 8 | // * Redistributions in binary form must reproduce the above copyright 9 | // notice, this list of conditions and the following disclaimer in the 10 | // documentation and/or other materials provided with the distribution. 11 | // * Neither the name of NVIDIA CORPORATION nor the names of its 12 | // contributors may be used to endorse or promote products derived 13 | // from this software without specific prior written permission. 14 | // 15 | // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY 16 | // EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 17 | // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 18 | // PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR 19 | // CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 20 | // EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 21 | // PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 22 | // PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY 23 | // OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 24 | // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | 27 | #include "model_instance_state.hh" 28 | #include "model_state.hh" 29 | #include "naming_convention.hh" 30 | #include "string_utils.hh" 31 | 32 | // 33 | // PyTorch C++ (LibTorch) Backend that implements the TRITONBACKEND API. 34 | // 35 | 36 | namespace triton::backend::pytorch { 37 | 38 | extern "C" { 39 | 40 | TRITONSERVER_Error* TRITONBACKEND_Initialize(TRITONBACKEND_Backend* backend); 41 | 42 | TRITONSERVER_Error* TRITONBACKEND_ModelInitialize(TRITONBACKEND_Model* model); 43 | 44 | TRITONSERVER_Error* TRITONBACKEND_ModelFinalize(TRITONBACKEND_Model* model); 45 | 46 | TRITONSERVER_Error* TRITONBACKEND_ModelInstanceInitialize( 47 | TRITONBACKEND_ModelInstance* instance); 48 | 49 | TRITONSERVER_Error* TRITONBACKEND_ModelInstanceFinalize( 50 | TRITONBACKEND_ModelInstance* instance); 51 | 52 | TRITONSERVER_Error* TRITONBACKEND_ModelInstanceExecute( 53 | TRITONBACKEND_ModelInstance* instance, TRITONBACKEND_Request** requests, 54 | const uint32_t request_count); 55 | 56 | } // extern "C" 57 | 58 | 59 | } // namespace triton::backend::pytorch 60 | -------------------------------------------------------------------------------- /tools/gen_pb_exec_env.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2023-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions 6 | # are met: 7 | # * Redistributions of source code must retain the above copyright 8 | # notice, this list of conditions and the following disclaimer. 9 | # * Redistributions in binary form must reproduce the above copyright 10 | # notice, this list of conditions and the following disclaimer in the 11 | # documentation and/or other materials provided with the distribution. 12 | # * Neither the name of NVIDIA CORPORATION nor the names of its 13 | # contributors may be used to endorse or promote products derived 14 | # from this software without specific prior written permission. 15 | # 16 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY 17 | # EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 18 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 19 | # PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR 20 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 21 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 22 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 23 | # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY 24 | # OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 25 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 26 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 27 | 28 | # install conda 29 | rm -rf ./miniconda 30 | wget https://repo.anaconda.com/miniconda/Miniconda3-py312_25.7.0-2-Linux-x86_64.sh 31 | bash Miniconda3-py312_25.7.0-2-Linux-x86_64.sh -p ./miniconda -b 32 | eval "$(./miniconda/bin/conda shell.bash hook)" 33 | 34 | # create conda environment 35 | conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/main 36 | conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/r 37 | conda create -n pt python=3.12 -y 38 | conda activate pt 39 | conda install -c conda-forge conda-pack -y 40 | conda install pip 41 | 42 | # pre install step 43 | export PYTHONNOUSERSITE=True 44 | conda install -c conda-forge libstdcxx-ng=15 -y 45 | 46 | # install PyTorch (torch from pip to avoid conda version issues) 47 | pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cu130 48 | conda install torchaudio pytorch-cuda=13.0 -c pytorch -c nvidia -y 49 | 50 | # pack environment 51 | rm -f pb_exec_env_model.py.tar.gz 52 | conda pack -o pb_exec_env_model.py.tar.gz 53 | 54 | # deactivate conda 55 | conda deactivate 56 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # Copyright 2023-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # Redistribution and use in source and binary forms, with or without 4 | # modification, are permitted provided that the following conditions 5 | # are met: 6 | # * Redistributions of source code must retain the above copyright 7 | # notice, this list of conditions and the following disclaimer. 8 | # * Redistributions in binary form must reproduce the above copyright 9 | # notice, this list of conditions and the following disclaimer in the 10 | # documentation and/or other materials provided with the distribution. 11 | # * Neither the name of NVIDIA CORPORATION nor the names of its 12 | # contributors may be used to endorse or promote products derived 13 | # from this software without specific prior written permission. 14 | # 15 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY 16 | # EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 17 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 18 | # PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR 19 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 20 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 21 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 22 | # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY 23 | # OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 24 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | 27 | repos: 28 | - repo: https://github.com/PyCQA/isort 29 | rev: 5.12.0 30 | hooks: 31 | - id: isort 32 | additional_dependencies: [toml] 33 | - repo: https://github.com/psf/black 34 | rev: 23.1.0 35 | hooks: 36 | - id: black 37 | types_or: [python, cython] 38 | - repo: https://github.com/PyCQA/flake8 39 | rev: 7.3.0 40 | hooks: 41 | - id: flake8 42 | args: [--max-line-length=88, --select=C,E,F,W,B,B950, --extend-ignore = E203,E501] 43 | types_or: [python, cython] 44 | - repo: https://github.com/pre-commit/mirrors-clang-format 45 | rev: v16.0.5 46 | hooks: 47 | - id: clang-format 48 | types_or: [c, c++, cuda, proto, textproto, java] 49 | args: ["-fallback-style=none", "-style=file", "-i"] 50 | - repo: https://github.com/codespell-project/codespell 51 | rev: v2.2.4 52 | hooks: 53 | - id: codespell 54 | additional_dependencies: [tomli] 55 | args: ["--toml", "pyproject.toml"] 56 | exclude: (?x)^(.*stemmer.*|.*stop_words.*|^CHANGELOG.md$) 57 | # More details about these pre-commit hooks here: 58 | # https://pre-commit.com/hooks.html 59 | - repo: https://github.com/pre-commit/pre-commit-hooks 60 | rev: v6.0.0 61 | hooks: 62 | - id: check-case-conflict 63 | - id: check-executables-have-shebangs 64 | - id: check-merge-conflict 65 | - id: check-json 66 | - id: check-toml 67 | - id: check-yaml 68 | - id: check-shebang-scripts-are-executable 69 | - id: end-of-file-fixer 70 | types_or: [c, c++, cuda, proto, textproto, java, python] 71 | - id: mixed-line-ending 72 | - id: requirements-txt-fixer 73 | - id: trailing-whitespace 74 | -------------------------------------------------------------------------------- /src/libtorch_utils.h: -------------------------------------------------------------------------------- 1 | // Copyright 2020-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // Redistribution and use in source and binary forms, with or without 4 | // modification, are permitted provided that the following conditions 5 | // are met: 6 | // * Redistributions of source code must retain the above copyright 7 | // notice, this list of conditions and the following disclaimer. 8 | // * Redistributions in binary form must reproduce the above copyright 9 | // notice, this list of conditions and the following disclaimer in the 10 | // documentation and/or other materials provided with the distribution. 11 | // * Neither the name of NVIDIA CORPORATION nor the names of its 12 | // contributors may be used to endorse or promote products derived 13 | // from this software without specific prior written permission. 14 | // 15 | // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY 16 | // EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 17 | // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 18 | // PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR 19 | // CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 20 | // EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 21 | // PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 22 | // PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY 23 | // OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 24 | // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | 27 | #pragma once 28 | 29 | #include "triton/backend/backend_common.h" 30 | #include "triton/core/tritonserver.h" 31 | 32 | // Suppress warnings in torch headers 33 | #pragma GCC diagnostic push 34 | #pragma GCC diagnostic ignored "-Wsign-compare" 35 | #pragma warning(push, 0) 36 | #include 37 | #include 38 | #include 39 | #include 40 | #include // One-stop header for TorchScript 41 | #pragma warning(pop) 42 | #pragma GCC diagnostic pop 43 | 44 | namespace triton { namespace backend { namespace pytorch { 45 | 46 | TRITONSERVER_DataType ConvertTorchTypeToDataType( 47 | const torch::ScalarType& ttype); 48 | std::pair ConvertDataTypeToTorchType( 49 | const TRITONSERVER_DataType dtype); 50 | std::pair ModelConfigDataTypeToTorchType( 51 | const std::string& data_type_str); 52 | 53 | #ifdef TRITON_ENABLE_GPU 54 | TRITONSERVER_Error* ConvertCUDAStatusToTritonError( 55 | cudaError_t cuda_error, TRITONSERVER_Error_Code code, const char* msg); 56 | #endif 57 | 58 | // If the key 'mkey' is present in 'params' then update 'value' with the 59 | // value associated with that key. If 'mkey' is not present in 'params' then 60 | // no update is made to 'value'. 61 | TRITONSERVER_Error* ParseParameter( 62 | triton::common::TritonJson::Value& params, const std::string& mkey, 63 | bool* value); 64 | 65 | // If the key 'mkey' is present in 'params' then update 'value' with the 66 | // value associated with that key. If 'mkey' is not present in 'params' then 67 | // 'value' is set to 'default_value'. 68 | TRITONSERVER_Error* ParseParameter( 69 | triton::common::TritonJson::Value& params, const std::string& mkey, 70 | int* value); 71 | 72 | }}} // namespace triton::backend::pytorch 73 | -------------------------------------------------------------------------------- /src/string_utils.hh: -------------------------------------------------------------------------------- 1 | // Copyright 2019-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // Redistribution and use in source and binary forms, with or without 4 | // modification, are permitted provided that the following conditions 5 | // are met: 6 | // * Redistributions of source code must retain the above copyright 7 | // notice, this list of conditions and the following disclaimer. 8 | // * Redistributions in binary form must reproduce the above copyright 9 | // notice, this list of conditions and the following disclaimer in the 10 | // documentation and/or other materials provided with the distribution. 11 | // * Neither the name of NVIDIA CORPORATION nor the names of its 12 | // contributors may be used to endorse or promote products derived 13 | // from this software without specific prior written permission. 14 | // 15 | // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY 16 | // EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 17 | // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 18 | // PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR 19 | // CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 20 | // EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 21 | // PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 22 | // PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY 23 | // OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 24 | // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | 27 | #pragma once 28 | 29 | #include 30 | 31 | #include 32 | #include 33 | #include 34 | 35 | #include "libtorch_utils.h" 36 | #include "triton/backend/backend_common.h" 37 | #include "triton/backend/backend_input_collector.h" 38 | #include "triton/backend/backend_memory.h" 39 | #include "triton/backend/backend_model.h" 40 | #include "triton/backend/backend_model_instance.h" 41 | #include "triton/backend/backend_output_responder.h" 42 | #include "triton/common/nvtx.h" 43 | #include "triton/core/tritonbackend.h" 44 | 45 | #ifdef TRITON_PYTORCH_ENABLE_TORCHVISION 46 | // Suppress warnings in torch headers 47 | #pragma GCC diagnostic push 48 | #pragma GCC diagnostic ignored "-Wsign-compare" 49 | #pragma warning(push, 0) 50 | #include 51 | #include // Torchvision header 52 | #pragma warning(pop) 53 | #pragma GCC diagnostic pop 54 | #endif // TRITON_PYTORCH_ENABLE_TORCHVISION 55 | 56 | #ifdef TRITON_ENABLE_GPU 57 | #include 58 | #include 59 | #include 60 | #endif // TRITON_ENABLE_GPU 61 | 62 | // for thread control 63 | // https://pytorch.org/docs/stable/notes/cpu_threading_torchscript_inference.html#runtime-api 64 | // https://github.com/pytorch/pytorch/blob/v2.2.1-rc3/aten/src/ATen/Parallel.h#L133 65 | #include 66 | 67 | 68 | namespace triton::backend::pytorch { 69 | 70 | void FillStringTensor(torch::List* input_list, const size_t cnt); 71 | 72 | // This function will return a tensor's contents as a contiguous 73 | // chunk in system memory. In some cases this will require copying the data. 74 | // If that happens, 'contiguous_buffer' will be set to hold the contiguous 75 | // chunk and 'cuda_copy' will be set to indicate whether CUDA copy is 76 | // conducted. The data copy can be avoided if the input is already in 77 | // a contiguous chunk and the input is located in memory type and id 78 | // specified. 79 | TRITONSERVER_Error* GetContiguousInputContent( 80 | TRITONBACKEND_Input* rinput, const uint32_t buffer_count, 81 | const char** content, size_t* content_byte_size, 82 | std::vector* contiguous_buffer, cudaStream_t stream, bool* cuda_copy); 83 | 84 | bool SetStringBuffer( 85 | torch::List* tensor, TRITONBACKEND_Response** response, 86 | TRITONBACKEND_Output* response_output, TRITONBACKEND_State* response_state, 87 | const size_t tensor_element_count, cudaStream_t stream, 88 | std::string* serialized, bool state); 89 | 90 | bool SetStringInputTensor( 91 | torch::List* input_list, TRITONBACKEND_Input* input, 92 | const char* name, const uint32_t buffer_count, 93 | const size_t request_element_cnt, TRITONBACKEND_Response** response, 94 | cudaStream_t stream, const char* host_policy_name); 95 | 96 | bool SetStringOutputBuffer( 97 | torch::List* tensor, TRITONBACKEND_Response** response, 98 | TRITONBACKEND_Output* response_output, const size_t tensor_element_count, 99 | cudaStream_t stream, std::string* serialized); 100 | 101 | bool SetStringStateBuffer( 102 | torch::List* tensor, TRITONBACKEND_Response** response, 103 | TRITONBACKEND_State* response_state, const size_t tensor_element_count, 104 | cudaStream_t stream, std::string* serialized); 105 | 106 | } // namespace triton::backend::pytorch 107 | -------------------------------------------------------------------------------- /src/model_state.hh: -------------------------------------------------------------------------------- 1 | // Copyright 2019-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // Redistribution and use in source and binary forms, with or without 4 | // modification, are permitted provided that the following conditions 5 | // are met: 6 | // * Redistributions of source code must retain the above copyright 7 | // notice, this list of conditions and the following disclaimer. 8 | // * Redistributions in binary form must reproduce the above copyright 9 | // notice, this list of conditions and the following disclaimer in the 10 | // documentation and/or other materials provided with the distribution. 11 | // * Neither the name of NVIDIA CORPORATION nor the names of its 12 | // contributors may be used to endorse or promote products derived 13 | // from this software without specific prior written permission. 14 | // 15 | // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY 16 | // EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 17 | // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 18 | // PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR 19 | // CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 20 | // EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 21 | // PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 22 | // PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY 23 | // OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 24 | // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | 27 | #pragma once 28 | 29 | #include 30 | 31 | #include 32 | #include 33 | #include 34 | 35 | #include "libtorch_utils.h" 36 | #include "naming_convention.hh" 37 | #include "triton/backend/backend_common.h" 38 | #include "triton/backend/backend_input_collector.h" 39 | #include "triton/backend/backend_memory.h" 40 | #include "triton/backend/backend_model.h" 41 | #include "triton/backend/backend_model_instance.h" 42 | #include "triton/backend/backend_output_responder.h" 43 | #include "triton/common/nvtx.h" 44 | #include "triton/core/tritonbackend.h" 45 | 46 | // for thread control 47 | // https://pytorch.org/docs/stable/notes/cpu_threading_torchscript_inference.html#runtime-api 48 | // https://github.com/pytorch/pytorch/blob/v2.2.1-rc3/aten/src/ATen/Parallel.h#L133 49 | #include 50 | 51 | 52 | namespace triton::backend::pytorch { 53 | 54 | class ModelState : public triton::backend::BackendModel { 55 | private: 56 | // Flag to indicate whether optimized execution is enabled. Defaults to true. 57 | bool enable_optimized_execution_; 58 | 59 | // Flag to indicate whether inference mode is enabled. Defaults to false. 60 | bool enable_inference_mode_; 61 | 62 | // Flag to indicate whether cudnn is enabled. Defaults to true. 63 | bool enable_cudnn_; 64 | 65 | // Flag to indicate whether cache cleaning after each run is enabled. 66 | // Defaults to false. 67 | bool enable_cache_cleaning_; 68 | 69 | // Flag to indicate whether weight sharing is enabled. Defaults to false. 70 | bool enable_weight_sharing_; 71 | 72 | // Flag pairs to indicate if various JIT settings are set and 73 | // enabled respectively. Defaults to (false, true). Default behavior 74 | // is to do nothing if not explicitly set. 75 | std::pair enable_tensor_fuser_pair_; 76 | std::pair enable_jit_profiling_pair_; 77 | std::pair enable_jit_executor_pair_; 78 | 79 | // Model mapping for shared TorchScript model across all instances on the 80 | // same device. The key is a pair of isGPU and device index. 81 | std::map< 82 | std::pair, std::shared_ptr> 83 | torch_models_; 84 | 85 | // model_outputs is a map that contains unique outputs that the model must 86 | // provide. The first pair is the model output index and the second is 87 | // the index in the model state, -1 is used if one is not required. 88 | // In the model configuration, the output in the state configuration 89 | // can have intersection with the outputs section of the model. If an output 90 | // is specified both in the output section and state section, it indicates 91 | // that the backend must return the output state to the client too. 92 | std::map> model_outputs_; 93 | 94 | public: 95 | virtual ~ModelState() = default; 96 | 97 | static TRITONSERVER_Error* Create( 98 | TRITONBACKEND_Model* triton_model, ModelState** state); 99 | 100 | bool EnabledCacheCleaning(); 101 | 102 | bool EnabledCudnn(); 103 | 104 | bool EnabledInferenceMode(); 105 | 106 | const std::pair& EnabledJitExecutor() const; 107 | 108 | const std::pair& EnabledJitProfiling() const; 109 | 110 | bool EnabledOptimizedExecution(); 111 | 112 | const std::pair& EnabledTensorExprFuser() const; 113 | 114 | bool EnabledWeightSharing(); 115 | 116 | TRITONSERVER_Error* LoadModel( 117 | const std::string& artifact_name, const torch::Device device, 118 | std::string* model_path, const TRITONSERVER_InstanceGroupKind& kind, 119 | std::shared_ptr* torch_model); 120 | 121 | const std::map>& ModelOutputs(); 122 | 123 | private: 124 | ModelState(TRITONBACKEND_Model* triton_model); 125 | 126 | TRITONSERVER_Error* AutoCompleteConfig(); 127 | 128 | TRITONSERVER_Error* ParseParameters(); 129 | }; 130 | 131 | } // namespace triton::backend::pytorch 132 | -------------------------------------------------------------------------------- /src/libtorch_utils.cc: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2020-24 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // Redistribution and use in source and binary forms, with or without 4 | // modification, are permitted provided that the following conditions 5 | // are met: 6 | // * Redistributions of source code must retain the above copyright 7 | // notice, this list of conditions and the following disclaimer. 8 | // * Redistributions in binary form must reproduce the above copyright 9 | // notice, this list of conditions and the following disclaimer in the 10 | // documentation and/or other materials provided with the distribution. 11 | // * Neither the name of NVIDIA CORPORATION nor the names of its 12 | // contributors may be used to endorse or promote products derived 13 | // from this software without specific prior written permission. 14 | // 15 | // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY 16 | // EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 17 | // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 18 | // PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR 19 | // CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 20 | // EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 21 | // PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 22 | // PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY 23 | // OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 24 | // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | 27 | #include "libtorch_utils.h" 28 | 29 | namespace triton { namespace backend { namespace pytorch { 30 | 31 | TRITONSERVER_DataType 32 | ConvertTorchTypeToDataType(const torch::ScalarType& stype) 33 | { 34 | switch (stype) { 35 | case torch::kBool: 36 | return TRITONSERVER_TYPE_BOOL; 37 | case torch::kByte: 38 | return TRITONSERVER_TYPE_UINT8; 39 | case torch::kChar: 40 | return TRITONSERVER_TYPE_INT8; 41 | case torch::kShort: 42 | return TRITONSERVER_TYPE_INT16; 43 | case torch::kInt: 44 | return TRITONSERVER_TYPE_INT32; 45 | case torch::kLong: 46 | return TRITONSERVER_TYPE_INT64; 47 | case torch::kHalf: 48 | return TRITONSERVER_TYPE_FP16; 49 | case torch::kFloat: 50 | return TRITONSERVER_TYPE_FP32; 51 | case torch::kDouble: 52 | return TRITONSERVER_TYPE_FP64; 53 | default: 54 | break; 55 | } 56 | 57 | return TRITONSERVER_TYPE_INVALID; 58 | } 59 | 60 | std::pair 61 | ConvertDataTypeToTorchType(const TRITONSERVER_DataType dtype) 62 | { 63 | torch::ScalarType type = torch::kInt; 64 | switch (dtype) { 65 | case TRITONSERVER_TYPE_BOOL: 66 | type = torch::kBool; 67 | break; 68 | case TRITONSERVER_TYPE_UINT8: 69 | type = torch::kByte; 70 | break; 71 | case TRITONSERVER_TYPE_INT8: 72 | type = torch::kChar; 73 | break; 74 | case TRITONSERVER_TYPE_INT16: 75 | type = torch::kShort; 76 | break; 77 | case TRITONSERVER_TYPE_INT32: 78 | type = torch::kInt; 79 | break; 80 | case TRITONSERVER_TYPE_INT64: 81 | type = torch::kLong; 82 | break; 83 | case TRITONSERVER_TYPE_FP16: 84 | type = torch::kHalf; 85 | break; 86 | case TRITONSERVER_TYPE_FP32: 87 | type = torch::kFloat; 88 | break; 89 | case TRITONSERVER_TYPE_FP64: 90 | type = torch::kDouble; 91 | break; 92 | case TRITONSERVER_TYPE_UINT16: 93 | case TRITONSERVER_TYPE_UINT32: 94 | case TRITONSERVER_TYPE_UINT64: 95 | case TRITONSERVER_TYPE_BYTES: 96 | default: 97 | return std::make_pair(false, type); 98 | } 99 | 100 | return std::make_pair(true, type); 101 | } 102 | 103 | std::pair 104 | ModelConfigDataTypeToTorchType(const std::string& data_type_str) 105 | { 106 | torch::ScalarType type = torch::kInt; 107 | 108 | // Must start with "TYPE_". 109 | if (data_type_str.rfind("TYPE_", 0) != 0) { 110 | return std::make_pair(false, type); 111 | } 112 | 113 | const std::string dtype = data_type_str.substr(strlen("TYPE_")); 114 | 115 | if (dtype == "BOOL") { 116 | type = torch::kBool; 117 | } else if (dtype == "UINT8") { 118 | type = torch::kByte; 119 | } else if (dtype == "INT8") { 120 | type = torch::kChar; 121 | } else if (dtype == "INT16") { 122 | type = torch::kShort; 123 | } else if (dtype == "INT32") { 124 | type = torch::kInt; 125 | } else if (dtype == "INT64") { 126 | type = torch::kLong; 127 | } else if (dtype == "FP16") { 128 | type = torch::kHalf; 129 | } else if (dtype == "FP32") { 130 | type = torch::kFloat; 131 | } else if (dtype == "FP64") { 132 | type = torch::kDouble; 133 | } else { 134 | return std::make_pair(false, type); 135 | } 136 | 137 | return std::make_pair(true, type); 138 | } 139 | 140 | TRITONSERVER_Error* 141 | ParseParameter( 142 | triton::common::TritonJson::Value& params, const std::string& mkey, 143 | bool* value) 144 | { 145 | std::string value_str; 146 | RETURN_IF_ERROR(GetParameterValue(params, mkey, &value_str)); 147 | RETURN_IF_ERROR(ParseBoolValue(value_str, value)); 148 | 149 | return nullptr; 150 | } 151 | 152 | TRITONSERVER_Error* 153 | ParseParameter( 154 | triton::common::TritonJson::Value& params, const std::string& mkey, 155 | int* value) 156 | { 157 | std::string value_str; 158 | RETURN_IF_ERROR(GetParameterValue(params, mkey, &value_str)); 159 | RETURN_IF_ERROR(ParseIntValue(value_str, value)); 160 | 161 | return nullptr; 162 | } 163 | 164 | 165 | #ifdef TRITON_ENABLE_GPU 166 | TRITONSERVER_Error* 167 | ConvertCUDAStatusToTritonError( 168 | cudaError_t cuda_error, TRITONSERVER_Error_Code code, const char* msg) 169 | { 170 | if (cuda_error != cudaSuccess) { 171 | return TRITONSERVER_ErrorNew( 172 | code, 173 | (std::string(msg) + ": " + cudaGetErrorString(cuda_error)).c_str()); 174 | } 175 | return nullptr; // success 176 | } 177 | #endif 178 | 179 | }}} // namespace triton::backend::pytorch 180 | -------------------------------------------------------------------------------- /src/model_instance_state.hh: -------------------------------------------------------------------------------- 1 | // Copyright 2019-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // Redistribution and use in source and binary forms, with or without 4 | // modification, are permitted provided that the following conditions 5 | // are met: 6 | // * Redistributions of source code must retain the above copyright 7 | // notice, this list of conditions and the following disclaimer. 8 | // * Redistributions in binary form must reproduce the above copyright 9 | // notice, this list of conditions and the following disclaimer in the 10 | // documentation and/or other materials provided with the distribution. 11 | // * Neither the name of NVIDIA CORPORATION nor the names of its 12 | // contributors may be used to endorse or promote products derived 13 | // from this software without specific prior written permission. 14 | // 15 | // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY 16 | // EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 17 | // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 18 | // PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR 19 | // CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 20 | // EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 21 | // PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 22 | // PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY 23 | // OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 24 | // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | 27 | #pragma once 28 | 29 | #include 30 | 31 | #include 32 | #include 33 | #include 34 | #include 35 | #include 36 | 37 | #include "libtorch_utils.h" 38 | #include "model_state.hh" 39 | #include "naming_convention.hh" 40 | #include "triton/backend/backend_common.h" 41 | #include "triton/backend/backend_input_collector.h" 42 | #include "triton/backend/backend_memory.h" 43 | #include "triton/backend/backend_model.h" 44 | #include "triton/backend/backend_model_instance.h" 45 | #include "triton/backend/backend_output_responder.h" 46 | #include "triton/common/nvtx.h" 47 | #include "triton/core/tritonbackend.h" 48 | 49 | 50 | namespace triton::backend::pytorch { 51 | 52 | // 53 | // ModelInstanceState 54 | // 55 | // State associated with a model instance. An object of this class is 56 | // created and associated with each TRITONBACKEND_ModelInstance. 57 | // 58 | class ModelInstanceState : public BackendModelInstance { 59 | private: 60 | ModelState* model_state_; 61 | 62 | // The full path to the TorchScript model file. 63 | std::string model_path_; 64 | 65 | std::shared_ptr torch_model_; 66 | torch::Device device_; 67 | 68 | // Map from configuration name for an input to the index of 69 | // that input in the model. 70 | std::unordered_map input_index_map_; 71 | uint32_t batch_input_count_ = 0; 72 | 73 | // Map from configuration name for an output to the index of 74 | // that output in the model. 75 | std::unordered_map output_index_map_; 76 | std::unordered_map output_dtype_map_; 77 | 78 | // If the input to the tensor is a dictionary of tensors. 79 | bool is_dict_input_; 80 | 81 | // If the model supports batching. 82 | bool supports_batching_; 83 | 84 | cudaEvent_t compute_input_start_event_; 85 | cudaEvent_t compute_infer_start_event_; 86 | cudaEvent_t compute_output_start_event_; 87 | 88 | // Store the cuda streams created for the 'KIND_MODEL' instance group. 89 | std::vector stream_vec_; 90 | 91 | // The number of available devices. 92 | int device_cnt_; 93 | 94 | public: 95 | virtual ~ModelInstanceState(); 96 | 97 | // Clear CUDA cache 98 | void ClearCache(); 99 | 100 | static TRITONSERVER_Error* Create( 101 | ModelState* model_state, 102 | TRITONBACKEND_ModelInstance* triton_model_instance, 103 | ModelInstanceState** state); 104 | 105 | // Execute... 106 | void ProcessRequests( 107 | TRITONBACKEND_Request** requests, const uint32_t request_count); 108 | 109 | // Get the state of the model that corresponds to this instance. 110 | ModelState* StateForModel() const; 111 | 112 | private: 113 | ModelInstanceState( 114 | ModelState* model_state, 115 | TRITONBACKEND_ModelInstance* triton_model_instance); 116 | 117 | void AddInputToMap( 118 | NamingConvention naming_convention, 119 | const std::vector allowed_inputs, const std::string& io_name, 120 | const uint32_t index); 121 | 122 | // Create CUDA events for statistics collection. 123 | void CreateCudaEvents(const int32_t& device_id); 124 | 125 | void Execute( 126 | std::vector* responses, 127 | const uint32_t response_count, 128 | std::vector* input_tensors, 129 | std::vector* output_tensors); 130 | 131 | // Get the elapsed time between two CUDA events. 132 | float GetCudaEventElapsedTime( 133 | const cudaEvent_t& start_event, const cudaEvent_t& end_event); 134 | 135 | // Get the appropriate CUDA stream for input and output handling based on 136 | // the instance group type. 137 | cudaStream_t GetCudaStreamByInstanceKind(); 138 | 139 | // Get the naming convention for inputs/outputs from the model configuration 140 | TRITONSERVER_Error* GetNamingConvention( 141 | NamingConvention* naming_convention, 142 | const std::vector& allowed_io); 143 | 144 | TRITONSERVER_Error* ReadOutputTensors( 145 | size_t total_batch_size, 146 | const std::vector& output_tensors, 147 | TRITONBACKEND_Request** requests, const uint32_t request_count, 148 | std::vector* responses); 149 | 150 | TRITONSERVER_Error* RecordBackendTimestamp( 151 | uint64_t* timestamp, void* cuda_event); 152 | 153 | // Replace the default CUDA stream with the stream we created to ensure 154 | // proper cuda stream synchronization. 155 | void SetCurrentCudaStream( 156 | const cudaStream_t& stream, const int32_t& device_id); 157 | 158 | TRITONSERVER_Error* SetInputTensors( 159 | size_t total_batch_size, TRITONBACKEND_Request** requests, 160 | const uint32_t request_count, 161 | std::vector* responses, 162 | BackendInputCollector* collector, std::vector* input_names, 163 | std::vector* input_tensors, bool* cuda_copy); 164 | 165 | TRITONSERVER_Error* ValidateBooleanSequenceControl( 166 | triton::common::TritonJson::Value& sequence_batching, 167 | const std::string& control_kind, bool required, bool* have_control); 168 | 169 | TRITONSERVER_Error* ValidateInputs(const size_t expected_input_cnt); 170 | 171 | TRITONSERVER_Error* ValidateOutputs(); 172 | 173 | TRITONSERVER_Error* ValidateTypedSequenceControl( 174 | triton::common::TritonJson::Value& sequence_batching, 175 | const std::string& control_kind, bool required, bool* have_control); 176 | }; 177 | 178 | } // namespace triton::backend::pytorch 179 | -------------------------------------------------------------------------------- /src/libtorch.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2019-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // Redistribution and use in source and binary forms, with or without 4 | // modification, are permitted provided that the following conditions 5 | // are met: 6 | // * Redistributions of source code must retain the above copyright 7 | // notice, this list of conditions and the following disclaimer. 8 | // * Redistributions in binary form must reproduce the above copyright 9 | // notice, this list of conditions and the following disclaimer in the 10 | // documentation and/or other materials provided with the distribution. 11 | // * Neither the name of NVIDIA CORPORATION nor the names of its 12 | // contributors may be used to endorse or promote products derived 13 | // from this software without specific prior written permission. 14 | // 15 | // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY 16 | // EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 17 | // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 18 | // PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR 19 | // CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 20 | // EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 21 | // PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 22 | // PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY 23 | // OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 24 | // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | 27 | #include "libtorch.hh" 28 | 29 | // 30 | // PyTorch C++ (LibTorch) Backend that implements the TRITONBACKEND API. 31 | // 32 | 33 | namespace triton::backend::pytorch { 34 | 35 | extern "C" { 36 | 37 | TRITONSERVER_Error* 38 | TRITONBACKEND_Initialize(TRITONBACKEND_Backend* backend) 39 | { 40 | const char* cname; 41 | RETURN_IF_ERROR(TRITONBACKEND_BackendName(backend, &cname)); 42 | std::string name(cname); 43 | 44 | LOG_MESSAGE( 45 | TRITONSERVER_LOG_INFO, 46 | (std::string("TRITONBACKEND_Initialize: ") + name).c_str()); 47 | 48 | // Check the backend API version that Triton supports vs. what this 49 | // backend was compiled against. 50 | uint32_t api_version_major, api_version_minor; 51 | RETURN_IF_ERROR( 52 | TRITONBACKEND_ApiVersion(&api_version_major, &api_version_minor)); 53 | 54 | LOG_MESSAGE( 55 | TRITONSERVER_LOG_INFO, 56 | (std::string("Triton TRITONBACKEND API version: ") + 57 | std::to_string(api_version_major) + "." + 58 | std::to_string(api_version_minor)) 59 | .c_str()); 60 | LOG_MESSAGE( 61 | TRITONSERVER_LOG_INFO, 62 | (std::string("'") + name + "' TRITONBACKEND API version: " + 63 | std::to_string(TRITONBACKEND_API_VERSION_MAJOR) + "." + 64 | std::to_string(TRITONBACKEND_API_VERSION_MINOR)) 65 | .c_str()); 66 | 67 | if ((api_version_major != TRITONBACKEND_API_VERSION_MAJOR) || 68 | (api_version_minor < TRITONBACKEND_API_VERSION_MINOR)) { 69 | return TRITONSERVER_ErrorNew( 70 | TRITONSERVER_ERROR_UNSUPPORTED, 71 | (std::string("Triton TRITONBACKEND API version: ") + 72 | std::to_string(api_version_major) + "." + 73 | std::to_string(api_version_minor) + " does not support '" + name + 74 | "' TRITONBACKEND API version: " + 75 | std::to_string(TRITONBACKEND_API_VERSION_MAJOR) + "." + 76 | std::to_string(TRITONBACKEND_API_VERSION_MINOR)) 77 | .c_str()); 78 | } 79 | 80 | return nullptr; // success 81 | } 82 | 83 | TRITONSERVER_Error* 84 | TRITONBACKEND_ModelInitialize(TRITONBACKEND_Model* model) 85 | { 86 | const char* cname; 87 | RETURN_IF_ERROR(TRITONBACKEND_ModelName(model, &cname)); 88 | std::string name(cname); 89 | 90 | uint64_t version; 91 | RETURN_IF_ERROR(TRITONBACKEND_ModelVersion(model, &version)); 92 | 93 | LOG_MESSAGE( 94 | TRITONSERVER_LOG_INFO, 95 | (std::string("TRITONBACKEND_ModelInitialize: ") + name + " (version " + 96 | std::to_string(version) + ")") 97 | .c_str()); 98 | 99 | // Create a ModelState object and associate it with the 100 | // TRITONBACKEND_Model. 101 | ModelState* model_state; 102 | RETURN_IF_ERROR(ModelState::Create(model, &model_state)); 103 | RETURN_IF_ERROR( 104 | TRITONBACKEND_ModelSetState(model, reinterpret_cast(model_state))); 105 | 106 | return nullptr; // success 107 | } 108 | 109 | TRITONSERVER_Error* 110 | TRITONBACKEND_ModelFinalize(TRITONBACKEND_Model* model) 111 | { 112 | void* vstate; 113 | RETURN_IF_ERROR(TRITONBACKEND_ModelState(model, &vstate)); 114 | ModelState* model_state = reinterpret_cast(vstate); 115 | 116 | LOG_MESSAGE( 117 | TRITONSERVER_LOG_INFO, "TRITONBACKEND_ModelFinalize: delete model state"); 118 | 119 | delete model_state; 120 | 121 | return nullptr; // success 122 | } 123 | 124 | TRITONSERVER_Error* 125 | TRITONBACKEND_ModelInstanceInitialize(TRITONBACKEND_ModelInstance* instance) 126 | { 127 | const char* cname; 128 | RETURN_IF_ERROR(TRITONBACKEND_ModelInstanceName(instance, &cname)); 129 | std::string name(cname); 130 | 131 | int32_t device_id; 132 | RETURN_IF_ERROR(TRITONBACKEND_ModelInstanceDeviceId(instance, &device_id)); 133 | 134 | TRITONSERVER_InstanceGroupKind kind; 135 | RETURN_IF_ERROR(TRITONBACKEND_ModelInstanceKind(instance, &kind)); 136 | 137 | LOG_MESSAGE( 138 | TRITONSERVER_LOG_INFO, 139 | (std::string("TRITONBACKEND_ModelInstanceInitialize: ") + name + " (" + 140 | TRITONSERVER_InstanceGroupKindString(kind) + " device " + 141 | std::to_string(device_id) + ")") 142 | .c_str()); 143 | 144 | // Get the model state associated with this instance's model. 145 | TRITONBACKEND_Model* model; 146 | RETURN_IF_ERROR(TRITONBACKEND_ModelInstanceModel(instance, &model)); 147 | 148 | void* vmodelstate; 149 | RETURN_IF_ERROR(TRITONBACKEND_ModelState(model, &vmodelstate)); 150 | ModelState* model_state = reinterpret_cast(vmodelstate); 151 | 152 | // Create a ModelInstanceState object and associate it with the 153 | // TRITONBACKEND_ModelInstance. 154 | ModelInstanceState* instance_state; 155 | RETURN_IF_ERROR( 156 | ModelInstanceState::Create(model_state, instance, &instance_state)); 157 | RETURN_IF_ERROR(TRITONBACKEND_ModelInstanceSetState( 158 | instance, reinterpret_cast(instance_state))); 159 | 160 | return nullptr; // success 161 | } 162 | 163 | TRITONSERVER_Error* 164 | TRITONBACKEND_ModelInstanceFinalize(TRITONBACKEND_ModelInstance* instance) 165 | { 166 | void* vstate; 167 | RETURN_IF_ERROR(TRITONBACKEND_ModelInstanceState(instance, &vstate)); 168 | ModelInstanceState* instance_state = 169 | reinterpret_cast(vstate); 170 | 171 | LOG_MESSAGE( 172 | TRITONSERVER_LOG_INFO, 173 | "TRITONBACKEND_ModelInstanceFinalize: delete instance state"); 174 | 175 | delete instance_state; 176 | 177 | return nullptr; // success 178 | } 179 | 180 | TRITONSERVER_Error* 181 | TRITONBACKEND_ModelInstanceExecute( 182 | TRITONBACKEND_ModelInstance* instance, TRITONBACKEND_Request** requests, 183 | const uint32_t request_count) 184 | { 185 | // Triton will not call this function simultaneously for the same 186 | // 'instance'. But since this backend could be used by multiple 187 | // instances from multiple models the implementation needs to handle 188 | // multiple calls to this function at the same time (with different 189 | // 'instance' objects). Suggested practice for this is to use only 190 | // function-local and model-instance-specific state (obtained from 191 | // 'instance'), which is what we do here. 192 | ModelInstanceState* instance_state; 193 | RETURN_IF_ERROR(TRITONBACKEND_ModelInstanceState( 194 | instance, reinterpret_cast(&instance_state))); 195 | ModelState* model_state = instance_state->StateForModel(); 196 | 197 | // This backend specifies BLOCKING execution policy. That means that 198 | // we should not return from this function until execution is 199 | // complete. Triton will automatically release 'instance' on return 200 | // from this function so that it is again available to be used for 201 | // another call to TRITONBACKEND_ModelInstanceExecute. 202 | 203 | LOG_MESSAGE( 204 | TRITONSERVER_LOG_VERBOSE, 205 | (std::string("model ") + model_state->Name() + ", instance " + 206 | instance_state->Name() + ", executing " + std::to_string(request_count) + 207 | " requests") 208 | .c_str()); 209 | 210 | // At this point we accept ownership of 'requests', which means that 211 | // even if something goes wrong we must still return success from 212 | // this function. If something does go wrong in processing a 213 | // particular request then we send an error response just for the 214 | // specific request. 215 | instance_state->ProcessRequests(requests, request_count); 216 | 217 | if (model_state->EnabledCacheCleaning()) { 218 | instance_state->ClearCache(); 219 | } 220 | 221 | return nullptr; // success 222 | }; 223 | 224 | } // extern "C" 225 | 226 | } // namespace triton::backend::pytorch 227 | -------------------------------------------------------------------------------- /src/string_utils.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2019-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // Redistribution and use in source and binary forms, with or without 4 | // modification, are permitted provided that the following conditions 5 | // are met: 6 | // * Redistributions of source code must retain the above copyright 7 | // notice, this list of conditions and the following disclaimer. 8 | // * Redistributions in binary form must reproduce the above copyright 9 | // notice, this list of conditions and the following disclaimer in the 10 | // documentation and/or other materials provided with the distribution. 11 | // * Neither the name of NVIDIA CORPORATION nor the names of its 12 | // contributors may be used to endorse or promote products derived 13 | // from this software without specific prior written permission. 14 | // 15 | // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY 16 | // EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 17 | // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 18 | // PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR 19 | // CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 20 | // EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 21 | // PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 22 | // PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY 23 | // OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 24 | // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | 27 | #include "string_utils.hh" 28 | 29 | 30 | namespace triton::backend::pytorch { 31 | 32 | // This function will return a tensor's contents as a contiguous 33 | // chunk in system memory. In some cases this will require copying the data. 34 | // If that happens, 'contiguous_buffer' will be set to hold the contiguous 35 | // chunk and 'cuda_copy' will be set to indicate whether CUDA copy is 36 | // conducted. The data copy can be avoided if the input is already in 37 | // a contiguous chunk and the input is located in memory type and id 38 | // specified. 39 | TRITONSERVER_Error* 40 | GetContiguousInputContent( 41 | TRITONBACKEND_Input* rinput, const uint32_t buffer_count, 42 | const char** content, size_t* content_byte_size, 43 | std::vector* contiguous_buffer, cudaStream_t stream, bool* cuda_copy) 44 | { 45 | *cuda_copy = false; 46 | 47 | // Check input buffers to see if data copy is necessary 48 | size_t chunk_count = 0; 49 | bool type_mismatch = false; 50 | uint64_t total_byte_size = 0; 51 | for (size_t idx = 0; idx < buffer_count; ++idx) { 52 | TRITONSERVER_MemoryType src_memory_type; 53 | int64_t src_memory_type_id; 54 | size_t src_byte_size; 55 | const void* src_ptr; 56 | 57 | RETURN_IF_ERROR(TRITONBACKEND_InputBuffer( 58 | rinput, idx, &src_ptr, &src_byte_size, &src_memory_type, 59 | &src_memory_type_id)); 60 | 61 | if (src_ptr != nullptr) { 62 | chunk_count++; 63 | total_byte_size += src_byte_size; 64 | type_mismatch |= (src_memory_type == TRITONSERVER_MEMORY_GPU); 65 | } 66 | } 67 | 68 | if (chunk_count == 0) { 69 | *content = nullptr; 70 | *content_byte_size = 0; 71 | } else if ((chunk_count == 1) && !type_mismatch) { 72 | TRITONSERVER_MemoryType src_memory_type; 73 | int64_t src_memory_type_id; 74 | RETURN_IF_ERROR(TRITONBACKEND_InputBuffer( 75 | rinput, 0, (const void**)content, content_byte_size, &src_memory_type, 76 | &src_memory_type_id)); 77 | } else { 78 | contiguous_buffer->resize(total_byte_size); 79 | 80 | size_t offset = 0; 81 | for (size_t i = 0; i < chunk_count; i++) { 82 | bool cuda_used; 83 | TRITONSERVER_MemoryType src_memory_type; 84 | int64_t src_memory_type_id; 85 | size_t src_byte_size; 86 | const void* src_ptr; 87 | 88 | RETURN_IF_ERROR(TRITONBACKEND_InputBuffer( 89 | rinput, i, &src_ptr, &src_byte_size, &src_memory_type, 90 | &src_memory_type_id)); 91 | RETURN_IF_ERROR(CopyBuffer( 92 | "Contiguous input", src_memory_type, src_memory_type_id, 93 | TRITONSERVER_MEMORY_CPU, 0, src_byte_size, src_ptr, 94 | contiguous_buffer->data() + offset, stream, &cuda_used)); 95 | *cuda_copy |= cuda_used; 96 | offset += src_byte_size; 97 | } 98 | 99 | *content = contiguous_buffer->data(); 100 | *content_byte_size = total_byte_size; 101 | } 102 | 103 | return nullptr; // success 104 | } 105 | 106 | void 107 | FillStringTensor(torch::List* input_list, const size_t cnt) 108 | { 109 | for (size_t c = 0; c < cnt; ++c) { 110 | input_list->push_back(""); 111 | } 112 | } 113 | 114 | bool 115 | SetStringBuffer( 116 | torch::List* tensor, TRITONBACKEND_Response** response, 117 | TRITONBACKEND_Output* response_output, TRITONBACKEND_State* response_state, 118 | const size_t tensor_element_count, cudaStream_t stream, 119 | std::string* serialized, bool state) 120 | { 121 | bool cuda_copy = false; 122 | 123 | // Serialize the output tensor strings. Each string is serialized as 124 | // a 4-byte length followed by the string itself with no 125 | // null-terminator. 126 | serialized->clear(); 127 | for (size_t e = 0; e < tensor_element_count; ++e) { 128 | std::string str = tensor->get(e).to(); 129 | const char* cstr = str.c_str(); 130 | size_t len = str.length(); 131 | serialized->append(reinterpret_cast(&len), sizeof(uint32_t)); 132 | if (len > 0) { 133 | serialized->append(cstr, len); 134 | } 135 | } 136 | 137 | // Allocate a buffer large enough to hold the serialized tensor. 138 | TRITONSERVER_MemoryType actual_memory_type = TRITONSERVER_MEMORY_CPU; 139 | int64_t actual_memory_type_id = 0; 140 | 141 | TRITONSERVER_Error* err; 142 | void* buffer; 143 | 144 | if (!state) { 145 | auto err = TRITONBACKEND_OutputBuffer( 146 | response_output, &buffer, serialized->size(), &actual_memory_type, 147 | &actual_memory_type_id); 148 | if (err != nullptr) { 149 | RESPOND_AND_SET_NULL_IF_ERROR(response, err); 150 | return cuda_copy; 151 | } 152 | } else { 153 | auto err = TRITONBACKEND_StateBuffer( 154 | response_state, &buffer, serialized->size(), &actual_memory_type, 155 | &actual_memory_type_id); 156 | if (err != nullptr) { 157 | RESPOND_AND_SET_NULL_IF_ERROR(response, err); 158 | return cuda_copy; 159 | } 160 | } 161 | // Copy the serialized tensor into the allocated buffer. 162 | bool cuda_used = false; 163 | err = CopyBuffer( 164 | "String output", TRITONSERVER_MEMORY_CPU /* src_memory_type */, 165 | 0 /* src_memory_type_id */, actual_memory_type, actual_memory_type_id, 166 | serialized->size(), reinterpret_cast(serialized->c_str()), 167 | buffer, stream, &cuda_used); 168 | cuda_copy |= cuda_used; 169 | 170 | if (err != nullptr) { 171 | RESPOND_AND_SET_NULL_IF_ERROR(response, err); 172 | return cuda_copy; 173 | } 174 | 175 | if (state) { 176 | RESPOND_AND_SET_NULL_IF_ERROR( 177 | response, TRITONBACKEND_StateUpdate(response_state)); 178 | } 179 | 180 | return cuda_copy; 181 | } 182 | 183 | bool 184 | SetStringInputTensor( 185 | torch::List* input_list, TRITONBACKEND_Input* input, 186 | const char* name, const uint32_t buffer_count, 187 | const size_t request_element_cnt, TRITONBACKEND_Response** response, 188 | cudaStream_t stream, const char* host_policy_name) 189 | { 190 | bool cuda_copy = false; 191 | 192 | // For string data type, we always need to have the data on CPU so 193 | // that we can read string length and construct the string 194 | // properly. So if the request's input tensor is not in CPU need to 195 | // copy it there. 196 | const char* content = nullptr; 197 | size_t content_byte_size = 0; 198 | 199 | std::vector contiguous_buffer; 200 | auto err = GetContiguousInputContent( 201 | input, buffer_count, &content, &content_byte_size, &contiguous_buffer, 202 | stream, &cuda_copy); 203 | if (err != nullptr) { 204 | RESPOND_AND_SET_NULL_IF_ERROR(response, err); 205 | FillStringTensor(input_list, request_element_cnt); 206 | return cuda_copy; 207 | } 208 | 209 | #ifdef TRITON_ENABLE_GPU 210 | if (cuda_copy) { 211 | cudaStreamSynchronize(stream); 212 | cuda_copy = false; 213 | } 214 | #endif // TRITON_ENABLE_GPU 215 | 216 | std::vector> str_list; 217 | err = ValidateStringBuffer( 218 | content, content_byte_size, request_element_cnt, name, &str_list); 219 | // Set string values. 220 | for (const auto& [addr, len] : str_list) { 221 | input_list->push_back(std::string(addr, len)); 222 | } 223 | 224 | size_t element_cnt = str_list.size(); 225 | if (err != nullptr) { 226 | RESPOND_AND_SET_NULL_IF_ERROR(response, err); 227 | FillStringTensor(input_list, request_element_cnt - element_cnt); 228 | } 229 | return cuda_copy; 230 | } 231 | 232 | bool 233 | SetStringOutputBuffer( 234 | torch::List* tensor, TRITONBACKEND_Response** response, 235 | TRITONBACKEND_Output* response_output, const size_t tensor_element_count, 236 | cudaStream_t stream, std::string* serialized) 237 | { 238 | return SetStringBuffer( 239 | tensor, response, response_output, nullptr /* response_state */, 240 | tensor_element_count, stream, serialized, false /* state */); 241 | } 242 | 243 | bool 244 | SetStringStateBuffer( 245 | torch::List* tensor, TRITONBACKEND_Response** response, 246 | TRITONBACKEND_State* response_state, const size_t tensor_element_count, 247 | cudaStream_t stream, std::string* serialized) 248 | { 249 | return SetStringBuffer( 250 | tensor, response, nullptr /* response_output */, response_state, 251 | tensor_element_count, stream, serialized, true /* state */); 252 | } 253 | 254 | } // namespace triton::backend::pytorch 255 | -------------------------------------------------------------------------------- /src/model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 4 | # 5 | # Redistribution and use in source and binary forms, with or without 6 | # modification, are permitted provided that the following conditions 7 | # are met: 8 | # * Redistributions of source code must retain the above copyright 9 | # notice, this list of conditions and the following disclaimer. 10 | # * Redistributions in binary form must reproduce the above copyright 11 | # notice, this list of conditions and the following disclaimer in the 12 | # documentation and/or other materials provided with the distribution. 13 | # * Neither the name of NVIDIA CORPORATION nor the names of its 14 | # contributors may be used to endorse or promote products derived 15 | # from this software without specific prior written permission. 16 | # 17 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY 18 | # EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 19 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 20 | # PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR 21 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 22 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 23 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 24 | # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY 25 | # OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 26 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | 29 | import importlib 30 | import json 31 | import os 32 | 33 | try: 34 | import torch 35 | except ModuleNotFoundError as error: 36 | raise RuntimeError("Missing/Incomplete PyTorch package installation") from error 37 | 38 | import triton_python_backend_utils as pb_utils 39 | 40 | 41 | def _get_model_path(config): 42 | # FIXME: Add support for torch.export IR models (.pt2) 43 | filenames = ["model.py", "model.pt"] 44 | if config["default_model_filename"]: 45 | filenames.insert(0, config["default_model_filename"]) 46 | for filename in filenames: 47 | model_path = os.path.join(pb_utils.get_model_dir(), filename) 48 | if os.path.exists(model_path): 49 | return model_path 50 | raise pb_utils.TritonModelException( 51 | "No model found in " + pb_utils.get_model_dir() + "/" + str(filenames) 52 | ) 53 | 54 | 55 | def _get_model_data_path(model_path): 56 | data_path_extensions = [".pt"] 57 | model_path_no_extension = model_path[: -(len(model_path.split(".")[-1]) + 1)] 58 | for extension in data_path_extensions: 59 | data_path = model_path_no_extension + extension 60 | if os.path.exists(data_path): 61 | return data_path 62 | # data file not provided 63 | return "" 64 | 65 | 66 | def _is_py_class_model(model_path): 67 | return model_path[-3:] == ".py" 68 | 69 | 70 | def _import_module_from_path(module_name, file_path): 71 | spec = importlib.util.spec_from_file_location(module_name, file_path) 72 | module = importlib.util.module_from_spec(spec) 73 | spec.loader.exec_module(module) 74 | return module 75 | 76 | 77 | def _get_model_class_from_module(module): 78 | names = dir(module) 79 | for name in names: 80 | attr = getattr(module, name) 81 | try: 82 | if issubclass(attr, torch.nn.Module): 83 | return attr 84 | except TypeError: 85 | # attr may not be a class 86 | pass 87 | raise pb_utils.TritonModelException("Cannot find a subclass of torch.nn.Module") 88 | 89 | 90 | def _parse_io_config(io_config): 91 | io = [] 92 | for conf in io_config: 93 | io.append({"name": conf["name"]}) 94 | return io 95 | 96 | 97 | def _get_device_name(kind, device_id): 98 | if kind == "GPU": 99 | return "cuda:" + device_id 100 | if kind == "CPU": 101 | return "cpu" 102 | # unspecified device 103 | return "" 104 | 105 | 106 | def _get_device(kind, device_id, model): 107 | device_name = _get_device_name(kind, device_id) 108 | if device_name == "": 109 | for param in model.parameters(): 110 | return param.device 111 | raise pb_utils.TritonModelException("Cannot determine model device") 112 | return torch.device(device_name) 113 | 114 | 115 | def _set_torch_parallelism(config): 116 | log_msg = "" 117 | parallelism_settings = ["NUM_THREADS", "NUM_INTEROP_THREADS"] 118 | for setting in parallelism_settings: 119 | val = "1" 120 | if setting in config["parameters"]: 121 | val = config["parameters"][setting]["string_value"] 122 | getattr(torch, "set_" + setting.lower())(int(val)) 123 | log_msg += setting + " = " + val + "; " 124 | return log_msg 125 | 126 | 127 | def _get_torch_compile_params(config): 128 | params = {} 129 | if "TORCH_COMPILE_OPTIONAL_PARAMETERS" in config["parameters"]: 130 | val = config["parameters"]["TORCH_COMPILE_OPTIONAL_PARAMETERS"]["string_value"] 131 | params = json.loads(val) 132 | if "model" in params: 133 | raise pb_utils.TritonModelException( 134 | "'model' is not an optional parameter for 'torch.compile'" 135 | ) 136 | return params 137 | 138 | 139 | def _gather_torch_tensors(scatter_tensors): 140 | gather_tensors = [] 141 | sections = [] 142 | for i in range(len(scatter_tensors)): 143 | tensors = scatter_tensors[i] 144 | for j in range(len(tensors)): 145 | tensor = tensors[j] 146 | if j < len(gather_tensors): 147 | # add to existing tensor 148 | gather_tensors[j] = torch.cat((gather_tensors[j], tensor), 0) 149 | else: 150 | # start a new tensor 151 | gather_tensors.append(tensor) 152 | # record section 153 | section_length = tensors[0].size()[0] 154 | sections.append(section_length) 155 | return gather_tensors, sections 156 | 157 | 158 | def _scatter_torch_tensors(gather_tensors, sections): 159 | scatter_tensors = [] 160 | for j in range(len(gather_tensors)): 161 | scatter_tensor = torch.split(gather_tensors[j], sections) 162 | for i in range(len(scatter_tensor)): 163 | tensor = scatter_tensor[i] 164 | if i < len(scatter_tensors): 165 | # add to existing response 166 | scatter_tensors[i].append(tensor) 167 | else: 168 | # start a new response 169 | scatter_tensors.append([tensor]) 170 | return scatter_tensors 171 | 172 | 173 | class TritonPythonModel: 174 | """Your Python model must use the same class name. Every Python model 175 | that is created must have "TritonPythonModel" as the class name. 176 | """ 177 | 178 | def initialize(self, args): 179 | """`initialize` is called only once when the model is being loaded. 180 | Implementing `initialize` function is optional. This function allows 181 | the model to initialize any state associated with this model. 182 | Parameters 183 | ---------- 184 | args : dict 185 | Both keys and values are strings. The dictionary keys and values are: 186 | * model_config: A JSON string containing the model configuration 187 | * model_instance_kind: A string containing model instance kind 188 | * model_instance_device_id: A string containing model instance device ID 189 | * model_repository: Model repository path 190 | * model_version: Model version 191 | * model_name: Model name 192 | """ 193 | self._model_name = args["model_name"] 194 | for_model = "for '" + self._model_name + "'" 195 | self._logger = pb_utils.Logger 196 | self._logger.log_info("Initializing model instance " + for_model) 197 | 198 | self._model_config = json.loads(args["model_config"]) 199 | self._kind = args["model_instance_kind"] 200 | self._device_id = args["model_instance_device_id"] 201 | self._support_batching = self._model_config["max_batch_size"] > 0 202 | self._inputs = _parse_io_config(self._model_config["input"]) 203 | self._outputs = _parse_io_config(self._model_config["output"]) 204 | 205 | setting_msg = _set_torch_parallelism(self._model_config) 206 | self._logger.log_verbose( 207 | "Torch parallelism settings " + for_model + ": " + setting_msg 208 | ) 209 | 210 | self._infer_mode = torch.inference_mode(mode=True) 211 | self._infer_mode.__enter__() 212 | 213 | params = _get_torch_compile_params(self._model_config) 214 | self._logger.log_verbose( 215 | "'torch.compile' optional parameter(s) " + for_model + ": " + str(params) 216 | ) 217 | if self._support_batching: 218 | self._gather = torch.compile(_gather_torch_tensors, **params) 219 | self._scatter = torch.compile(_scatter_torch_tensors, **params) 220 | 221 | model_path = _get_model_path(self._model_config) 222 | if not _is_py_class_model(model_path): 223 | self._logger.log_info("Loading '" + self._model_name + "' as TorchScript") 224 | self._model = torch.jit.load(model_path) 225 | self._device = _get_device(self._kind, self._device_id, self._model) 226 | self._model.to(self._device) 227 | self._model.eval() 228 | return 229 | 230 | self._model_module = _import_module_from_path(self._model_name, model_path) 231 | self._model_class = _get_model_class_from_module(self._model_module) 232 | self._raw_model = self._model_class() 233 | self._device = _get_device(self._kind, self._device_id, self._raw_model) 234 | data_path = _get_model_data_path(model_path) 235 | if data_path != "": 236 | self._raw_model.load_state_dict( 237 | torch.load(data_path, map_location=self._device) 238 | ) 239 | else: 240 | self._logger.log_info("Model parameter file not found " + for_model) 241 | self._raw_model.to(self._device) 242 | self._raw_model.eval() 243 | self._model = torch.compile(self._raw_model, **params) 244 | 245 | def execute(self, requests): 246 | """`execute` MUST be implemented in every Python model. `execute` 247 | function receives a list of pb_utils.InferenceRequest as the only 248 | argument. This function is called when an inference request is made 249 | for this model. Depending on the batching configuration (e.g. Dynamic 250 | Batching) used, `requests` may contain multiple requests. Every 251 | Python model, must create one pb_utils.InferenceResponse for every 252 | pb_utils.InferenceRequest in `requests`. If there is an error, you can 253 | set the error argument when creating a pb_utils.InferenceResponse 254 | Parameters 255 | ---------- 256 | requests : list 257 | A list of pb_utils.InferenceRequest 258 | Returns 259 | ------- 260 | list 261 | A list of pb_utils.InferenceResponse. The length of this list must 262 | be the same as `requests` 263 | """ 264 | 265 | responses = [] 266 | 267 | requests_tensors = [] 268 | for request in requests: 269 | tensors = [] 270 | for io in self._inputs: 271 | tensor = pb_utils.get_input_tensor_by_name( 272 | request, io["name"] 273 | ).to_dlpack() 274 | tensor = torch.from_dlpack(tensor).to(self._device) 275 | tensors.append(tensor) 276 | requests_tensors.append(tensors) 277 | 278 | sections = None 279 | if self._support_batching: 280 | requests_tensors, sections = self._gather(requests_tensors) 281 | requests_tensors = [requests_tensors] 282 | 283 | responses_tensors = [] 284 | for input_tensors in requests_tensors: 285 | output_tensors = self._model(*input_tensors) 286 | if not isinstance(output_tensors, tuple) and not isinstance( 287 | output_tensors, list 288 | ): 289 | output_tensors = [output_tensors] 290 | responses_tensors.append(output_tensors) 291 | 292 | if self._support_batching: 293 | responses_tensors = self._scatter(responses_tensors[0], sections) 294 | 295 | for response_tensors in responses_tensors: 296 | output_tensors = [] 297 | for i in range(len(self._outputs)): 298 | io = self._outputs[i] 299 | tensor = response_tensors[i].detach() 300 | tensor = pb_utils.Tensor.from_dlpack(io["name"], tensor) 301 | output_tensors.append(tensor) 302 | inference_response = pb_utils.InferenceResponse( 303 | output_tensors=output_tensors 304 | ) 305 | responses.append(inference_response) 306 | 307 | return responses 308 | 309 | def finalize(self): 310 | """`finalize` is called only once when the model is being unloaded. 311 | Implementing `finalize` function is OPTIONAL. This function allows 312 | the model to perform any necessary clean ups before exit. 313 | """ 314 | self._logger.log_info("Removing model instance for '" + self._model_name + "'") 315 | self._infer_mode.__exit__(exc_type=None, exc_value=None, traceback=None) 316 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 28 | 29 | # PyTorch (LibTorch) Backend 30 | 31 | [![License](https://img.shields.io/badge/License-BSD3-lightgrey.svg)](https://opensource.org/licenses/BSD-3-Clause) 32 | 33 | The Triton backend for 34 | [PyTorch](https://github.com/pytorch/pytorch) 35 | is designed to run 36 | [TorchScript](https://pytorch.org/docs/stable/jit.html) 37 | models using the PyTorch C++ API. 38 | All models created in PyTorch using the python API must be traced/scripted to produce a TorchScript model. 39 | 40 | You can learn more about Triton backends in the 41 | [Triton Backend](https://github.com/triton-inference-server/backend) 42 | repository. 43 | 44 | Ask questions or report problems using 45 | [Triton Server issues](https://github.com/triton-inference-server/server/issues). 46 | 47 | Be sure to read all the information below as well as the 48 | [general Triton documentation](https://github.com/triton-inference-server/server#triton-inference-server) 49 | available in the [Triton Server](https://github.com/triton-inference-server/server) repository. 50 | 51 | ## Build the PyTorch Backend 52 | 53 | Use a recent cmake to build. 54 | First install the required dependencies. 55 | 56 | ```bash 57 | apt-get install rapidjson-dev python3-dev python3-pip 58 | pip3 install patchelf==0.17.2 59 | ``` 60 | 61 | An appropriate PyTorch container from [NVIDIA NGC Catalog](https://ngc.nvidia.com) must be used. 62 | For example, to build a backend that uses the 23.04 version of the PyTorch container from NGC: 63 | 64 | ```bash 65 | mkdir build 66 | cd build 67 | cmake -DCMAKE_INSTALL_PREFIX:PATH=`pwd`/install -DTRITON_PYTORCH_DOCKER_IMAGE="nvcr.io/nvidia/pytorch:23.04-py3" .. 68 | make install 69 | ``` 70 | 71 | The following required Triton repositories will be pulled and used in the build. 72 | By default, the `main` head will be used for each repository but the listed CMake argument can be used to override the value. 73 | 74 | * triton-inference-server/backend: `-DTRITON_BACKEND_REPO_TAG=[tag]` 75 | * triton-inference-server/core: `-DTRITON_CORE_REPO_TAG=[tag]` 76 | * triton-inference-server/common: `-DTRITON_COMMON_REPO_TAG=[tag]` 77 | 78 | ## Build the PyTorch Backend With Custom PyTorch 79 | 80 | Currently, Triton requires that a specially patched version of PyTorch be used with the PyTorch backend. 81 | The full source for these PyTorch versions are available as Docker images from 82 | [NGC](https://ngc.nvidia.com). 83 | 84 | For example, the PyTorch version compatible with the 25.09 release of Triton is available as `nvcr.io/nvidia/pytorch:25.09-py3` which supports PyTorch version `2.9.0a0`. 85 | 86 | > [!NOTE] 87 | > Additional details and version information can be found in the container's 88 | > [release notes](https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-25-09.html#rel-25-09). 89 | 90 | Copy over the LibTorch and TorchVision headers and libraries from the 91 | [PyTorch NGC container](https://ngc.nvidia.com/catalog/containers/nvidia:pytorch) 92 | into local directories. 93 | You can see which headers and libraries are needed/copied from the docker. 94 | 95 | ```bash 96 | mkdir build 97 | cd build 98 | cmake -DCMAKE_INSTALL_PREFIX:PATH=`pwd`/install -DTRITON_PYTORCH_INCLUDE_PATHS="/torch;/torch/torch/csrc/api/include;/torchvision" -DTRITON_PYTORCH_LIB_PATHS="" .. 99 | make install 100 | ``` 101 | 102 | ## Using the PyTorch Backend 103 | 104 | ### PyTorch 2.0 Models 105 | 106 | PyTorch 2.0 features are available. 107 | However, Triton's PyTorch backend requires a serialized representation of the model in the form a `model.pt` file. 108 | The serialized representation of the model can be generated using PyTorch's 109 | [`torch.save()`](https://docs.pytorch.org/tutorials/beginner/saving_loading_models.html#id1) 110 | function to generate the `model.pt` file. 111 | 112 | The model repository should look like: 113 | 114 | ```bash 115 | model_repository/ 116 | `-- model_directory 117 | |-- 1 118 | | `-- model.pt 119 | `-- config.pbtxt 120 | ``` 121 | 122 | Where `model.pt` is the serialized representation of the model. 123 | 124 | ### TorchScript Models 125 | 126 | The model repository should look like: 127 | 128 | ```bash 129 | model_repository/ 130 | `-- model_directory 131 | |-- 1 132 | | `-- model.pt 133 | `-- config.pbtxt 134 | ``` 135 | 136 | The `model.pt` is the TorchScript model file. 137 | 138 | ## Configuration 139 | 140 | Triton exposes some flags to control the execution mode of the TorchScript models through the `Parameters` section of the model's `config.pbtxt` file. 141 | 142 | ### Configuration Options 143 | 144 | * `default_model_name`: 145 | Instructs the Triton PyTorch backend to load the model from a file of the given name. 146 | 147 | The model config specifying the option would look like: 148 | 149 | ```proto 150 | default_model_name: "another_file_name.pt" 151 | ``` 152 | 153 | ### Parameters 154 | 155 | * `DISABLE_OPTIMIZED_EXECUTION`: 156 | Boolean flag to disable the optimized execution of TorchScript models. 157 | By default, the optimized execution is always enabled. 158 | 159 | The initial calls to a loaded TorchScript model take a significant amount of time. 160 | Due to this longer model warmup 161 | ([pytorch #57894](https://github.com/pytorch/pytorch/issues/57894)), 162 | Triton also allows execution of models without these optimizations. 163 | In some models, optimized execution does not benefit performance 164 | ([pytorch #19978](https://github.com/pytorch/pytorch/issues/19978)) 165 | and in other cases impacts performance negatively 166 | ([pytorch #53824](https://github.com/pytorch/pytorch/issues/53824)). 167 | 168 | The section of model config file specifying this parameter will look like: 169 | 170 | ```proto 171 | parameters: { 172 | key: "DISABLE_OPTIMIZED_EXECUTION" 173 | value: { string_value: "true" } 174 | } 175 | ``` 176 | 177 | * `INFERENCE_MODE`: 178 | 179 | Boolean flag to enable the Inference Mode execution of TorchScript models. 180 | By default, the inference mode is enabled. 181 | 182 | [InferenceMode](https://pytorch.org/cppdocs/notes/inference_mode.html) is a new RAII guard analogous to `NoGradMode` to be used when you are certain your operations will have no interactions with autograd. 183 | Compared to `NoGradMode`, code run under this mode gets better performance by disabling autograd. 184 | 185 | Please note that in some models, InferenceMode might not benefit performance and in fewer cases might impact performance negatively. 186 | 187 | To enable inference mode, use the configuration example below: 188 | 189 | ```proto 190 | parameters: { 191 | key: "INFERENCE_MODE" 192 | value: { string_value: "true" } 193 | } 194 | ``` 195 | 196 | * `DISABLE_CUDNN`: 197 | 198 | Boolean flag to disable the cuDNN library. 199 | By default, cuDNN is enabled. 200 | 201 | [cuDNN](https://developer.nvidia.com/cudnn) is a GPU-accelerated library of primitives for deep neural networks. 202 | It provides highly tuned implementations for standard routines. 203 | 204 | Typically, models run with cuDNN enabled execute faster. 205 | However there are some exceptions where using cuDNN can be slower, cause higher memory usage, or result in errors. 206 | 207 | To disable cuDNN, use the configuration example below: 208 | 209 | ```proto 210 | parameters: { 211 | key: "DISABLE_CUDNN" 212 | value: { string_value: "true" } 213 | } 214 | ``` 215 | 216 | * `ENABLE_WEIGHT_SHARING`: 217 | 218 | Boolean flag to enable model instances on the same device to share weights. 219 | This optimization should not be used with stateful models. 220 | If not specified, weight sharing is disabled. 221 | 222 | To enable weight sharing, use the configuration example below: 223 | 224 | ```proto 225 | parameters: { 226 | key: "ENABLE_WEIGHT_SHARING" 227 | value: { string_value: "true" } 228 | } 229 | ``` 230 | 231 | * `ENABLE_CACHE_CLEANING`: 232 | 233 | Boolean flag to enable CUDA cache cleaning after each model execution. 234 | If not specified, cache cleaning is disabled. 235 | This flag has no effect if model is on CPU. 236 | 237 | Setting this flag to true will likely negatively impact the performance due to additional CUDA cache cleaning operation after each model execution. 238 | Therefore, you should only use this flag if you serve multiple models with Triton and encounter CUDA out-of-memory issues during model executions. 239 | 240 | To enable cleaning of the CUDA cache after every execution, use the configuration example below: 241 | 242 | ```proto 243 | parameters: { 244 | key: "ENABLE_CACHE_CLEANING" 245 | value: { string_value: "true" } 246 | } 247 | ``` 248 | 249 | * `INTER_OP_THREAD_COUNT`: 250 | 251 | PyTorch allows using multiple CPU threads during TorchScript model inference. 252 | One or more inference threads execute a model’s forward pass on the given inputs. 253 | Each inference thread invokes a JIT interpreter that executes the ops of a model inline, one by one. 254 | 255 | This parameter sets the size of this thread pool. 256 | The default value of this setting is the number of cpu cores. 257 | 258 | > [!TIP] 259 | > Refer to 260 | > [CPU Threading TorchScript](https://pytorch.org/docs/stable/notes/cpu_threading_torchscript_inference.html) 261 | > on how to set this parameter properly. 262 | 263 | To set the inter-op thread count, use the configuration example below: 264 | 265 | ```proto 266 | parameters: { 267 | key: "INTER_OP_THREAD_COUNT" 268 | value: { string_value: "1" } 269 | } 270 | ``` 271 | 272 | > [!NOTE] 273 | > This parameter is set globally for the PyTorch backend. 274 | > The value from the first model config file that specifies this parameter will be used. 275 | > Subsequent values from other model config files, if different, will be ignored. 276 | 277 | * `INTRA_OP_THREAD_COUNT`: 278 | 279 | In addition to the inter-op parallelism, PyTorch can also utilize multiple threads within the ops (intra-op parallelism). 280 | This can be useful in many cases, including element-wise ops on large tensors, convolutions, GEMMs, embedding lookups and others. 281 | 282 | The default value for this setting is the number of CPU cores. 283 | 284 | > [!TIP] 285 | > Refer to 286 | > [CPU Threading TorchScript](https://pytorch.org/docs/stable/notes/cpu_threading_torchscript_inference.html) 287 | > on how to set this parameter properly. 288 | 289 | To set the intra-op thread count, use the configuration example below: 290 | 291 | ```proto 292 | parameters: { 293 | key: "INTRA_OP_THREAD_COUNT" 294 | value: { string_value: "1" } 295 | } 296 | ``` 297 | 298 | * **Additional Optimizations**: 299 | 300 | Three additional boolean parameters are available to disable certain Torch optimizations that can sometimes cause latency regressions in models with complex execution modes and dynamic shapes. 301 | If not specified, all are enabled by default. 302 | 303 | `ENABLE_JIT_EXECUTOR` 304 | 305 | `ENABLE_JIT_PROFILING` 306 | 307 | ### Model Instance Group Kind 308 | 309 | The PyTorch backend supports the following kinds of 310 | [Model Instance Groups](https://github.com/triton-inference-server/server/blob/main/docs/user_guide/model_configuration.md#instance-groups) 311 | where the input tensors are placed as follows: 312 | 313 | * `KIND_GPU`: 314 | 315 | Inputs are prepared on the GPU device associated with the model instance. 316 | 317 | * `KIND_CPU`: 318 | 319 | Inputs are prepared on the CPU. 320 | 321 | * `KIND_MODEL`: 322 | 323 | Inputs are prepared on the CPU. 324 | When loading the model, the backend does not choose the GPU device for the model; 325 | instead, it respects the device(s) specified in the model and uses them as they are during inference. 326 | 327 | This is useful when the model internally utilizes multiple GPUs, as demonstrated in 328 | [this example model](https://github.com/triton-inference-server/server/blob/main/qa/L0_libtorch_instance_group_kind_model/gen_models.py). 329 | 330 | > [!IMPORTANT] 331 | > If a device is not specified in the model, the backend uses the first available GPU device. 332 | 333 | To set the model instance group, use the configuration example below: 334 | 335 | ```proto 336 | instance_group { 337 | count: 2 338 | kind: KIND_GPU 339 | } 340 | ``` 341 | 342 | ### Customization 343 | 344 | The following PyTorch settings may be customized by setting parameters on the 345 | `config.pbtxt`. 346 | 347 | [`torch.set_num_threads(int)`](https://pytorch.org/docs/stable/generated/torch.set_num_threads.html#torch.set_num_threads) 348 | 349 | * Key: `NUM_THREADS` 350 | * Value: The number of threads used for intra-op parallelism on CPU. 351 | 352 | [`torch.set_num_interop_threads(int)`](https://pytorch.org/docs/stable/generated/torch.set_num_interop_threads.html#torch.set_num_interop_threads) 353 | 354 | * Key: `NUM_INTEROP_THREADS` 355 | * Value: The number of threads used for interop parallelism (e.g. in JIT interpreter) on CPU. 356 | 357 | [`torch.compile()` parameters](https://pytorch.org/docs/stable/generated/torch.compile.html#torch-compile) 358 | 359 | * Key: `TORCH_COMPILE_OPTIONAL_PARAMETERS` 360 | * Value: Any of following parameter(s) encoded as a JSON object. 361 | * `fullgraph` (`bool`): Whether it is ok to break model into several subgraphs. 362 | * `dynamic` (`bool`): Use dynamic shape tracing. 363 | * `backend` (`str`): The backend to be used. 364 | * `mode` (`str`): Can be either `"default"`, `"reduce-overhead"`, or `"max-autotune"`. 365 | * `options` (`dict`): A dictionary of options to pass to the backend. 366 | * `disable` (`bool`): Turn `torch.compile()` into a no-op for testing. 367 | 368 | For example: 369 | 370 | ```proto 371 | parameters: { 372 | key: "NUM_THREADS" 373 | value: { string_value: "4" } 374 | } 375 | parameters: { 376 | key: "TORCH_COMPILE_OPTIONAL_PARAMETERS" 377 | value: { string_value: "{\"disable\": true}" } 378 | } 379 | ``` 380 | 381 | ## Important Notes 382 | 383 | * The execution of PyTorch model on GPU is asynchronous in nature. 384 | See 385 | [CUDA Asynchronous Execution](https://pytorch.org/docs/stable/notes/cuda.html#asynchronous-execution) 386 | for additional details. 387 | Consequently, an error in PyTorch model execution may be raised during the next few inference requests to the server. 388 | Setting environment variable `CUDA_LAUNCH_BLOCKING=1` when launching server will help in correctly debugging failing cases by forcing synchronous execution. 389 | 390 | * The PyTorch model in such cases may or may not recover from the failed state and a restart of the server may be required to continue serving successfully. 391 | 392 | * PyTorch does not support Tensor of Strings but it does support models that accept a List of Strings as input(s) / produces a List of String as output(s). 393 | For these models Triton allows users to pass String input(s)/receive String output(s) using the String datatype. 394 | As a limitation of using List instead of Tensor for String I/O, only for 1-dimensional input(s)/output(s) are supported for I/O of String type. 395 | 396 | * In a multi-GPU environment, a potential runtime issue can occur when using 397 | [Tracing](https://pytorch.org/docs/stable/generated/torch.jit.trace.html) 398 | to generate a 399 | [TorchScript](https://pytorch.org/docs/stable/jit.html) 400 | model. 401 | This issue arises due to a device mismatch between the model instance and the tensor. 402 | 403 | By default, Triton creates a single execution instance of the model for each available GPU. 404 | The runtime error occurs when a request is sent to a model instance with a different GPU device from the one used during the TorchScript generation process. 405 | 406 | To address this problem, it is highly recommended to use 407 | [Scripting](https://pytorch.org/docs/stable/generated/torch.jit.script.html#torch.jit.script) 408 | instead of Tracing for model generation in a multi-GPU environment. 409 | Scripting avoids the device mismatch issue and ensures compatibility with different GPUs when used with Triton. 410 | 411 | However, if using Tracing is unavoidable, there is a workaround available. 412 | You can explicitly specify the GPU device for the model instance in the 413 | [model configuration](https://github.com/triton-inference-server/server/blob/main/docs/user_guide/model_configuration.md#instance-groups) 414 | to ensure that the model instance and the tensors used for inference are assigned to the same GPU device as on which the model was traced. 415 | 416 | * When using `KIND_MODEL` as model instance kind, the default device of the first parameter on the model is used. 417 | 418 | > [!WARNING] 419 | > 420 | > * Python functions optimizable by `torch.compile` may not be served directly in the `model.py` file, they need to be enclosed by a class extending the 421 | [`torch.nn.Module`](https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module). 422 | > 423 | > * Model weights cannot be shared across multiple instances on the same GPU device. 424 | -------------------------------------------------------------------------------- /src/model_state.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2019-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // Redistribution and use in source and binary forms, with or without 4 | // modification, are permitted provided that the following conditions 5 | // are met: 6 | // * Redistributions of source code must retain the above copyright 7 | // notice, this list of conditions and the following disclaimer. 8 | // * Redistributions in binary form must reproduce the above copyright 9 | // notice, this list of conditions and the following disclaimer in the 10 | // documentation and/or other materials provided with the distribution. 11 | // * Neither the name of NVIDIA CORPORATION nor the names of its 12 | // contributors may be used to endorse or promote products derived 13 | // from this software without specific prior written permission. 14 | // 15 | // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY 16 | // EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 17 | // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 18 | // PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR 19 | // CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 20 | // EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 21 | // PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 22 | // PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY 23 | // OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 24 | // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | 27 | #include "model_state.hh" 28 | 29 | #include 30 | 31 | 32 | namespace { 33 | std::once_flag pytorch_interop_threads_flag; 34 | std::once_flag pytorch_intraop_threads_flag; 35 | } // namespace 36 | 37 | namespace triton::backend::pytorch { 38 | 39 | ModelState::ModelState(TRITONBACKEND_Model* triton_model) 40 | : BackendModel(triton_model), enable_optimized_execution_(true), 41 | enable_inference_mode_(true), enable_cudnn_(true), 42 | enable_cache_cleaning_(false), enable_weight_sharing_(false), 43 | enable_tensor_fuser_pair_({false, true}), 44 | enable_jit_profiling_pair_({false, true}), 45 | enable_jit_executor_pair_({false, true}) 46 | { 47 | } 48 | 49 | TRITONSERVER_Error* 50 | ModelState::AutoCompleteConfig() 51 | { 52 | // Auto-complete configuration is not supported since PyTorch does not 53 | // store/capture sufficient model metadata so just log error instead. 54 | LOG_MESSAGE( 55 | TRITONSERVER_LOG_WARN, 56 | (std::string("skipping model configuration auto-complete for '") + 57 | Name() + "': not supported for pytorch backend") 58 | .c_str()); 59 | 60 | return nullptr; // success 61 | } 62 | 63 | TRITONSERVER_Error* 64 | ModelState::Create(TRITONBACKEND_Model* triton_model, ModelState** state) 65 | { 66 | try { 67 | *state = new ModelState(triton_model); 68 | } 69 | catch (const BackendModelException& ex) { 70 | RETURN_ERROR_IF_TRUE( 71 | ex.err_ == nullptr, TRITONSERVER_ERROR_INTERNAL, 72 | std::string("unexpected nullptr in BackendModelException")); 73 | RETURN_IF_ERROR(ex.err_); 74 | } 75 | 76 | // Auto-complete the configuration if requested... 77 | bool auto_complete_config = false; 78 | RETURN_IF_ERROR(TRITONBACKEND_ModelAutoCompleteConfig( 79 | triton_model, &auto_complete_config)); 80 | if (auto_complete_config) { 81 | RETURN_IF_ERROR((*state)->AutoCompleteConfig()); 82 | RETURN_IF_ERROR((*state)->SetModelConfig()); 83 | } 84 | 85 | auto& model_outputs = (*state)->model_outputs_; 86 | // Parse the output states in the model configuration 87 | triton::common::TritonJson::Value sequence_batching; 88 | if ((*state)->ModelConfig().Find("sequence_batching", &sequence_batching)) { 89 | triton::common::TritonJson::Value states; 90 | if (sequence_batching.Find("state", &states)) { 91 | for (size_t i = 0; i < states.ArraySize(); i++) { 92 | triton::common::TritonJson::Value state; 93 | RETURN_IF_ERROR(states.IndexAsObject(i, &state)); 94 | std::string output_state_name; 95 | RETURN_IF_ERROR( 96 | state.MemberAsString("output_name", &output_state_name)); 97 | auto it = model_outputs.find(output_state_name); 98 | if (it == model_outputs.end()) { 99 | model_outputs.insert({output_state_name, std::make_pair(-1, i)}); 100 | } else { 101 | it->second.second = i; 102 | } 103 | } 104 | } 105 | } 106 | 107 | // Parse the output names in the model configuration 108 | triton::common::TritonJson::Value outputs; 109 | RETURN_IF_ERROR((*state)->ModelConfig().MemberAsArray("output", &outputs)); 110 | for (size_t i = 0; i < outputs.ArraySize(); i++) { 111 | triton::common::TritonJson::Value output; 112 | THROW_IF_BACKEND_INSTANCE_ERROR(outputs.IndexAsObject(i, &output)); 113 | 114 | // Use names from ModelConfig by reference since the model 115 | // config will persist longer than this inference execution. 116 | std::string output_name; 117 | THROW_IF_BACKEND_INSTANCE_ERROR( 118 | output.MemberAsString("name", &output_name)); 119 | 120 | auto it = model_outputs.find(output_name); 121 | if (it == model_outputs.end()) { 122 | model_outputs.insert({output_name, std::make_pair(i, -1)}); 123 | } else { 124 | it->second.first = i; 125 | } 126 | } 127 | 128 | RETURN_IF_ERROR((*state)->ParseParameters()); 129 | 130 | return nullptr; // success 131 | } 132 | 133 | bool 134 | ModelState::EnabledCacheCleaning() 135 | { 136 | return enable_cache_cleaning_; 137 | } 138 | 139 | bool 140 | ModelState::EnabledCudnn() 141 | { 142 | return enable_cudnn_; 143 | } 144 | 145 | bool 146 | ModelState::EnabledInferenceMode() 147 | { 148 | return enable_inference_mode_; 149 | } 150 | 151 | const std::pair& 152 | ModelState::EnabledJitExecutor() const 153 | { 154 | return enable_jit_executor_pair_; 155 | } 156 | 157 | const std::pair& 158 | ModelState::EnabledJitProfiling() const 159 | { 160 | return enable_jit_profiling_pair_; 161 | } 162 | 163 | bool 164 | ModelState::EnabledOptimizedExecution() 165 | { 166 | return enable_optimized_execution_; 167 | } 168 | 169 | const std::pair& 170 | ModelState::EnabledTensorExprFuser() const 171 | { 172 | return enable_tensor_fuser_pair_; 173 | } 174 | 175 | bool 176 | ModelState::EnabledWeightSharing() 177 | { 178 | return enable_weight_sharing_; 179 | } 180 | 181 | TRITONSERVER_Error* 182 | ModelState::LoadModel( 183 | const std::string& artifact_name, const torch::Device device, 184 | std::string* model_path, const TRITONSERVER_InstanceGroupKind& kind, 185 | std::shared_ptr* torch_model) 186 | { 187 | // Find the TorchScript file that describes the model. If the model 188 | // configuration doesn't have an explicit model file specified then 189 | // use the default name ("model.pt"). 190 | std::string cc_model_filename = artifact_name; 191 | if (cc_model_filename.empty()) { 192 | cc_model_filename = "model.pt"; 193 | } 194 | 195 | *model_path = JoinPath( 196 | {RepositoryPath(), std::to_string(Version()), cc_model_filename}); 197 | 198 | { 199 | bool exists; 200 | RETURN_IF_ERROR(FileExists(*model_path, &exists)); 201 | RETURN_ERROR_IF_FALSE( 202 | exists, TRITONSERVER_ERROR_UNAVAILABLE, 203 | std::string("unable to find '") + *model_path + 204 | "' for model instance '" + Name() + "'"); 205 | } 206 | 207 | // If weight sharing is enabled, skip loading model if 208 | // it is already available on the target device 209 | std::pair device_pair; 210 | if (enable_weight_sharing_) { 211 | device_pair = std::make_pair(!device.is_cpu(), device.index()); 212 | auto mit = torch_models_.find(device_pair); 213 | if (mit != torch_models_.end()) { 214 | *torch_model = mit->second; 215 | LOG_MESSAGE( 216 | TRITONSERVER_LOG_INFO, 217 | (std::string("Reusing TorchScript model for instance '") + Name() + 218 | "'") 219 | .c_str()); 220 | return nullptr; // success 221 | } 222 | } 223 | 224 | // Serialize the torch model to string 225 | std::string model_data_str; 226 | RETURN_IF_ERROR(ReadTextFile(*model_path, &model_data_str)); 227 | 228 | // InferenceMode should be used to guard all tensors operations including 229 | // model loading: https://pytorch.org/cppdocs/notes/inference_mode.html 230 | torch::InferenceMode infer_guard(EnabledInferenceMode()); 231 | 232 | try { 233 | std::istringstream model_stream(model_data_str); 234 | if (kind == TRITONSERVER_INSTANCEGROUPKIND_MODEL) { 235 | // Load the model without selecting a device. 236 | torch_model->reset( 237 | new torch::jit::Module(torch::jit::load(model_stream))); 238 | } else { 239 | torch_model->reset( 240 | new torch::jit::Module(torch::jit::load(model_stream, device))); 241 | } 242 | } 243 | catch (const std::exception& ex) { 244 | return TRITONSERVER_ErrorNew( 245 | TRITONSERVER_ERROR_INTERNAL, 246 | ("failed to load model '" + Name() + "': " + ex.what()).c_str()); 247 | } 248 | 249 | if (enable_weight_sharing_) { 250 | if (!((torch_models_.emplace(device_pair, *torch_model)).second)) { 251 | std::string type = device.is_cpu() ? "CPU" : "GPU"; 252 | LOG_MESSAGE( 253 | TRITONSERVER_LOG_WARN, 254 | (std::string("Model already found on target ") + type + " device " + 255 | "(id " + std::to_string(device.index()) + ") for '" + Name() + "'") 256 | .c_str()); 257 | } 258 | } 259 | 260 | return nullptr; // success 261 | } 262 | 263 | const std::map>& 264 | ModelState::ModelOutputs() 265 | { 266 | return model_outputs_; 267 | } 268 | 269 | TRITONSERVER_Error* 270 | ModelState::ParseParameters() 271 | { 272 | triton::common::TritonJson::Value params; 273 | bool status = model_config_.Find("parameters", ¶ms); 274 | if (status) { 275 | // If 'DISABLE_OPTIMIZED_EXECUTION' is not present in 'parameters' then no 276 | // update is made to 'enable_optimized_execution_'. 277 | bool disable_optimized_execution = false; 278 | TRITONSERVER_Error* err = ParseParameter( 279 | params, "DISABLE_OPTIMIZED_EXECUTION", &disable_optimized_execution); 280 | if (err != nullptr) { 281 | if (TRITONSERVER_ErrorCode(err) != TRITONSERVER_ERROR_NOT_FOUND) { 282 | return err; 283 | } else { 284 | TRITONSERVER_ErrorDelete(err); 285 | } 286 | } 287 | enable_optimized_execution_ = !disable_optimized_execution; 288 | 289 | LOG_MESSAGE( 290 | TRITONSERVER_LOG_INFO, 291 | (std::string("Optimized execution is ") + 292 | (enable_optimized_execution_ ? "enabled" : "disabled") + 293 | " for model instance '" + Name() + "'") 294 | .c_str()); 295 | 296 | // If 'ENABLE_CACHE_CLEANING' is not present in 'parameters' then 297 | // no update is made to 'enable_cache_cleaning_'. 298 | err = ParseParameter( 299 | params, "ENABLE_CACHE_CLEANING", &enable_cache_cleaning_); 300 | if (err != nullptr) { 301 | if (TRITONSERVER_ErrorCode(err) != TRITONSERVER_ERROR_NOT_FOUND) { 302 | return err; 303 | } else { 304 | TRITONSERVER_ErrorDelete(err); 305 | } 306 | } 307 | 308 | LOG_MESSAGE( 309 | TRITONSERVER_LOG_INFO, 310 | (std::string("Cache Cleaning is ") + 311 | (enable_cache_cleaning_ ? "enabled" : "disabled") + 312 | " for model instance '" + Name() + "'") 313 | .c_str()); 314 | 315 | // If 'INFERENCE_MODE' is not present in 'parameters' then no update is made 316 | // to 'enable_inference_mode_'. 317 | err = ParseParameter(params, "INFERENCE_MODE", &enable_inference_mode_); 318 | if (err != nullptr) { 319 | if (TRITONSERVER_ErrorCode(err) != TRITONSERVER_ERROR_NOT_FOUND) { 320 | return err; 321 | } else { 322 | TRITONSERVER_ErrorDelete(err); 323 | } 324 | } 325 | LOG_MESSAGE( 326 | TRITONSERVER_LOG_INFO, 327 | (std::string("Inference Mode is ") + 328 | (enable_inference_mode_ ? "enabled" : "disabled") + 329 | " for model instance '" + Name() + "'") 330 | .c_str()); 331 | 332 | // If 'DISABLE_CUDNN' is not present in 'parameters' then no update is made 333 | // to 'enable_cudnn_'. 334 | bool disable_cudnn = false; 335 | err = ParseParameter(params, "DISABLE_CUDNN", &disable_cudnn); 336 | if (err != nullptr) { 337 | if (TRITONSERVER_ErrorCode(err) != TRITONSERVER_ERROR_NOT_FOUND) { 338 | return err; 339 | } else { 340 | TRITONSERVER_ErrorDelete(err); 341 | } 342 | } 343 | enable_cudnn_ = !disable_cudnn; 344 | LOG_MESSAGE( 345 | TRITONSERVER_LOG_INFO, 346 | (std::string("cuDNN is ") + (enable_cudnn_ ? "enabled" : "disabled") + 347 | " for model instance '" + Name() + "'") 348 | .c_str()); 349 | 350 | // If 'ENABLE_TENSOR_FUSER' is not present in 'parameters' then no 351 | // update is made to 'enable_tensor_fuser'. 352 | bool enable_tensor_fuser = false; 353 | err = ParseParameter(params, "ENABLE_TENSOR_FUSER", &enable_tensor_fuser); 354 | if (err != nullptr) { 355 | if (TRITONSERVER_ErrorCode(err) != TRITONSERVER_ERROR_NOT_FOUND) { 356 | return err; 357 | } else { 358 | TRITONSERVER_ErrorDelete(err); 359 | } 360 | } else { 361 | enable_tensor_fuser_pair_ = {true, enable_tensor_fuser}; 362 | LOG_MESSAGE( 363 | TRITONSERVER_LOG_INFO, 364 | (std::string("Tensor fuser is ") + 365 | (enable_tensor_fuser ? "enabled" : "disabled") + 366 | " for model instance '" + Name() + "'") 367 | .c_str()); 368 | } 369 | 370 | // If 'ENABLE_WEIGHT_SHARING' is not present in 'parameters' then no 371 | // update is made to 'enable_weight_sharing'. 372 | err = ParseParameter( 373 | params, "ENABLE_WEIGHT_SHARING", &enable_weight_sharing_); 374 | if (err != nullptr) { 375 | if (TRITONSERVER_ErrorCode(err) != TRITONSERVER_ERROR_NOT_FOUND) { 376 | return err; 377 | } else { 378 | TRITONSERVER_ErrorDelete(err); 379 | } 380 | } else { 381 | LOG_MESSAGE( 382 | TRITONSERVER_LOG_INFO, 383 | (std::string("Weight sharing is ") + 384 | (enable_weight_sharing_ ? "enabled" : "disabled") + 385 | " for model instance '" + Name() + "'") 386 | .c_str()); 387 | } 388 | 389 | // If 'ENABLE_JIT_PROFILING' is not present in 'parameters' then no update 390 | // is made to 'enable_jit_profiling'. 391 | bool enable_jit_profiling = false; 392 | err = ParseParameter(params, "ENABLE_JIT_PROFILING", &enable_jit_profiling); 393 | if (err != nullptr) { 394 | if (TRITONSERVER_ErrorCode(err) != TRITONSERVER_ERROR_NOT_FOUND) { 395 | return err; 396 | } else { 397 | TRITONSERVER_ErrorDelete(err); 398 | } 399 | } else { 400 | enable_jit_profiling_pair_ = {true, enable_jit_profiling}; 401 | LOG_MESSAGE( 402 | TRITONSERVER_LOG_INFO, 403 | (std::string("Jit profiling is ") + 404 | (enable_jit_profiling ? "enabled" : "disabled") + 405 | " for model instance '" + Name() + "'") 406 | .c_str()); 407 | } 408 | 409 | // If 'ENABLE_JIT_EXECUTOR' is not present in 'parameters' then no update is 410 | // made to 'enable_jit_executor'. 411 | bool enable_jit_executor = false; 412 | err = ParseParameter(params, "ENABLE_JIT_EXECUTOR", &enable_jit_executor); 413 | if (err != nullptr) { 414 | if (TRITONSERVER_ErrorCode(err) != TRITONSERVER_ERROR_NOT_FOUND) { 415 | return err; 416 | } else { 417 | TRITONSERVER_ErrorDelete(err); 418 | } 419 | } else { 420 | enable_jit_executor_pair_ = {true, enable_jit_executor}; 421 | LOG_MESSAGE( 422 | TRITONSERVER_LOG_INFO, 423 | (std::string("Jit executor is ") + 424 | (enable_jit_executor ? "enabled" : "disabled") + 425 | " for model instance '" + Name() + "'") 426 | .c_str()); 427 | } 428 | 429 | // If 'INTRA_OP_THREAD_COUNT' is not present in 'parameters' then no update 430 | // is made to 'intra_op_thread_count', which by default will take all 431 | // threads 432 | int intra_op_thread_count = -1; 433 | err = 434 | ParseParameter(params, "INTRA_OP_THREAD_COUNT", &intra_op_thread_count); 435 | if (err != nullptr) { 436 | if (TRITONSERVER_ErrorCode(err) != TRITONSERVER_ERROR_NOT_FOUND) { 437 | return err; 438 | } else { 439 | TRITONSERVER_ErrorDelete(err); 440 | } 441 | } else { 442 | if (intra_op_thread_count > 0) { 443 | // at::set_num_threads() does not throw if called more than once, but 444 | // issues warnings. std::call_once() is useful to limit these. 445 | std::call_once(pytorch_intraop_threads_flag, [intra_op_thread_count]() { 446 | at::set_num_threads(intra_op_thread_count); 447 | }); 448 | LOG_MESSAGE( 449 | TRITONSERVER_LOG_INFO, 450 | (std::string("Intra op thread count is set to ") + 451 | std::to_string(at::get_num_threads()) + " for model instance '" + 452 | Name() + "'") 453 | .c_str()); 454 | } 455 | } 456 | 457 | // If 'INTER_OP_THREAD_COUNT' is not present in 'parameters' then no update 458 | // is made to 'inter_op_thread_count', which by default will take all 459 | // threads 460 | int inter_op_thread_count = -1; 461 | err = 462 | ParseParameter(params, "INTER_OP_THREAD_COUNT", &inter_op_thread_count); 463 | if (err != nullptr) { 464 | if (TRITONSERVER_ErrorCode(err) != TRITONSERVER_ERROR_NOT_FOUND) { 465 | return err; 466 | } else { 467 | TRITONSERVER_ErrorDelete(err); 468 | } 469 | } else { 470 | if (inter_op_thread_count > 0) { 471 | // at::set_num_interop_threads() throws if called more than once. 472 | // std::call_once() should prevent this, but try/catch is additionally 473 | // used for safety. 474 | std::call_once(pytorch_interop_threads_flag, [inter_op_thread_count]() { 475 | try { 476 | at::set_num_interop_threads(inter_op_thread_count); 477 | } 478 | catch (const c10::Error& e) { 479 | // do nothing 480 | } 481 | }); 482 | LOG_MESSAGE( 483 | TRITONSERVER_LOG_INFO, 484 | (std::string("Inter op thread count is set to ") + 485 | std::to_string(at::get_num_interop_threads()) + 486 | " for model instance '" + Name() + "'") 487 | .c_str()); 488 | } 489 | } 490 | } 491 | 492 | return nullptr; 493 | } 494 | 495 | } // namespace triton::backend::pytorch 496 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # Copyright 2019-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # Redistribution and use in source and binary forms, with or without 4 | # modification, are permitted provided that the following conditions 5 | # are met: 6 | # * Redistributions of source code must retain the above copyright 7 | # notice, this list of conditions and the following disclaimer. 8 | # * Redistributions in binary form must reproduce the above copyright 9 | # notice, this list of conditions and the following disclaimer in the 10 | # documentation and/or other materials provided with the distribution. 11 | # * Neither the name of NVIDIA CORPORATION nor the names of its 12 | # contributors may be used to endorse or promote products derived 13 | # from this software without specific prior written permission. 14 | # 15 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY 16 | # EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 17 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 18 | # PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR 19 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 20 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 21 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 22 | # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY 23 | # OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 24 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | 27 | cmake_minimum_required (VERSION 3.31.8) 28 | 29 | project(tritonpytorchbackend LANGUAGES C CXX) 30 | 31 | # Use C++17 standard as Triton's minimum required. 32 | set(TRITON_MIN_CXX_STANDARD 17 CACHE STRING "The minimum C++ standard which features are requested to build this target.") 33 | 34 | # 35 | # Options 36 | # 37 | # To build the PyTorch backend you must either: 38 | # 39 | # - Point to the already built PyTorch and Torchvision using 40 | # TRITON_PYTORCH_INCLUDE_PATHS and TRITON_PYTORCH_LIB_PATHS 41 | # 42 | # or: 43 | # 44 | # - Set TRITON_PYTORCH_DOCKER_IMAGE to use the docker image of 45 | # PyTorch to base the build off. 46 | # 47 | 48 | option(TRITON_ENABLE_GPU "Enable GPU support in backend" ON) 49 | option(TRITON_ENABLE_STATS "Include statistics collections in backend" ON) 50 | option(TRITON_ENABLE_NVTX "Include nvtx markers collection in backend." OFF) 51 | option(TRITON_PYTORCH_ENABLE_TORCHTRT "Enable TorchTRT support" OFF) 52 | option(TRITON_PYTORCH_ENABLE_TORCHVISION "Enable Torchvision support" ON) 53 | option(TRITON_PYTORCH_NVSHMEM "Enable NVSHMEM support" ON) 54 | 55 | set(TRITON_PYTORCH_DOCKER_IMAGE "" CACHE STRING "Docker image containing the PyTorch build required by backend.") 56 | set(TRITON_PYTORCH_INCLUDE_PATHS "" CACHE PATH "Paths to Torch includes") 57 | set(TRITON_PYTORCH_LIB_PATHS "" CACHE PATH "Paths to Torch libraries") 58 | 59 | set(TRITON_REPO_ORGANIZATION "https://github.com/triton-inference-server" CACHE STRING "Git repository to pull from") 60 | set(TRITON_BACKEND_REPO_TAG "main" CACHE STRING "Tag for triton-inference-server/backend repo") 61 | set(TRITON_CORE_REPO_TAG "main" CACHE STRING "Tag for triton-inference-server/core repo") 62 | set(TRITON_COMMON_REPO_TAG "main" CACHE STRING "Tag for triton-inference-server/common repo") 63 | 64 | if(NOT CMAKE_BUILD_TYPE) 65 | set(CMAKE_BUILD_TYPE Release) 66 | endif() 67 | 68 | set(TRITON_PYTORCH_DOCKER_BUILD OFF) 69 | if(TRITON_PYTORCH_LIB_PATHS STREQUAL "") 70 | if(TRITON_PYTORCH_DOCKER_IMAGE STREQUAL "") 71 | message(FATAL_ERROR "Using the PyTorch docker based build requires TRITON_PYTORCH_DOCKER_IMAGE") 72 | endif() 73 | set(TRITON_PYTORCH_DOCKER_BUILD ON) 74 | message(STATUS "Using PyTorch docker: ${TRITON_PYTORCH_DOCKER_IMAGE}") 75 | else() 76 | # Look for installed Torch-TRT package in lib paths 77 | if(TRITON_PYTORCH_ENABLE_TORCHTRT AND NOT EXISTS "${TRITON_PYTORCH_LIB_PATHS}/libtorchtrt_runtime.so") 78 | message(WARNING "TRITON_PYTORCH_ENABLE_TORCHTRT is on, but TRITON_PYTORCH_LIB_PATHS does not contain Torch-TRT package") 79 | endif() 80 | 81 | # Look for installed TorchVision package in lib paths 82 | find_library(LIBTORCHVISION libtorchvision.so libtorchvision.so.1 PATHS ${TRITON_PYTORCH_LIB_PATHS}) 83 | if(NOT ${LIBTORCHVISION}) 84 | message(WARNING "TRITON_PYTORCH_ENABLE_TORCHVISION is on, but TRITON_PYTORCH_LIB_PATHS does not contain TorchVision package") 85 | endif(NOT ${LIBTORCHVISION}) 86 | endif() 87 | 88 | # Python.h needed by torch headers. 89 | find_package(Python3 REQUIRED COMPONENTS Development.Module) 90 | 91 | set(RHEL_BUILD OFF) 92 | set(LIB_DIR "lib") 93 | set(LIBTORCH_LIBS_PATH "/usr/local/lib") 94 | set(PY_INSTALL_PATH "/usr/local/lib/python3.12/dist-packages") 95 | if(LINUX) 96 | file(STRINGS "/etc/os-release" DISTRO_ID_LIKE REGEX "ID_LIKE") 97 | if(${DISTRO_ID_LIKE} MATCHES "rhel|centos") 98 | set(RHEL_BUILD ON) 99 | set(LIB_DIR "lib64") 100 | set(PY_INSTALL_PATH "/opt/_internal/cpython-3.12.1/lib/python3.12/site-packages") 101 | if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "x86_64") 102 | set(LIBTORCH_LIBS_PATH "/opt/_internal/cpython-3.12.1/lib") 103 | endif(${CMAKE_SYSTEM_PROCESSOR} MATCHES "x86_64") 104 | endif(${DISTRO_ID_LIKE} MATCHES "rhel|centos") 105 | endif(LINUX) 106 | 107 | message(TRACE "CMAKE_HOST_SYSTEM_PROCESSOR: ${CMAKE_HOST_SYSTEM_PROCESSOR}") 108 | message(TRACE "TRITON_ENABLE_GPU: ${TRITON_ENABLE_GPU}") 109 | message(TRACE "TRITON_ENABLE_STATS: ${TRITON_ENABLE_STATS}") 110 | message(TRACE "TRITON_ENABLE_NVTX: ${TRITON_ENABLE_NVTX}") 111 | message(TRACE "TRITON_PYTORCH_ENABLE_TORCHTRT: ${TRITON_PYTORCH_ENABLE_TORCHTRT}") 112 | message(TRACE "TRITON_PYTORCH_ENABLE_TORCHVISION: ${TRITON_PYTORCH_ENABLE_TORCHVISION}") 113 | message(TRACE "TRITON_PYTORCH_NVSHMEM: ${TRITON_PYTORCH_NVSHMEM}") 114 | message(TRACE "TRITON_PYTORCH_DOCKER_IMAGE: ${TRITON_PYTORCH_DOCKER_IMAGE}") 115 | message(TRACE "TRITON_PYTORCH_INCLUDE_PATHS: ${TRITON_PYTORCH_INCLUDE_PATHS}") 116 | message(TRACE "TRITON_PYTORCH_LIB_PATHS: ${TRITON_PYTORCH_LIB_PATHS}") 117 | message(TRACE "TRITON_REPO_ORGANIZATION: ${TRITON_REPO_ORGANIZATION}") 118 | message(TRACE "TRITON_BACKEND_REPO_TAG: ${TRITON_BACKEND_REPO_TAG}") 119 | message(TRACE "TRITON_CORE_REPO_TAG: ${TRITON_CORE_REPO_TAG}") 120 | message(TRACE "TRITON_COMMON_REPO_TAG: ${TRITON_COMMON_REPO_TAG}") 121 | message(TRACE "TRITON_PYTORCH_DOCKER_BUILD: ${TRITON_PYTORCH_DOCKER_BUILD}") 122 | message(TRACE "RHEL_BUILD: ${RHEL_BUILD}") 123 | message(TRACE "LIB_DIR: ${LIB_DIR}") 124 | message(TRACE "LIBTORCH_LIBS_PATH: ${LIBTORCH_LIBS_PATH}") 125 | message(TRACE "PY_INSTALL_PATH: ${PY_INSTALL_PATH}") 126 | 127 | 128 | # 129 | # Dependencies 130 | # 131 | # FetchContent's composability isn't very good. We must include the 132 | # transitive closure of all repos so that we can override the tag. 133 | # 134 | include(FetchContent) 135 | 136 | FetchContent_Declare( 137 | repo-common 138 | GIT_REPOSITORY ${TRITON_REPO_ORGANIZATION}/common.git 139 | GIT_TAG ${TRITON_COMMON_REPO_TAG} 140 | GIT_SHALLOW ON 141 | ) 142 | FetchContent_Declare( 143 | repo-core 144 | GIT_REPOSITORY ${TRITON_REPO_ORGANIZATION}/core.git 145 | GIT_TAG ${TRITON_CORE_REPO_TAG} 146 | GIT_SHALLOW ON 147 | ) 148 | FetchContent_Declare( 149 | repo-backend 150 | GIT_REPOSITORY ${TRITON_REPO_ORGANIZATION}/backend.git 151 | GIT_TAG ${TRITON_BACKEND_REPO_TAG} 152 | GIT_SHALLOW ON 153 | ) 154 | FetchContent_MakeAvailable(repo-common repo-core repo-backend) 155 | 156 | # 157 | # CUDA 158 | # 159 | if(${TRITON_ENABLE_GPU}) 160 | find_package(CUDAToolkit REQUIRED) 161 | else() 162 | if (${TRITON_PYTORCH_ENABLE_TORCHTRT}) 163 | message(FATAL_ERROR "TRITON_PYTORCH_ENABLE_TORCHTRT is ON when TRITON_ENABLE_GPU is OFF") 164 | endif() 165 | endif() # TRITON_ENABLE_GPU 166 | 167 | if(${TRITON_ENABLE_NVTX}) 168 | add_definitions(-DTRITON_ENABLE_NVTX=1) 169 | endif() # TRITON_ENABLE_NVTX 170 | 171 | # 172 | # Shared library implementing the Triton Backend API 173 | # 174 | configure_file(src/libtriton_pytorch.ldscript libtriton_pytorch.ldscript COPYONLY) 175 | 176 | set(PT_LIBS 177 | "libc10.so" 178 | "libc10_cuda.so" 179 | "libtorch.so" 180 | "libtorch_cpu.so" 181 | "libtorch_cuda.so" 182 | "libtorch_cuda_linalg.so" 183 | "libtorch_global_deps.so" 184 | "libjpeg.so.62" 185 | "libshm.so" 186 | "libbackend_with_compiler.so" 187 | "libaoti_custom_ops.so" 188 | "libtorch_python.so" 189 | "libcaffe2_nvrtc.so" 190 | ) 191 | 192 | if (${TRITON_PYTORCH_NVSHMEM}) 193 | set(PT_LIBS 194 | ${PT_LIBS} 195 | "libtorch_nvshmem.so" 196 | ) 197 | endif() # TRITON_PYTORCH_NVSHMEM 198 | 199 | if (${TRITON_PYTORCH_ENABLE_TORCHVISION}) 200 | set(PT_LIBS 201 | ${PT_LIBS} 202 | libtorchvision.so.1 203 | ) 204 | endif() # TRITON_PYTORCH_ENABLE_TORCHVISION 205 | 206 | if (${TRITON_PYTORCH_ENABLE_TORCHTRT}) 207 | set(PT_LIBS 208 | ${PT_LIBS} 209 | "libtorchtrt.so" 210 | "libtorchtrt_runtime.so" 211 | ) 212 | endif() # TRITON_PYTORCH_ENABLE_TORCHTRT 213 | 214 | if (CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "aarch64") 215 | set(LIBS_ARCH "aarch64") 216 | set(LIBTORCH_LIBS 217 | "libnvpl_blas_core.so.0" 218 | "libnvpl_blas_ilp64_gomp.so.0" 219 | "libnvpl_blas_ilp64_seq.so.0" 220 | "libnvpl_blas_lp64_gomp.so.0" 221 | "libnvpl_blas_lp64_seq.so.0" 222 | "libnvpl_lapack_core.so.0" 223 | "libnvpl_lapack_ilp64_gomp.so.0" 224 | "libnvpl_lapack_ilp64_seq.so.0" 225 | "libnvpl_lapack_lp64_gomp.so.0" 226 | "libnvpl_lapack_lp64_seq.so.0" 227 | ) 228 | else() 229 | set(LIBS_ARCH "x86_64") 230 | set(LIBTORCH_LIBS 231 | "libmkl_avx2.so.1" 232 | "libmkl_avx512.so.1" 233 | "libmkl_core.so.1" 234 | "libmkl_def.so.1" 235 | "libmkl_gnu_thread.so.1" 236 | "libmkl_intel_lp64.so.1" 237 | "libmkl_intel_thread.so.1" 238 | "libmkl_rt.so.1" 239 | "libmkl_sequential.so.1" 240 | "libmkl_vml_def.so.1" 241 | ) 242 | endif() 243 | set(TORCHVISION_LIBS 244 | $,libjpeg.so.62,libjpeg.so> 245 | $,libpng16.so.16,libpng16.so> 246 | ) 247 | 248 | message(TRACE "LIBS_ARCH: ${LIBS_ARCH}") 249 | message(TRACE "LIBTORCH_LIBS: ${LIBTORCH_LIBS}") 250 | 251 | # The patchelf commands ensure the MKL libraries are loaded correctly during runtime 252 | # Without these, the framework/backend complains of missing libraries / symbols and 253 | # in some cases leads to segmentation faults. 254 | if (${TRITON_PYTORCH_DOCKER_BUILD}) 255 | string(REPLACE ";" " " LIBTORCH_LIBS_STR "${LIBTORCH_LIBS}") 256 | string(RANDOM 8 "abcdefghijklmnopqrstuvwxyz" random_id) 257 | 258 | add_custom_command( 259 | OUTPUT 260 | ${PT_LIBS} 261 | ${LIBTORCH_LIBS} 262 | ${TORCHVISION_LIBS} 263 | LICENSE.pytorch 264 | include/torch 265 | include/torchvision 266 | COMMAND ${CMAKE_COMMAND} -E make_directory "include/torchvision" 267 | COMMAND docker pull ${TRITON_PYTORCH_DOCKER_IMAGE} 268 | COMMAND docker rm pytorch_backend_ptlib || echo "error ignored..." || true 269 | COMMAND docker create --name pytorch_backend_ptlib ${TRITON_PYTORCH_DOCKER_IMAGE} 270 | COMMAND /bin/sh -c "for i in ${LIBTORCH_LIBS_STR} ; do echo copying $i && docker cp -L pytorch_backend_ptlib:${LIBTORCH_LIBS_PATH}/$i $i ; done" 271 | COMMAND docker cp pytorch_backend_ptlib:${PY_INSTALL_PATH}/torch/lib/libc10.so libc10.so 272 | COMMAND docker cp pytorch_backend_ptlib:${PY_INSTALL_PATH}/torch/lib/libc10_cuda.so libc10_cuda.so 273 | COMMAND docker cp pytorch_backend_ptlib:${PY_INSTALL_PATH}/torch/lib/libtorch.so libtorch.so 274 | COMMAND docker cp pytorch_backend_ptlib:${PY_INSTALL_PATH}/torch/lib/libtorch_cpu.so libtorch_cpu.so 275 | COMMAND docker cp pytorch_backend_ptlib:${PY_INSTALL_PATH}/torch/lib/libtorch_cuda.so libtorch_cuda.so 276 | COMMAND docker cp pytorch_backend_ptlib:${PY_INSTALL_PATH}/torch/lib/libtorch_cuda_linalg.so libtorch_cuda_linalg.so 277 | COMMAND docker cp pytorch_backend_ptlib:${PY_INSTALL_PATH}/torch/lib/libtorch_global_deps.so libtorch_global_deps.so 278 | COMMAND docker cp pytorch_backend_ptlib:${PY_INSTALL_PATH}/torch/lib/libcaffe2_nvrtc.so libcaffe2_nvrtc.so 279 | COMMAND docker cp pytorch_backend_ptlib:${PY_INSTALL_PATH}/torch/lib/libshm.so libshm.so 280 | COMMAND docker cp pytorch_backend_ptlib:${PY_INSTALL_PATH}/torch/lib/libbackend_with_compiler.so libbackend_with_compiler.so 281 | COMMAND docker cp pytorch_backend_ptlib:${PY_INSTALL_PATH}/torch/lib/libaoti_custom_ops.so libaoti_custom_ops.so 282 | COMMAND docker cp pytorch_backend_ptlib:${PY_INSTALL_PATH}/torch/lib/libtorch_python.so libtorch_python.so 283 | COMMAND /bin/sh -c "if [ ${TRITON_PYTORCH_NVSHMEM} = 'ON' ]; then docker cp pytorch_backend_ptlib:${PY_INSTALL_PATH}/torch/lib/libtorch_nvshmem.so libtorch_nvshmem.so; fi" 284 | COMMAND /bin/sh -c "if [ ${TRITON_PYTORCH_ENABLE_TORCHVISION} = 'ON' ]; then docker cp -a -L pytorch_backend_ptlib:/usr/local/${LIB_DIR}/libtorchvision.so.1 libtorchvision.so.1; fi;" 285 | COMMAND /bin/sh -c "if [ ${TRITON_PYTORCH_ENABLE_TORCHVISION} = 'ON' ]; then docker cp pytorch_backend_ptlib:/opt/pytorch/vision/torchvision/csrc include/torchvision/torchvision; fi" 286 | COMMAND /bin/sh -c "if [ ${TRITON_PYTORCH_ENABLE_TORCHTRT} = 'ON' ]; then docker cp pytorch_backend_ptlib:${PY_INSTALL_PATH}/torch_tensorrt/lib/libtorchtrt.so libtorchtrt.so; fi" 287 | COMMAND /bin/sh -c "if [ ${TRITON_PYTORCH_ENABLE_TORCHTRT} = 'ON' ]; then docker cp pytorch_backend_ptlib:${PY_INSTALL_PATH}/torch_tensorrt/lib/libtorchtrt_runtime.so libtorchtrt_runtime.so; fi" 288 | COMMAND /bin/sh -c "if [ ${TRITON_PYTORCH_ENABLE_TORCHTRT} = 'ON' ]; then docker cp pytorch_backend_ptlib:${PY_INSTALL_PATH}/torch_tensorrt/bin/torchtrtc torchtrtc; fi" 289 | COMMAND docker cp pytorch_backend_ptlib:/opt/pytorch/pytorch/LICENSE LICENSE.pytorch 290 | COMMAND docker cp pytorch_backend_ptlib:${PY_INSTALL_PATH}/torch/include include/torch 291 | COMMAND docker cp pytorch_backend_ptlib:/opt/pytorch/pytorch/torch/csrc/jit/codegen include/torch/torch/csrc/jit/. 292 | 293 | COMMAND /bin/sh -c "if [ ${RHEL_BUILD} = 'ON' ]; then docker cp -L pytorch_backend_ptlib:/usr/lib64/libjpeg.so.62 libjpeg.so.62; else docker cp -L pytorch_backend_ptlib:/usr/local/lib/libjpeg.so.62 libjpeg.so.62 && docker cp pytorch_backend_ptlib:/usr/lib/${LIBS_ARCH}-linux-gnu/libjpeg.so.8.2.2 libjpeg.so; fi;" 294 | COMMAND /bin/sh -c "if [ ${RHEL_BUILD} = 'ON' ]; then docker cp -L pytorch_backend_ptlib:/usr/lib64/libpng16.so.16 libpng16.so.16; else docker cp -L pytorch_backend_ptlib:/usr/lib/${LIBS_ARCH}-linux-gnu/libpng16.so libpng16.so; fi;" 295 | COMMAND /bin/sh -c "if [ -f libmkl_def.so.1 ]; then patchelf --add-needed libmkl_gnu_thread.so.1 libmkl_def.so.1; fi" 296 | COMMAND /bin/sh -c "if [ -f libmkl_def.so.1 ]; then patchelf --add-needed libmkl_core.so.1 libmkl_def.so.1; fi" 297 | COMMAND /bin/sh -c "if [ -f libmkl_avx2.so.1 ]; then patchelf --add-needed libmkl_gnu_thread.so.1 libmkl_avx2.so.1; fi" 298 | COMMAND /bin/sh -c "if [ -f libmkl_avx2.so.1 ]; then patchelf --add-needed libmkl_core.so.1 libmkl_avx2.so.1; fi" 299 | COMMAND /bin/sh -c "if [ -f libmkl_avx512.so.1 ]; then patchelf --add-needed libmkl_gnu_thread.so.1 libmkl_avx512.so.1; fi" 300 | COMMAND /bin/sh -c "if [ -f libmkl_avx512.so.1 ]; then patchelf --add-needed libmkl_core.so.1 libmkl_avx512.so.1; fi" 301 | COMMAND /bin/sh -c "if [ -f libmkl_vml_def.so.1 ]; then patchelf --add-needed libmkl_gnu_thread.so.1 libmkl_vml_def.so.1; fi" 302 | COMMAND /bin/sh -c "if [ -f libmkl_vml_def.so.1 ]; then patchelf --add-needed libmkl_intel_thread.so.1 libmkl_vml_def.so.1; fi" 303 | COMMAND /bin/sh -c "if [ -f libmkl_vml_def.so.1 ]; then patchelf --add-needed libmkl_core.so.1 libmkl_vml_def.so.1; fi" 304 | COMMAND /bin/sh -c "if [ -f libmkl_intel_thread.so.1 ]; then patchelf --add-needed libmkl_intel_lp64.so.1 libmkl_intel_thread.so.1; fi" 305 | COMMAND /bin/sh -c "if [ ${TRITON_PYTORCH_ENABLE_TORCHVISION} = 'ON' ]; then ln -s libtorchvision.so.1 libtorchvision.so; fi;" 306 | COMMAND docker rm pytorch_backend_ptlib 307 | COMMENT "Extracting pytorch and torchvision libraries and includes from ${TRITON_PYTORCH_DOCKER_IMAGE}" 308 | VERBATIM 309 | ) 310 | add_custom_target(ptlib_target DEPENDS ${PT_LIBS} ${LIBTORCH_LIBS} ${TORCHVISION_LIBS}) 311 | add_library(ptlib SHARED IMPORTED GLOBAL) 312 | add_dependencies(ptlib ptlib_target) 313 | 314 | # Just one of the libs are enough to ensure the docker build 315 | set_target_properties( 316 | ptlib 317 | PROPERTIES 318 | IMPORTED_LOCATION libtorch.so 319 | ) 320 | endif() # TRITON_PYTORCH_DOCKER_BUILD 321 | 322 | add_library( 323 | triton-pytorch-backend SHARED 324 | src/libtorch.cc 325 | src/libtorch_utils.cc 326 | src/libtorch_utils.h 327 | src/model_instance_state.cc 328 | src/model_state.cc 329 | src/string_utils.cc 330 | ) 331 | 332 | add_library( 333 | TritonPyTorchBackend::triton-pytorch-backend ALIAS triton-pytorch-backend 334 | ) 335 | 336 | target_include_directories( 337 | triton-pytorch-backend 338 | PRIVATE 339 | ${CMAKE_CURRENT_SOURCE_DIR}/src 340 | ${Python3_INCLUDE_DIRS} 341 | ) 342 | 343 | if (${TRITON_PYTORCH_DOCKER_BUILD}) 344 | target_include_directories( 345 | triton-pytorch-backend 346 | PRIVATE 347 | ${CMAKE_CURRENT_BINARY_DIR}/include/torch 348 | ${CMAKE_CURRENT_BINARY_DIR}/include/torch/torch/csrc/api/include 349 | ${CMAKE_CURRENT_BINARY_DIR}/include/torchvision 350 | ) 351 | else() 352 | target_include_directories( 353 | triton-pytorch-backend 354 | PRIVATE ${TRITON_PYTORCH_INCLUDE_PATHS} 355 | ) 356 | endif() # TRITON_PYTORCH_DOCKER_BUILD 357 | 358 | # Need to turn off -Werror due to Torchvision vision.h extern initialization 359 | # Unfortunately gcc does not provide a specific flag to ignore the specific 360 | # warning: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=45977 361 | target_compile_features(triton-pytorch-backend PRIVATE cxx_std_${TRITON_MIN_CXX_STANDARD}) 362 | target_compile_options( 363 | triton-pytorch-backend PRIVATE 364 | $<$,$,$>: 365 | -Wall -Wextra -Wno-unused-parameter -Wno-type-limits> 366 | ) 367 | 368 | if(${TRITON_ENABLE_GPU}) 369 | target_compile_definitions( 370 | triton-pytorch-backend 371 | PRIVATE TRITON_ENABLE_GPU=1 372 | ) 373 | endif() # TRITON_ENABLE_GPU 374 | 375 | set_target_properties( 376 | triton-pytorch-backend 377 | PROPERTIES 378 | POSITION_INDEPENDENT_CODE ON 379 | OUTPUT_NAME triton_pytorch 380 | SKIP_BUILD_RPATH TRUE 381 | BUILD_WITH_INSTALL_RPATH TRUE 382 | INSTALL_RPATH_USE_LINK_PATH FALSE 383 | INSTALL_RPATH "$\{ORIGIN\}" 384 | LINK_DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/libtriton_pytorch.ldscript 385 | LINK_FLAGS "-Wl,--no-as-needed,--version-script libtriton_pytorch.ldscript" 386 | ) 387 | 388 | # Need to turn off unused-but-set-variable due to Torchvision 389 | # Need to turn off unknown-pragmas due to ATen OpenMP 390 | set_target_properties( 391 | triton-pytorch-backend 392 | PROPERTIES COMPILE_FLAGS 393 | "-Wno-unknown-pragmas -Wno-unused-but-set-variable" 394 | ) 395 | 396 | if (${TRITON_PYTORCH_DOCKER_BUILD}) 397 | add_dependencies( 398 | triton-pytorch-backend 399 | ptlib 400 | ) 401 | endif() # TRITON_PYTORCH_DOCKER_BUILD 402 | 403 | message(STATUS "Torchvision support is ${TRITON_PYTORCH_ENABLE_TORCHVISION}") 404 | message(STATUS "Torch-TRT support is ${TRITON_PYTORCH_ENABLE_TORCHTRT}") 405 | 406 | set(TRITON_PYTORCH_LDFLAGS "") 407 | if (${TRITON_PYTORCH_DOCKER_BUILD}) 408 | set(TRITON_PYTORCH_LIBS "${CMAKE_CURRENT_BINARY_DIR}/libtorch.so") 409 | 410 | if (${TRITON_PYTORCH_ENABLE_TORCHVISION}) 411 | set(TRITON_PYTORCH_LIBS 412 | ${TRITON_PYTORCH_LIBS} 413 | libtorchvision.so.1 ) 414 | endif() # TRITON_PYTORCH_ENABLE_TORCHVISION 415 | 416 | if (${TRITON_PYTORCH_ENABLE_TORCHTRT}) 417 | set(TRITON_PYTORCH_LIBS 418 | ${TRITON_PYTORCH_LIBS} 419 | "${CMAKE_CURRENT_BINARY_DIR}/libtorchtrt_runtime.so") 420 | endif() # TRITON_PYTORCH_ENABLE_TORCHTRT 421 | else() 422 | set (TRITON_PYTORCH_LIBS "-ltorch") 423 | 424 | if (${TRITON_PYTORCH_ENABLE_TORCHVISION}) 425 | set(TRITON_PYTORCH_LIBS 426 | ${TRITON_PYTORCH_LIBS} 427 | "-ltorchvision" 428 | ) 429 | endif() # TRITON_PYTORCH_ENABLE_TORCHVISION 430 | 431 | if (${TRITON_PYTORCH_ENABLE_TORCHTRT}) 432 | set(TRITON_PYTORCH_LIBS 433 | ${TRITON_PYTORCH_LIBS} 434 | "-ltorchtrt_runtime" 435 | ) 436 | endif() # TRITON_PYTORCH_ENABLE_TORCHTRT 437 | 438 | FOREACH(p ${TRITON_PYTORCH_LIB_PATHS}) 439 | set(TRITON_PYTORCH_LDFLAGS ${TRITON_PYTORCH_LDFLAGS} "-L${p}") 440 | ENDFOREACH(p) 441 | endif() # TRITON_PYTORCH_DOCKER_BUILD 442 | 443 | message(TRACE "TRITON_PYTORCH_LDFLAGS: ${TRITON_PYTORCH_LDFLAGS}") 444 | message(TRACE "TRITON_PYTORCH_LIBS: ${TRITON_PYTORCH_LIBS}") 445 | 446 | target_link_libraries( 447 | triton-pytorch-backend 448 | PRIVATE 449 | triton-core-serverapi # from repo-core 450 | triton-core-backendapi # from repo-core 451 | triton-core-serverstub # from repo-core 452 | triton-backend-utils # from repo-backend 453 | ${TRITON_PYTORCH_LDFLAGS} 454 | ${TRITON_PYTORCH_LIBS} 455 | ) 456 | 457 | if(${TRITON_ENABLE_GPU}) 458 | target_link_libraries( 459 | triton-pytorch-backend 460 | PRIVATE 461 | CUDA::cudart 462 | ) 463 | endif() # TRITON_ENABLE_GPU 464 | 465 | # 466 | # Install 467 | # 468 | include(GNUInstallDirs) 469 | set(INSTALL_CONFIGDIR ${CMAKE_INSTALL_LIBDIR}/cmake/TritonPyTorchBackend) 470 | message(TRACE "INSTALL_CONFIGDIR: ${INSTALL_CONFIGDIR}") 471 | 472 | install( 473 | TARGETS 474 | triton-pytorch-backend 475 | EXPORT 476 | triton-pytorch-backend-targets 477 | LIBRARY DESTINATION ${CMAKE_INSTALL_PREFIX}/backends/pytorch 478 | ARCHIVE DESTINATION ${CMAKE_INSTALL_PREFIX}/backends/pytorch 479 | ) 480 | 481 | if (${TRITON_PYTORCH_DOCKER_BUILD}) 482 | set(PT_LIB_PATHS "") 483 | FOREACH(plib ${PT_LIBS} ${LIBTORCH_LIBS} ${TORCHVISION_LIBS}) 484 | set(PT_LIB_PATHS ${PT_LIB_PATHS} "${CMAKE_CURRENT_BINARY_DIR}/${plib}") 485 | ENDFOREACH(plib) 486 | 487 | message(TRACE "PT_LIB_PATHS: ${PT_LIB_PATHS}") 488 | 489 | install( 490 | FILES 491 | ${PT_LIB_PATHS} 492 | ${CMAKE_CURRENT_BINARY_DIR}/LICENSE.pytorch 493 | DESTINATION ${CMAKE_INSTALL_PREFIX}/backends/pytorch 494 | ) 495 | 496 | if (${TRITON_PYTORCH_ENABLE_TORCHTRT}) 497 | install( 498 | FILES 499 | ${CMAKE_CURRENT_BINARY_DIR}/torchtrtc 500 | DESTINATION ${CMAKE_INSTALL_PREFIX}/backends/pytorch 501 | ) 502 | endif() # TRITON_PYTORCH_ENABLE_TORCHTRT 503 | 504 | FOREACH(plib ${PT_LIBS} ${LIBTORCH_LIBS} ${TORCHVISION_LIBS}) 505 | install( 506 | CODE 507 | "EXECUTE_PROCESS( 508 | COMMAND patchelf --set-rpath \$ORIGIN ${plib} 509 | RESULT_VARIABLE PATCHELF_STATUS 510 | WORKING_DIRECTORY ${CMAKE_INSTALL_PREFIX}/backends/pytorch) 511 | if(PATCHELF_STATUS AND NOT PATCHELF_STATUS EQUAL 0) 512 | message(FATAL_ERROR \"FAILED: to run patchelf\") 513 | endif()" 514 | ) 515 | ENDFOREACH(plib) 516 | 517 | install( 518 | CODE 519 | "EXECUTE_PROCESS( 520 | COMMAND ln -sf libpng16.so libpng16.so.16 521 | COMMAND ln -sf libjpeg.so libjpeg.so.8 522 | RESULT_VARIABLE LINK_STATUS 523 | WORKING_DIRECTORY ${CMAKE_INSTALL_PREFIX}/backends/pytorch) 524 | if(LINK_STATUS AND NOT LINK_STATUS EQUAL 0) 525 | message(FATAL_ERROR \"FAILED: to create links\") 526 | endif()" 527 | ) 528 | else() 529 | FOREACH(plib ${PT_LIBS}) 530 | set(PT_LIB_PATHS ${PT_LIB_PATHS} "${TRITON_PYTORCH_LIB_PATHS}/${plib}") 531 | ENDFOREACH(plib) 532 | 533 | message(TRACE "PT_LIB_PATHS: ${PT_LIB_PATHS}") 534 | 535 | install( 536 | FILES 537 | ${PT_LIB_PATHS} 538 | DESTINATION ${CMAKE_INSTALL_PREFIX}/backends/pytorch 539 | ) 540 | 541 | FOREACH(plib ${PT_LIBS}) 542 | install( 543 | CODE 544 | "EXECUTE_PROCESS( 545 | COMMAND patchelf --set-rpath \$ORIGIN ${plib} 546 | RESULT_VARIABLE PATCHELF_STATUS 547 | WORKING_DIRECTORY ${CMAKE_INSTALL_PREFIX}/backends/pytorch) 548 | if(PATCHELF_STATUS AND NOT PATCHELF_STATUS EQUAL 0) 549 | message(FATAL_ERROR \"FAILED: to run patchelf\") 550 | endif()" 551 | ) 552 | ENDFOREACH(plib) 553 | endif() # TRITON_PYTORCH_DOCKER_BUILD 554 | 555 | install( 556 | EXPORT 557 | triton-pytorch-backend-targets 558 | FILE 559 | TritonPyTorchBackendTargets.cmake 560 | NAMESPACE 561 | TritonPyTorchBackend:: 562 | DESTINATION 563 | ${INSTALL_CONFIGDIR} 564 | ) 565 | 566 | install( 567 | FILES 568 | src/model.py 569 | DESTINATION 570 | ${CMAKE_INSTALL_PREFIX}/backends/pytorch 571 | ) 572 | 573 | include(CMakePackageConfigHelpers) 574 | configure_package_config_file( 575 | ${CMAKE_CURRENT_LIST_DIR}/cmake/TritonPyTorchBackendConfig.cmake.in 576 | ${CMAKE_CURRENT_BINARY_DIR}/TritonPyTorchBackendConfig.cmake 577 | INSTALL_DESTINATION ${INSTALL_CONFIGDIR} 578 | ) 579 | 580 | install( 581 | FILES 582 | ${CMAKE_CURRENT_BINARY_DIR}/TritonPyTorchBackendConfig.cmake 583 | DESTINATION ${INSTALL_CONFIGDIR} 584 | ) 585 | 586 | # 587 | # Export from build tree 588 | # 589 | export( 590 | EXPORT triton-pytorch-backend-targets 591 | FILE ${CMAKE_CURRENT_BINARY_DIR}/TritonPyTorchBackendTargets.cmake 592 | NAMESPACE TritonPyTorchBackend:: 593 | ) 594 | 595 | export(PACKAGE TritonPyTorchBackend) 596 | -------------------------------------------------------------------------------- /src/model_instance_state.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2019-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // Redistribution and use in source and binary forms, with or without 4 | // modification, are permitted provided that the following conditions 5 | // are met: 6 | // * Redistributions of source code must retain the above copyright 7 | // notice, this list of conditions and the following disclaimer. 8 | // * Redistributions in binary form must reproduce the above copyright 9 | // notice, this list of conditions and the following disclaimer in the 10 | // documentation and/or other materials provided with the distribution. 11 | // * Neither the name of NVIDIA CORPORATION nor the names of its 12 | // contributors may be used to endorse or promote products derived 13 | // from this software without specific prior written permission. 14 | // 15 | // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY 16 | // EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 17 | // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 18 | // PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR 19 | // CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 20 | // EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 21 | // PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 22 | // PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY 23 | // OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 24 | // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | 27 | #include "model_instance_state.hh" 28 | 29 | #include "string_utils.hh" 30 | 31 | #ifdef TRITON_PYTORCH_ENABLE_TORCHVISION 32 | // Suppress warnings in torch headers 33 | #pragma GCC diagnostic push 34 | #pragma GCC diagnostic ignored "-Wsign-compare" 35 | #pragma warning(push, 0) 36 | #include 37 | #include // Torchvision header 38 | #pragma warning(pop) 39 | #pragma GCC diagnostic pop 40 | #endif // TRITON_PYTORCH_ENABLE_TORCHVISION 41 | 42 | #ifdef TRITON_ENABLE_GPU 43 | #include 44 | #include 45 | #include 46 | #endif // TRITON_ENABLE_GPU 47 | 48 | 49 | namespace triton::backend::pytorch { 50 | 51 | ModelInstanceState::ModelInstanceState( 52 | ModelState* model_state, TRITONBACKEND_ModelInstance* triton_model_instance) 53 | : BackendModelInstance(model_state, triton_model_instance), 54 | model_state_(model_state), device_(torch::kCPU), is_dict_input_(false), 55 | device_cnt_(0) 56 | { 57 | if (Kind() == TRITONSERVER_INSTANCEGROUPKIND_GPU) { 58 | #ifdef TRITON_ENABLE_GPU 59 | device_ = torch::Device(torch::kCUDA, DeviceId()); 60 | CreateCudaEvents(DeviceId()); 61 | #endif 62 | } 63 | 64 | #ifdef TRITON_ENABLE_GPU 65 | device_cnt_ = torch::cuda::device_count(); 66 | #endif 67 | 68 | THROW_IF_BACKEND_INSTANCE_ERROR(model_state->LoadModel( 69 | ArtifactFilename(), device_, &model_path_, Kind(), &torch_model_)); 70 | 71 | if (Kind() == TRITONSERVER_INSTANCEGROUPKIND_MODEL) { 72 | #ifdef TRITON_ENABLE_GPU 73 | // Since we cannot determine the exact devices used by the model, we create 74 | // a CUDA stream for every available device to ensure proper synchronization 75 | // of CUDA streams. This approach may have implications when a timestamp is 76 | // captured on a device that is not used by the model. Currently, this issue 77 | // is addressed by synchronizing the CUDA streams before recording 78 | // timestamps to prevent timestamp skewing. However, in the future, any 79 | // modifications to the CUDA stream synchronization logic should be handled 80 | // with caution. 81 | for (int i = 0; i < device_cnt_; i++) { 82 | cudaStream_t stream; 83 | THROW_IF_BACKEND_INSTANCE_ERROR( 84 | CreateCudaStream(i, 0 /* cuda_stream_priority */, &stream)); 85 | stream_vec_.push_back(stream); 86 | } 87 | if (!stream_vec_.empty()) { 88 | // Create CUDA events on the first device that will be used for collecting 89 | // inputs/outputs. 90 | CreateCudaEvents(0); 91 | } 92 | #endif 93 | } 94 | 95 | size_t expected_input_cnt = 0; 96 | { 97 | triton::common::TritonJson::Value inputs; 98 | if (model_state->ModelConfig().Find("input", &inputs)) { 99 | expected_input_cnt = inputs.ArraySize(); 100 | } 101 | 102 | triton::common::TritonJson::Value config_batch_inputs; 103 | if (model_state->ModelConfig().Find("batch_input", &config_batch_inputs)) { 104 | batch_input_count_ = config_batch_inputs.ArraySize(); 105 | expected_input_cnt += batch_input_count_; 106 | } 107 | } 108 | 109 | // If this is a sequence model then make sure that the required 110 | // inputs are present in the model and have the correct shape and 111 | // datatype. 112 | triton::common::TritonJson::Value sequence_batching; 113 | if (model_state->ModelConfig().Find( 114 | "sequence_batching", &sequence_batching)) { 115 | bool have_start, have_end, have_ready, have_corrid; 116 | THROW_IF_BACKEND_INSTANCE_ERROR(ValidateBooleanSequenceControl( 117 | sequence_batching, "CONTROL_SEQUENCE_START", false /* required */, 118 | &have_start)); 119 | THROW_IF_BACKEND_INSTANCE_ERROR(ValidateBooleanSequenceControl( 120 | sequence_batching, "CONTROL_SEQUENCE_END", false /* required */, 121 | &have_end)); 122 | THROW_IF_BACKEND_INSTANCE_ERROR(ValidateBooleanSequenceControl( 123 | sequence_batching, "CONTROL_SEQUENCE_READY", false /* required */, 124 | &have_ready)); 125 | THROW_IF_BACKEND_INSTANCE_ERROR(ValidateTypedSequenceControl( 126 | sequence_batching, "CONTROL_SEQUENCE_CORRID", false /* required */, 127 | &have_corrid)); 128 | if (have_start) { 129 | expected_input_cnt += 1; 130 | } 131 | if (have_end) { 132 | expected_input_cnt += 1; 133 | } 134 | if (have_ready) { 135 | expected_input_cnt += 1; 136 | } 137 | if (have_corrid) { 138 | expected_input_cnt += 1; 139 | } 140 | // Add the state inputs to the expected count 141 | triton::common::TritonJson::Value states; 142 | if (sequence_batching.Find("state", &states)) { 143 | expected_input_cnt += states.ArraySize(); 144 | } 145 | } 146 | supports_batching_ = model_state_->MaxBatchSize() > 0; 147 | 148 | THROW_IF_BACKEND_INSTANCE_ERROR(ValidateInputs(expected_input_cnt)); 149 | THROW_IF_BACKEND_INSTANCE_ERROR(ValidateOutputs()); 150 | } 151 | 152 | ModelInstanceState::~ModelInstanceState() 153 | { 154 | torch_model_.reset(); 155 | ClearCache(); 156 | 157 | if (Kind() == TRITONSERVER_INSTANCEGROUPKIND_MODEL) { 158 | #ifdef TRITON_ENABLE_GPU 159 | for (size_t i = 0; i < stream_vec_.size(); i++) { 160 | LOG_IF_ERROR( 161 | ConvertCUDAStatusToTritonError( 162 | cudaSetDevice(i), TRITONSERVER_ERROR_INTERNAL, 163 | "Failed to set the device"), 164 | "Failed to set the device"); 165 | 166 | LOG_IF_ERROR( 167 | ConvertCUDAStatusToTritonError( 168 | cudaStreamDestroy(stream_vec_[i]), TRITONSERVER_ERROR_INTERNAL, 169 | "Failed to destroy cuda stream"), 170 | "~ModelInstanceState error: "); 171 | stream_vec_[i] = nullptr; 172 | } 173 | #endif 174 | } 175 | } 176 | 177 | void 178 | ModelInstanceState::AddInputToMap( 179 | NamingConvention naming_convention, 180 | const std::vector allowed_inputs, const std::string& io_name, 181 | const uint32_t index) 182 | { 183 | std::string deliminator = "__"; 184 | 185 | if (is_dict_input_) { 186 | // If dictionary, index is irrelevant but we use the map to store the 187 | // input names since they are the keys for the dictionary 188 | input_index_map_[io_name] = index; 189 | } else { 190 | switch (naming_convention) { 191 | case NamingConvention::FORWARD_ARGUMENT: { 192 | auto itr = 193 | std::find(allowed_inputs.begin(), allowed_inputs.end(), io_name); 194 | if (itr != allowed_inputs.end()) { 195 | input_index_map_[io_name] = 196 | std::distance(allowed_inputs.begin(), itr); 197 | } 198 | return; 199 | } 200 | case NamingConvention::NAMED_INDEX: { 201 | int start_pos = io_name.find(deliminator); 202 | int ip_index = std::atoi(io_name.substr(start_pos + 2).c_str()); 203 | input_index_map_[io_name] = ip_index; 204 | return; 205 | } 206 | case NamingConvention::STRICT_CONFIG_ORDERING: { 207 | input_index_map_[io_name] = index; 208 | return; 209 | } 210 | } 211 | } 212 | } 213 | 214 | void 215 | ModelInstanceState::ClearCache() 216 | { 217 | #ifdef TRITON_ENABLE_GPU 218 | if (device_.is_cuda() || 219 | ((Kind() == TRITONSERVER_INSTANCEGROUPKIND_MODEL) && (device_cnt_ > 0))) { 220 | c10::cuda::CUDACachingAllocator::emptyCache(); 221 | } 222 | #endif // TRITON_ENABLE_GPU 223 | } 224 | 225 | TRITONSERVER_Error* 226 | ModelInstanceState::Create( 227 | ModelState* model_state, TRITONBACKEND_ModelInstance* triton_model_instance, 228 | ModelInstanceState** state) 229 | { 230 | try { 231 | *state = new ModelInstanceState(model_state, triton_model_instance); 232 | } 233 | catch (const BackendModelInstanceException& ex) { 234 | RETURN_ERROR_IF_TRUE( 235 | ex.err_ == nullptr, TRITONSERVER_ERROR_INTERNAL, 236 | std::string("unexpected nullptr in BackendModelInstanceException")); 237 | RETURN_IF_ERROR(ex.err_); 238 | } 239 | 240 | return nullptr; // success 241 | } 242 | 243 | void 244 | ModelInstanceState::CreateCudaEvents(const int32_t& device_id) 245 | { 246 | #ifdef TRITON_ENABLE_GPU 247 | // Need to set the CUDA context so that the context that events are 248 | // created on match with contexts that events are recorded with. 249 | THROW_IF_BACKEND_INSTANCE_ERROR(ConvertCUDAStatusToTritonError( 250 | cudaSetDevice(device_id), TRITONSERVER_ERROR_INTERNAL, 251 | "Failed to set the device")); 252 | THROW_IF_BACKEND_INSTANCE_ERROR(ConvertCUDAStatusToTritonError( 253 | cudaEventCreate(&compute_input_start_event_), TRITONSERVER_ERROR_INTERNAL, 254 | "Failed to create cuda event")); 255 | THROW_IF_BACKEND_INSTANCE_ERROR(ConvertCUDAStatusToTritonError( 256 | cudaEventCreate(&compute_infer_start_event_), TRITONSERVER_ERROR_INTERNAL, 257 | "Failed to create cuda event")); 258 | THROW_IF_BACKEND_INSTANCE_ERROR(ConvertCUDAStatusToTritonError( 259 | cudaEventCreate(&compute_output_start_event_), 260 | TRITONSERVER_ERROR_INTERNAL, "Failed to create cuda event")); 261 | #endif 262 | } 263 | 264 | void 265 | ModelInstanceState::Execute( 266 | std::vector* responses, 267 | const uint32_t response_count, 268 | std::vector* input_tensors, 269 | std::vector* output_tensors) 270 | { 271 | NVTX_RANGE(nvtx_, "Execute " + Name()); 272 | 273 | torch::jit::IValue model_outputs_; 274 | 275 | try { 276 | // enable/disable optimized execution 277 | torch::jit::setGraphExecutorOptimize( 278 | model_state_->EnabledOptimizedExecution()); 279 | 280 | // enable/disable inference mode - supersedes NoGradGuard 281 | torch::InferenceMode infer_guard(model_state_->EnabledInferenceMode()); 282 | 283 | // enable/disable cudnn 284 | at::globalContext().setUserEnabledCuDNN(model_state_->EnabledCudnn()); 285 | 286 | // JIT. No change is made unless parameter is explicitly set. 287 | if (std::get<0>(model_state_->EnabledJitProfiling())) { 288 | torch::jit::getProfilingMode() = 289 | std::get<1>(model_state_->EnabledJitProfiling()); 290 | } 291 | 292 | if (std::get<0>(model_state_->EnabledJitExecutor())) { 293 | torch::jit::getExecutorMode() = 294 | std::get<1>(model_state_->EnabledJitExecutor()); 295 | } 296 | 297 | // Fuser. No change is made unless fuser is explicitly set in 298 | // parameters. 299 | if (std::get<0>(model_state_->EnabledTensorExprFuser())) { 300 | torch::jit::setTensorExprFuserEnabled( 301 | std::get<1>(model_state_->EnabledTensorExprFuser())); 302 | } 303 | 304 | torch::NoGradGuard no_grad; 305 | 306 | // If input is a dictionary, prepare dictionary from 'input_tensors'. 307 | if (is_dict_input_) { 308 | torch::Dict input_dict; 309 | for (auto& input_index : input_index_map_) { 310 | torch::jit::IValue ival = (*input_tensors)[input_index.second]; 311 | input_dict.insert(input_index.first, ival.toTensor()); 312 | } 313 | std::vector input_dict_ivalue = {input_dict}; 314 | model_outputs_ = torch_model_->forward(input_dict_ivalue); 315 | } else { 316 | model_outputs_ = torch_model_->forward(*input_tensors); 317 | } 318 | 319 | if (model_outputs_.isTuple()) { 320 | auto model_outputs_tuple = model_outputs_.toTuple(); 321 | size_t op_index = 0; 322 | for (auto& m_op : model_outputs_tuple->elements()) { 323 | if (m_op.isList()) { 324 | auto list_output = m_op.toList(); 325 | if (list_output.elementType()->kind() != c10::TypeKind::StringType) { 326 | throw std::invalid_argument( 327 | "output at index " + std::to_string(op_index) + 328 | " must be of type Tensor or List[str], received List[" + 329 | list_output.elementType()->str() + "]"); 330 | } 331 | output_tensors->push_back(m_op); 332 | } else { 333 | auto tensor_output = m_op.toTensor(); 334 | output_tensors->push_back(m_op); 335 | } 336 | op_index++; 337 | } 338 | } else if (model_outputs_.isTensor()) { 339 | output_tensors->push_back(model_outputs_); 340 | } else if (model_outputs_.isList()) { 341 | auto list_output = model_outputs_.toList(); 342 | if (list_output.elementType()->kind() != c10::TypeKind::StringType) { 343 | throw std::invalid_argument( 344 | "output must be of type Tensor or List[str], received List[" + 345 | list_output.elementType()->str() + "]"); 346 | } 347 | output_tensors->push_back(model_outputs_); 348 | } else { 349 | throw std::invalid_argument( 350 | "output must be of type Tensor, List[str] or Tuple containing one of " 351 | "these two types. It should not be a List / Dictionary of Tensors or " 352 | "a Scalar"); 353 | } 354 | } 355 | catch (std::exception& ex) { 356 | SendErrorForResponses( 357 | responses, response_count, 358 | TRITONSERVER_ErrorNew( 359 | TRITONSERVER_ERROR_INTERNAL, 360 | ("PyTorch execute failure: " + std::string(ex.what())).c_str())); 361 | } 362 | } 363 | 364 | float 365 | ModelInstanceState::GetCudaEventElapsedTime( 366 | const cudaEvent_t& start_event, const cudaEvent_t& end_event) 367 | { 368 | float duration = 0; 369 | #ifdef TRITON_ENABLE_GPU 370 | // [FIXME] in the case of cudaEventElapsedTime failure, should handle 371 | // stats reporting more gracefully as the durations are inaccurate 372 | LOG_IF_ERROR( 373 | ConvertCUDAStatusToTritonError( 374 | cudaEventElapsedTime(&duration, start_event, end_event), 375 | TRITONSERVER_ERROR_INTERNAL, "Failed to capture elapsed time"), 376 | "Failed to capture elapsed time"); 377 | #endif 378 | return duration; 379 | } 380 | 381 | 382 | cudaStream_t 383 | ModelInstanceState::GetCudaStreamByInstanceKind() 384 | { 385 | #ifdef TRITON_ENABLE_GPU 386 | if (Kind() == TRITONSERVER_INSTANCEGROUPKIND_GPU) { 387 | return stream_; 388 | } else if ( 389 | (Kind() == TRITONSERVER_INSTANCEGROUPKIND_MODEL) && 390 | !stream_vec_.empty()) { 391 | return stream_vec_[0]; 392 | } 393 | #endif 394 | return nullptr; 395 | } 396 | 397 | TRITONSERVER_Error* 398 | ModelInstanceState::GetNamingConvention( 399 | NamingConvention* naming_convention, 400 | const std::vector& allowed_ios) 401 | { 402 | // Rules for (non-Dictionary) input tensor names: 403 | // 1. Must be in 'allowed_inputs' (arguments in the forward function) 404 | // 2. Must follow the naming convention i.e. __ 405 | // 3. If neither of the above conditions are satisfied, enforce strict 406 | // ordering of model inputs. 407 | // 408 | // Rules for output tensor names: 409 | // 1. Must follow the naming convention i.e. __ 410 | // 2. If not, we enforce strict ordering of model outputs. 411 | std::string deliminator = "__"; 412 | std::string io_kind = "input"; 413 | *naming_convention = NamingConvention::FORWARD_ARGUMENT; 414 | 415 | // symbolizes output 416 | if (allowed_ios.size() == 0) { 417 | io_kind = "output"; 418 | *naming_convention = NamingConvention::NAMED_INDEX; 419 | } 420 | 421 | triton::common::TritonJson::Value ios; 422 | RETURN_IF_ERROR( 423 | model_state_->ModelConfig().MemberAsArray(io_kind.c_str(), &ios)); 424 | 425 | if (io_kind == "input") { 426 | for (size_t i = 0; i < ios.ArraySize(); i++) { 427 | triton::common::TritonJson::Value io; 428 | RETURN_IF_ERROR(ios.IndexAsObject(i, &io)); 429 | 430 | // Validate name 431 | std::string io_name; 432 | RETURN_IF_ERROR(io.MemberAsString("name", &io_name)); 433 | auto itr = std::find(allowed_ios.begin(), allowed_ios.end(), io_name); 434 | if (itr == allowed_ios.end()) { 435 | *naming_convention = NamingConvention::NAMED_INDEX; 436 | break; 437 | } 438 | } 439 | } 440 | 441 | // If not, check if inputs follow INDEX 442 | if (*naming_convention == NamingConvention::NAMED_INDEX) { 443 | for (size_t i = 0; i < ios.ArraySize(); i++) { 444 | triton::common::TritonJson::Value io; 445 | RETURN_IF_ERROR(ios.IndexAsObject(i, &io)); 446 | 447 | // Validate name 448 | std::string io_name; 449 | RETURN_IF_ERROR(io.MemberAsString("name", &io_name)); 450 | int start_pos = io_name.find(deliminator); 451 | if (start_pos == -1) { 452 | *naming_convention = NamingConvention::STRICT_CONFIG_ORDERING; 453 | break; 454 | } else { 455 | // check if the index part of the name is not an integer 456 | std::string index_str = io_name.substr(start_pos + 2); 457 | bool is_int = true; 458 | for (auto itr = index_str.begin(); itr != index_str.end(); itr++) { 459 | if (std::isdigit(*itr) == 0) { 460 | is_int = false; 461 | } 462 | } 463 | 464 | if (!is_int) { 465 | if (io_kind == "input") { 466 | LOG_MESSAGE( 467 | TRITONSERVER_LOG_WARN, 468 | ("input '" + io_name + 469 | "' or previous input(s) are neither an input argument to the " 470 | "model '" + 471 | model_state_->Name() + 472 | "' nor do they follow the __ naming convention. " 473 | "Falling back to enforcing strict ordering from model " 474 | "configuration.") 475 | .c_str()); 476 | } else { 477 | LOG_MESSAGE( 478 | TRITONSERVER_LOG_WARN, 479 | ("output '" + io_name + 480 | "' or previous output(s) of the model '" + 481 | model_state_->Name() + 482 | "' do not follow the __ naming convention. " 483 | "Falling back to enforcing strict ordering from model " 484 | "configuration.") 485 | .c_str()); 486 | } 487 | *naming_convention = NamingConvention::STRICT_CONFIG_ORDERING; 488 | break; 489 | } 490 | } 491 | } 492 | } 493 | 494 | triton::common::TritonJson::Value sequence_batching; 495 | if (model_state_->ModelConfig().Find( 496 | "sequence_batching", &sequence_batching)) { 497 | // If we need to manage state for the model, then we need to check 498 | // the naming of the state adheres to both the input and output conventions 499 | triton::common::TritonJson::Value states; 500 | if (sequence_batching.Find("state", &states)) { 501 | if (*naming_convention != NamingConvention::NAMED_INDEX) { 502 | return TRITONSERVER_ErrorNew( 503 | TRITONSERVER_ERROR_INVALID_ARG, 504 | ("PyTorch model '" + model_state_->Name() + 505 | "' is using sequence batching with state but not all inputs and " 506 | "outputs follow the __ naming convention. ") 507 | .c_str()); 508 | } 509 | } 510 | 511 | for (size_t i = 0; i < states.ArraySize(); i++) { 512 | triton::common::TritonJson::Value state; 513 | RETURN_IF_ERROR(states.IndexAsObject(i, &state)); 514 | std::string name_entry = 515 | io_kind == "input" ? "input_name" : "output_name"; 516 | std::string state_name; 517 | RETURN_IF_ERROR(state.MemberAsString(name_entry.c_str(), &state_name)); 518 | int start_pos = state_name.find(deliminator); 519 | if (start_pos == -1) { 520 | return TRITONSERVER_ErrorNew( 521 | TRITONSERVER_ERROR_INVALID_ARG, 522 | ("PyTorch model '" + model_state_->Name() + 523 | "' is using sequence batching with state but state '" + 524 | state_name + 525 | "' does not follow the __ naming convention. ") 526 | .c_str()); 527 | } else { 528 | // check if the index part of the name is not an integer 529 | std::string index_str = state_name.substr(start_pos + 2); 530 | bool is_int = true; 531 | for (auto itr = index_str.begin(); itr != index_str.end(); itr++) { 532 | if (std::isdigit(*itr) == 0) { 533 | is_int = false; 534 | } 535 | } 536 | if (!is_int) { 537 | return TRITONSERVER_ErrorNew( 538 | TRITONSERVER_ERROR_INVALID_ARG, 539 | ("PyTorch model '" + model_state_->Name() + 540 | "' is using sequence batching with state but state '" + 541 | state_name + 542 | "' does not follow the __ naming convention. ") 543 | .c_str()); 544 | } 545 | } 546 | } 547 | } 548 | 549 | return nullptr; // success 550 | } 551 | 552 | void 553 | ModelInstanceState::ProcessRequests( 554 | TRITONBACKEND_Request** requests, const uint32_t request_count) 555 | { 556 | LOG_MESSAGE( 557 | TRITONSERVER_LOG_VERBOSE, 558 | (std::string("TRITONBACKEND_ModelExecute: Running ") + Name() + " with " + 559 | std::to_string(request_count) + " requests") 560 | .c_str()); 561 | 562 | #ifdef TRITON_ENABLE_GPU 563 | if (Kind() == TRITONSERVER_INSTANCEGROUPKIND_GPU) { 564 | SetCurrentCudaStream(stream_, DeviceId()); 565 | } else if (Kind() == TRITONSERVER_INSTANCEGROUPKIND_MODEL) { 566 | // Replace the default stream of each device with the one we created. 567 | for (size_t i = 0; i < stream_vec_.size(); i++) { 568 | SetCurrentCudaStream(stream_vec_[i], i); 569 | } 570 | } 571 | #endif 572 | 573 | NVTX_RANGE(nvtx_, "ProcessRequests " + Name()); 574 | 575 | uint64_t exec_start_ns = 0; 576 | SET_TIMESTAMP(exec_start_ns); 577 | 578 | const int max_batch_size = model_state_->MaxBatchSize(); 579 | 580 | // For each request collect the total batch size for this inference 581 | // execution. The batch-size, number of inputs, and size of each 582 | // input has already been checked so don't need to do that here. 583 | size_t total_batch_size = 0; 584 | for (size_t i = 0; i < request_count; i++) { 585 | // If we get a nullptr request then something is badly wrong. Fail 586 | // and release all requests. 587 | if (requests[i] == nullptr) { 588 | RequestsRespondWithError( 589 | requests, request_count, 590 | TRITONSERVER_ErrorNew( 591 | TRITONSERVER_ERROR_INTERNAL, 592 | std::string( 593 | "null request given to PyTorch backend for '" + Name() + "'") 594 | .c_str())); 595 | return; 596 | } 597 | } 598 | 599 | // At this point we are committed to running inference with all 600 | // 'requests'. Create a response for each request. During input 601 | // processing if there is an error with any request that error will 602 | // be sent immediately with the corresponding response (and the 603 | // response unique_ptr will then be nullptr). The request object 604 | // itself will not be released until after all inferencing is done 605 | // (below) as we may need to access the request object when 606 | // determine how to process outputs (for example, even if we don't 607 | // need the outputs for a request that has an error, we do need to 608 | // know the size of those outputs associated with the request so we 609 | // can skip them in the output tensors). 610 | std::vector responses; 611 | responses.reserve(request_count); 612 | bool all_response_failed = false; 613 | 614 | for (size_t i = 0; i < request_count; i++) { 615 | TRITONBACKEND_Response* response; 616 | auto err = TRITONBACKEND_ResponseNew(&response, requests[i]); 617 | if (err == nullptr) { 618 | responses.emplace_back(response); 619 | } else { 620 | responses.emplace_back(nullptr); 621 | LOG_MESSAGE(TRITONSERVER_LOG_ERROR, "Fail to create response"); 622 | TRITONSERVER_ErrorDelete(err); 623 | } 624 | } 625 | 626 | for (size_t i = 0; i < request_count; i++) { 627 | if (max_batch_size > 0) { 628 | // Retrieve the batch size from one of the inputs, if the model 629 | // supports batching, the first dimension size is batch size. 630 | TRITONBACKEND_Input* input; 631 | TRITONSERVER_Error* err = 632 | TRITONBACKEND_RequestInputByIndex(requests[i], 0 /* index */, &input); 633 | if (err == nullptr) { 634 | const int64_t* shape; 635 | err = TRITONBACKEND_InputProperties( 636 | input, nullptr, nullptr, &shape, nullptr, nullptr, nullptr); 637 | total_batch_size += shape[0]; 638 | } 639 | if (err != nullptr) { 640 | RESPOND_ALL_AND_SET_TRUE_IF_ERROR( 641 | responses, request_count, all_response_failed, err); 642 | } 643 | } else { 644 | total_batch_size += 1; 645 | } 646 | } 647 | 648 | // If there are no valid payloads then no need to run the inference. 649 | if (total_batch_size == 0) { 650 | return; 651 | } 652 | 653 | // Make sure the maximum batch size is not exceeded. The 654 | // total_batch_size must be 1 for models that don't support batching 655 | // (i.e. max_batch_size == 0). If max_batch_size is exceeded then 656 | // scheduler has done something badly wrong so fail and release all 657 | // requests. 658 | if (!all_response_failed) { 659 | if ((total_batch_size != 1) && 660 | (total_batch_size > (size_t)max_batch_size)) { 661 | RESPOND_ALL_AND_SET_TRUE_IF_ERROR( 662 | responses, request_count, all_response_failed, 663 | TRITONSERVER_ErrorNew( 664 | TRITONSERVER_ERROR_INTERNAL, 665 | std::string( 666 | "batch size " + std::to_string(total_batch_size) + " for '" + 667 | Name() + "', max allowed is " + 668 | std::to_string(max_batch_size)) 669 | .c_str())); 670 | } 671 | } 672 | 673 | std::vector input_names; 674 | std::vector input_tensors; 675 | bool cuda_copy = false; 676 | std::unique_ptr collector; 677 | 678 | // For 'KIND_MODEL', it's fine to use CUDA events to calculate the compute 679 | // input duration since only one stream will be used for input collection. 680 | if ((Kind() == TRITONSERVER_INSTANCEGROUPKIND_GPU) || 681 | ((Kind() == TRITONSERVER_INSTANCEGROUPKIND_MODEL) && (device_cnt_ > 0))) { 682 | #ifdef TRITON_ENABLE_GPU 683 | RESPOND_ALL_AND_SET_TRUE_IF_ERROR( 684 | responses, request_count, all_response_failed, 685 | ConvertCUDAStatusToTritonError( 686 | cudaEventRecord( 687 | compute_input_start_event_, GetCudaStreamByInstanceKind()), 688 | TRITONSERVER_ERROR_INTERNAL, "Failed to record the event.")); 689 | #endif 690 | } 691 | 692 | if (!all_response_failed) { 693 | collector.reset(new BackendInputCollector( 694 | requests, request_count, &responses, 695 | model_state_->TritonMemoryManager(), model_state_->EnablePinnedInput(), 696 | GetCudaStreamByInstanceKind(), nullptr, nullptr, 0, 697 | HostPolicyName().c_str())); 698 | RESPOND_ALL_AND_SET_TRUE_IF_ERROR( 699 | responses, request_count, all_response_failed, 700 | SetInputTensors( 701 | total_batch_size, requests, request_count, &responses, 702 | collector.get(), &input_names, &input_tensors, &cuda_copy)); 703 | } 704 | 705 | #ifdef TRITON_ENABLE_GPU 706 | if (cuda_copy) { 707 | cudaStreamSynchronize(GetCudaStreamByInstanceKind()); 708 | cuda_copy = false; 709 | } 710 | #endif 711 | 712 | std::vector output_tensors; 713 | uint64_t compute_start_ns = 0; 714 | uint64_t compute_infer_start = 0; 715 | 716 | RESPOND_ALL_AND_SET_TRUE_IF_ERROR( 717 | responses, request_count, all_response_failed, 718 | RecordBackendTimestamp( 719 | &compute_start_ns, 720 | reinterpret_cast(&compute_infer_start_event_))); 721 | 722 | // For 'KIND_MODEL', capture the timestamp for the compute infer duration. 723 | if ((Kind() == TRITONSERVER_INSTANCEGROUPKIND_MODEL) && (device_cnt_ > 0)) { 724 | SET_TIMESTAMP(compute_infer_start); 725 | } 726 | 727 | // Run... 728 | if (!all_response_failed) { 729 | Execute(&responses, request_count, &input_tensors, &output_tensors); 730 | } 731 | 732 | // Verify output indices are valid with number of outputs after execution 733 | bool invalid_index = false; 734 | int max_index = output_tensors.size() - 1; 735 | 736 | if (!all_response_failed) { 737 | for (const auto& name : model_state_->ModelOutputs()) { 738 | int op_index = output_index_map_[name.first]; 739 | if ((op_index < 0) || (op_index > max_index)) { 740 | RESPOND_ALL_AND_SET_TRUE_IF_ERROR( 741 | responses, request_count, all_response_failed, 742 | TRITONSERVER_ErrorNew( 743 | TRITONSERVER_ERROR_INVALID_ARG, 744 | std::string( 745 | "The output " + std::string(name.first) + 746 | " in the model configuration refers to an output index " 747 | "which doesn't exist. This model has " + 748 | std::to_string(max_index + 1) + " outputs") 749 | .c_str())); 750 | invalid_index = true; 751 | break; 752 | } 753 | } 754 | } 755 | 756 | #ifdef TRITON_ENABLE_GPU 757 | if (Kind() == TRITONSERVER_INSTANCEGROUPKIND_MODEL) { 758 | // For 'KIND_MODEL', multiple streams will be involved, so we need to call 759 | // 'cudaStreamSynchronize' before reading the output tensors. 760 | for (auto& stream : stream_vec_) { 761 | cudaStreamSynchronize(stream); 762 | } 763 | } 764 | #endif 765 | 766 | uint64_t compute_end_ns = 0; 767 | uint64_t compute_output_start = 0; 768 | 769 | if ((Kind() == TRITONSERVER_INSTANCEGROUPKIND_MODEL) && (device_cnt_ > 0)) { 770 | #ifdef TRITON_ENABLE_GPU 771 | SET_TIMESTAMP(compute_output_start); 772 | #endif 773 | } else { 774 | RESPOND_ALL_AND_SET_TRUE_IF_ERROR( 775 | responses, request_count, all_response_failed, 776 | RecordBackendTimestamp( 777 | &compute_end_ns, 778 | reinterpret_cast(&compute_output_start_event_))); 779 | } 780 | 781 | if (!all_response_failed) { 782 | if (!invalid_index) { 783 | RESPOND_ALL_AND_SET_TRUE_IF_ERROR( 784 | responses, request_count, all_response_failed, 785 | ReadOutputTensors( 786 | total_batch_size, output_tensors, requests, request_count, 787 | &responses)); 788 | } 789 | } 790 | 791 | uint64_t exec_end_ns = 0; 792 | SET_TIMESTAMP(exec_end_ns); 793 | 794 | // Send all the responses that haven't already been sent because of 795 | // an earlier error. Note that the responses are not set to nullptr 796 | // here as we need that indication below to determine if the request 797 | // we successful or not. 798 | for (auto& response : responses) { 799 | if (response != nullptr) { 800 | LOG_IF_ERROR( 801 | TRITONBACKEND_ResponseSend( 802 | response, TRITONSERVER_RESPONSE_COMPLETE_FINAL, nullptr), 803 | "failed to send PyTorch backend response"); 804 | } 805 | } 806 | 807 | // We don't need an explicit CUDA syncrhonization here since we have already 808 | // synchronized the stream in the ReadOutputTensors function. 809 | if (Kind() == TRITONSERVER_INSTANCEGROUPKIND_GPU) { 810 | #ifdef TRITON_ENABLE_GPU 811 | float compute_input_duration = GetCudaEventElapsedTime( 812 | compute_input_start_event_, compute_infer_start_event_); 813 | float compute_infer_duration = GetCudaEventElapsedTime( 814 | compute_infer_start_event_, compute_output_start_event_); 815 | 816 | compute_start_ns = exec_start_ns + (compute_input_duration * 1e6); 817 | compute_end_ns = compute_start_ns + (compute_infer_duration * 1e6); 818 | #endif 819 | } else if ( 820 | (Kind() == TRITONSERVER_INSTANCEGROUPKIND_MODEL) && (device_cnt_ > 0)) { 821 | #ifdef TRITON_ENABLE_GPU 822 | float compute_input_duration = GetCudaEventElapsedTime( 823 | compute_input_start_event_, compute_infer_start_event_); 824 | uint64_t compute_infer_duration = 825 | compute_output_start - compute_infer_start; 826 | 827 | compute_start_ns = exec_start_ns + (compute_input_duration * 1e6); 828 | compute_end_ns = compute_start_ns + compute_infer_duration; 829 | #endif 830 | } 831 | 832 | // Report statistics for each request. 833 | for (uint32_t r = 0; r < request_count; ++r) { 834 | auto& request = requests[r]; 835 | LOG_IF_ERROR( 836 | TRITONBACKEND_ModelInstanceReportStatistics( 837 | TritonModelInstance(), request, 838 | (responses[r] != nullptr) /* success */, exec_start_ns, 839 | compute_start_ns, compute_end_ns, exec_end_ns), 840 | "failed reporting request statistics"); 841 | 842 | LOG_IF_ERROR( 843 | TRITONBACKEND_RequestRelease(request, TRITONSERVER_REQUEST_RELEASE_ALL), 844 | "failed releasing request"); 845 | } 846 | 847 | if (!all_response_failed) { 848 | // Report the entire batch statistics. 849 | LOG_IF_ERROR( 850 | TRITONBACKEND_ModelInstanceReportBatchStatistics( 851 | TritonModelInstance(), total_batch_size, exec_start_ns, 852 | compute_start_ns, compute_end_ns, exec_end_ns), 853 | "failed reporting batch request statistics"); 854 | } 855 | } 856 | 857 | TRITONSERVER_Error* 858 | ModelInstanceState::ReadOutputTensors( 859 | size_t total_batch_size, 860 | const std::vector& output_tensors, 861 | TRITONBACKEND_Request** requests, const uint32_t request_count, 862 | std::vector* responses) 863 | { 864 | NVTX_RANGE(nvtx_, "ReadOutputTensors " + Name()); 865 | 866 | BackendOutputResponder responder( 867 | requests, request_count, responses, model_state_->TritonMemoryManager(), 868 | model_state_->MaxBatchSize() > 0, model_state_->EnablePinnedInput(), 869 | GetCudaStreamByInstanceKind()); 870 | 871 | bool cuda_copy = false; 872 | // The serialized string buffer must be valid until output copies are done 873 | std::vector> string_buffer; 874 | for (auto& output : model_state_->ModelOutputs()) { 875 | int op_index = output_index_map_[output.first]; 876 | auto name = output.first; 877 | auto output_tensor_pair = output.second; 878 | 879 | if (output_tensors[op_index].isTensor()) { 880 | torch::Tensor output_flat; 881 | try { 882 | output_flat = 883 | output_tensors[op_index].toTensor().contiguous().flatten(); 884 | } 885 | catch (std::exception& ex) { 886 | RETURN_IF_ERROR(TRITONSERVER_ErrorNew( 887 | TRITONSERVER_ERROR_INTERNAL, 888 | (std::string("output tensor '") + name + "' is not found") 889 | .c_str())); 890 | } 891 | 892 | // Verify output datatype matches datatype from model config 893 | TRITONSERVER_DataType output_dtype = 894 | ConvertTorchTypeToDataType(output_flat.scalar_type()); 895 | TRITONSERVER_DataType config_datatype = output_dtype_map_[name]; 896 | if (config_datatype != output_dtype) { 897 | RETURN_IF_ERROR(TRITONSERVER_ErrorNew( 898 | TRITONSERVER_ERROR_INVALID_ARG, 899 | (std::string("configuration expects datatype TYPE_") + 900 | TRITONSERVER_DataTypeString(config_datatype) + " for output '" + 901 | name + "', model provides TYPE_" + 902 | TRITONSERVER_DataTypeString(output_dtype)) 903 | .c_str())); 904 | } 905 | 906 | const char* output_buffer = 907 | static_cast(output_flat.data_ptr()); 908 | 909 | // Output tensors may not reside on the same device as model 910 | torch::Device tensor_device = output_flat.device(); 911 | const auto memory_type = (tensor_device.type() == torch::kCPU) 912 | ? TRITONSERVER_MEMORY_CPU 913 | : TRITONSERVER_MEMORY_GPU; 914 | const auto memory_id = 915 | (tensor_device.type() == torch::kCPU) ? 0 : tensor_device.index(); 916 | 917 | // Batch output doesn't support string data type yet, as it is not trivial 918 | // to parse string output 919 | const BatchOutput* batch_output = StateForModel()->FindBatchOutput(name); 920 | if (batch_output == nullptr) { 921 | // Get output shape 922 | std::vector batchn_shape; 923 | auto shape = output_tensors[op_index].toTensor().sizes(); 924 | for (auto itr = shape.begin(); itr != shape.end(); itr++) { 925 | batchn_shape.push_back(*itr); 926 | } 927 | 928 | if (batchn_shape.size() == 0) { 929 | return TRITONSERVER_ErrorNew( 930 | TRITONSERVER_ERROR_INVALID_ARG, 931 | (std::string("output '") + name + 932 | "' is a scalar which is not supported.") 933 | .c_str()); 934 | } 935 | if (output_tensor_pair.first != -1) { 936 | responder.ProcessTensor( 937 | name, output_dtype, batchn_shape, output_buffer, memory_type, 938 | memory_id); 939 | } 940 | if (output_tensor_pair.second != -1) { 941 | std::vector states; 942 | states = responder.ProcessStateTensor( 943 | name, output_dtype, batchn_shape, output_buffer, memory_type, 944 | memory_id); 945 | // Update the states 946 | for (auto& state : states) { 947 | RETURN_IF_ERROR(TRITONBACKEND_StateUpdate(state)); 948 | } 949 | } 950 | 951 | } else { 952 | responder.ProcessBatchOutput( 953 | name, *batch_output, output_buffer, memory_type, memory_id); 954 | } 955 | } else if (output_tensors[op_index].isList()) { 956 | // Custom handling for string/bytes tensor... 957 | torch::List output_list = 958 | output_tensors[op_index].toList(); 959 | 960 | // Get output shape 961 | std::vector batchn_shape{(int64_t)output_list.size()}; 962 | 963 | for (size_t idx = 0; idx < responses->size(); idx++) { 964 | auto& request = requests[idx]; 965 | auto& response = (*responses)[idx]; 966 | 967 | if (supports_batching_ != 0) { 968 | TRITONBACKEND_Input* input; 969 | TRITONBACKEND_RequestInputByIndex(request, 0 /* index*/, &input); 970 | const int64_t* shape; 971 | TRITONBACKEND_InputProperties( 972 | input, nullptr, nullptr, &shape, nullptr, nullptr, nullptr); 973 | batchn_shape[0] = shape[0]; 974 | } 975 | 976 | int64_t tensor_element_cnt = 0; 977 | RETURN_IF_ERROR(GetElementCount(batchn_shape, &tensor_element_cnt)); 978 | 979 | // Only need an response tensor for requested outputs. 980 | if (response != nullptr) { 981 | if (output_tensor_pair.first != -1) { 982 | TRITONBACKEND_Output* response_output; 983 | RESPOND_AND_SET_NULL_IF_ERROR( 984 | &response, TRITONBACKEND_ResponseOutput( 985 | response, &response_output, name.c_str(), 986 | TRITONSERVER_TYPE_BYTES, batchn_shape.data(), 987 | batchn_shape.size())); 988 | string_buffer.emplace_back(new std::string()); 989 | cuda_copy |= SetStringOutputBuffer( 990 | &output_list, &response, response_output, tensor_element_cnt, 991 | GetCudaStreamByInstanceKind(), string_buffer.back().get()); 992 | } 993 | } 994 | if (output_tensor_pair.second != -1) { 995 | TRITONBACKEND_State* response_state; 996 | RESPOND_AND_SET_NULL_IF_ERROR( 997 | &response, TRITONBACKEND_StateNew( 998 | &response_state, request, name.c_str(), 999 | TRITONSERVER_TYPE_BYTES, batchn_shape.data(), 1000 | batchn_shape.size())); 1001 | 1002 | string_buffer.emplace_back(new std::string()); 1003 | cuda_copy |= SetStringStateBuffer( 1004 | &output_list, &response, response_state, tensor_element_cnt, 1005 | GetCudaStreamByInstanceKind(), string_buffer.back().get()); 1006 | } 1007 | } 1008 | } else { 1009 | return TRITONSERVER_ErrorNew( 1010 | TRITONSERVER_ERROR_INVALID_ARG, 1011 | (std::string("output '") + name + 1012 | "' must be of type Tensor or List[str].") 1013 | .c_str()); 1014 | } 1015 | } 1016 | 1017 | // Finalize and wait for any pending buffer copies. 1018 | cuda_copy |= responder.Finalize(); 1019 | 1020 | #ifdef TRITON_ENABLE_GPU 1021 | // We have to always synchronize the stream. This is to make sure that 1022 | // the events on the cuda stream are synchronized. Otherwise, the events 1023 | // are only guaranteed to be synchronized if the model provides the output 1024 | // on GPU. 1025 | cudaStreamSynchronize(GetCudaStreamByInstanceKind()); 1026 | #endif 1027 | 1028 | return nullptr; 1029 | } 1030 | 1031 | TRITONSERVER_Error* 1032 | ModelInstanceState::RecordBackendTimestamp( 1033 | uint64_t* timestamp, void* cuda_event) 1034 | { 1035 | if ((Kind() == TRITONSERVER_INSTANCEGROUPKIND_GPU) || 1036 | ((Kind() == TRITONSERVER_INSTANCEGROUPKIND_MODEL) && (device_cnt_ > 0))) { 1037 | #ifdef TRITON_ENABLE_GPU 1038 | cudaEvent_t* lcuda_event = reinterpret_cast(cuda_event); 1039 | RETURN_IF_ERROR(ConvertCUDAStatusToTritonError( 1040 | cudaEventRecord(*lcuda_event, GetCudaStreamByInstanceKind()), 1041 | TRITONSERVER_ERROR_INTERNAL, "Failed to record the event.")); 1042 | #endif 1043 | } else { 1044 | SET_TIMESTAMP(*timestamp); 1045 | } 1046 | return nullptr; 1047 | } 1048 | 1049 | void 1050 | ModelInstanceState::SetCurrentCudaStream( 1051 | const cudaStream_t& stream, const int& device_id) 1052 | { 1053 | #ifdef TRITON_ENABLE_GPU 1054 | at::cuda::CUDAStream torch_stream = 1055 | at::cuda::getStreamFromExternal(stream, device_id); 1056 | // This function replaces the default stream with the stream we created. It 1057 | // is not necessary to change the current device to the desired device when 1058 | // replacing the default stream for that device. See the documentation here: 1059 | // https://pytorch.org/cppdocs/api/function_namespacec10_1_1cuda_1a6ed50cc0fc16cc7014d9c2f4c3bd098d.html 1060 | at::cuda::setCurrentCUDAStream(torch_stream); 1061 | #endif 1062 | } 1063 | 1064 | TRITONSERVER_Error* 1065 | ModelInstanceState::SetInputTensors( 1066 | size_t total_batch_size, TRITONBACKEND_Request** requests, 1067 | const uint32_t request_count, 1068 | std::vector* responses, 1069 | BackendInputCollector* collector, std::vector* input_names, 1070 | std::vector* input_tensors, bool* cuda_copy) 1071 | { 1072 | // InferenceMode should be used to guard all tensors operations 1073 | torch::InferenceMode infer_guard(model_state_->EnabledInferenceMode()); 1074 | 1075 | // All requests must have equally-sized input tensors so use any 1076 | // request as the representative for the input tensors. 1077 | uint32_t input_count; 1078 | RETURN_IF_ERROR(TRITONBACKEND_RequestInputCount(requests[0], &input_count)); 1079 | 1080 | input_tensors->resize(input_count + batch_input_count_); 1081 | 1082 | // The inputs must be in contiguous CPU/GPU memory. 1083 | std::vector> alloc_perference; 1084 | if (device_.is_cpu()) { 1085 | alloc_perference = { 1086 | {TRITONSERVER_MEMORY_CPU_PINNED, 0}, {TRITONSERVER_MEMORY_CPU, 0}}; 1087 | } else { 1088 | alloc_perference = {{TRITONSERVER_MEMORY_GPU, device_.index()}}; 1089 | } 1090 | 1091 | for (uint32_t input_idx = 0; input_idx < input_count; input_idx++) { 1092 | TRITONBACKEND_Input* input; 1093 | RETURN_IF_ERROR( 1094 | TRITONBACKEND_RequestInputByIndex(requests[0], input_idx, &input)); 1095 | 1096 | const char* input_name; 1097 | TRITONSERVER_DataType input_datatype; 1098 | const int64_t* input_shape; 1099 | uint32_t input_dims_count; 1100 | RETURN_IF_ERROR(TRITONBACKEND_InputProperties( 1101 | input, &input_name, &input_datatype, &input_shape, &input_dims_count, 1102 | nullptr, nullptr)); 1103 | 1104 | input_names->emplace_back(input_name); 1105 | 1106 | // The shape for the entire input patch, 1107 | // [total_batch_size, ...] for non-ragged input and 1108 | // [total_element_count] for ragged input (non-nested tensor) 1109 | std::vector batchn_shape; 1110 | if (StateForModel()->IsInputRagged(input_name)) { 1111 | batchn_shape = std::vector{0}; 1112 | for (size_t idx = 0; idx < request_count; idx++) { 1113 | TRITONBACKEND_Input* input; 1114 | RESPOND_AND_SET_NULL_IF_ERROR( 1115 | &((*responses)[idx]), 1116 | TRITONBACKEND_RequestInput(requests[idx], input_name, &input)); 1117 | const int64_t* input_shape; 1118 | uint32_t input_dims_count; 1119 | RESPOND_AND_SET_NULL_IF_ERROR( 1120 | &((*responses)[idx]), TRITONBACKEND_InputProperties( 1121 | input, nullptr, nullptr, &input_shape, 1122 | &input_dims_count, nullptr, nullptr)); 1123 | 1124 | int64_t element_cnt = 0; 1125 | RESPOND_AND_SET_NULL_IF_ERROR( 1126 | &((*responses)[idx]), 1127 | GetElementCount(input_shape, input_dims_count, &element_cnt)); 1128 | batchn_shape[0] += element_cnt; 1129 | } 1130 | } else { 1131 | batchn_shape = 1132 | std::vector(input_shape, input_shape + input_dims_count); 1133 | if (supports_batching_) { 1134 | batchn_shape[0] = total_batch_size; 1135 | } 1136 | } 1137 | 1138 | // The input must be in contiguous CPU/GPU memory. 1139 | std::vector> alloc_perference; 1140 | // For 'KIND_MODEL', input will always be in CPU as we don't have a way to 1141 | // query the input types. 1142 | if (device_.is_cpu() || (Kind() == TRITONSERVER_INSTANCEGROUPKIND_MODEL)) { 1143 | alloc_perference = { 1144 | {TRITONSERVER_MEMORY_CPU_PINNED, 0}, {TRITONSERVER_MEMORY_CPU, 0}}; 1145 | } else { 1146 | alloc_perference = {{TRITONSERVER_MEMORY_GPU, device_.index()}}; 1147 | } 1148 | 1149 | const char* input_buffer; 1150 | size_t batchn_byte_size; 1151 | TRITONSERVER_MemoryType memory_type; 1152 | int64_t memory_type_id; 1153 | RETURN_IF_ERROR(collector->ProcessTensor( 1154 | input_name, nullptr, 0, alloc_perference, &input_buffer, 1155 | &batchn_byte_size, &memory_type, &memory_type_id)); 1156 | 1157 | // Create Torch tensor 1158 | const auto torch_dtype = ConvertDataTypeToTorchType(input_datatype); 1159 | torch::TensorOptions options{torch_dtype.second}; 1160 | auto updated_options = (memory_type == TRITONSERVER_MEMORY_GPU) 1161 | ? options.device(torch::kCUDA, device_.index()) 1162 | : options.device(torch::kCPU); 1163 | 1164 | if (input_datatype == TRITONSERVER_TYPE_BYTES) { 1165 | // Create the PyTorch list to hold the strings. 1166 | torch::List input_list; 1167 | input_list.reserve(batchn_shape[0]); 1168 | 1169 | for (size_t idx = 0; idx < request_count; idx++) { 1170 | TRITONBACKEND_Input* input; 1171 | RESPOND_AND_SET_NULL_IF_ERROR( 1172 | &((*responses)[idx]), 1173 | TRITONBACKEND_RequestInput(requests[idx], input_name, &input)); 1174 | const int64_t* shape; 1175 | uint32_t dims_count; 1176 | uint32_t buffer_count; 1177 | RESPOND_AND_SET_NULL_IF_ERROR( 1178 | &((*responses)[idx]), 1179 | TRITONBACKEND_InputPropertiesForHostPolicy( 1180 | input, HostPolicyName().c_str(), nullptr, nullptr, &shape, 1181 | &dims_count, nullptr, &buffer_count)); 1182 | 1183 | int64_t batch_element_cnt = 0; 1184 | RESPOND_AND_SET_NULL_IF_ERROR( 1185 | &((*responses)[idx]), 1186 | GetElementCount(shape, dims_count, &batch_element_cnt)); 1187 | 1188 | *cuda_copy |= SetStringInputTensor( 1189 | &input_list, input, input_name, buffer_count, batch_element_cnt, 1190 | &((*responses)[idx]), GetCudaStreamByInstanceKind(), 1191 | HostPolicyName().c_str()); 1192 | } 1193 | 1194 | (*input_tensors)[input_index_map_[input_name]] = input_list; 1195 | } else { 1196 | if (batchn_byte_size) { 1197 | // Remove constness to align with the signature of torch::from_blob() 1198 | torch::Tensor input_tensor = torch::from_blob( 1199 | const_cast(input_buffer), batchn_shape, updated_options); 1200 | (*input_tensors)[input_index_map_[input_name]] = input_tensor; 1201 | } else { 1202 | // torch:from_blob seems not working when the input size is 0 1203 | // create zero-length inputs directly 1204 | torch::Tensor input_tensor = 1205 | torch::zeros(batchn_shape, updated_options); 1206 | (*input_tensors)[input_index_map_[input_name]] = input_tensor; 1207 | } 1208 | } 1209 | } 1210 | 1211 | for (const auto& batch_input : StateForModel()->BatchInputs()) { 1212 | std::vector shape; 1213 | collector->BatchInputShape(batch_input, &shape); 1214 | 1215 | for (const auto& input_name : batch_input.TargetNames()) { 1216 | input_names->emplace_back(input_name.c_str()); 1217 | 1218 | const char* dst_buffer; 1219 | size_t dst_buffer_byte_size; 1220 | TRITONSERVER_MemoryType dst_memory_type; 1221 | int64_t dst_memory_type_id; 1222 | 1223 | RESPOND_ALL_AND_SET_NULL_IF_ERROR( 1224 | (*responses), responses->size(), 1225 | collector->ProcessBatchInput( 1226 | batch_input, nullptr, 0, alloc_perference, &dst_buffer, 1227 | &dst_buffer_byte_size, &dst_memory_type, &dst_memory_type_id)); 1228 | 1229 | const auto torch_dtype = 1230 | ConvertDataTypeToTorchType(batch_input.DataType()); 1231 | torch::TensorOptions options{torch_dtype.second}; 1232 | auto updated_options = (dst_memory_type == TRITONSERVER_MEMORY_GPU) 1233 | ? options.device(torch::kCUDA, device_.index()) 1234 | : options.device(torch::kCPU); 1235 | 1236 | if (dst_buffer_byte_size) { 1237 | torch::Tensor input_tensor = torch::from_blob( 1238 | const_cast(dst_buffer), shape, updated_options); 1239 | (*input_tensors)[input_index_map_[input_name]] = input_tensor; 1240 | } else { 1241 | // special handle when input has zero size 1242 | torch::Tensor input_tensor = torch::zeros(shape, updated_options); 1243 | (*input_tensors)[input_index_map_[input_name]] = input_tensor; 1244 | } 1245 | } 1246 | } 1247 | 1248 | // Finalize... 1249 | *cuda_copy |= collector->Finalize(); 1250 | 1251 | return nullptr; 1252 | } 1253 | 1254 | ModelState* 1255 | ModelInstanceState::StateForModel() const 1256 | { 1257 | return model_state_; 1258 | } 1259 | 1260 | TRITONSERVER_Error* 1261 | ModelInstanceState::ValidateBooleanSequenceControl( 1262 | triton::common::TritonJson::Value& sequence_batching, 1263 | const std::string& control_kind, bool required, bool* have_control) 1264 | { 1265 | std::string tensor_name; 1266 | std::string tensor_datatype; 1267 | RETURN_IF_ERROR(GetBooleanSequenceControlProperties( 1268 | sequence_batching, model_state_->Name(), control_kind, required, 1269 | &tensor_name, &tensor_datatype, nullptr, nullptr, nullptr, nullptr, 1270 | nullptr, nullptr)); 1271 | *have_control = !tensor_name.empty(); 1272 | if (*have_control) { 1273 | std::string deliminator = "__"; 1274 | int ip_index = 0; 1275 | int start_pos = tensor_name.find(deliminator); 1276 | if (start_pos == -1) { 1277 | return TRITONSERVER_ErrorNew( 1278 | TRITONSERVER_ERROR_INTERNAL, 1279 | ("input '" + tensor_name + 1280 | "' does not follow __ naming convention.") 1281 | .c_str()); 1282 | } 1283 | 1284 | // check if the index part of the name is not an integer 1285 | std::string index_str = tensor_name.substr(start_pos + 2); 1286 | for (auto itr = index_str.begin(); itr != index_str.end(); itr++) { 1287 | if (std::isdigit(*itr) == 0) { 1288 | return TRITONSERVER_ErrorNew( 1289 | TRITONSERVER_ERROR_INTERNAL, 1290 | ("input '" + tensor_name + 1291 | "' does not follow __ naming convention.") 1292 | .c_str()); 1293 | } 1294 | } 1295 | 1296 | ip_index = std::atoi(tensor_name.substr(start_pos + 2).c_str()); 1297 | input_index_map_[tensor_name] = ip_index; 1298 | } 1299 | 1300 | return nullptr; // success 1301 | } 1302 | 1303 | TRITONSERVER_Error* 1304 | ModelInstanceState::ValidateInputs(const size_t expected_input_cnt) 1305 | { 1306 | // Collect all the expected input tensor names and validate that the model 1307 | // configuration specifies only those. 1308 | std::vector allowed_inputs; 1309 | 1310 | const torch::jit::Method& method = torch_model_->get_method("forward"); 1311 | const auto& schema = method.function().getSchema(); 1312 | const std::vector& arguments = schema.arguments(); 1313 | 1314 | // Currently, only models with a single input of type Dict(str, Tensor) are 1315 | // supported. If the model expects more than one input then they must be all 1316 | // be of type Tensor. 1317 | // 1318 | // Ignore the argument at idx 0 if it is of Class type (self param in forward 1319 | // function) 1320 | size_t start_idx = 0; 1321 | if ((arguments.size() > 0) && 1322 | (arguments.at(0).type()->kind() == c10::TypeKind::ClassType)) { 1323 | start_idx = 1; 1324 | } 1325 | if ((arguments.size() == (1 + start_idx)) && 1326 | (arguments.at(start_idx).type()->kind() == c10::TypeKind::DictType)) { 1327 | is_dict_input_ = true; 1328 | } else if (arguments.size() > start_idx) { 1329 | // Return error if multiple inputs are of kind DictType 1330 | for (size_t i = start_idx + 1; i < arguments.size(); i++) { 1331 | if (arguments.at(i).type()->kind() == c10::TypeKind::DictType) { 1332 | return TRITONSERVER_ErrorNew( 1333 | TRITONSERVER_ERROR_INTERNAL, 1334 | "Multiple inputs of kind DictType were detected. Only a single " 1335 | "input of type Dict(str, Tensor) is supported."); 1336 | } 1337 | } 1338 | 1339 | // Return error if all inputs are not of type Tensor 1340 | for (size_t i = start_idx; i < arguments.size(); i++) { 1341 | if ((arguments.at(i).type()->kind() != c10::TypeKind::TensorType) && 1342 | (arguments.at(i).type()->kind() != c10::TypeKind::ListType)) { 1343 | return TRITONSERVER_ErrorNew( 1344 | TRITONSERVER_ERROR_INTERNAL, 1345 | (std::string("An input of type '") + arguments.at(i).type()->str() + 1346 | "' was detected in the model. Only a single input of type " 1347 | "Dict(str, Tensor) or input(s) of type Tensor are supported.") 1348 | .c_str()); 1349 | } 1350 | allowed_inputs.emplace_back(arguments.at(i).name()); 1351 | } 1352 | 1353 | // If all inputs are tensors, match number of expected inputs between model 1354 | // and configuration 1355 | if ((arguments.size() - start_idx) != expected_input_cnt) { 1356 | return TRITONSERVER_ErrorNew( 1357 | TRITONSERVER_ERROR_INVALID_ARG, 1358 | (std::string("unable to load model '") + model_state_->Name() + 1359 | "', configuration expects " + std::to_string(expected_input_cnt) + 1360 | " inputs, model provides " + 1361 | std::to_string(arguments.size() - start_idx)) 1362 | .c_str()); 1363 | } 1364 | } 1365 | 1366 | triton::common::TritonJson::Value ios; 1367 | RETURN_IF_ERROR(model_state_->ModelConfig().MemberAsArray("input", &ios)); 1368 | 1369 | if (ios.ArraySize() == 0) { 1370 | return TRITONSERVER_ErrorNew( 1371 | TRITONSERVER_ERROR_INTERNAL, 1372 | "model configuration must contain at least one input, none were " 1373 | "specified."); 1374 | } 1375 | 1376 | NamingConvention naming_convention; 1377 | RETURN_IF_ERROR(GetNamingConvention(&naming_convention, allowed_inputs)); 1378 | 1379 | for (size_t i = 0; i < ios.ArraySize(); i++) { 1380 | triton::common::TritonJson::Value io; 1381 | RETURN_IF_ERROR(ios.IndexAsObject(i, &io)); 1382 | 1383 | // Validate name 1384 | std::string io_name; 1385 | RETURN_IF_ERROR(io.MemberAsString("name", &io_name)); 1386 | AddInputToMap(naming_convention, allowed_inputs, io_name, i); 1387 | // Validate data type 1388 | std::string io_dtype; 1389 | RETURN_IF_ERROR(io.MemberAsString("data_type", &io_dtype)); 1390 | const auto pr = ModelConfigDataTypeToTorchType(io_dtype); 1391 | if (!pr.first && (io_dtype != "TYPE_STRING")) { 1392 | return TRITONSERVER_ErrorNew( 1393 | TRITONSERVER_ERROR_INTERNAL, 1394 | ("unsupported datatype " + io_dtype + " for input '" + io_name + 1395 | "' for model '" + model_state_->Name() + "'") 1396 | .c_str()); 1397 | } 1398 | 1399 | // Validate shape for String inputs. Only allow 1 dimension. 1400 | if (io_dtype == "TYPE_STRING") { 1401 | // If a reshape is provided for the input then use that when 1402 | // validating the model shapes. 1403 | std::vector dims; 1404 | triton::common::TritonJson::Value reshape; 1405 | if (io.Find("reshape", &reshape)) { 1406 | RETURN_IF_ERROR(ParseShape(reshape, "shape", &dims)); 1407 | } else { 1408 | RETURN_IF_ERROR(ParseShape(io, "dims", &dims)); 1409 | } 1410 | 1411 | if ((dims.size() + (supports_batching_ ? 1 : 0)) > 1) { 1412 | return TRITONSERVER_ErrorNew( 1413 | TRITONSERVER_ERROR_INTERNAL, 1414 | ("Triton only supports 1 dimensional List of String as input for " 1415 | "'" + 1416 | std::string(io_name) + "' for model '" + model_state_->Name() + 1417 | "'") 1418 | .c_str()); 1419 | } 1420 | } 1421 | } 1422 | triton::common::TritonJson::Value sequence_batching; 1423 | if (model_state_->ModelConfig().Find( 1424 | "sequence_batching", &sequence_batching)) { 1425 | triton::common::TritonJson::Value states; 1426 | if (sequence_batching.Find("state", &states)) { 1427 | for (size_t i = 0; i < states.ArraySize(); i++) { 1428 | triton::common::TritonJson::Value state; 1429 | RETURN_IF_ERROR(states.IndexAsObject(i, &state)); 1430 | std::string state_name; 1431 | RETURN_IF_ERROR(state.MemberAsString("input_name", &state_name)); 1432 | AddInputToMap(naming_convention, allowed_inputs, state_name, i); 1433 | 1434 | // Validate data type 1435 | std::string state_dtype; 1436 | RETURN_IF_ERROR(state.MemberAsString("data_type", &state_dtype)); 1437 | const auto pr = ModelConfigDataTypeToTorchType(state_dtype); 1438 | if (!pr.first && (state_dtype != "TYPE_STRING")) { 1439 | return TRITONSERVER_ErrorNew( 1440 | TRITONSERVER_ERROR_INTERNAL, 1441 | ("unsupported datatype " + state_dtype + " for input state '" + 1442 | state_name + "' for model '" + model_state_->Name() + "'") 1443 | .c_str()); 1444 | } 1445 | 1446 | // Validate shape for String inputs. Only allow 1 dimension. 1447 | if (state_dtype == "TYPE_STRING") { 1448 | std::vector dims; 1449 | if ((dims.size() + (supports_batching_ ? 1 : 0)) > 1) { 1450 | return TRITONSERVER_ErrorNew( 1451 | TRITONSERVER_ERROR_INTERNAL, 1452 | ("Triton only supports 1 dimensional List of String as input " 1453 | "for " 1454 | "'" + 1455 | std::string(state_name) + "' for model '" + 1456 | model_state_->Name() + "'") 1457 | .c_str()); 1458 | } 1459 | } 1460 | } 1461 | } 1462 | } 1463 | 1464 | triton::common::TritonJson::Value batch_inputs; 1465 | RETURN_IF_ERROR( 1466 | model_state_->ModelConfig().MemberAsArray("batch_input", &batch_inputs)); 1467 | size_t i = 0; 1468 | for (const auto& batch_input : StateForModel()->BatchInputs()) { 1469 | for (const auto& input_name : batch_input.TargetNames()) { 1470 | AddInputToMap( 1471 | naming_convention, allowed_inputs, input_name, i + ios.ArraySize()); 1472 | i++; 1473 | } 1474 | } 1475 | 1476 | return nullptr; // success 1477 | } 1478 | 1479 | TRITONSERVER_Error* 1480 | ModelInstanceState::ValidateOutputs() 1481 | { 1482 | triton::common::TritonJson::Value ios; 1483 | RETURN_IF_ERROR(model_state_->ModelConfig().MemberAsArray("output", &ios)); 1484 | std::string deliminator = "__"; 1485 | int op_index = 0; 1486 | 1487 | if (ios.ArraySize() == 0) { 1488 | return TRITONSERVER_ErrorNew( 1489 | TRITONSERVER_ERROR_INTERNAL, 1490 | "model configuration must contain at least one output, none were " 1491 | "specified."); 1492 | } 1493 | 1494 | NamingConvention naming_convention; 1495 | RETURN_IF_ERROR(GetNamingConvention(&naming_convention, {})); 1496 | 1497 | for (size_t i = 0; i < ios.ArraySize(); i++) { 1498 | triton::common::TritonJson::Value io; 1499 | RETURN_IF_ERROR(ios.IndexAsObject(i, &io)); 1500 | 1501 | // Validate name 1502 | std::string io_name; 1503 | RETURN_IF_ERROR(io.MemberAsString("name", &io_name)); 1504 | switch (naming_convention) { 1505 | case NamingConvention::NAMED_INDEX: { 1506 | int start_pos = io_name.find(deliminator); 1507 | op_index = std::atoi(io_name.substr(start_pos + 2).c_str()); 1508 | break; 1509 | } 1510 | case NamingConvention::STRICT_CONFIG_ORDERING: { 1511 | op_index = i; 1512 | break; 1513 | } 1514 | default: 1515 | break; 1516 | } 1517 | 1518 | // Validate data type 1519 | std::string io_dtype; 1520 | RETURN_IF_ERROR(io.MemberAsString("data_type", &io_dtype)); 1521 | const auto pr = ModelConfigDataTypeToTorchType(io_dtype); 1522 | if (!pr.first && (io_dtype != "TYPE_STRING")) { 1523 | return TRITONSERVER_ErrorNew( 1524 | TRITONSERVER_ERROR_INTERNAL, 1525 | ("unsupported datatype " + io_dtype + " for output '" + io_name + 1526 | "' for model '" + model_state_->Name() + "'") 1527 | .c_str()); 1528 | } 1529 | 1530 | // Validate shape for String outputs. Only allow 1 dimension. 1531 | if (io_dtype == "TYPE_STRING") { 1532 | // If a reshape is provided for the output then use that when 1533 | // validating the model shapes. 1534 | std::vector dims; 1535 | triton::common::TritonJson::Value reshape; 1536 | if (io.Find("reshape", &reshape)) { 1537 | RETURN_IF_ERROR(ParseShape(reshape, "shape", &dims)); 1538 | } else { 1539 | RETURN_IF_ERROR(ParseShape(io, "dims", &dims)); 1540 | } 1541 | 1542 | if ((dims.size() + (supports_batching_ ? 1 : 0)) > 1) { 1543 | return TRITONSERVER_ErrorNew( 1544 | TRITONSERVER_ERROR_INTERNAL, 1545 | ("Triton only supports 1 dimensional List of String as output for " 1546 | "'" + 1547 | std::string(io_name) + "' for model '" + model_state_->Name() + 1548 | "'") 1549 | .c_str()); 1550 | } 1551 | } 1552 | 1553 | output_index_map_[io_name] = op_index; 1554 | output_dtype_map_[io_name] = ConvertTorchTypeToDataType(pr.second); 1555 | } 1556 | 1557 | triton::common::TritonJson::Value sequence_batching; 1558 | if (model_state_->ModelConfig().Find( 1559 | "sequence_batching", &sequence_batching)) { 1560 | triton::common::TritonJson::Value states; 1561 | if (sequence_batching.Find("state", &states)) { 1562 | for (size_t i = 0; i < states.ArraySize(); i++) { 1563 | triton::common::TritonJson::Value state; 1564 | RETURN_IF_ERROR(states.IndexAsObject(i, &state)); 1565 | std::string state_name; 1566 | RETURN_IF_ERROR(state.MemberAsString("output_name", &state_name)); 1567 | std::string state_dtype; 1568 | RETURN_IF_ERROR(state.MemberAsString("data_type", &state_dtype)); 1569 | std::vector dims; 1570 | RETURN_IF_ERROR(ParseShape(state, "dims", &dims)); 1571 | 1572 | // For state, naming convention is enforced to be NAMED_INDEX 1573 | int start_pos = state_name.find(deliminator); 1574 | op_index = std::atoi(state_name.substr(start_pos + 2).c_str()); 1575 | 1576 | const auto pr = ModelConfigDataTypeToTorchType(state_dtype); 1577 | if (!pr.first && (state_dtype != "TYPE_STRING")) { 1578 | return TRITONSERVER_ErrorNew( 1579 | TRITONSERVER_ERROR_INTERNAL, 1580 | ("unsupported datatype " + state_dtype + " for state '" + 1581 | state_name + "' for model '" + model_state_->Name() + "'") 1582 | .c_str()); 1583 | } 1584 | 1585 | // Validate shape for String outputs. Only allow 1 dimension. 1586 | if (state_dtype == "TYPE_STRING") { 1587 | if ((dims.size() + (supports_batching_ ? 1 : 0)) > 1) { 1588 | return TRITONSERVER_ErrorNew( 1589 | TRITONSERVER_ERROR_INTERNAL, 1590 | ("Triton only supports 1 dimensional List of String as output " 1591 | "for " 1592 | "'" + 1593 | std::string(state_name) + "' for model '" + 1594 | model_state_->Name() + "'") 1595 | .c_str()); 1596 | } 1597 | } 1598 | 1599 | output_index_map_[state_name] = op_index; 1600 | output_dtype_map_[state_name] = ConvertTorchTypeToDataType(pr.second); 1601 | } 1602 | } 1603 | } 1604 | 1605 | return nullptr; // success 1606 | } 1607 | 1608 | TRITONSERVER_Error* 1609 | ModelInstanceState::ValidateTypedSequenceControl( 1610 | triton::common::TritonJson::Value& sequence_batching, 1611 | const std::string& control_kind, bool required, bool* have_control) 1612 | { 1613 | std::string tensor_name; 1614 | std::string tensor_datatype; 1615 | RETURN_IF_ERROR(GetTypedSequenceControlProperties( 1616 | sequence_batching, model_state_->Name(), control_kind, required, 1617 | &tensor_name, &tensor_datatype)); 1618 | *have_control = !tensor_name.empty(); 1619 | if (*have_control) { 1620 | std::string deliminator = "__"; 1621 | int ip_index = 0; 1622 | int start_pos = tensor_name.find(deliminator); 1623 | if (start_pos == -1) { 1624 | return TRITONSERVER_ErrorNew( 1625 | TRITONSERVER_ERROR_INTERNAL, 1626 | ("input '" + tensor_name + 1627 | "' does not follow __ naming convention.") 1628 | .c_str()); 1629 | } 1630 | 1631 | // check if the index part of the name is not an integer 1632 | std::string index_str = tensor_name.substr(start_pos + 2); 1633 | for (auto itr = index_str.begin(); itr != index_str.end(); itr++) { 1634 | if (std::isdigit(*itr) == 0) { 1635 | return TRITONSERVER_ErrorNew( 1636 | TRITONSERVER_ERROR_INTERNAL, 1637 | ("input '" + tensor_name + 1638 | "' does not follow __ naming convention.") 1639 | .c_str()); 1640 | } 1641 | } 1642 | 1643 | // check if the data type is supported by PyTorch 1644 | if (!ModelConfigDataTypeToTorchType(tensor_datatype).first) { 1645 | return TRITONSERVER_ErrorNew( 1646 | TRITONSERVER_ERROR_INTERNAL, 1647 | ("input '" + tensor_name + "' type '" + tensor_datatype + 1648 | "' is not supported by PyTorch.") 1649 | .c_str()); 1650 | } 1651 | 1652 | ip_index = std::atoi(tensor_name.substr(start_pos + 2).c_str()); 1653 | input_index_map_[tensor_name] = ip_index; 1654 | } 1655 | 1656 | return nullptr; // success 1657 | } 1658 | 1659 | 1660 | } // namespace triton::backend::pytorch 1661 | --------------------------------------------------------------------------------