├── demo ├── flux1.dev │ ├── __init__.py │ ├── models │ │ ├── __init__.py │ │ ├── flux_params.py │ │ └── flux_model.py │ ├── pipelines │ │ └── __init__.py │ ├── requirements.txt │ └── flux_demo.py ├── tests │ ├── requirements-test.txt │ ├── README.md │ ├── pytest.ini │ ├── conftest.py │ ├── test_cache_safety.py │ ├── test_cache_consistency.py │ ├── test_hierarchical_structure.py │ ├── test_consolidated.py │ ├── test_metadata_shapes.py │ ├── test_precision_changes.py │ ├── test_license_headers.py │ └── test_onnx_acquisition.py ├── utils │ ├── base_params.py │ ├── __init__.py │ ├── base_model.py │ ├── notebook_utils.py │ ├── memory_manager.py │ ├── timing_data.py │ ├── engine_metadata.py │ └── model_registry.py ├── setup.py └── README.md ├── .github ├── copy-pr-bot.yaml ├── workflows │ ├── .python-version │ ├── pyproject.toml │ ├── build.py │ ├── ci.yaml │ ├── utils.py │ └── test.py └── ISSUE_TEMPLATE │ └── bug_report.md ├── .mdformat.toml ├── samples ├── apiUsage │ ├── python │ │ └── requirements.txt │ ├── CMakeLists.txt │ ├── cpp │ │ └── CMakeLists.txt │ └── README.md ├── helloWorld │ ├── python │ │ ├── requirements.txt │ │ └── hello_world.py │ ├── helloWorld.onnx │ ├── helloWorldOnnx.png │ ├── CMakeLists.txt │ ├── cpp │ │ └── CMakeLists.txt │ └── README.md ├── README.md ├── CMakeLists.txt └── cmake │ └── modules │ └── get_version.cmake ├── .cmake-format.json ├── .gitignore ├── ruff.toml ├── .pre-commit-config.yaml ├── README.md ├── .clang-format ├── CONTRIBUTING.md └── LICENSE /demo/flux1.dev/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /demo/flux1.dev/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /demo/flux1.dev/pipelines/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.github/copy-pr-bot.yaml: -------------------------------------------------------------------------------- 1 | enabled: true 2 | -------------------------------------------------------------------------------- /.github/workflows/.python-version: -------------------------------------------------------------------------------- 1 | 3.9 2 | -------------------------------------------------------------------------------- /.mdformat.toml: -------------------------------------------------------------------------------- 1 | # .mdformat.toml 2 | number = true 3 | -------------------------------------------------------------------------------- /demo/tests/requirements-test.txt: -------------------------------------------------------------------------------- 1 | pytest>=6.0.0 2 | -------------------------------------------------------------------------------- /samples/apiUsage/python/requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | cuda-python<13.0.0 3 | -------------------------------------------------------------------------------- /samples/helloWorld/python/requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | cuda-python<13.0.0 3 | -------------------------------------------------------------------------------- /samples/helloWorld/helloWorld.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA/TensorRT-RTX/HEAD/samples/helloWorld/helloWorld.onnx -------------------------------------------------------------------------------- /samples/helloWorld/helloWorldOnnx.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA/TensorRT-RTX/HEAD/samples/helloWorld/helloWorldOnnx.png -------------------------------------------------------------------------------- /.github/workflows/pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "trt_rtx_ci" 3 | version = "0.1.0" 4 | requires-python = ">=3.9" 5 | dependencies = [ 6 | "requests==2.*" 7 | ] 8 | 9 | [dependency-groups] 10 | build = [ 11 | "cmake==3.*" 12 | ] 13 | -------------------------------------------------------------------------------- /.cmake-format.json: -------------------------------------------------------------------------------- 1 | { 2 | "format": { 3 | "line_width": 120, 4 | "tab_size": 2, 5 | "max_subgroups_hwrap": 5, 6 | "max_pargs_hwrap": 5 7 | }, 8 | "markup": { 9 | "first_comment_is_literal": true 10 | }, 11 | "lint": { 12 | "disabled_codes": [] 13 | } 14 | } 15 | -------------------------------------------------------------------------------- /demo/flux1.dev/requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate 2 | --extra-index-url https://download.pytorch.org/whl/cu129 3 | torch>=2.7.0 4 | transformers>=4.52.4 5 | diffusers>=0.33.1 6 | huggingface-hub>=0.32.4 7 | tqdm>=4.67.1 8 | pillow 9 | numpy 10 | cuda-python<13.0.0 11 | polygraphy>=0.49.24 12 | packaging 13 | protobuf 14 | sentencepiece 15 | tensorrt-rtx>=1.1.0 16 | 17 | # Jupyter dependencies 18 | requests>=2.25.0 19 | ipython>=8.0.0 20 | ipywidgets>=8.0.0 21 | jupyter 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Build directories 2 | **/build/ 3 | 4 | # Cache directories 5 | demo_cache/ 6 | 7 | # Cache directories 8 | demo_cache/ 9 | 10 | # Editor files 11 | *.swp 12 | *~ 13 | .history/ 14 | .vscode/ 15 | .vs/ 16 | 17 | # Python 18 | .venv/ 19 | __pycache__/ 20 | *.py[cod] 21 | *$py.class 22 | *.so 23 | .Python 24 | dist/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | 29 | # Jupyter Notebook 30 | .ipynb_checkpoints 31 | 32 | # Pre-commit 33 | .pre-commit-cache/ 34 | 35 | # Generated files 36 | *.png 37 | -------------------------------------------------------------------------------- /samples/README.md: -------------------------------------------------------------------------------- 1 | # TensorRT for RTX Samples 2 | 3 | This directory contains standalone samples demonstrating TensorRT for RTX functionality. Each sample focuses on different aspects of TensorRT for RTX usage. 4 | 5 | For detailed information about each sample's features and implementation, please refer to their individual README files and source code. 6 | 7 | ## Available Samples 8 | 9 | - [Hello World](helloWorld/README.md) - A basic example showing how to create and run a simple TensorRT-RTX network 10 | - [API Usage](apiUsage/README.md) - Demonstrates how to use TensorRT-RTX advanced APIs for fine-grained control of inference 11 | 12 | To build and run each sample, follow the instructions provided in their respective README files. 13 | 14 | ## License 15 | 16 | This project is licensed under the Apache License 2.0 - see the LICENSE file for details. 17 | -------------------------------------------------------------------------------- /samples/apiUsage/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | cmake_minimum_required(VERSION 3.10) 17 | project(TensorRT_RTX_samples LANGUAGES CXX CUDA) 18 | 19 | add_subdirectory(cpp) 20 | -------------------------------------------------------------------------------- /samples/helloWorld/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | cmake_minimum_required(VERSION 3.10) 17 | project(TensorRT_RTX_samples LANGUAGES CXX CUDA) 18 | 19 | add_subdirectory(cpp) 20 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: "[BUG] " 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | 12 | 13 | **Steps to reproduce** 14 | 1. 15 | 2. 16 | 3. 17 | 18 | 19 | 20 | **Expected behavior** 21 | 22 | 23 | **Environment** 24 | - TensorRT-RTX version: 25 | - GPU: 26 | - Operating system: 27 | - CUDA version: 28 | - CPU architecture: 29 | 30 | **Screenshots** 31 | 32 | 33 | **Additional context** 34 | 35 | 36 | -------------------------------------------------------------------------------- /ruff.toml: -------------------------------------------------------------------------------- 1 | # Exclude a variety of commonly ignored directories. 2 | exclude = [ 3 | ".bzr", 4 | ".direnv", 5 | ".eggs", 6 | ".git", 7 | ".git-rewrite", 8 | ".hg", 9 | ".ipynb_checkpoints", 10 | ".mypy_cache", 11 | ".nox", 12 | ".pants.d", 13 | ".pyenv", 14 | ".pytest_cache", 15 | ".pytype", 16 | ".ruff_cache", 17 | ".svn", 18 | ".tox", 19 | ".venv", 20 | ".vscode", 21 | "__pypackages__", 22 | "_build", 23 | "buck-out", 24 | "build", 25 | "dist", 26 | "node_modules", 27 | "site-packages", 28 | "venv", 29 | ] 30 | 31 | target-version = "py39" 32 | line-length = 120 33 | 34 | [lint] 35 | select = [ 36 | # pycodestyle 37 | "E", 38 | # Pyflakes 39 | "F", 40 | # pyupgrade 41 | "UP", 42 | # flake8-bugbear 43 | "B", 44 | # flake8-simplify 45 | "SIM", 46 | # isort 47 | "I", 48 | ] 49 | ignore = [ 50 | # Line too long 51 | "E501" 52 | ] 53 | -------------------------------------------------------------------------------- /demo/tests/README.md: -------------------------------------------------------------------------------- 1 | # TensorRT-RTX Demos Utility Tests 2 | 3 | This directory contains utility tests for the TensorRT-RTX Demos. 4 | 5 | ## Quick Start 6 | 7 | We recommend using Python versions between 3.9 and 3.12 inclusive due to supported versions for required dependencies. 8 | 9 | 1. **Clone and install** 10 | 11 | ```bash 12 | git clone https://github.com/NVIDIA/TensorRT-RTX.git 13 | cd TensorRT-RTX 14 | 15 | # Install TensorRT-RTX from the wheels located in the downloaded tarfile 16 | # Visit https://developer.nvidia.com/tensorrt-rtx to download 17 | # Example below is for Python 3.12 on Linux (customize with your Python version + OS) 18 | python -m pip install YOUR_TENSORRT_RTX_DOWNLOAD_DIR/python/tensorrt_rtx-1.0.0.20-cp312-none-linux_x86_64.whl 19 | 20 | # Install demo dependencies 21 | python -m pip install -r demo/flux1.dev/requirements.txt 22 | 23 | # Install test dependencies 24 | python -m pip install -r demo/tests/requirements-test.txt 25 | ``` 26 | 27 | 2. **Run tests** 28 | 29 | The tests are located in the `demo/tests` directory. 30 | 31 | ```bash 32 | python -m pytest demo/tests -v 33 | ``` 34 | -------------------------------------------------------------------------------- /demo/utils/base_params.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | from abc import ABC 18 | from dataclasses import dataclass 19 | 20 | 21 | @dataclass(frozen=True) 22 | class BaseModelParams(ABC): # noqa: B024 23 | """ 24 | Generic base class for model parameters. 25 | 26 | This class provides a template that specific model parameter classes can inherit from. 27 | The only required parameters are batch size 28 | """ 29 | 30 | # Required batch size parameters 31 | MIN_BATCH_SIZE: int = 1 32 | MAX_BATCH_SIZE: int = 8 33 | -------------------------------------------------------------------------------- /demo/tests/pytest.ini: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | [pytest] 17 | minversion = 6.0 18 | addopts = -ra -q --strict-markers --strict-config 19 | testpaths = demo/tests/ 20 | python_files = test_*.py 21 | python_classes = Test* 22 | python_functions = test_* 23 | markers = 24 | unit: Unit tests 25 | integration: Integration tests 26 | slow: Tests that take longer to run 27 | gpu: Tests that require GPU 28 | cache: Tests related to cache functionality 29 | paths: Tests related to path management 30 | pipeline: Tests related to pipeline functionality 31 | -------------------------------------------------------------------------------- /.github/workflows/build.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | # SPDX-License-Identifier: Apache-2.0 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | from utils import BUILD_DIR, TRTRTX_INSTALL_DIR, run_command, setup_trt_rtx 18 | 19 | 20 | def build_samples(): 21 | """Build C++ samples.""" 22 | print("Building C++ samples...") 23 | run_command(f"cmake -B {BUILD_DIR} -S samples -DTRTRTX_INSTALL_DIR={TRTRTX_INSTALL_DIR}") 24 | run_command(f"cmake --build {BUILD_DIR}") 25 | 26 | 27 | def main(): 28 | # Setup TensorRT RTX 29 | setup_trt_rtx() 30 | 31 | # Build samples 32 | build_samples() 33 | 34 | print("Build completed successfully!") 35 | 36 | 37 | if __name__ == "__main__": 38 | main() 39 | -------------------------------------------------------------------------------- /demo/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | """ 18 | TRT-RTX Demos Utils - ONNX -> TensorRT-RTX Inference pipeline 19 | 20 | Core components: 21 | - Pipeline: Main pipeline class for ONNX -> TRT-RTX -> Inference flow 22 | - Engine: TensorRT engine wrapper 23 | - BaseModel: Base class for model implementations 24 | - BaseModelParams: Generic base class for model parameters 25 | - ModelRegistry: Common model definitions 26 | - PathManager: Path management for ONNX and TRT-RTX Engine Files 27 | """ 28 | 29 | from .base_model import BaseModel 30 | from .base_params import BaseModelParams 31 | from .engine import Engine 32 | from .model_registry import registry 33 | from .path_manager import PathManager 34 | from .pipeline import Pipeline 35 | 36 | __all__ = ["Pipeline", "Engine", "BaseModel", "BaseModelParams", "registry", "PathManager"] 37 | -------------------------------------------------------------------------------- /demo/flux1.dev/models/flux_params.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | from dataclasses import dataclass 18 | 19 | from utils.base_params import BaseModelParams 20 | 21 | 22 | @dataclass(frozen=True) 23 | class FluxParams(BaseModelParams): 24 | """Parameters for Flux models""" 25 | 26 | # Batch dimension 27 | MIN_BATCH_SIZE: int = 1 28 | MAX_BATCH_SIZE: int = 4 29 | 30 | # Image dimensions 31 | MIN_HEIGHT: int = 256 32 | MAX_HEIGHT: int = 1024 33 | MIN_WIDTH: int = 256 34 | MAX_WIDTH: int = 1024 35 | 36 | # Text sequence length 37 | CLIP_SEQUENCE_LENGTH: int = 77 38 | T5_SEQUENCE_LENGTH: int = 512 39 | 40 | # Inference steps 41 | MIN_NUM_INFERENCE_STEPS: int = 1 42 | MAX_NUM_INFERENCE_STEPS: int = 50 43 | 44 | # VAE spatial compression ratio 45 | VAE_SPATIAL_COMPRESSION_RATIO: int = 8 46 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | # Basic Python linting 3 | - repo: https://github.com/astral-sh/ruff-pre-commit 4 | # Ruff version. 5 | rev: v0.11.13 6 | hooks: 7 | - id: ruff-check 8 | types_or: [ python, pyi ] 9 | args: [ --fix ] 10 | - id: ruff-format 11 | types_or: [ python, pyi ] 12 | 13 | # Remove trailing whitespace and fix line endings 14 | - repo: https://github.com/pre-commit/pre-commit-hooks 15 | rev: v4.4.0 16 | hooks: 17 | - id: trailing-whitespace 18 | - id: end-of-file-fixer 19 | - id: check-yaml 20 | - id: check-added-large-files 21 | args: [--maxkb=10000] 22 | - id: check-merge-conflict 23 | 24 | # C++ formatting 25 | - repo: https://github.com/pre-commit/mirrors-clang-format 26 | rev: v16.0.0 27 | hooks: 28 | - id: clang-format 29 | types_or: [c++, c] 30 | 31 | # CMake formatting 32 | - repo: https://github.com/cheshirekow/cmake-format-precommit 33 | rev: v0.6.10 34 | hooks: 35 | - id: cmake-format 36 | types_or: [cmake] 37 | - id: cmake-lint 38 | types_or: [cmake] 39 | 40 | # Markdown formatting 41 | - repo: https://github.com/executablebooks/mdformat 42 | rev: 1.0.0 43 | hooks: 44 | - id: mdformat 45 | exclude: ^.github/ 46 | additional_dependencies: 47 | - mdformat-gfm 48 | - mdformat-black 49 | - mdformat-beautysh 50 | types_or: [markdown] 51 | 52 | - repo: https://github.com/codespell-project/codespell 53 | rev: v2.4.1 54 | hooks: 55 | - id: codespell 56 | additional_dependencies: 57 | - tomli 58 | args: [--skip, "*.ipynb"] 59 | -------------------------------------------------------------------------------- /demo/setup.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from setuptools import find_packages, setup 17 | 18 | setup( 19 | name="rtx-demos", 20 | version="1.0.0", 21 | description="RTX Demos", 22 | packages=find_packages(), 23 | install_requires=[ 24 | "torch>=2.7.0", 25 | "transformers>=4.52.4", 26 | "diffusers>=0.33.1", 27 | "huggingface-hub>=0.32.4", 28 | "tqdm>=4.67.1", 29 | "pillow", 30 | "numpy", 31 | "cuda-python<13.0.0", 32 | "polygraphy>=0.49.24", 33 | "packaging", 34 | "tensorrt-rtx>=1.1.0", 35 | "accelerate", 36 | "protobuf", 37 | "sentencepiece", 38 | ], 39 | extras_require={ 40 | "dev": [ 41 | "pre-commit", 42 | "ruff", 43 | "flake8", 44 | ], 45 | "jupyter": [ 46 | "requests>=2.25.0", 47 | "ipython>=8.0.0", 48 | "ipywidgets>=8.0.0", 49 | "jupyter", 50 | "notebook", 51 | "jupyterlab", 52 | ], 53 | }, 54 | python_requires=">=3.9", 55 | ) 56 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TensorRT for RTX 2 | 3 | TensorRT for RTX builds on the proven performance of the NVIDIA TensorRT inference library, and simplifies the deployment of AI models on NVIDIA RTX GPUs across desktops, laptops, and workstations. 4 | 5 | TensorRT for RTX is a drop-in replacement for NVIDIA TensorRT in applications targeting NVIDIA RTX GPUs from Turing through Blackwell generations. It introduces a Just-In-Time (JIT) optimizer in the runtime that compiles improved inference engines directly on the end-user's RTX-accelerated PC in under 30 seconds. This eliminates the need for lengthy pre-compilation steps and enables rapid engine generation, improved application portability, and cutting-edge inference performance. To support integration into lightweight applications and deployment in memory-constrained environments, TensorRT for RTX is compact under 200 MB. TensorRT for RTX makes real-time, responsive AI applications for image processing, speech synthesis, and generative AI practical and performant on consumer-grade devices. 6 | 7 | For detailed information on TensorRT-RTX features, software enhancements, and release updates, see the [official developer documentation](https://docs.nvidia.com/deeplearning/tensorrt-rtx/latest/index.html). 8 | To get the latest TensorRT-RTX SDK, visit the [developer download page](http://developer.nvidia.com/tensorrt-rtx) and follow the [installation guide](http://docs.nvidia.com/deeplearning/tensorrt-rtx/latest/installing-tensorrt-rtx/installing.html). 9 | 10 | This repository includes open source components that accompany the TensorRT-RTX SDK. If you'd like to contribute, please review our [contribution guide](CONTRIBUTING.md). 11 | 12 | # Quickstart Examples 13 | 14 | - [Samples](samples/README.md) that illustrate key TensorRT-RTX capabilities and API usage in C++ and Python. 15 | - [Demos](demo/README.md) that highlight practical deployment considerations and reference implementations of popular models. 16 | -------------------------------------------------------------------------------- /samples/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | cmake_minimum_required(VERSION 3.10) 17 | project(TensorRT_RTX_samples LANGUAGES CXX CUDA) 18 | 19 | # Options for TensorRT-RTX paths 20 | set(TRTRTX_INSTALL_DIR "" CACHE PATH "Path to TensorRT-RTX install directory") 21 | # Resolve absolute path for TRTRTX_INSTALL_DIR 22 | get_filename_component(TRTRTX_INSTALL_DIR "${TRTRTX_INSTALL_DIR}" REALPATH) 23 | set(TRTRTX_INCLUDE_DIR "${TRTRTX_INSTALL_DIR}/include" CACHE PATH "Path to TensorRT-RTX include directory") 24 | set(TRTRTX_LIB_DIR "${TRTRTX_INSTALL_DIR}/lib" CACHE PATH "Path to TensorRT-RTX library directory") 25 | 26 | if(${TRTRTX_INSTALL_DIR} STREQUAL "") 27 | message(FATAL_ERROR "TRTRTX_INSTALL_DIR must be specified") 28 | endif() 29 | 30 | # Set C++ standard 31 | set(CMAKE_CXX_STANDARD 17) 32 | set(CMAKE_CXX_STANDARD_REQUIRED ON) 33 | 34 | # Find required packages 35 | find_package(CUDAToolkit REQUIRED) 36 | 37 | # Set common include directories 38 | include_directories(${CUDA_INCLUDE_DIRS} ${TRTRTX_INCLUDE_DIR}) 39 | 40 | # Set common link directories 41 | link_directories(${CUDA_LIBRARY_DIRS} ${TRTRTX_LIB_DIR}) 42 | 43 | # Add subdirectories for each sample... 44 | add_subdirectory(helloWorld) 45 | add_subdirectory(apiUsage) 46 | 47 | # and add to custom target to run them. 48 | add_custom_target(runSamples COMMAND $ COMMAND $ 49 | COMMENT "Build and immediately run all samples.") 50 | -------------------------------------------------------------------------------- /samples/apiUsage/cpp/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | cmake_minimum_required(VERSION 3.10) 17 | project(TensorRT_RTX_ApiUsage LANGUAGES CXX CUDA) 18 | 19 | include(${CMAKE_CURRENT_LIST_DIR}/../../cmake/modules/get_version.cmake) 20 | 21 | # Options for TensorRT-RTX paths 22 | set(TRTRTX_INSTALL_DIR "" CACHE PATH "Path to TensorRT-RTX install directory") 23 | # Resolve absolute path for TRTRTX_INSTALL_DIR 24 | get_filename_component(TRTRTX_INSTALL_DIR "${TRTRTX_INSTALL_DIR}" REALPATH) 25 | set(TRTRTX_INCLUDE_DIR "${TRTRTX_INSTALL_DIR}/include" CACHE PATH "Path to TensorRT-RTX include directory") 26 | set(TRTRTX_LIB_DIR "${TRTRTX_INSTALL_DIR}/lib" CACHE PATH "Path to TensorRT-RTX library directory") 27 | 28 | if(${TRTRTX_INSTALL_DIR} STREQUAL "") 29 | message(FATAL_ERROR "TRTRTX_INSTALL_DIR must be specified") 30 | endif() 31 | 32 | set(CMAKE_CXX_STANDARD 17) 33 | set(CMAKE_CXX_STANDARD_REQUIRED ON) 34 | 35 | find_package(CUDAToolkit REQUIRED) 36 | 37 | add_executable(apiUsage apiUsage.cpp) 38 | 39 | # Add compiler flags to ignore deprecated API warnings from TensorRT headers. 40 | if(NOT MSVC) 41 | target_compile_options(apiUsage PRIVATE -Wno-deprecated-declarations -Wall) 42 | else() 43 | target_compile_options(apiUsage PRIVATE /wd4996 /wd4100 /W4) 44 | endif() 45 | 46 | get_version(${TRTRTX_INCLUDE_DIR} TRT_RTX_VERSION TRT_RTX_SOVERSION) 47 | get_library_name(TRT_RTX_SOVERSION TRTRTX_LIB_NAME TRTRTX_ONNXPARSER_LIB_NAME) 48 | 49 | message( 50 | STATUS "Building ${PROJECT_NAME} for TensorRT-RTX version: ${TRT_RTX_VERSION}, library version: ${TRT_RTX_SOVERSION}") 51 | 52 | target_include_directories(apiUsage PRIVATE ${TRTRTX_INCLUDE_DIR}) 53 | target_link_directories(apiUsage PRIVATE ${TRTRTX_LIB_DIR}) 54 | target_link_libraries(apiUsage PRIVATE ${TRTRTX_LIB_NAME} ${TRTRTX_ONNXPARSER_LIB_NAME} CUDA::cudart) 55 | -------------------------------------------------------------------------------- /samples/helloWorld/cpp/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | cmake_minimum_required(VERSION 3.10) 17 | project(TensorRT_RTX_HelloWorld LANGUAGES CXX CUDA) 18 | include(${CMAKE_CURRENT_LIST_DIR}/../../cmake/modules/get_version.cmake) 19 | 20 | # Options for TensorRT-RTX paths 21 | set(TRTRTX_INSTALL_DIR "" CACHE PATH "Path to TensorRT-RTX install directory") 22 | # Resolve absolute path for TRTRTX_INSTALL_DIR 23 | get_filename_component(TRTRTX_INSTALL_DIR "${TRTRTX_INSTALL_DIR}" REALPATH) 24 | set(TRTRTX_INCLUDE_DIR "${TRTRTX_INSTALL_DIR}/include" CACHE PATH "Path to TensorRT-RTX include directory") 25 | set(TRTRTX_LIB_DIR "${TRTRTX_INSTALL_DIR}/lib" CACHE PATH "Path to TensorRT-RTX library directory") 26 | 27 | if(${TRTRTX_INSTALL_DIR} STREQUAL "") 28 | message(FATAL_ERROR "TRTRTX_INSTALL_DIR must be specified") 29 | endif() 30 | 31 | set(CMAKE_CXX_STANDARD 17) 32 | set(CMAKE_CXX_STANDARD_REQUIRED ON) 33 | 34 | find_package(CUDAToolkit REQUIRED) 35 | 36 | add_executable(helloWorld helloWorld.cpp) 37 | 38 | # Add compiler flags to ignore deprecated API warnings from TensorRT headers. 39 | if(NOT MSVC) 40 | target_compile_options(helloWorld PRIVATE -Wno-deprecated-declarations -Wall) 41 | else() 42 | target_compile_options(helloWorld PRIVATE /wd4996 /wd4100 /W4) 43 | endif() 44 | 45 | get_version(${TRTRTX_INCLUDE_DIR} TRT_RTX_VERSION TRT_RTX_SOVERSION) 46 | get_library_name(TRT_RTX_SOVERSION TRTRTX_LIB_NAME TRTRTX_ONNXPARSER_LIB_NAME) 47 | 48 | message( 49 | STATUS "Building ${PROJECT_NAME} for TensorRT-RTX version: ${TRT_RTX_VERSION}, library version: ${TRT_RTX_SOVERSION}") 50 | 51 | target_include_directories(helloWorld PRIVATE ${TRTRTX_INCLUDE_DIR}) 52 | target_link_directories(helloWorld PRIVATE ${TRTRTX_LIB_DIR}) 53 | target_link_libraries(helloWorld PRIVATE ${TRTRTX_LIB_NAME} ${TRTRTX_ONNXPARSER_LIB_NAME} CUDA::cudart) 54 | -------------------------------------------------------------------------------- /demo/tests/conftest.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """ 17 | Pytest configuration and common fixtures. 18 | """ 19 | 20 | import sys 21 | import tempfile 22 | from collections.abc import Generator 23 | from pathlib import Path 24 | 25 | import pytest 26 | 27 | # Add the parent directory to the Python path for imports 28 | sys.path.insert(0, str(Path(__file__).parent.parent)) 29 | 30 | from utils.path_manager import PathManager 31 | 32 | 33 | @pytest.fixture 34 | def temp_cache_dir() -> Generator[Path, None, None]: 35 | """Create a temporary cache directory for testing.""" 36 | with tempfile.TemporaryDirectory() as temp_dir: 37 | cache_dir = Path(temp_dir) / "test_cache" 38 | yield cache_dir 39 | # Cleanup is automatic with tempfile.TemporaryDirectory 40 | 41 | 42 | @pytest.fixture 43 | def path_manager(temp_cache_dir: Path) -> PathManager: 44 | """Create a PathManager instance with temporary cache directory.""" 45 | return PathManager(cache_dir=str(temp_cache_dir)) 46 | 47 | 48 | @pytest.fixture 49 | def temp_source_dir() -> Generator[Path, None, None]: 50 | """Create a temporary source directory for test files.""" 51 | with tempfile.TemporaryDirectory() as temp_dir: 52 | source_dir = Path(temp_dir) / "source_models" 53 | source_dir.mkdir() 54 | yield source_dir 55 | # Cleanup is automatic with tempfile.TemporaryDirectory 56 | 57 | 58 | def create_dummy_onnx_file(file_path: Path, content: str = "# Dummy ONNX file for testing\n") -> None: 59 | """Create a dummy ONNX file for testing.""" 60 | file_path.parent.mkdir(parents=True, exist_ok=True) 61 | with open(file_path, "w") as f: 62 | f.write(content) 63 | 64 | # Also create a related file (like .onnx.data) 65 | data_file = file_path.with_suffix(".onnx.data") 66 | with open(data_file, "w") as f: 67 | f.write("# Dummy ONNX data file\n") 68 | 69 | 70 | # Make the helper function available to all tests 71 | pytest.create_dummy_onnx_file = create_dummy_onnx_file 72 | -------------------------------------------------------------------------------- /demo/utils/base_model.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | from abc import abstractmethod 18 | from typing import Any, Optional 19 | 20 | from utils.base_params import BaseModelParams 21 | from utils.model_registry import registry 22 | 23 | 24 | class BaseModel: 25 | """Base model for ONNX -> TRT-RTX pipeline""" 26 | 27 | def __init__( 28 | self, 29 | name: str, 30 | device: str = "cuda", 31 | model_params: Optional[BaseModelParams] = None, 32 | hf_token: Optional[str] = None, 33 | ): 34 | self.name = name 35 | assert registry.is_model_id(name), f"Model {name} not found in model registry" 36 | 37 | self.device = device 38 | self.model_params = model_params 39 | self.hf_token = hf_token 40 | 41 | def validate_io_names(self, precision: str, input_names: list, output_names: Optional[list] = None) -> None: 42 | """Validate input and output names""" 43 | expected_inputs, expected_outputs = registry.get_io_names(self.name, precision) 44 | 45 | assert expected_inputs, f"Model '{self.name}' not found for precision '{precision}'" 46 | 47 | # Validate input names 48 | assert set(input_names) == set(expected_inputs), ( 49 | f"Input name mismatch for {self.name}[{precision}]: expected {expected_inputs}, got {input_names}" 50 | ) 51 | 52 | # Validate output names if provided 53 | if output_names is not None: 54 | assert set(output_names) == set(expected_outputs), ( 55 | f"Output name mismatch for {self.name}[{precision}]: expected {expected_outputs}, got {output_names}" 56 | ) 57 | 58 | @abstractmethod 59 | def get_input_profile(self, *args, **kwargs) -> dict[str, Any]: 60 | """Return TensorRT input profile for dynamic shapes""" 61 | pass 62 | 63 | @abstractmethod 64 | def get_shape_dict(self, *args, **kwargs) -> dict[str, Any]: 65 | """Return shape dictionary for tensor allocation""" 66 | pass 67 | -------------------------------------------------------------------------------- /.clang-format: -------------------------------------------------------------------------------- 1 | --- 2 | AccessModifierOffset: -4 3 | AlignAfterOpenBracket: DontAlign 4 | AlignConsecutiveAssignments: false 5 | AlignConsecutiveDeclarations: false 6 | AlignEscapedNewlinesLeft: false 7 | AlignOperands: false 8 | AlignTrailingComments: true 9 | AllowAllParametersOfDeclarationOnNextLine: true 10 | AllowShortBlocksOnASingleLine: false 11 | AllowShortCaseLabelsOnASingleLine: true 12 | AllowShortFunctionsOnASingleLine: Empty 13 | AllowShortIfStatementsOnASingleLine: false 14 | AllowShortLoopsOnASingleLine: false 15 | AlwaysBreakAfterDefinitionReturnType: None 16 | AlwaysBreakAfterReturnType: None 17 | AlwaysBreakBeforeMultilineStrings: true 18 | AlwaysBreakTemplateDeclarations: true 19 | BasedOnStyle: None 20 | BinPackArguments: true 21 | BinPackParameters: true 22 | BreakBeforeBinaryOperators: All 23 | BreakBeforeBraces: Allman 24 | BreakBeforeTernaryOperators: true 25 | BreakConstructorInitializersBeforeComma: true 26 | ColumnLimit: 120 27 | CommentPragmas: '^ IWYU pragma:' 28 | ConstructorInitializerAllOnOneLineOrOnePerLine: false 29 | ConstructorInitializerIndentWidth: 4 30 | ContinuationIndentWidth: 4 31 | Cpp11BracedListStyle: true 32 | DerivePointerAlignment: false 33 | DisableFormat: false 34 | ExperimentalAutoDetectBinPacking: false 35 | ForEachMacros: [ foreach, Q_FOREACH, BOOST_FOREACH ] 36 | IncludeBlocks: Preserve 37 | IncludeCategories: 38 | - Regex: '^"(llvm|llvm-c|clang|clang-c)/' 39 | Priority: 2 40 | - Regex: '^(<|"(gtest|isl|json)/)' 41 | Priority: 3 42 | - Regex: '.*' 43 | Priority: 1 44 | IndentCaseLabels: false 45 | IndentWidth: 4 46 | IndentWrappedFunctionNames: false 47 | KeepEmptyLinesAtTheStartOfBlocks: true 48 | Language: Cpp 49 | MacroBlockBegin: '' 50 | MacroBlockEnd: '' 51 | MaxEmptyLinesToKeep: 1 52 | NamespaceIndentation: None 53 | ObjCBlockIndentWidth: 4 54 | ObjCSpaceAfterProperty: true 55 | ObjCSpaceBeforeProtocolList: true 56 | PenaltyBreakBeforeFirstCallParameter: 19 57 | PenaltyBreakComment: 300 58 | PenaltyBreakFirstLessLess: 120 59 | PenaltyBreakString: 1000 60 | PenaltyExcessCharacter: 1000000 61 | PenaltyReturnTypeOnItsOwnLine: 60 62 | PointerAlignment: Left 63 | PointerBindsToType: false 64 | ReflowComments: true 65 | SortIncludes: true 66 | SpaceAfterCStyleCast: true 67 | SpaceBeforeAssignmentOperators: true 68 | SpaceBeforeParens: ControlStatements 69 | SpaceInEmptyParentheses: false 70 | SpacesBeforeTrailingComments: 1 71 | SpacesInAngles: false 72 | SpacesInCStyleCastParentheses: false 73 | SpacesInContainerLiterals: true 74 | SpacesInParentheses: false 75 | SpacesInSquareBrackets: false 76 | Standard: Cpp11 77 | StatementMacros: [API_ENTRY_TRY,TRT_TRY] 78 | TabWidth: 4 79 | UseTab: Never 80 | -------------------------------------------------------------------------------- /.github/workflows/ci.yaml: -------------------------------------------------------------------------------- 1 | name: CI 2 | on: 3 | push: 4 | branches: 5 | - "pull-request/[0-9]+" 6 | env: 7 | TRT_RTX_FILENAME: TensorRT-RTX-1.2.0.54-Linux-x86_64-cuda-12.9-Release-external.tar.gz 8 | TRTRTX_INSTALL_DIR: /opt/tensorrt_rtx 9 | BUILD_DIR: build 10 | UV_PROJECT: .github/workflows/ 11 | jobs: 12 | build-linux: 13 | runs-on: ubuntu-24.04 14 | container: nvcr.io/nvidia/cuda:12.9.1-devel-ubuntu22.04 15 | steps: 16 | - uses: actions/checkout@v5 17 | 18 | - name: Install uv 19 | uses: astral-sh/setup-uv@v6 20 | with: 21 | version: "0.8.22" 22 | python-version: "3.9" 23 | 24 | - name: Set up Python 25 | run: uv python install 26 | 27 | - name: Cache TensorRT RTX 28 | id: cache-trt-rtx 29 | uses: actions/cache@v4 30 | with: 31 | path: ${{ env.TRTRTX_INSTALL_DIR }} 32 | key: tensorrt-rtx-${{ runner.os }}-${{ env.TRT_RTX_FILENAME }} 33 | 34 | - name: Run build script 35 | env: 36 | CACHE_DEB_HIT: ${{ steps.cache-deb.outputs.cache-hit }} 37 | CACHE_TRT_RTX_HIT: ${{ steps.cache-trt-rtx.outputs.cache-hit }} 38 | run: uv run --group build .github/workflows/build.py 39 | 40 | - name: Upload build artifacts 41 | uses: actions/upload-artifact@v4 42 | with: 43 | name: linux-build-artifacts 44 | path: ${{ env.BUILD_DIR }}/ 45 | 46 | lint: 47 | runs-on: ubuntu-22.04 48 | steps: 49 | - uses: actions/checkout@v5 50 | - uses: actions/setup-python@v6 51 | with: 52 | python-version: "3.11" 53 | - name: Run pre-commit 54 | uses: pre-commit/action@v3.0.1 55 | 56 | test-linux: 57 | needs: build-linux 58 | runs-on: [linux-amd64-gpu-rtx4090-latest-1] 59 | container: 60 | image: nvcr.io/nvidia/cuda:12.9.1-runtime-ubuntu22.04 61 | options: --gpus all 62 | steps: 63 | - uses: actions/checkout@v5 64 | 65 | - name: Install uv 66 | uses: astral-sh/setup-uv@v6 67 | with: 68 | version: "0.8.22" 69 | python-version: "3.9" 70 | 71 | - name: Set up Python 72 | run: uv python install 73 | 74 | - name: Cache TensorRT RTX 75 | id: cache-trt-rtx 76 | uses: actions/cache@v4 77 | with: 78 | path: ${{ env.TRTRTX_INSTALL_DIR }} 79 | key: tensorrt-rtx-${{ runner.os }}-${{ env.TRT_RTX_FILENAME }} 80 | 81 | - name: Download build artifacts 82 | uses: actions/download-artifact@v4 83 | with: 84 | name: linux-build-artifacts 85 | path: ${{ env.BUILD_DIR }}/ 86 | 87 | - name: Run test script 88 | env: 89 | CACHE_TRT_RTX_HIT: ${{ steps.cache-trt-rtx.outputs.cache-hit }} 90 | UV_PROJECT_ENVIRONMENT: .workspace/ci/ 91 | run: uv run .github/workflows/test.py 92 | -------------------------------------------------------------------------------- /.github/workflows/utils.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import io 17 | import os 18 | import subprocess 19 | import sys 20 | import tarfile 21 | from pathlib import Path 22 | 23 | import requests 24 | 25 | # Shared constants 26 | TRT_RTX_BASE_URL = "https://developer.nvidia.com/downloads/trt/rtx_sdk/secure/1.2/" 27 | TRT_RTX_FILENAME = os.environ.get( 28 | "TRT_RTX_FILENAME", "TensorRT-RTX-1.2.0.54-Linux-x86_64-cuda-12.9-Release-external.tar.gz" 29 | ) 30 | TRTRTX_INSTALL_DIR = os.environ.get("TRTRTX_INSTALL_DIR", "/opt/tensorrt_rtx") 31 | BUILD_DIR = os.environ.get("BUILD_DIR", "build") 32 | 33 | 34 | def run_command(cmd, check=True, shell=False, env=None): 35 | """Run a command and handle errors.""" 36 | try: 37 | subprocess.run(cmd if shell else cmd.split(), check=check, shell=shell, env=env) 38 | except subprocess.CalledProcessError as e: 39 | print(f"Error running command: {cmd}") 40 | print(f"Exit code: {e.returncode}") 41 | sys.exit(e.returncode) 42 | 43 | 44 | def setup_trt_rtx(): 45 | """Download and setup TensorRT RTX.""" 46 | if os.environ.get("CACHE_TRT_RTX_HIT") != "true": 47 | print("Cache miss for TensorRT RTX, downloading...") 48 | url = f"{TRT_RTX_BASE_URL}/{TRT_RTX_FILENAME}" 49 | 50 | if os.path.exists(TRTRTX_INSTALL_DIR): 51 | print(f"Error: {TRTRTX_INSTALL_DIR} already exists. Remove it or set CACHE_TRT_RTX_HIT=true to proceed.") 52 | exit(1) 53 | 54 | # Download the TRT RTX tar file 55 | response = requests.get(url, stream=True) 56 | response.raise_for_status() 57 | 58 | # Create a file-like object from the response content 59 | tar_bytes = io.BytesIO(response.content) 60 | 61 | # Extract tar file, stripping the first directory component 62 | os.makedirs(TRTRTX_INSTALL_DIR) 63 | with tarfile.open(fileobj=tar_bytes, mode="r:gz") as tar: 64 | members = [m for m in tar.getmembers() if len(Path(m.name).parts) > 1] 65 | for member in members: 66 | member.name = str(Path(*Path(member.name).parts[1:])) 67 | tar.extract(member, TRTRTX_INSTALL_DIR) 68 | else: 69 | print("Cache hit for TensorRT RTX") 70 | -------------------------------------------------------------------------------- /.github/workflows/test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | # SPDX-License-Identifier: Apache-2.0 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import os 18 | from pathlib import Path 19 | 20 | from utils import BUILD_DIR, TRTRTX_INSTALL_DIR, run_command, setup_trt_rtx 21 | 22 | 23 | def install_python_deps(): 24 | """Install Python dependencies.""" 25 | print("Installing Python dependencies...") 26 | 27 | # Install TensorRT RTX wheel 28 | wheel_dir = Path(TRTRTX_INSTALL_DIR) / "python" 29 | wheel_file = next(wheel_dir.glob("tensorrt_rtx-*-cp39-none-linux_x86_64.whl")) 30 | run_command(f"uv pip install {wheel_file}") 31 | 32 | # Install sample requirements 33 | run_command("uv pip install -r samples/helloWorld/python/requirements.txt") 34 | run_command("uv pip install -r samples/apiUsage/python/requirements.txt") 35 | run_command("uv pip install --index-strategy unsafe-best-match -r demo/flux1.dev/requirements.txt") 36 | run_command("uv pip install -r demo/tests/requirements-test.txt") 37 | 38 | 39 | def run_cpp_tests(): 40 | """Run C++ sample tests.""" 41 | print("Running C++ tests...") 42 | BINARIES = [f"{BUILD_DIR}/helloWorld/cpp/helloWorld", f"{BUILD_DIR}/apiUsage/cpp/apiUsage"] 43 | for binary in BINARIES: 44 | # Add the executable permission if not on Windows 45 | if os.name != "nt": 46 | os.chmod(binary, os.stat(binary).st_mode | 0o111) 47 | run_command(binary) 48 | 49 | 50 | def run_python_tests(): 51 | """Run Python sample tests.""" 52 | print("Running Python tests...") 53 | # Set up environment for tests 54 | test_env = os.environ.copy() 55 | test_env["LD_LIBRARY_PATH"] = f"{TRTRTX_INSTALL_DIR}/lib:{test_env.get('LD_LIBRARY_PATH', '')}" 56 | 57 | run_command("uv run samples/helloWorld/python/hello_world.py", env=test_env) 58 | run_command("uv run samples/apiUsage/python/api_usage.py", env=test_env) 59 | run_command("uv run pytest demo/tests -v", env=test_env) 60 | 61 | 62 | def main(): 63 | # Setup TensorRT RTX 64 | setup_trt_rtx() 65 | 66 | # Install Python dependencies 67 | install_python_deps() 68 | 69 | # Run tests 70 | run_cpp_tests() 71 | run_python_tests() 72 | 73 | print("All tests completed successfully!") 74 | 75 | 76 | if __name__ == "__main__": 77 | main() 78 | -------------------------------------------------------------------------------- /samples/helloWorld/README.md: -------------------------------------------------------------------------------- 1 | # TensorRT for RTX Hello World Sample 2 | 3 | This sample demonstrates how to use TensorRT for RTX to create, compile, and 4 | run a simple neural network. The sample shows basic concepts such as: 5 | 6 | - Creating a TensorRT-RTX builder and network definition. 7 | - Building a simple fully connected neural network. 8 | - Parsing a provided ONNX model. 9 | - Performing ahead-of-time (AOT) compilation. 10 | - Running inference with the compiled engine. 11 | 12 | ## Building the Sample 13 | 14 | ### Prerequisites 15 | 16 | - CMake 3.10 or later 17 | - Python 3.9 or later 18 | - CUDA Toolkit 19 | - An installation of TensorRT for RTX 20 | 21 | ### Build Instructions 22 | 23 | On Windows, add the TensorRT for RTX `lib` directory to your `PATH` environment variable: 24 | 25 | ```powershell 26 | $Env:PATH += ";$Env:PATH_TO_TRT_RTX\lib" 27 | ``` 28 | 29 | On Linux, add the TensorRT for RTX `lib` directory to your `LD_LIBRARY_PATH` environment variable: 30 | 31 | ```bash 32 | export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:${PATH_TO_TRT_RTX}/lib 33 | ``` 34 | 35 | #### Build for C++ 36 | 37 | 1. Run CMake from the current or the `cpp` directory, pointing it to your TensorRT for RTX installation, to create artifacts in the `build` directory 38 | 39 | ```bash 40 | cmake -B build -S . -DTRTRTX_INSTALL_DIR=/path/to/tensorrt-rtx 41 | ``` 42 | 43 | 2. Build the sample: 44 | 45 | ```bash 46 | cmake --build build 47 | ``` 48 | 49 | #### Build for Python 50 | 51 | 1. Install the `tensorrt_rtx` wheel from your TensorRT for RTX directory: 52 | 53 | ```bash 54 | python -m pip install /path/to/tensorrt-rtx/python/tensorrt_rtx-${version}-cp${py3-ver}-none-${os-ver}_x86_64.whl 55 | ``` 56 | 57 | 2. Install `numpy` and `cuda-python` from the `python/requirements.txt` file: 58 | 59 | ```bash 60 | python -m pip install -r python/requirements.txt 61 | ``` 62 | 63 | ## Running the Sample 64 | 65 | After building, you can run the sample with: 66 | 67 | ```bash 68 | ./helloWorld 69 | ``` 70 | 71 | from the build directory. 72 | 73 | To build the network by parsing the provided ONNX model, run: 74 | 75 | ```bash 76 | ./helloWorld --onnx=/path/to/helloWorld.onnx 77 | ``` 78 | 79 | For the Python sample, run: 80 | 81 | ```bash 82 | python hello_world.py 83 | python hello_world.py --onnx=/path/to/helloWorld.onnx 84 | ``` 85 | 86 | ![helloWorld ONNX network](helloWorldOnnx.png) 87 | 88 | The sample will: 89 | 90 | 1. Create and compile a simple neural network. 91 | 2. Run inference with different input values. 92 | 3. Display the results. 93 | 94 | ## Code Overview 95 | 96 | The sample demonstrates several key concepts: 97 | 98 | - Network creation and configuration. 99 | - Network creation by parsing an ONNX model. 100 | - Engine serialization and deserialization. 101 | - Inference execution. 102 | 103 | For detailed comments explaining each step, please refer to the [helloWorld.cpp](cpp/helloWorld.cpp) and [hello_world.py](python/hello_world.py) source files. 104 | -------------------------------------------------------------------------------- /samples/apiUsage/README.md: -------------------------------------------------------------------------------- 1 | # TensorRT for RTX API Usage Sample 2 | 3 | This sample demonstrates how to use TensorRT for RTX APIs to fine-tune engine 4 | compilation and inference. First please refer to the [Hello World](../helloWorld) 5 | sample that goes over the basic concepts. In addition, this sample covers 6 | 7 | - Creating a TensorRT-RTX builder and network definition with dynamic shapes and setting AoT compilation targets using the `setComputeCapability` and associated API. 8 | - Efficiently checking if an engine file is expected to work for the current platform/environment using the Engine Compatibility API. 9 | - Configuring and serializing a runtime cache via `setRuntimeCache` and associated API to store JIT compiled kernels. 10 | - Setting, querying and running inference with dynamic shape information via various dynamic shape APIs. 11 | - Building weightless engines, and subsequently refitting weights on the deployed machines using the refit APIs. 12 | - Running inference for multiple input shapes with the compiled engine. 13 | 14 | ## Building the Sample 15 | 16 | ### Prerequisites 17 | 18 | - CMake 3.10 or later 19 | - Python 3.9 or later 20 | - CUDA Toolkit 21 | - An installation of TensorRT for RTX 22 | 23 | ### Build Instructions 24 | 25 | On Windows, add the TensorRT for RTX `lib` directory to your `PATH` environment variable: 26 | 27 | ```powershell 28 | $Env:PATH += ";$Env:PATH_TO_TRT_RTX\lib" 29 | ``` 30 | 31 | On Linux, add the TensorRT for RTX `lib` directory to your `LD_LIBRARY_PATH` environment variable: 32 | 33 | ```bash 34 | export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:${PATH_TO_TRT_RTX}/lib 35 | ``` 36 | 37 | #### Build for C++ 38 | 39 | 1. Run CMake from the current or the `cpp` directory, pointing it to your TensorRT for RTX installation, to create artifacts in the `build` directory 40 | 41 | ```bash 42 | cmake -B build -S . -DTRTRTX_INSTALL_DIR=/path/to/tensorrt-rtx 43 | ``` 44 | 45 | 2. Build the sample: 46 | 47 | ```bash 48 | cmake --build build 49 | ``` 50 | 51 | #### Build for Python 52 | 53 | 1. Install the `tensorrt_rtx` wheel from your TensorRT for RTX directory: 54 | 55 | ```bash 56 | python -m pip install /path/to/tensorrt-rtx/python/tensorrt_rtx-${version}-cp${py3-ver}-none-${os-ver}_x86_64.whl 57 | ``` 58 | 59 | 2. Install `numpy` and `cuda-python` from the `python/requirements.txt` file: 60 | 61 | ```bash 62 | python -m pip install -r python/requirements.txt 63 | ``` 64 | 65 | ## Running the Sample 66 | 67 | After building, you can run the sample with: 68 | 69 | ```bash 70 | ./apiUsage 71 | ``` 72 | 73 | from the build directory. 74 | 75 | For the Python sample, run: 76 | 77 | ```bash 78 | python api_usage.py 79 | ``` 80 | 81 | The sample will: 82 | 83 | 1. Create and compile a simple neural network with dynamic shapes. 84 | 2. Build a weightless engine on the current device and then refuel its weights. 85 | 3. Run inference with different batch sizes and input values. 86 | 4. Display the results. 87 | 88 | ## Code Overview 89 | 90 | The sample demonstrates several key concepts related to TensorRT for RTX APIs: 91 | 92 | - Network creation and configuration for dynamically-shaped input tensors. 93 | - Selecting deployment targets at AOT. 94 | - Configuring a weightless engine and refueling weights during deployment. 95 | - Using runtime cache to store JIT-compiled kernels. 96 | - Inference execution with changing dynamic shapes. 97 | 98 | For detailed comments explaining each step, please refer to the [apiUsage.cpp](cpp/apiUsage.cpp) and [api_usage.py](python/api_usage.py) source files. 99 | -------------------------------------------------------------------------------- /demo/tests/test_cache_safety.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """ 17 | Test cache persistence and safety of local dev files. 18 | 19 | This test verifies: 20 | 1. Cache persistence - running twice doesn't re-copy files 21 | 2. Safety - delete_cached_files never deletes original source files 22 | """ 23 | 24 | from pathlib import Path 25 | 26 | import pytest 27 | from utils.path_manager import PathManager 28 | 29 | 30 | @pytest.mark.cache 31 | @pytest.mark.integration 32 | class TestCacheSafety: 33 | """Test cache persistence and safety of local dev files.""" 34 | 35 | def test_cache_persistence_and_safety(self, temp_cache_dir: Path, temp_source_dir: Path): 36 | """Test cache persistence and safety of local dev files.""" 37 | # Create original dev files 38 | original_onnx = temp_source_dir / "my_model_fp16.onnx" 39 | original_data = temp_source_dir / "my_model_fp16.onnx.data" 40 | original_config = temp_source_dir / "config.json" 41 | 42 | original_onnx.write_text("original onnx content") 43 | original_data.write_text("original data content") 44 | original_config.write_text("original config content") 45 | 46 | # Create PathManager with separate cache directory 47 | path_manager = PathManager(str(temp_cache_dir)) 48 | 49 | # Test 1: First run - should copy files 50 | success1 = path_manager.acquire_onnx_file("my_model", "fp16", str(original_onnx)) 51 | 52 | assert success1, "First acquisition should succeed" 53 | 54 | cached_onnx = path_manager.get_onnx_path("my_model", "fp16") 55 | assert cached_onnx.exists(), "Cached ONNX should exist after first acquisition" 56 | 57 | # Test 2: Second run - should skip copying (files already exist) 58 | success2 = path_manager.acquire_onnx_file("my_model", "fp16", str(original_onnx)) 59 | 60 | assert success2, "Second run should be successful (skip copy)" 61 | 62 | # Test 3: Cache deletion safety 63 | path_manager.delete_cached_files("my_model", "fp16", "dynamic") 64 | 65 | # CRITICAL TEST: Verify originals are STILL safe 66 | originals_still_safe = all(f.exists() for f in [original_onnx, original_data, original_config]) 67 | assert originals_still_safe, "SAFETY FAILURE: Original dev files were deleted!" 68 | 69 | # Verify cache is cleaned up 70 | remaining_files = list(path_manager.shared_onnx_dir.rglob("*")) 71 | remaining_model_files = [f for f in remaining_files if f.is_file() and "my_model" in str(f)] 72 | assert len(remaining_model_files) == 0, "Cache should be cleaned up" 73 | 74 | # Test 4: Verify file paths are different 75 | canonical_path = path_manager.get_onnx_path("my_model", "fp16") 76 | assert original_onnx.parent != canonical_path.parent, "Original and cache directories should be different" 77 | -------------------------------------------------------------------------------- /demo/tests/test_cache_consistency.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """ 17 | Test cache consistency fix using pytest. 18 | 19 | This test simulates the problem where canonical files exist but pipeline links don't, 20 | and shows how the new PathManager methods fix this inconsistent state. 21 | """ 22 | 23 | from pathlib import Path 24 | 25 | import pytest 26 | from utils.path_manager import PathManager 27 | 28 | 29 | @pytest.mark.cache 30 | @pytest.mark.unit 31 | class TestCacheConsistency: 32 | """Test cache consistency functionality.""" 33 | 34 | def test_cache_consistency_fix(self, path_manager: PathManager, temp_source_dir: Path): 35 | """Test cache consistency issue detection and fix.""" 36 | # Create a source ONNX file 37 | source_onnx = temp_source_dir / "transformer.onnx" 38 | pytest.create_dummy_onnx_file(source_onnx) 39 | 40 | # Test parameters 41 | model_id = "test_transformer" 42 | precision = "fp16" 43 | 44 | # Step 1: Setup cache (before files exist) 45 | canonical_onnx = path_manager.get_onnx_path(model_id, precision) 46 | assert not canonical_onnx.exists() 47 | 48 | # Step 2: Acquire ONNX file 49 | success = path_manager.acquire_onnx_file(model_id, precision, str(source_onnx)) 50 | 51 | assert success, "ONNX acquisition should succeed" 52 | assert canonical_onnx.exists(), "Canonical ONNX should exist after acquisition" 53 | 54 | # Step 3: Check cache state (fix the key name) 55 | cache_status = path_manager.check_cached_files(model_id, precision, "dynamic") 56 | assert cache_status["onnx"], "ONNX should exist in cache" 57 | 58 | def test_deletion_methods(self, path_manager: PathManager, temp_source_dir: Path): 59 | """Test the deletion methods work correctly.""" 60 | # Create test models 61 | model_configs = [("test_transformer", "fp16"), ("test_vae", "fp16")] 62 | 63 | for model_id, precision in model_configs: 64 | source_onnx = temp_source_dir / f"{model_id}.onnx" 65 | pytest.create_dummy_onnx_file(source_onnx) 66 | 67 | success = path_manager.acquire_onnx_file(model_id, precision, str(source_onnx)) 68 | assert success, f"Should acquire {model_id}_{precision}" 69 | 70 | canonical_onnx = path_manager.get_onnx_path(model_id, precision) 71 | assert canonical_onnx.exists(), f"Canonical {model_id}_{precision} should exist" 72 | 73 | # Test deleting only ONNX files for one model 74 | path_manager.delete_cached_onnx_files("test_vae", "fp16") 75 | 76 | vae_onnx = path_manager.get_onnx_path("test_vae", "fp16") 77 | transformer_onnx = path_manager.get_onnx_path("test_transformer", "fp16") 78 | 79 | assert not vae_onnx.exists(), "VAE ONNX should be deleted" 80 | assert transformer_onnx.exists(), "Transformer ONNX should remain" 81 | 82 | # Test deleting all files for remaining model 83 | path_manager.delete_cached_files("test_transformer", "fp16", "dynamic") 84 | assert not transformer_onnx.exists(), "Transformer ONNX should be deleted" 85 | -------------------------------------------------------------------------------- /demo/utils/notebook_utils.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import base64 17 | import logging 18 | import mimetypes 19 | from pathlib import Path 20 | from typing import Optional 21 | from urllib.parse import urlparse 22 | 23 | import requests 24 | from IPython.display import HTML, Markdown, display 25 | 26 | # Initialize logger for this module 27 | logger = logging.getLogger("rtx_demo.utils.notebook_utils") 28 | 29 | 30 | def printmd(string: str, newlines: int = 0): 31 | """Display markdown text in Jupyter notebook.""" 32 | if newlines > 0: 33 | print("\n" * newlines) 34 | 35 | display(Markdown(string)) 36 | 37 | 38 | def markdown_bold_green_format(string: str) -> str: 39 | """Format string as bold green markdown.""" 40 | return f"{string}" 41 | 42 | 43 | def print_prompt(prompt: str): 44 | """Print prompt in markdown.""" 45 | printmd(f"**Prompt**: {prompt}\n") 46 | 47 | 48 | def display_image_from_path(path: str, width: Optional[int] = None, height: Optional[int] = None): 49 | """ 50 | Display an image from URL or local path in Jupyter notebook. 51 | Embeds image data as base64 for HTML export. 52 | 53 | Parameters: 54 | ----------- 55 | path : str 56 | URL or local file path of the image 57 | width : int, optional 58 | Width in pixels 59 | height : int, optional 60 | Height in pixels 61 | 62 | Returns: 63 | -------- 64 | IPython.display.HTML 65 | HTML object with embedded image 66 | """ 67 | try: 68 | # Determine if path is URL or local file 69 | parsed = urlparse(str(path)) 70 | is_url = parsed.scheme in ("http", "https") 71 | 72 | if is_url: 73 | # Handle URL 74 | response = requests.get(path) 75 | response.raise_for_status() 76 | image_data = response.content 77 | content_type = response.headers.get("content-type", "image/png") 78 | else: 79 | # Handle local file 80 | file_path = Path(path) 81 | if not file_path.exists(): 82 | raise FileNotFoundError(f"Image file not found: {path}") 83 | 84 | image_data = file_path.read_bytes() 85 | content_type = mimetypes.guess_type(str(file_path))[0] or "image/png" 86 | 87 | # Encode as base64 88 | image_base64 = base64.b64encode(image_data).decode("utf-8") 89 | 90 | # Build style 91 | style_parts = [] 92 | if width: 93 | style_parts.append(f"width: {width}px") 94 | if height: 95 | style_parts.append(f"height: {height}px") 96 | style = "; ".join(style_parts) 97 | 98 | html = f'' 99 | result = HTML(html) 100 | 101 | except Exception as e: 102 | # Fallback with error message 103 | error_msg = f"Error loading image: {e}" 104 | fallback_html = f'
{error_msg}
' 105 | result = HTML(fallback_html) 106 | 107 | display(result) 108 | -------------------------------------------------------------------------------- /demo/tests/test_hierarchical_structure.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """ 17 | Test script to verify the new hierarchical folder structure. 18 | """ 19 | 20 | from pathlib import PurePosixPath 21 | 22 | import pytest 23 | from utils.path_manager import PathManager 24 | 25 | 26 | @pytest.mark.paths 27 | @pytest.mark.unit 28 | class TestHierarchicalStructure: 29 | """Test hierarchical folder structure functionality.""" 30 | 31 | @pytest.mark.parametrize("shape_mode", ["dynamic", "static"]) 32 | def test_shared_paths_structure(self, path_manager: PathManager, shape_mode: str): 33 | """Test that shared paths follow the correct hierarchical structure.""" 34 | shared_onnx = path_manager.get_onnx_path("t5_text_encoder", "fp16") 35 | shared_engine = path_manager.get_engine_path("t5_text_encoder", "fp16", shape_mode) 36 | shared_metadata = path_manager.get_metadata_path("t5_text_encoder", "fp16", shape_mode) 37 | 38 | # Verify structure by checking path components 39 | assert "shared/onnx/t5_text_encoder/fp16/t5_text_encoder.onnx" in str(PurePosixPath(shared_onnx)) 40 | assert f"shared/engines/t5_text_encoder/fp16/t5_text_encoder_{shape_mode}.engine" in str( 41 | PurePosixPath(shared_engine) 42 | ) 43 | assert f"shared/engines/t5_text_encoder/fp16/t5_text_encoder_{shape_mode}.metadata.json" in str( 44 | PurePosixPath(shared_metadata) 45 | ) 46 | 47 | def test_different_precisions_separated(self, path_manager: PathManager): 48 | """Test that different precisions for same model are properly separated.""" 49 | fp16_onnx = path_manager.get_onnx_path("flux_transformer", "fp16") 50 | fp8_onnx = path_manager.get_onnx_path("flux_transformer", "fp8") 51 | 52 | # Verify they're in different directories 53 | assert fp16_onnx.parent != fp8_onnx.parent, "Different precisions should be in different directories" 54 | assert "fp16" in str(fp16_onnx), "FP16 path should contain 'fp16'" 55 | assert "fp8" in str(fp8_onnx), "FP8 path should contain 'fp8'" 56 | 57 | @pytest.mark.parametrize("shape_mode", ["dynamic", "static"]) 58 | def test_directory_creation(self, path_manager: PathManager, shape_mode: str): 59 | """Test that directories are created correctly.""" 60 | shared_onnx = path_manager.get_onnx_path("t5_text_encoder", "fp16") 61 | shared_engine = path_manager.get_engine_path("t5_text_encoder", "fp16", shape_mode) 62 | 63 | assert shared_onnx.parent.exists(), "Shared ONNX directory should be created" 64 | assert shared_engine.parent.exists(), "Shared engine directory should be created" 65 | 66 | @pytest.mark.parametrize("shape_mode", ["dynamic", "static"]) 67 | def test_file_naming_consistency(self, path_manager: PathManager, shape_mode: str): 68 | """Test that file naming is consistent and clean.""" 69 | shared_onnx = path_manager.get_onnx_path("t5_text_encoder", "fp16") 70 | shared_engine = path_manager.get_engine_path("t5_text_encoder", "fp16", shape_mode) 71 | 72 | # All files should have clean names without precision suffixes 73 | assert shared_onnx.name == "t5_text_encoder.onnx", f"ONNX should have clean name, got: {shared_onnx.name}" 74 | assert shared_engine.name == f"t5_text_encoder_{shape_mode}.engine", ( 75 | f"Engine should have clean name, got: {shared_engine.name}" 76 | ) 77 | -------------------------------------------------------------------------------- /samples/cmake/modules/get_version.cmake: -------------------------------------------------------------------------------- 1 | # 2 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | # SPDX-License-Identifier: Apache-2.0 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | # ----------------------------------------------------------------------------- 18 | # cmake-format: off 19 | # Function: get_version 20 | # Retrieves the TensorRT-RTX version information from the specified include directory. 21 | # 22 | # Args: 23 | # include_dir - Path to the directory containing NvInferVersion.h 24 | # version_variable - Output variable; will be set to the version string in the format "MAJOR.MINOR.PATCH" 25 | # soversion_variable - Output variable; will be set to the SOVERSION string in the format "MAJOR_MINOR" 26 | # 27 | # Example: 28 | # get_version(${TRTRTX_INCLUDE_DIR} TRT_RTX_VERSION TRT_RTX_SOVERSION) 29 | # # TRT_RTX_VERSION will be set to e.g. "9.0.1" 30 | # # TRT_RTX_SOVERSION will be set to e.g. "9_0" 31 | # cmake-format: on 32 | # ----------------------------------------------------------------------------- 33 | function(get_version include_dir version_variable soversion_variable) 34 | set(header_file "${include_dir}/NvInferVersion.h") 35 | if(NOT EXISTS "${header_file}") 36 | message(FATAL_ERROR "TensorRT-RTX version header not found: ${header_file}") 37 | endif() 38 | 39 | file(STRINGS "${header_file}" VERSION_STRINGS REGEX "#define TRT_.*_RTX") 40 | if(NOT VERSION_STRINGS) 41 | message( 42 | FATAL_ERROR "No TRT_*_RTX version defines found in ${header_file}, please check if the path provided is correct.") 43 | endif() 44 | 45 | foreach(type MAJOR MINOR PATCH) 46 | set(trt_${type} "") 47 | foreach(version_line ${VERSION_STRINGS}) 48 | string(REGEX MATCH "TRT_${type}_RTX [0-9]+" trt_type_string "${version_line}") 49 | if(trt_type_string) 50 | string(REGEX MATCH "[0-9]+" trt_${type} "${trt_type_string}") 51 | break() 52 | endif() 53 | endforeach() 54 | if(NOT DEFINED trt_${type}) 55 | message(FATAL_ERROR "Failed to extract TRT_${type}_RTX from ${header_file}") 56 | endif() 57 | endforeach(type) 58 | set(${version_variable} ${trt_MAJOR}.${trt_MINOR}.${trt_PATCH} PARENT_SCOPE) 59 | set(${soversion_variable} ${trt_MAJOR}_${trt_MINOR} PARENT_SCOPE) 60 | endfunction() 61 | 62 | # ----------------------------------------------------------------------------- 63 | # cmake-format: off 64 | # Function: get_library_name 65 | # Retrieves the library name for TensorRT-RTX and TensorRT-ONNXParser-RTX. 66 | # 67 | # Args: 68 | # soversion_variable - Input variable; should contain the SOVERSION string in the format "MAJOR_MINOR" 69 | # lib_name_variable - Output variable; will be set to the library name 70 | # onnxparser_lib_name_variable - Output variable; will be set to the ONNXParser library name 71 | # 72 | # Example: 73 | # get_library_name(TRT_RTX_SOVERSION TRT_RTX_LIB_NAME TRT_RTX_ONNXPARSER_LIB_NAME) 74 | # # TRT_RTX_LIB_NAME will be set to e.g. "tensorrt_rtx" 75 | # # TRT_RTX_ONNXPARSER_LIB_NAME will be set to e.g. "tensorrt_onnxparser_rtx" 76 | # cmake-format: on 77 | # ----------------------------------------------------------------------------- 78 | function(get_library_name soversion_variable lib_name_variable onnxparser_lib_name_variable) 79 | set(trtrtx_lib_name "tensorrt_rtx") 80 | set(trtrtx_onnxparser_lib_name "tensorrt_onnxparser_rtx") 81 | if(WIN32) 82 | set(${lib_name_variable} "${trtrtx_lib_name}_${${soversion_variable}}" PARENT_SCOPE) 83 | set(${onnxparser_lib_name_variable} "${trtrtx_onnxparser_lib_name}_${${soversion_variable}}" PARENT_SCOPE) 84 | else() 85 | set(${lib_name_variable} "${trtrtx_lib_name}" PARENT_SCOPE) 86 | set(${onnxparser_lib_name_variable} "${trtrtx_onnxparser_lib_name}" PARENT_SCOPE) 87 | endif() 88 | endfunction() 89 | -------------------------------------------------------------------------------- /demo/utils/memory_manager.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | import logging 18 | import time 19 | 20 | import cuda.bindings.runtime as cudart 21 | 22 | # Initialize logger for this module 23 | logger = logging.getLogger("rtx_demo.utils.memory_manager") 24 | 25 | 26 | class ModelMemoryManager: 27 | """ 28 | Context manager for efficiently loading and unloading models to optimize VRAM usage. 29 | 30 | This class provides the following memory optimization mode: 31 | 32 | **Low VRAM Mode**: Just-in-time model loading 33 | - Load/allocate models only when needed 34 | - Deallocate immediately after use 35 | 36 | Args: 37 | pipeline: The pipeline instance containing engines and model instances 38 | model_names (list): List of model names to manage 39 | low_vram (bool): Whether to enable low VRAM mode 40 | """ 41 | 42 | def __init__(self, pipeline, model_name, low_vram=False): 43 | self.pipeline = pipeline 44 | self.model_name = model_name 45 | self.low_vram = low_vram 46 | self.timing = {} 47 | 48 | assert self.model_name in self.pipeline.engines, f"Model {self.model_name} not found in pipeline.engines" 49 | assert isinstance(self.model_name, str), "model_name must be a string" 50 | 51 | def __enter__(self): 52 | if not self.low_vram: 53 | return self 54 | else: 55 | return self._enter_low_vram() 56 | 57 | def __exit__(self, exc_type, exc_val, exc_tb): 58 | if not self.low_vram: 59 | return 60 | else: 61 | return self._exit_low_vram() 62 | 63 | def _enter_low_vram(self): 64 | """Low VRAM mode entry - load and allocate specified models""" 65 | logger.debug(f"[MEMORY] Low VRAM: Loading models {self.model_name}...") 66 | 67 | if self.model_name not in self.pipeline.shape_dicts: 68 | raise RuntimeError(f"Model {self.model_name} not found in pipeline.shape_dicts") 69 | 70 | engine = self.pipeline.engines[self.model_name] 71 | shape_dict = self.pipeline.shape_dicts[self.model_name] 72 | 73 | start_time = time.time() 74 | 75 | # Load engine 76 | engine.load() 77 | 78 | # Allocate device memory for this model 79 | device_memory_size = engine.engine.device_memory_size_v2 80 | _, device_memory = cudart.cudaMalloc(device_memory_size) 81 | self.pipeline.shared_device_memory = device_memory 82 | 83 | # Activate engine with allocated memory 84 | engine.activate(device_memory=device_memory) 85 | 86 | # Allocate buffers 87 | engine.allocate_buffers(shape_dict, device=self.pipeline.device) 88 | 89 | setup_time = time.time() - start_time 90 | self.timing[f"{self.model_name}_setup"] = setup_time 91 | 92 | return 93 | 94 | def _exit_low_vram(self): 95 | """Handle low VRAM mode exit - deallocate specified models""" 96 | logger.debug(f"[MEMORY] Low VRAM: Deallocating models {self.model_name}...") 97 | 98 | engine = self.pipeline.engines[self.model_name] 99 | 100 | start_time = time.time() 101 | 102 | # Deallocate buffers 103 | engine.deallocate_buffers() 104 | 105 | # Deactivate engine 106 | engine.deactivate() 107 | 108 | # Unload engine 109 | engine.unload() 110 | 111 | # Free workspace memory 112 | cudart.cudaFree(self.pipeline.shared_device_memory) 113 | self.pipeline.shared_device_memory = None 114 | 115 | logger.debug("[MEMORY] Low VRAM: Freed workspace memory") 116 | 117 | cleanup_time = time.time() - start_time 118 | self.timing[f"{self.model_name}_cleanup"] = cleanup_time 119 | 120 | logger.debug(f"[MEMORY] Low VRAM: {self.model_name} deallocated in {cleanup_time:.3f}s") 121 | 122 | def get_timing_summary(self): 123 | """Get timing summary for memory operations""" 124 | return self.timing.copy() 125 | -------------------------------------------------------------------------------- /demo/tests/test_consolidated.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """ 17 | Test the consolidated structure with proper separation of concerns. 18 | 19 | This test verifies: 20 | - PathManager handles all path operations 21 | - ModelRegistry contains only model definitions 22 | - Pipeline uses both correctly 23 | """ 24 | 25 | from pathlib import Path, PurePosixPath 26 | 27 | import pytest 28 | from utils.model_registry import registry 29 | from utils.path_manager import PathManager 30 | 31 | 32 | @pytest.mark.integration 33 | class TestConsolidated: 34 | """Test separation of concerns between PathManager and ModelRegistry.""" 35 | 36 | def test_separation_of_concerns(self, temp_cache_dir: Path): 37 | """Test that PathManager and ModelRegistry work together properly.""" 38 | # Test PathManager 39 | path_manager = PathManager(str(temp_cache_dir)) 40 | 41 | # Test model registry 42 | available_pipelines = list(registry.pipelines.keys()) 43 | assert len(available_pipelines) > 0, "Should have available pipelines" 44 | assert "flux_1_dev" in available_pipelines, "Flux.1 [dev] pipeline should exist" 45 | 46 | # Test path operations - use a valid pipeline name 47 | model_id = registry.get_model_id("flux_1_dev", "transformer") 48 | assert model_id is not None, "Should get model ID for transformer" 49 | 50 | # Test path generation 51 | canonical_path = path_manager.get_engine_path(model_id, "fp16", "static") 52 | assert canonical_path is not None, "Should generate canonical path" 53 | assert canonical_path.parent.exists(), "Parent directories should be created" 54 | 55 | # Test that path contains expected structure 56 | assert "shared/engines" in str(PurePosixPath(canonical_path)), "Should use shared engines directory" 57 | assert model_id in str(canonical_path), "Should contain model ID" 58 | assert "fp16" in str(canonical_path), "Should contain precision" 59 | 60 | def test_model_registry_functionality(self): 61 | """Test core model registry functionality.""" 62 | # Test pipeline existence 63 | pipelines = list(registry.pipelines.keys()) 64 | assert "flux_1_dev" in pipelines, "Flux.1 [dev] pipeline should exist" 65 | 66 | # Test model ID retrieval 67 | model_id = registry.get_model_id("flux_1_dev", "transformer") 68 | assert isinstance(model_id, str), "Model ID should be a string" 69 | 70 | # Test precision options 71 | precisions = registry.get_available_precisions("flux_1_dev", "transformer") 72 | assert isinstance(precisions, list), "Precisions should be a list" 73 | assert len(precisions) > 0, "Should have available precisions" 74 | 75 | # Test default precision 76 | default_precision = registry.get_default_precision("flux_1_dev", "transformer") 77 | assert default_precision in precisions, "Default precision should be in available precisions" 78 | 79 | def test_pathmanager_functionality(self, temp_cache_dir: Path): 80 | """Test core PathManager functionality.""" 81 | path_manager = PathManager(str(temp_cache_dir)) 82 | 83 | # Test path generation for different file types 84 | model_id = "test_model" 85 | precision = "fp16" 86 | 87 | onnx_path = path_manager.get_onnx_path(model_id, precision) 88 | engine_path = path_manager.get_engine_path(model_id, precision, "dynamic") 89 | metadata_path = path_manager.get_metadata_path(model_id, precision, "dynamic") 90 | 91 | # Verify proper file extensions 92 | assert onnx_path.suffix == ".onnx", "ONNX path should have .onnx extension" 93 | assert engine_path.suffix == ".engine", "Engine path should have .engine extension" 94 | assert metadata_path.suffix == ".json", "Metadata path should have .json extension" 95 | 96 | # Verify hierarchical structure 97 | assert "shared/onnx" in str(PurePosixPath(onnx_path)), "ONNX should be in shared/onnx" 98 | assert "shared/engines" in str(PurePosixPath(engine_path)), "Engine should be in shared/engines" 99 | assert "shared/engines" in str(PurePosixPath(metadata_path)), "Metadata should be in shared/engines" 100 | -------------------------------------------------------------------------------- /demo/utils/timing_data.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | import logging 18 | from collections import defaultdict 19 | from dataclasses import dataclass, field 20 | 21 | import numpy as np 22 | 23 | # Initialize logger for this module 24 | logger = logging.getLogger("rtx_demo.utils.timing_data") 25 | 26 | 27 | @dataclass 28 | class InferenceTimingData: 29 | """Container for detailed inference timing data using CUDA events.""" 30 | 31 | # Pipeline timings (in milliseconds) - maps model type to list of execution times 32 | pipeline_times: dict[str, list[float]] = field(default_factory=lambda: defaultdict(list)) 33 | 34 | # Total times 35 | total_inference_time: float = 0.0 36 | 37 | # Metadata 38 | num_inference_steps: int = 0 39 | height: int = 0 40 | width: int = 0 41 | batch_size: int = 0 42 | guidance_scale: float = 0.0 43 | 44 | def to_dict(self) -> dict: 45 | """Return timing data as a parseable dictionary.""" 46 | # Calculate total end-to-end runtime 47 | total_e2e = self.total_inference_time 48 | if total_e2e == 0 and self.pipeline_times: 49 | # Calculate total from pipeline_times if total_inference_time is not set 50 | total_e2e = sum(np.sum(times) for times in self.pipeline_times.values() if times) 51 | 52 | # Extract transformer runtimes 53 | transformer_times = [] 54 | if "transformer" in self.pipeline_times: 55 | transformer_raw = self.pipeline_times["transformer"] 56 | for time_entry in transformer_raw: 57 | transformer_times.append(float(time_entry)) 58 | 59 | # Calculate throughput 60 | throughput = 0.0 61 | if total_e2e > 0: 62 | throughput = 1000.0 / total_e2e # images per second 63 | 64 | return { 65 | "total_e2e_runtime_ms": total_e2e, 66 | "transformer_runtimes_ms": transformer_times, 67 | "throughput_images_per_sec": throughput, 68 | "metadata": { 69 | "height": self.height, 70 | "width": self.width, 71 | "batch_size": self.batch_size, 72 | "num_inference_steps": self.num_inference_steps, 73 | "guidance_scale": self.guidance_scale, 74 | }, 75 | } 76 | 77 | def print_summary(self): 78 | """Print a summary of timing data.""" 79 | logger.info("\n" + "=" * 60) 80 | logger.info("INFERENCE TIMING SUMMARY") 81 | logger.info("=" * 60) 82 | logger.info(f"Configuration: {self.height}x{self.width}") 83 | logger.info(f"Batch size: {self.batch_size}, Guidance scale: {self.guidance_scale}") 84 | logger.info(f"Inference steps: {self.num_inference_steps}") 85 | logger.info("-" * 60) 86 | 87 | # Pipeline timing data 88 | if self.pipeline_times: 89 | for model_type, times in self.pipeline_times.items(): 90 | if times: # Only show models with timing data 91 | times_array = np.array(times) 92 | mean_time = float(np.mean(times_array)) 93 | std_time = float(np.std(times_array)) 94 | total_time = float(np.sum(times_array)) 95 | count = len(times) 96 | 97 | logger.info(f"{model_type.replace('_', ' ').title()}:") 98 | if count > 1: 99 | logger.info(f" - Total: {total_time:8.2f} ms ({count} executions)") 100 | logger.info(f" - Average: {mean_time:8.2f} ± {std_time:.2f} ms") 101 | logger.info( 102 | f" - Range: {float(np.min(times_array)):.2f} - {float(np.max(times_array)):.2f} ms" 103 | ) 104 | logger.info(f" - Entries: {times_array}") 105 | else: 106 | logger.info(f" - Time: {total_time:8.2f} ms") 107 | 108 | logger.info("-" * 60) 109 | if self.total_inference_time > 0: 110 | logger.info(f"Total Inference: {self.total_inference_time:8.2f} ms") 111 | elif self.pipeline_times: 112 | # Calculate total from pipeline_times 113 | total_pipeline_time = sum(np.sum(times) for times in self.pipeline_times.values() if times) 114 | if total_pipeline_time > 0: 115 | logger.info(f"Total Pipeline: {total_pipeline_time:8.2f} ms") 116 | 117 | # Calculate throughput 118 | effective_total = self.total_inference_time 119 | if effective_total == 0 and self.pipeline_times: 120 | effective_total = sum(np.sum(times) for times in self.pipeline_times.values() if times) 121 | 122 | if effective_total > 0: 123 | img_per_sec = 1000.0 / effective_total 124 | logger.info(f"Throughput: {img_per_sec:8.2f} images/second") 125 | 126 | logger.info("=" * 60) 127 | -------------------------------------------------------------------------------- /demo/flux1.dev/flux_demo.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | # SPDX-License-Identifier: Apache-2.0 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """ 18 | A basic demonstration of the Flux text-to-image pipeline using the TensorRT-RTX framework. 19 | This demo focuses on the happy-path workflow for generating images from text prompts. 20 | """ 21 | 22 | import argparse 23 | import logging 24 | import sys 25 | from pathlib import Path 26 | 27 | try: 28 | from pipelines.flux_pipeline import FluxPipeline 29 | except ImportError: 30 | sys.path.append(str(Path(__file__).parent.parent)) 31 | from pipelines.flux_pipeline import FluxPipeline 32 | 33 | 34 | logger = logging.getLogger("rtx_demo.flux1.dev.flux_demo") 35 | 36 | 37 | def main(): 38 | parser = argparse.ArgumentParser(description="Simple Flux text-to-image demo") 39 | # Required arguments 40 | parser.add_argument("--hf-token", type=str, required=True, help="Hugging Face token") 41 | 42 | # Image Generation Parameters 43 | parser.add_argument( 44 | "--prompt", 45 | type=str, 46 | default="A serene lake at sunset with mountains in the background", 47 | help="Text prompt for image generation", 48 | ) 49 | parser.add_argument("--height", type=int, default=512, help="Image height (default: 512)") 50 | parser.add_argument("--width", type=int, default=512, help="Image width (default: 512)") 51 | parser.add_argument("--batch-size", type=int, default=1, help="Batch size (default: 1)") 52 | parser.add_argument("--seed", type=int, default=0, help="Seed for random number generator (default: 0)") 53 | parser.add_argument("--num-inference-steps", type=int, default=50, help="Number of denoising steps (default: 50)") 54 | parser.add_argument("--guidance-scale", type=float, default=3.5, help="Guidance scale (default: 3.5)") 55 | 56 | # Engine Generation, Caching, Memory Management, and Verbosity 57 | parser.add_argument( 58 | "--precision", 59 | type=str, 60 | default="fp8", 61 | help="Precision for the transformer model (default: fp8)", 62 | choices=["bf16", "fp8", "fp4"], 63 | ) 64 | parser.add_argument("--enable-runtime-cache", action="store_true", help="Enable runtime caching") 65 | parser.add_argument( 66 | "--cuda-graph-strategy", 67 | type=str, 68 | default="disabled", 69 | help="Cuda graph strategy (default: disabled)", 70 | choices=["disabled", "whole_graph_capture"], 71 | ) 72 | parser.add_argument("--low-vram", action="store_true", help="Enable low VRAM mode") 73 | parser.add_argument("--dynamic-shape", action="store_true", default=False, help="Enable dynamic-shape engines") 74 | parser.add_argument("--verbose", action="store_true", help="Enable verbose logging") 75 | parser.add_argument( 76 | "--cache-dir", type=str, default="./demo_cache", help="Cache directory for models (default: ./demo_cache)" 77 | ) 78 | parser.add_argument( 79 | "--cache-mode", type=str, default="full", help="Cache mode (default: full)", choices=["full", "lean"] 80 | ) 81 | args = parser.parse_args() 82 | 83 | try: 84 | pipeline = FluxPipeline( 85 | cache_dir=args.cache_dir, 86 | device="cuda", 87 | verbose=args.verbose, 88 | cache_mode=args.cache_mode, 89 | guidance_scale=args.guidance_scale, 90 | num_inference_steps=args.num_inference_steps, 91 | hf_token=args.hf_token, 92 | low_vram=args.low_vram, 93 | cuda_graph_strategy=args.cuda_graph_strategy, 94 | enable_runtime_cache=args.enable_runtime_cache, 95 | ) 96 | 97 | # Print header and configuration 98 | logger.info("=" * 50) 99 | logger.info("Simple Flux Text-to-Image Demo") 100 | logger.info("=" * 50) 101 | logger.info(f"Prompt: '{args.prompt}'") 102 | logger.info(f"Transformer Precision: {args.precision}") 103 | logger.info(f"Resolution: {args.width}x{args.height}") 104 | logger.info(f"Batch size: {args.batch_size}") 105 | logger.info(f"Seed: {args.seed}") 106 | logger.info(f"Inference steps: {args.num_inference_steps}") 107 | logger.info(f"Guidance scale: {args.guidance_scale}") 108 | logger.info(f"Cache directory: {args.cache_dir}") 109 | logger.info(f"Low VRAM mode: {args.low_vram}") 110 | logger.info(f"Cudagraphs: {args.cuda_graph_strategy}") 111 | logger.info(f"Dynamic shape: {args.dynamic_shape}") 112 | logger.info(f"Runtime caching: {args.enable_runtime_cache}") 113 | logger.info(f"Cache mode: {args.cache_mode}") 114 | logger.info("") 115 | 116 | jit_times = pipeline.load_engines( 117 | transformer_precision=args.precision, 118 | opt_batch_size=args.batch_size, 119 | opt_height=args.height, 120 | opt_width=args.width, 121 | shape_mode="dynamic" if args.dynamic_shape else "static", 122 | ) 123 | 124 | for model, jit_time in jit_times.items(): 125 | logger.info(f"JIT Compilation + Execution Context Creation Time for {model}: {round(jit_time, 2)} seconds") 126 | 127 | pipeline.load_resources( 128 | batch_size=args.batch_size, 129 | height=args.height, 130 | width=args.width, 131 | ) 132 | 133 | # Print memory usage summary 134 | if not args.low_vram: 135 | pipeline.print_gpu_vram_summary() 136 | 137 | # Run inference 138 | logger.info("Generating image...") 139 | logger.info(f"Running {args.num_inference_steps} denoising steps...") 140 | 141 | save_dir = "." 142 | pipeline.infer( 143 | prompt=args.prompt, 144 | save_path=save_dir, 145 | height=args.height, 146 | width=args.width, 147 | seed=args.seed, 148 | batch_size=args.batch_size, 149 | num_inference_steps=args.num_inference_steps, 150 | guidance_scale=args.guidance_scale, 151 | ) 152 | finally: 153 | # Cleanup 154 | if "pipeline" in locals(): 155 | pipeline.cleanup() 156 | 157 | return 0 158 | 159 | 160 | if __name__ == "__main__": 161 | sys.exit(main()) 162 | -------------------------------------------------------------------------------- /demo/tests/test_metadata_shapes.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """ 17 | Test script to verify the fixed _shapes_fit_profile method. 18 | 19 | This script tests the different input shape profile formats: 20 | 1. Static shapes (single tuple) 21 | 2. Dynamic shapes (list of 3 tuples) 22 | 3. Effectively static shapes (all 3 tuples are the same) 23 | """ 24 | 25 | import pytest 26 | from utils.engine_metadata import EngineMetadata 27 | 28 | 29 | @pytest.mark.unit 30 | class TestMetadataShapes: 31 | """Test _shapes_fit_profile method with different profile formats.""" 32 | 33 | def test_static_shapes(self): 34 | """Test static shapes (single tuple).""" 35 | metadata_static = EngineMetadata( 36 | model_name="test_model", 37 | precision="fp16", 38 | onnx_path="/path/to/model.onnx", 39 | onnx_hash="abcd1234", 40 | input_shapes={ 41 | "input_ids": (1, 77), # Static shape: single tuple 42 | "attention_mask": (1, 77), 43 | }, 44 | extra_args=set(), 45 | build_timestamp=1234567890, 46 | ) 47 | 48 | # Test exact match (should pass) 49 | new_shapes_exact = {"input_ids": (1, 77), "attention_mask": (1, 77)} 50 | result = metadata_static._shapes_fit_profile(new_shapes_exact) 51 | assert result is True, "Static exact match should pass" 52 | 53 | # Test different shape (should fail) 54 | new_shapes_different = {"input_ids": (1, 128), "attention_mask": (1, 77)} 55 | result = metadata_static._shapes_fit_profile(new_shapes_different) 56 | assert result is False, "Static different shape should fail" 57 | 58 | def test_dynamic_shapes(self): 59 | """Test dynamic shapes (list of 3 tuples).""" 60 | metadata_dynamic = EngineMetadata( 61 | model_name="test_model", 62 | precision="fp16", 63 | onnx_path="/path/to/model.onnx", 64 | onnx_hash="abcd1234", 65 | input_shapes={ 66 | "latents": [(1, 4, 32, 32), (1, 4, 64, 64), (1, 4, 128, 128)], # min, opt, max 67 | "timestep": [(1,), (1,), (1,)], # Effectively static 68 | }, 69 | extra_args=set(), 70 | build_timestamp=1234567890, 71 | ) 72 | 73 | # Test shape within range (should pass) 74 | new_shapes_within = {"latents": (1, 4, 48, 48), "timestep": (1,)} 75 | result = metadata_dynamic._shapes_fit_profile(new_shapes_within) 76 | assert result is True, "Dynamic within range should pass" 77 | 78 | # Test shape at min boundary (should pass) 79 | new_shapes_min = {"latents": (1, 4, 32, 32), "timestep": (1,)} 80 | result = metadata_dynamic._shapes_fit_profile(new_shapes_min) 81 | assert result is True, "Dynamic at min boundary should pass" 82 | 83 | # Test shape at max boundary (should pass) 84 | new_shapes_max = {"latents": (1, 4, 128, 128), "timestep": (1,)} 85 | result = metadata_dynamic._shapes_fit_profile(new_shapes_max) 86 | assert result is True, "Dynamic at max boundary should pass" 87 | 88 | # Test shape below min (should fail) 89 | new_shapes_below = {"latents": (1, 4, 16, 16), "timestep": (1,)} 90 | result = metadata_dynamic._shapes_fit_profile(new_shapes_below) 91 | assert result is False, "Dynamic below min should fail" 92 | 93 | # Test shape above max (should fail) 94 | new_shapes_above = {"latents": (1, 4, 256, 256), "timestep": (1,)} 95 | result = metadata_dynamic._shapes_fit_profile(new_shapes_above) 96 | assert result is False, "Dynamic above max should fail" 97 | 98 | def test_effectively_static_shapes(self): 99 | """Test effectively static shapes (all 3 tuples are the same).""" 100 | metadata_effectively_static = EngineMetadata( 101 | model_name="test_model", 102 | precision="fp16", 103 | onnx_path="/path/to/model.onnx", 104 | onnx_hash="abcd1234", 105 | input_shapes={ 106 | "embeddings": [(1, 77, 768), (1, 77, 768), (1, 77, 768)], # All the same = static 107 | "mask": [(1, 77), (1, 77), (1, 77)], 108 | }, 109 | extra_args=set(), 110 | build_timestamp=1234567890, 111 | ) 112 | 113 | # Test exact match (should pass) 114 | new_shapes_exact = {"embeddings": (1, 77, 768), "mask": (1, 77)} 115 | result = metadata_effectively_static._shapes_fit_profile(new_shapes_exact) 116 | assert result is True, "Effectively static exact match should pass" 117 | 118 | # Test different shape (should fail) 119 | new_shapes_different = {"embeddings": (1, 77, 512), "mask": (1, 77)} 120 | result = metadata_effectively_static._shapes_fit_profile(new_shapes_different) 121 | assert result is False, "Effectively static different shape should fail" 122 | 123 | def test_mixed_static_and_dynamic(self): 124 | """Test mixed static and dynamic shapes.""" 125 | metadata_mixed = EngineMetadata( 126 | model_name="test_model", 127 | precision="fp16", 128 | onnx_path="/path/to/model.onnx", 129 | onnx_hash="abcd1234", 130 | input_shapes={ 131 | "static_input": (1, 768), # Static 132 | "dynamic_input": [(1, 1, 512), (1, 4, 512), (1, 8, 512)], # Dynamic 133 | }, 134 | extra_args=set(), 135 | build_timestamp=1234567890, 136 | ) 137 | 138 | # Test valid combination (should pass) 139 | new_shapes_valid = {"static_input": (1, 768), "dynamic_input": (1, 6, 512)} 140 | result = metadata_mixed._shapes_fit_profile(new_shapes_valid) 141 | assert result is True, "Mixed valid should pass" 142 | 143 | # Test invalid static (should fail) 144 | new_shapes_invalid_static = {"static_input": (1, 512), "dynamic_input": (1, 6, 512)} 145 | result = metadata_mixed._shapes_fit_profile(new_shapes_invalid_static) 146 | assert result is False, "Mixed invalid static should fail" 147 | 148 | # Test invalid dynamic (should fail) 149 | new_shapes_invalid_dynamic = {"static_input": (1, 768), "dynamic_input": (1, 16, 512)} 150 | result = metadata_mixed._shapes_fit_profile(new_shapes_invalid_dynamic) 151 | assert result is False, "Mixed invalid dynamic should fail" 152 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | ## TensorRT-RTX OSS Contribution Rules 2 | 3 | #### Issue Tracking 4 | 5 | - All enhancement, bugfix, or change requests must begin with the creation of a [TensorRT-RTX Issue Request](https://github.com/nvidia/TensorRT-RTX/issues). 6 | - The issue request must be reviewed by TensorRT-RTX engineers and approved prior to code review. 7 | 8 | #### Coding Guidelines 9 | 10 | - TensorRT-RTX follows the same coding guidelines as TensorRT. Therefore, all source code contributions 11 | must strictly adhere to the [TensorRT Coding Guidelines](https://github.com/NVIDIA/TensorRT/blob/HEAD/CODING-GUIDELINES.md). 12 | 13 | - In addition, please follow the existing conventions in the relevant file, submodule, module, and project when you add new code or when you extend/fix existing functionality. 14 | 15 | - To maintain consistency in code formatting and style, you should also run `clang-format` on the modified sources with the provided configuration file. This applies TensorRT-RTX code formatting rules to: 16 | 17 | - class, function/method, and variable/field naming 18 | - comment style 19 | - indentation 20 | - line length 21 | 22 | - Format git changes: 23 | 24 | ```bash 25 | # Commit ID is optional - if unspecified, run format on staged changes. 26 | git-clang-format --style file [commit ID/reference] 27 | ``` 28 | 29 | - Format individual source files: 30 | 31 | ```bash 32 | # -style=file : Obtain the formatting rules from .clang-format 33 | # -i : In-place modification of the processed file 34 | clang-format -style=file -i -fallback-style=none 35 | ``` 36 | 37 | - Format entire codebase (for project maintainers only): 38 | 39 | ```bash 40 | find samples -iname *.h -o -iname *.c -o -iname *.cpp -o -iname *.hpp \ 41 | | xargs clang-format -style=file -i -fallback-style=none 42 | ``` 43 | 44 | - Avoid introducing unnecessary complexity into existing code so that maintainability and readability are preserved. 45 | 46 | - Try to keep pull requests (PRs) as concise as possible: 47 | 48 | - Avoid committing commented-out code. 49 | - Wherever possible, each PR should address a single concern. If there are several otherwise-unrelated things that should be fixed to reach a desired endpoint, our recommendation is to open several PRs and indicate the dependencies in the description. The more complex the changes are in a single PR, the more time it will take to review those changes. 50 | 51 | - Write commit titles using imperative mood and [these rules](https://chris.beams.io/posts/git-commit/), and reference the Issue number corresponding to the PR. Following is the recommended format for commit texts: 52 | 53 | ``` 54 | # - 55 | 56 | 57 | ``` 58 | 59 | - Ensure that the build log is clean, meaning no warnings or errors should be present. 60 | 61 | - Ensure that all `sample` and `demo` tests pass prior to submitting your code. Build and run samples and demos using the instructions in the [sample](./samples/README.md) and [demo](./demo/README.md) README files, and run demo tests using the instructions in the [test README](./demo/tests/README.md). 62 | 63 | - All OSS components must contain accompanying documentation (READMEs) describing the functionality, dependencies, and known issues. 64 | 65 | - All OSS components must have an accompanying test. 66 | 67 | - If introducing a new component, provide a test sample to verify the functionality. 68 | 69 | - To add or disable functionality: 70 | 71 | - Add a CMake option with a default value that matches the existing behavior. 72 | - Where entire files can be included/excluded based on the value of this option, selectively include/exclude the relevant files from compilation by modifying `CMakeLists.txt` rather than using `#if` guards around the entire body of each file. 73 | - Where the functionality involves minor changes to existing files, use `#if` guards. 74 | 75 | - Make sure that you can contribute your work to open source (no license and/or patent conflict is introduced by your code). You will need to [`sign`](#signing-your-work) your commit. 76 | 77 | - Thanks in advance for your patience as we review your contributions; we do appreciate them! 78 | 79 | #### Pull Requests 80 | 81 | Developer workflow for code contributions is as follows: 82 | 83 | 1. Developers must first [fork](https://help.github.com/en/articles/fork-a-repo) the [upstream](https://github.com/nvidia/TensorRT-RTX) TensorRT-RTX OSS repository. 84 | 85 | 2. Git clone the forked repository and push changes to the personal fork. 86 | 87 | ```bash 88 | git clone https://github.com/YOUR_USERNAME/YOUR_FORK.git TensorRT-RTX 89 | # Checkout the targeted branch and commit changes 90 | # Push the commits to a branch on the fork (remote). 91 | git push -u origin : 92 | ``` 93 | 94 | 3. Once the code changes are staged on the fork and ready for review, a [Pull Request](https://help.github.com/en/articles/about-pull-requests) (PR) can be [requested](https://help.github.com/en/articles/creating-a-pull-request) to merge the changes from a branch of the fork into a selected branch of upstream. 95 | 96 | - Exercise caution when selecting the source and target branches for the PR. 97 | Note that versioned releases of TensorRT-RTX OSS are posted to `release/` branches of the upstream repo. 98 | - Creation of a PR creation kicks off the code review process. 99 | - At least one TensorRT-RTX engineer will be assigned for the review. 100 | - While under review, mark your PRs as work-in-progress by prefixing the PR title with [WIP]. 101 | 102 | 4. Since there is no CI/CD process in place yet, the PR will be accepted and the corresponding issue closed only after adequate testing has been completed, manually, by the developer and/or TensorRT-RTX engineer reviewing the code. 103 | 104 | #### Signing Your Work 105 | 106 | - We require that all contributors "sign-off" on their commits. This certifies that the contribution is your original work, or you have rights to submit it under the same license, or a compatible license. 107 | 108 | - Any contribution which contains commits that are not Signed-Off will not be accepted. 109 | 110 | - To sign off on a commit you simply use the `--signoff` (or `-s`) option when committing your changes: 111 | 112 | ```bash 113 | $ git commit -s -m "Add cool feature." 114 | ``` 115 | 116 | This will append the following to your commit message: 117 | 118 | ``` 119 | Signed-off-by: Your Name 120 | ``` 121 | 122 | - Full text of the DCO: 123 | 124 | ``` 125 | Developer Certificate of Origin 126 | Version 1.1 127 | 128 | Copyright (C) 2004, 2006 The Linux Foundation and its contributors. 129 | 1 Letterman Drive 130 | Suite D4700 131 | San Francisco, CA, 94129 132 | 133 | Everyone is permitted to copy and distribute verbatim copies of this license document, but changing it is not allowed. 134 | ``` 135 | 136 | ``` 137 | Developer's Certificate of Origin 1.1 138 | 139 | By making a contribution to this project, I certify that: 140 | 141 | (a) The contribution was created in whole or in part by me and I have the right to submit it under the open source license indicated in the file; or 142 | 143 | (b) The contribution is based upon previous work that, to the best of my knowledge, is covered under an appropriate open source license and I have the right under that license to submit that work with modifications, whether created in whole or in part by me, under the same open source license (unless I am permitted to submit under a different license), as indicated in the file; or 144 | 145 | (c) The contribution was provided directly to me by some other person who certified (a), (b) or (c) and I have not modified it. 146 | 147 | (d) I understand and agree that this project and the contribution are public and that a record of the contribution (including all personal information I submit with it, including my sign-off) is maintained indefinitely and may be redistributed consistent with this project or the open source license(s) involved. 148 | ``` 149 | -------------------------------------------------------------------------------- /demo/README.md: -------------------------------------------------------------------------------- 1 | # TensorRT-RTX Demos 2 | 3 | A collection of demos showcasing key [TensorRT-RTX](https://developer.nvidia.com/tensorrt-rtx) features through model pipelines. 4 | 5 | ## Quick Start 6 | 7 | 1. **Clone and install** 8 | 9 | We recommend using Python versions between 3.9 and 3.12 inclusive due to supported versions for required dependencies. 10 | 11 | ```bash 12 | git clone https://github.com/NVIDIA/TensorRT-RTX.git 13 | cd TensorRT-RTX 14 | 15 | # Install TensorRT-RTX from the wheels located in the downloaded tarfile 16 | # Visit https://developer.nvidia.com/tensorrt-rtx to download 17 | # Example below is for Python 3.12 on Linux (customize with your Python version + OS) 18 | python -m pip install YOUR_TENSORRT_RTX_DOWNLOAD_DIR/python/tensorrt_rtx-1.0.0.20-cp312-none-linux_x86_64.whl 19 | 20 | # Install demo dependencies (example: Flux 1.dev) 21 | python -m pip install -r demo/flux1.dev/requirements.txt 22 | ``` 23 | 24 | 2. **Run demo** 25 | 26 | ```bash 27 | # Standalone Python script 28 | python demo/flux1.dev/flux_demo.py -h 29 | 30 | # Interactive Jupyter notebook 31 | jupyter notebook demo/flux1.dev/flux_demo.ipynb 32 | ``` 33 | 34 | ## Python Script Usage Examples 35 | 36 | The standalone script provides extensive configuration options for various use cases. For detailed walkthroughs, interactive exploration, and comprehensive documentation, see the [Flux.1 [dev] Demo Notebook](./flux1.dev/flux_demo.ipynb) which offers in-depth coverage of TensorRT-RTX features. 37 | 38 | > **GPU Compatibility**: This demo is verified on Ada and Blackwell GPUs. See [Transformer Precision Options](#transformer-precision-options) for more compatibility details. 39 | 40 | ### Required Parameters 41 | 42 | To download model checkpoints for the FLUX.1 [dev] pipeline, obtain a `read` access token to the model repository on HuggingFace Hub. See [instructions](https://huggingface.co/docs/hub/security-tokens). 43 | 44 | ```bash 45 | --hf-token YOUR_HF_TOKEN # Hugging Face token with read access to the Flux.1 [dev] model 46 | ``` 47 | 48 | ### Image Generation Parameters 49 | 50 | ```bash 51 | --prompt "Your text prompt" # Text prompt for generation 52 | --height 512 # Image height (default: 512) 53 | --width 512 # Image width (default: 512) 54 | --batch-size 1 # Batch size (default: 1) 55 | --seed 0 # Random seed (default: 0) 56 | --num-inference-steps 50 # Denoising steps (default: 50) 57 | --guidance-scale 3.5 # Guidance scale (default: 3.5) 58 | ``` 59 | 60 | ### Engine & Performance Options 61 | 62 | ```bash 63 | --precision {bf16,fp8,fp4} # Transformer precision (default: fp8) 64 | --dynamic-shape # Enable dynamic shape engines 65 | --enable-runtime-cache # Enable runtime caching 66 | --low-vram # Enable low VRAM mode 67 | --verbose # Enable verbose logging 68 | ``` 69 | 70 | ### Cache Management 71 | 72 | ```bash 73 | --cache-dir ./demo_cache # Cache directory (default: ./demo_cache) 74 | --cache-mode {full,lean} # Cache mode (default: full) 75 | ``` 76 | 77 | ### Example Commands for Flux.1 [dev] Pipeline 78 | 79 | **Default Parameters Image Generation:** 80 | 81 | ```bash 82 | python demo/flux1.dev/flux_demo.py --hf-token YOUR_TOKEN 83 | ``` 84 | 85 | **Large Image Generation (1024x1024):** 86 | 87 | ```bash 88 | python demo/flux1.dev/flux_demo.py --hf-token YOUR_TOKEN --height 1024 --width 1024 --prompt "A detailed cityscape at golden hour" 89 | ``` 90 | 91 | **Faster JIT Compilation Times with Runtime Caching:** 92 | 93 | ```bash 94 | python demo/flux1.dev/flux_demo.py --hf-token YOUR_TOKEN --enable-runtime-cache --prompt "A cat meanders down a dimly lit alleyway in a large city." 95 | ``` 96 | 97 | **Dynamic-Shape Engines with Shape-Specialized Kernels:** 98 | 99 | ```bash 100 | python demo/flux1.dev/flux_demo.py --hf-token YOUR_TOKEN --dynamic-shape --prompt "A dramatic cityscape from a dazzling angle" 101 | ``` 102 | 103 | **Low VRAM + FP4 Quantized (for Blackwell GPUs with memory constraints):** 104 | 105 | ```bash 106 | python demo/flux1.dev/flux_demo.py --hf-token YOUR_TOKEN --low-vram --precision fp4 --prompt "A serene forest scene" 107 | ``` 108 | 109 | > **Tip**: The [Jupyter notebook](./flux1.dev/flux_demo.ipynb) provides interactive parameter exploration, detailed explanations of each feature, and additional use cases. 110 | 111 | ## Key Features 112 | 113 | - **Smart Caching**: Shared models across pipelines with intelligent cleanup 114 | - **Cross-Platform**: Works on Windows and Linux 115 | - **Flexible Precision**: Configure transformer model precision (bf16, fp8, fp4) 116 | - **Memory Management**: Low-VRAM mode for memory-constrained GPUs 117 | - **Dynamic Shapes**: Support for flexible input dimensions with runtime optimization 118 | 119 | ## Notable Configuration Options 120 | 121 | ### Transformer Precision Options 122 | 123 | Choose based on your GPU architecture and VRAM requirements: 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 |
PrecisionSupported GPU ArchitectureApprox. Max VRAM Usage
Default--low-vram
BF16Ampere, Ada, Blackwell32.1 GB23.1 GB
FP8Ada, Blackwell21.6 GB12.0 GB
FP4Blackwell20.5 GB11.0 GB
158 | 159 | ```python 160 | # Configure precision when loading engines 161 | pipeline.load_engines(transformer_precision="fp8") # Default: fp8 162 | ``` 163 | 164 | ### Input Shape Modes 165 | 166 | ```python 167 | # Static shapes (default) 168 | pipeline.load_engines(opt_height=512, opt_width=512, shape_mode="static") 169 | 170 | # Dynamic shapes (flexible resolutions without recompilation) 171 | pipeline.load_engines(opt_height=512, opt_width=512, shape_mode="dynamic") 172 | ``` 173 | 174 | ### GPU Memory Management 175 | 176 | ```python 177 | # Default (fastest, more VRAM usage) 178 | pipeline = Pipeline(..., low_vram=False) 179 | 180 | # Low VRAM mode (slower, less VRAM usage) 181 | pipeline = Pipeline(..., low_vram=True) 182 | ``` 183 | 184 | ### Disk Memory Management 185 | 186 | - **`full`** (default): Keep all cached models 187 | - **`lean`**: Auto-cleanup unused models to save disk space 188 | 189 | #### Cache Structure 190 | 191 | Models and engines are stored in a shared cache by `model_id` and `precision`: 192 | 193 | ``` 194 | demo_cache/ 195 | ├── shared/ 196 | │ ├── onnx/{model_id}/{precision}/ # ONNX models 197 | │ └── engines/{model_id}/{precision}/ # TensorRT engines 198 | ├── runtime.cache # JIT compilation cache 199 | └── .cache_state.json # Usage tracking 200 | ``` 201 | 202 | ## Troubleshooting 203 | 204 | **Image Quality Issues** 205 | 206 | - Ensure the dimensions are multiples of 16 207 | - Try altering the `seed` and `guidance_scale` parameters 208 | - See [Flux.1 [dev] Demo Notebook](./flux1.dev/flux_demo.ipynb) for more tips and examples 209 | 210 | **GPU Out of Memory** 211 | 212 | - Use `low_vram=True` to reduce VRAM usage 213 | - Use `enable_runtime_cache=False` or omit the `--enable-runtime-cache` flag 214 | - Try lower precision: `fp8` (Ada/Blackwell) or `fp4` (Blackwell only) 215 | - Reduce batch size or image resolution 216 | 217 | **Disk Space Issues** 218 | 219 | - Use `cache_mode="lean"` to reduce disk usage by automatically cleaning up unused models 220 | - Manually delete demo cache directory 221 | 222 | **Build Errors** 223 | 224 | - Verify TensorRT-RTX and dependencies are installed (see [Quick Start](#quick-start)) 225 | - Ensure the precision being used is supported by the GPU architecture (see [Support Matrix](https://docs.nvidia.com/deeplearning/tensorrt-rtx/latest/getting-started/support-matrix.html)) 226 | 227 | ## Running Tests 228 | 229 | To configure the test environment and run demo tests, refer to the [test README](./tests/README.md). 230 | -------------------------------------------------------------------------------- /demo/tests/test_precision_changes.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """ 17 | Test script to verify precision changes are handled correctly across runs. 18 | """ 19 | 20 | import json 21 | from pathlib import Path 22 | 23 | import pytest 24 | from utils.path_manager import ModelConfig, PathManager 25 | 26 | 27 | @pytest.mark.integration 28 | class TestPrecisionChanges: 29 | """Test precision changes across multiple runs.""" 30 | 31 | def test_precision_changes_lean_mode(self, temp_cache_dir: Path): 32 | """Test precision changes across multiple runs in lean mode.""" 33 | # Test 1: Initial run with specific precisions 34 | path_manager = PathManager(cache_dir=str(temp_cache_dir), cache_mode="lean") 35 | 36 | initial_config = { 37 | "text_encoder": ModelConfig("flux_t5_text_encoder", "fp16", "dynamic"), 38 | "transformer": ModelConfig("flux_transformer", "fp16", "dynamic"), # fp16 initially 39 | "vae": ModelConfig("flux_vae_decoder", "fp16", "dynamic"), 40 | } 41 | 42 | path_manager.set_pipeline_models("flux_1_dev", initial_config) 43 | 44 | # Create dummy files for initial models 45 | for _, (model_id, precision, shape_mode) in initial_config.items(): 46 | onnx_path = path_manager.get_onnx_path(model_id, precision) 47 | engine_path = path_manager.get_engine_path(model_id, precision, shape_mode) 48 | metadata_path = path_manager.get_metadata_path(model_id, precision, shape_mode) 49 | 50 | onnx_path.touch() 51 | engine_path.touch() 52 | metadata_path.touch() 53 | 54 | # Verify state file was created 55 | state_file = temp_cache_dir / ".cache_state.json" 56 | assert state_file.exists(), "State file should be created" 57 | 58 | with open(state_file) as f: 59 | state = json.load(f) 60 | 61 | expected_state = { 62 | "flux_1_dev": { 63 | "text_encoder": {"model_id": "flux_t5_text_encoder", "precision": "fp16", "shape_mode": "dynamic"}, 64 | "transformer": {"model_id": "flux_transformer", "precision": "fp16", "shape_mode": "dynamic"}, 65 | "vae": {"model_id": "flux_vae_decoder", "precision": "fp16", "shape_mode": "dynamic"}, 66 | } 67 | } 68 | assert state == expected_state, f"State mismatch: {state} != {expected_state}" 69 | 70 | def test_precision_change_cleanup(self, temp_cache_dir: Path): 71 | """Test that changing precision cleans up old files in lean mode.""" 72 | path_manager = PathManager(cache_dir=str(temp_cache_dir), cache_mode="lean") 73 | 74 | # Initial configuration 75 | initial_config = { 76 | "text_encoder": ModelConfig("flux_t5_text_encoder", "fp16", "dynamic"), 77 | "transformer": ModelConfig("flux_transformer", "fp16", "dynamic"), 78 | "vae": ModelConfig("flux_vae_decoder", "fp16", "dynamic"), 79 | } 80 | 81 | path_manager.set_pipeline_models("flux_1_dev", initial_config) 82 | 83 | # Create initial files 84 | for _, (model_id, precision, _) in initial_config.items(): 85 | onnx_path = path_manager.get_onnx_path(model_id, precision) 86 | onnx_path.touch() 87 | 88 | # Change transformer precision 89 | changed_config = { 90 | "text_encoder": ModelConfig("flux_t5_text_encoder", "fp16", "dynamic"), # Same 91 | "transformer": ModelConfig("flux_transformer", "fp8", "dynamic"), # Changed fp16 -> fp8 92 | "vae": ModelConfig("flux_vae_decoder", "fp16", "dynamic"), # Same 93 | } 94 | 95 | # Check that old fp16 transformer exists before change 96 | old_transformer_onnx = path_manager.get_onnx_path("flux_transformer", "fp16") 97 | assert old_transformer_onnx.exists(), "Old fp16 transformer should exist before change" 98 | 99 | path_manager.set_pipeline_models("flux_1_dev", changed_config) 100 | 101 | # Create new fp8 transformer file 102 | new_transformer_onnx = path_manager.get_onnx_path("flux_transformer", "fp8") 103 | new_transformer_onnx.touch() 104 | 105 | # Verify cleanup happened 106 | assert not old_transformer_onnx.exists(), "Old fp16 transformer should be deleted" 107 | assert new_transformer_onnx.exists(), "New fp8 transformer should exist" 108 | 109 | # Verify shared models are preserved 110 | shared_text_encoder = path_manager.get_onnx_path("flux_t5_text_encoder", "fp16") 111 | shared_vae = path_manager.get_onnx_path("flux_vae_decoder", "fp16") 112 | assert shared_text_encoder.exists(), "Shared text encoder should be preserved" 113 | assert shared_vae.exists(), "Shared VAE should be preserved" 114 | 115 | def test_lean_mode_single_active_pipeline(self, temp_cache_dir: Path): 116 | """Test that lean mode only keeps the currently active pipeline, cleaning up models from previous pipelines.""" 117 | path_manager = PathManager(cache_dir=str(temp_cache_dir), cache_mode="lean") 118 | 119 | # First pipeline configuration 120 | pipeline1_config = { 121 | "text_encoder": ModelConfig("flux_t5_text_encoder", "fp16", "dynamic"), 122 | "transformer": ModelConfig("flux_transformer", "fp16", "dynamic"), 123 | "vae": ModelConfig("flux_vae_decoder", "fp16", "dynamic"), 124 | } 125 | 126 | path_manager.set_pipeline_models("flux_1_dev", pipeline1_config) 127 | 128 | # Create files 129 | for _, (model_id, precision, _) in pipeline1_config.items(): 130 | onnx_path = path_manager.get_onnx_path(model_id, precision) 131 | onnx_path.touch() 132 | 133 | # Switch to another pipeline that uses some shared models 134 | pipeline2_config = { 135 | "text_encoder": ModelConfig("flux_t5_text_encoder", "fp16", "dynamic"), # Shared with pipeline1 136 | "transformer": ModelConfig("flux_transformer", "fp16", "dynamic"), # Shared with pipeline1 137 | "vae": ModelConfig("sdxl_vae", "fp16", "dynamic"), # Different VAE 138 | } 139 | 140 | path_manager.set_pipeline_models("another_pipeline", pipeline2_config) 141 | 142 | # Create new VAE file 143 | sdxl_vae_path = path_manager.get_onnx_path("sdxl_vae", "fp16") 144 | sdxl_vae_path.touch() 145 | 146 | # Verify pipeline1 models that are shared are preserved, but flux_vae_decoder is cleaned up 147 | text_encoder_path = path_manager.get_onnx_path("flux_t5_text_encoder", "fp16") 148 | transformer_path = path_manager.get_onnx_path("flux_transformer", "fp16") 149 | old_vae_path = path_manager.get_onnx_path("flux_vae_decoder", "fp16") 150 | new_vae_path = path_manager.get_onnx_path("sdxl_vae", "fp16") 151 | 152 | assert text_encoder_path.exists(), "Shared text encoder should be preserved" 153 | assert transformer_path.exists(), "Shared transformer should be preserved" 154 | assert not old_vae_path.exists(), "Old VAE not used by current pipeline should be cleaned up" 155 | assert new_vae_path.exists(), "New VAE should exist" 156 | 157 | # Verify only the current pipeline is tracked 158 | assert "flux_1_dev" not in path_manager.pipeline_states, "Previous pipeline should not be tracked" 159 | assert "another_pipeline" in path_manager.pipeline_states, "Current pipeline should be tracked" 160 | 161 | # Now change the current pipeline to use completely different models 162 | completely_different_config = { 163 | "text_encoder": ModelConfig("different_text_encoder", "fp8", "static"), 164 | "transformer": ModelConfig("different_transformer", "fp8", "static"), 165 | "vae": ModelConfig("different_vae", "fp8", "static"), 166 | } 167 | 168 | # Create new model files 169 | for _, (model_id, precision, _) in completely_different_config.items(): 170 | onnx_path = path_manager.get_onnx_path(model_id, precision) 171 | onnx_path.touch() 172 | 173 | path_manager.set_pipeline_models("another_pipeline", completely_different_config) 174 | 175 | # All previous models should be cleaned up since none are shared with the new config 176 | assert not text_encoder_path.exists(), "Previous text encoder should be cleaned up" 177 | assert not transformer_path.exists(), "Previous transformer should be cleaned up" 178 | assert not new_vae_path.exists(), "Previous VAE should be cleaned up" 179 | 180 | # New models should exist 181 | new_text_encoder = path_manager.get_onnx_path("different_text_encoder", "fp8") 182 | new_transformer = path_manager.get_onnx_path("different_transformer", "fp8") 183 | new_vae = path_manager.get_onnx_path("different_vae", "fp8") 184 | 185 | assert new_text_encoder.exists(), "New text encoder should exist" 186 | assert new_transformer.exists(), "New transformer should exist" 187 | assert new_vae.exists(), "New VAE should exist" 188 | 189 | def test_full_mode_keeps_all_models(self, temp_cache_dir: Path): 190 | """Test that full mode never deletes models regardless of precision changes.""" 191 | path_manager = PathManager(cache_dir=str(temp_cache_dir), cache_mode="full") 192 | 193 | # Initial configuration 194 | initial_config = { 195 | "transformer": ("test_transformer", "fp16"), 196 | } 197 | 198 | path_manager.set_pipeline_models("test_pipeline", initial_config) 199 | 200 | # Create initial file 201 | fp16_path = path_manager.get_onnx_path("test_transformer", "fp16") 202 | fp16_path.touch() 203 | 204 | # Change precision 205 | changed_config = { 206 | "transformer": ("test_transformer", "fp8"), 207 | } 208 | 209 | path_manager.set_pipeline_models("test_pipeline", changed_config) 210 | 211 | # Create new precision file 212 | fp8_path = path_manager.get_onnx_path("test_transformer", "fp8") 213 | fp8_path.touch() 214 | 215 | # In full mode, both should exist 216 | assert fp16_path.exists(), "Full mode should keep fp16 model" 217 | assert fp8_path.exists(), "Full mode should have fp8 model" 218 | -------------------------------------------------------------------------------- /samples/helloWorld/python/hello_world.py: -------------------------------------------------------------------------------- 1 | # 2 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | # SPDX-License-Identifier: Apache-2.0 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | 18 | import argparse 19 | 20 | import numpy as np 21 | import tensorrt_rtx as trt 22 | from cuda.bindings import runtime as cudart 23 | 24 | 25 | def cuda_assert(call: tuple) -> object: 26 | res = None 27 | err = call[0] 28 | if len(call) > 1: 29 | res = call[1] 30 | if err != cudart.cudaError_t.cudaSuccess: 31 | raise RuntimeError(f"CUDA error: {err}") 32 | return res 33 | 34 | 35 | # These sizes are arbitrary. 36 | k_input_size = 3 37 | k_hidden_size = 10 38 | k_output_size = 2 39 | k_bytes_per_float = 4 40 | 41 | # Set --onnx=/path/to/helloWorld.onnx to parse the provided ONNX model. 42 | k_onnx_model_path = "" 43 | 44 | logger = trt.Logger(trt.Logger.VERBOSE) 45 | 46 | 47 | # Create a simple fully connected network with one input, one hidden layer, and one output. 48 | def create_network(builder: trt.Builder, fc1_weights: trt.Weights, fc2_weights: trt.Weights) -> trt.INetworkDefinition: 49 | # Specify network creation options. 50 | # Note: TensorRT-RTX only supports strongly typed networks, explicitly specify this to avoid warning. 51 | flags = 1 << int(trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED) 52 | 53 | # Create an empty network graph. 54 | network = builder.create_network(flags) 55 | 56 | # Add the network input. 57 | input = network.add_input(name="input", dtype=trt.float32, shape=(1, k_input_size)) 58 | 59 | # Create constant layers containing weights for fc1/fc2. 60 | fc1_weights_layer = network.add_constant(trt.Dims2(k_input_size, k_hidden_size), fc1_weights) 61 | fc1_weights_layer.name = "fully connected layer 1 weights" 62 | 63 | fc2_weights_layer = network.add_constant(trt.Dims2(k_hidden_size, k_output_size), fc2_weights) 64 | fc2_weights_layer.name = "fully connected layer 2 weights" 65 | 66 | # Add a fully connected layer, fc1. 67 | fc1 = network.add_matrix_multiply( 68 | input, trt.MatrixOperation.NONE, fc1_weights_layer.get_output(0), trt.MatrixOperation.NONE 69 | ) 70 | fc1.name = "fully connected layer 1" 71 | 72 | # Add a relu layer. 73 | relu = network.add_activation(fc1.get_output(0), type=trt.ActivationType.RELU) 74 | relu.name = "relu activation" 75 | 76 | # Add a fully connected layer, fc2. 77 | fc2 = network.add_matrix_multiply( 78 | relu.get_output(0), trt.MatrixOperation.NONE, fc2_weights_layer.get_output(0), trt.MatrixOperation.NONE 79 | ) 80 | fc2.name = "fully connected layer 2" 81 | 82 | # Mark the network output tensor. 83 | fc2.get_output(0).name = "output" 84 | network.mark_output(fc2.get_output(0)) 85 | 86 | return network 87 | 88 | 89 | # Create a network by parsing the included "helloWorld.onnx" model. 90 | # The ONNX model contains the same layers and weights as the custom network. 91 | def create_network_from_onnx(builder: trt.Builder, onnx_file_path: str) -> trt.INetworkDefinition: 92 | print("Parsing ONNX file: ", onnx_file_path) 93 | 94 | # Specify network creation options. 95 | # Note: TensorRT-RTX only supports strongly typed networks, explicitly specify this to avoid warning. 96 | flags = 1 << int(trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED) 97 | 98 | # Create an empty network graph. 99 | network = builder.create_network(flags) 100 | 101 | # Parse the network from the ONNX model. 102 | parser = trt.OnnxParser(network, logger) 103 | if not parser: 104 | raise RuntimeError("Failed to create parser!") 105 | if not parser.parse_from_file(onnx_file_path): 106 | raise RuntimeError("Failed to parse ONNX file!") 107 | 108 | # Check input and output dimensions to ensure that the selected model is what we expect. 109 | input = network.get_input(0) 110 | if input.shape != (1, k_input_size): 111 | raise ValueError(f"Invalid ONNX input dimension, expected [1, {k_input_size}]!") 112 | output = network.get_output(0) 113 | if output.shape != (1, k_output_size): 114 | raise ValueError(f"Invalid ONNX output dimension, expected [1, {k_output_size}]!") 115 | 116 | return network 117 | 118 | 119 | # Build the serialized engine. 120 | # In TensorRT-RTX, we often refer to this stage as "Ahead-of-Time" (AOT) 121 | # compilation. This stage tends to be slower than the "Just-in-Time" (JIT) 122 | # compilation stage. For this reason, you should perform this operation at 123 | # installation time or first run, and then save the resulting engine. 124 | # 125 | # You may choose to build the engine once and then deploy it to end-users; 126 | # it is OS-independent and by default supports Ampere and later GPUs. But 127 | # be aware that the engine does not guarantee forward compatibility, so 128 | # you must build a new engine for each new TensorRT-RTX version. 129 | def create_serialized_engine() -> trt.IHostMemory: 130 | # The weights in this example are initialized to 1.0f, but typically would 131 | # be loaded from a file or other source. 132 | # The data backing IConstantLayers must remain valid until the engine has 133 | # been built; therefore we create weights_data here. 134 | fc1_weights_data = np.ones(k_input_size * k_hidden_size, dtype=np.float32) 135 | fc2_weights_data = np.ones(k_hidden_size * k_output_size, dtype=np.float32) 136 | 137 | # Create a builder object. 138 | builder = trt.Builder(logger) 139 | if not builder: 140 | raise RuntimeError("Failed to create builder!") 141 | 142 | # Create a builder configuration to specify optional settings. 143 | builder_config = builder.create_builder_config() 144 | if not builder_config: 145 | raise RuntimeError("Failed to create builder configuration!") 146 | 147 | # Create a simple fully connected network. 148 | if k_onnx_model_path: 149 | network = create_network_from_onnx(builder, k_onnx_model_path) 150 | else: 151 | fc1_weights = trt.Weights(fc1_weights_data) 152 | fc2_weights = trt.Weights(fc2_weights_data) 153 | network = create_network(builder, fc1_weights, fc2_weights) 154 | 155 | # Perform AOT optimizations on the network graph and generate an engine. 156 | serialized_engine = builder.build_serialized_network(network, builder_config) 157 | 158 | return serialized_engine 159 | 160 | 161 | # Create an engine execution context out of the serialized engine, then perform inference. 162 | def run_inference(serialized_engine: trt.IHostMemory) -> None: 163 | runtime = trt.Runtime(logger) 164 | if not runtime: 165 | raise RuntimeError("Failed to create runtime!") 166 | 167 | # Deserialize the engine. 168 | inference_engine = runtime.deserialize_cuda_engine(serialized_engine) 169 | if not inference_engine: 170 | raise RuntimeError("Failed to deserialize engine!") 171 | 172 | # Optional settings to configure the behavior of the inference runtime. 173 | runtime_config = inference_engine.create_runtime_config() 174 | if not runtime_config: 175 | raise RuntimeError("Failed to create runtime config!") 176 | 177 | # Create an engine execution context out of the deserialized engine. 178 | # TRT-RTX performs "Just-in-Time" (JIT) optimization here, targeting the current GPU. 179 | # JIT phase is faster than AOT phase, and typically completes in under 15 seconds. 180 | context = inference_engine.create_execution_context(runtime_config) 181 | if not context: 182 | raise RuntimeError("Failed to create execution context!") 183 | 184 | # Create a stream for asynchronous execution. 185 | stream = cuda_assert(cudart.cudaStreamCreate()) 186 | 187 | # Allocate GPU memory for input and output bindings. 188 | input_binding = cuda_assert(cudart.cudaMallocAsync(k_input_size * k_bytes_per_float, stream)) 189 | output_binding = cuda_assert(cudart.cudaMallocAsync(k_output_size * k_bytes_per_float, stream)) 190 | 191 | input_buffer = np.zeros(k_input_size, dtype=np.float32) 192 | output_buffer = np.zeros(k_output_size, dtype=np.float32) 193 | 194 | # Specify the tensor addresses. 195 | context.set_tensor_address("input", input_binding) 196 | context.set_tensor_address("output", output_binding) 197 | 198 | try: 199 | for i in range(5): 200 | input_buffer.fill(i) 201 | 202 | # Copy input data into the GPU input buffer and execute inference. 203 | cuda_assert( 204 | cudart.cudaMemcpyAsync( 205 | input_binding, 206 | input_buffer.ctypes.data, 207 | len(input_buffer) * k_bytes_per_float, 208 | cudart.cudaMemcpyKind.cudaMemcpyHostToDevice, 209 | stream, 210 | ) 211 | ) 212 | 213 | status = context.execute_async_v3(stream_handle=stream) 214 | if not status: 215 | raise RuntimeError("Failed to execute inference!") 216 | 217 | # Read the results back from GPU output buffer. 218 | cuda_assert( 219 | cudart.cudaMemcpyAsync( 220 | output_buffer.ctypes.data, 221 | output_binding, 222 | len(output_buffer) * k_bytes_per_float, 223 | cudart.cudaMemcpyKind.cudaMemcpyDeviceToHost, 224 | stream, 225 | ) 226 | ) 227 | cuda_assert(cudart.cudaStreamSynchronize(stream)) 228 | print("Input: ", input_buffer) 229 | print("Output: ", output_buffer) 230 | 231 | finally: 232 | cuda_assert(cudart.cudaFreeAsync(input_binding, stream)) 233 | cuda_assert(cudart.cudaFreeAsync(output_binding, stream)) 234 | cuda_assert(cudart.cudaStreamSynchronize(stream)) 235 | cuda_assert(cudart.cudaStreamDestroy(stream)) 236 | 237 | print("Successfully ran the network.") 238 | 239 | 240 | def main() -> None: 241 | serialized_engine = create_serialized_engine() 242 | if not serialized_engine: 243 | raise RuntimeError("Failed to build serialized engine!") 244 | print(f"Successfully built the network. Engine size: {serialized_engine.nbytes} bytes.") 245 | 246 | run_inference(serialized_engine) 247 | 248 | 249 | if __name__ == "__main__": 250 | # Set --onnx=/path/to/helloWorld.onnx to parse the provided ONNX model. 251 | parser = argparse.ArgumentParser(description="TensorRT-RTX hello world sample") 252 | parser.add_argument("--onnx", type=str, help="Path to ONNX model file") 253 | args = parser.parse_args() 254 | k_onnx_model_path = args.onnx 255 | main() 256 | -------------------------------------------------------------------------------- /demo/tests/test_license_headers.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import re 17 | from datetime import datetime 18 | from pathlib import Path 19 | 20 | 21 | class TestLicenseHeaders: 22 | """Test suite to verify that all relevant files have the required NVIDIA license header.""" 23 | 24 | # The expected license text (without comment markers) 25 | COPYRIGHT_YEAR_PREFIX = "SPDX-FileCopyrightText: Copyright (c) " 26 | COPYRIGHT_YEAR_SUFFIX = " NVIDIA CORPORATION & AFFILIATES. All rights reserved." 27 | EXPECTED_LICENSE_LINES = [ 28 | f"{COPYRIGHT_YEAR_PREFIX}2025{COPYRIGHT_YEAR_SUFFIX}", 29 | "SPDX-License-Identifier: Apache-2.0", 30 | "", 31 | 'Licensed under the Apache License, Version 2.0 (the "License");', 32 | "you may not use this file except in compliance with the License.", 33 | "You may obtain a copy of the License at", 34 | "", 35 | "http://www.apache.org/licenses/LICENSE-2.0", 36 | "", 37 | "Unless required by applicable law or agreed to in writing, software", 38 | 'distributed under the License is distributed on an "AS IS" BASIS,', 39 | "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.", 40 | "See the License for the specific language governing permissions and", 41 | "limitations under the License.", 42 | ] 43 | 44 | @classmethod 45 | def get_project_root(cls): 46 | """Get the project root directory.""" 47 | # Go up from demo/tests to the project root 48 | current_dir = Path(__file__).parent.parent.parent 49 | return current_dir 50 | 51 | @classmethod 52 | def find_files_by_pattern(cls, root_path, patterns): 53 | """Find all files matching the given patterns, excluding build and temporary directories within the repo.""" 54 | # Directories to exclude from license header checks (only within the repository) 55 | exclude_dirs = { 56 | "build", 57 | "site-packages", 58 | "activate_this.py", 59 | } 60 | 61 | files = [] 62 | for pattern in patterns: 63 | for file_path in root_path.rglob(pattern): 64 | # Get the relative path from root_path to check only dirs within the repo 65 | try: 66 | relative_path = file_path.relative_to(root_path) 67 | # Check if any directory in the relative path is in the exclude list 68 | if not any(part in exclude_dirs for part in relative_path.parts): 69 | files.append(file_path) 70 | except ValueError: 71 | # If file_path is not relative to root_path, skip it 72 | continue 73 | return files 74 | 75 | @classmethod 76 | def extract_license_with_octothorpe_comments(cls, file_path, skip_shebang): 77 | """Extract license header from Python or CMake file (lines starting with #).""" 78 | license_lines = [] 79 | 80 | try: 81 | with open(file_path, encoding="utf-8") as f: 82 | for i, line in enumerate(f): 83 | stripped = line.strip() 84 | 85 | # Skip the shebang line at the beginning, if present (e.g. Python) 86 | if skip_shebang and i == 0 and stripped.startswith("#!"): 87 | continue 88 | 89 | if stripped.startswith("#"): 90 | license_lines.append(stripped.removeprefix("#").strip()) 91 | else: 92 | # First non-comment line - stop processing 93 | break 94 | except UnicodeDecodeError: 95 | # Skip binary files 96 | return [] 97 | 98 | return cls.normalize_license_lines(license_lines) 99 | 100 | @classmethod 101 | def normalize_license_lines(cls, lines): 102 | """Normalize license lines by removing empty lines at start/end and extra whitespace.""" 103 | # Remove empty lines from the beginning and end 104 | while lines and lines[0] == "": 105 | lines.pop(0) 106 | while lines and lines[-1] == "": 107 | lines.pop() 108 | 109 | return lines 110 | 111 | @classmethod 112 | def extract_license_from_cpp_file(cls, file_path): 113 | """Extract license header from C++ file (block comment /* ... */).""" 114 | license_lines = [] 115 | 116 | try: 117 | with open(file_path, encoding="utf-8") as f: 118 | for line in f: 119 | content = line.strip() 120 | 121 | if content == "/*": 122 | # Start of block comment 123 | continue 124 | elif content == "*/": 125 | # End of block comment 126 | break 127 | elif content.startswith("*"): 128 | # License content line 129 | license_lines.append(content[1:].strip()) 130 | else: 131 | # First non-comment line - stop processing 132 | break 133 | except UnicodeDecodeError: 134 | # Skip binary files 135 | return [] 136 | 137 | return cls.normalize_license_lines(license_lines) 138 | 139 | @classmethod 140 | def validate_copyright_year(cls, copyright_line): 141 | """Validate that copyright line has exact structure with current year""" 142 | current_year = datetime.now().year 143 | 144 | # Allowed formats: 145 | # "SPDX-FileCopyrightText: Copyright (c) YYYY NVIDIA CORPORATION & AFFILIATES. All rights reserved." 146 | # or 147 | # "SPDX-FileCopyrightText: Copyright (c) YYYY-YYYY NVIDIA CORPORATION & AFFILIATES. All rights reserved." 148 | 149 | if not copyright_line.startswith(cls.COPYRIGHT_YEAR_PREFIX): 150 | return False 151 | if not copyright_line.endswith(cls.COPYRIGHT_YEAR_SUFFIX): 152 | return False 153 | 154 | # Extract the year part between prefix and suffix 155 | year_part = copyright_line[len(cls.COPYRIGHT_YEAR_PREFIX) : -len(cls.COPYRIGHT_YEAR_SUFFIX)] 156 | 157 | # Validate year format and current year requirement 158 | if "-" in year_part: 159 | # Range like "2023-2025" 160 | if not re.match(r"^\d{4}-\d{4}$", year_part): 161 | return False 162 | start_year, end_year = year_part.split("-") 163 | return int(start_year) < current_year == int(end_year) 164 | else: 165 | # Single year like "2025" 166 | if not re.match(r"^\d{4}$", year_part): 167 | return False 168 | return int(year_part) == current_year 169 | 170 | @classmethod 171 | def is_license_compatible(cls, extracted_lines, expected_lines): 172 | """Check if the extracted license is compatible with expected format.""" 173 | if len(extracted_lines) != len(expected_lines): 174 | return False 175 | 176 | for i, (extracted, expected) in enumerate(zip(extracted_lines, expected_lines)): 177 | if i == 0 and not cls.validate_copyright_year(extracted): 178 | # First line - validate copyright with exact structure and flexible years 179 | return False 180 | elif i != 0 and extracted != expected: 181 | # All other lines must match exactly 182 | return False 183 | 184 | return True 185 | 186 | @classmethod 187 | def check_license_header(cls, file_path, file_type): 188 | """Check if file has the correct license header.""" 189 | if file_type == "Python": 190 | extracted_lines = cls.extract_license_with_octothorpe_comments(file_path, skip_shebang=True) 191 | elif file_type == "C++": 192 | extracted_lines = cls.extract_license_from_cpp_file(file_path) 193 | elif file_type == "CMake": 194 | extracted_lines = cls.extract_license_with_octothorpe_comments(file_path, skip_shebang=False) 195 | else: 196 | raise ValueError(f"Unknown file type: {file_type}") 197 | 198 | return ( 199 | cls.is_license_compatible(extracted_lines, cls.EXPECTED_LICENSE_LINES), 200 | extracted_lines, 201 | ) 202 | 203 | @classmethod 204 | def test_all_files_have_license_header(self): 205 | """Test that all relevant files have the required license header.""" 206 | root_path = self.get_project_root() 207 | 208 | # Define file types with their patterns and type names 209 | file_checks = [ 210 | (["*.py"], "Python"), 211 | (["CMakeLists.txt", "CMakeLists*.txt"], "CMake"), 212 | (["*.cpp", "*.c", "*.h", "*.hpp"], "C++"), 213 | ] 214 | 215 | all_missing = [] 216 | all_incorrect = [] 217 | 218 | for patterns, file_type in file_checks: 219 | files = self.find_files_by_pattern(root_path, patterns) 220 | 221 | for file_path in files: 222 | # Skip empty __init__.py files 223 | if file_type == "Python" and file_path.name == "__init__.py": 224 | try: 225 | with open(file_path, encoding="utf-8") as f: 226 | content = f.read().strip() 227 | if not content: 228 | continue 229 | except Exception: 230 | pass # If any issues are encountered, continue with normal processing 231 | 232 | has_correct_license, extracted = self.check_license_header(file_path, file_type) 233 | if not extracted: 234 | all_missing.append((file_path, file_type)) 235 | elif not has_correct_license: 236 | all_incorrect.append((file_path, file_type, extracted)) 237 | 238 | # Create consolidated error message 239 | error_msg = [] 240 | if all_missing: 241 | error_msg.append(f"Files missing license header ({len(all_missing)}):") 242 | for file_path, file_type in all_missing: 243 | error_msg.append(f" - {file_path} ({file_type})") 244 | 245 | if all_incorrect: 246 | error_msg.append(f"\nFiles with incorrect license header ({len(all_incorrect)}):") 247 | for file_path, file_type, extracted in all_incorrect: 248 | error_msg.append(f" - {file_path} ({file_type})") 249 | error_msg.append(f" Found: {extracted[:3]}...") 250 | 251 | assert not all_missing and not all_incorrect, "\n".join(error_msg) 252 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright 2025 NVIDIA CORPORATION & AFFILIATES 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /demo/utils/engine_metadata.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | """ 18 | Engine Metadata Management 19 | 20 | Tracks engine compilation parameters to determine when recompilation is needed. 21 | """ 22 | 23 | import hashlib 24 | import json 25 | import logging 26 | import time 27 | from dataclasses import asdict, dataclass 28 | from pathlib import Path 29 | from typing import Any, Optional 30 | 31 | # Initialize logger for this module 32 | logger = logging.getLogger("rtx_demo.utils.engine_metadata") 33 | 34 | 35 | @dataclass 36 | class EngineMetadata: 37 | """Metadata for a compiled TensorRT engine""" 38 | 39 | model_name: str 40 | precision: str 41 | onnx_path: str 42 | onnx_hash: str 43 | input_shapes: dict[str, Any] 44 | extra_args: set[str] 45 | build_timestamp: float 46 | 47 | def to_dict(self) -> dict[str, Any]: 48 | """Convert to dictionary for JSON serialization""" 49 | return asdict(self) 50 | 51 | @classmethod 52 | def from_dict(cls, data: dict[str, Any]) -> "EngineMetadata": 53 | """Create from dictionary""" 54 | return cls(**data) 55 | 56 | def is_compatible_with(self, new_shapes: dict[str, Any], new_extra_args: Optional[set[str]] = None) -> bool: 57 | """Check if engine is compatible with new shapes and extra args""" 58 | new_extra_args = set() if new_extra_args is not None else set() 59 | if self.extra_args != new_extra_args: 60 | return False 61 | return self._shapes_fit_profile(new_shapes) 62 | 63 | def _shapes_fit_profile(self, new_shapes: dict[str, Any]) -> bool: 64 | """Check if new shapes fit within dynamic profile""" 65 | for input_name, new_shape in new_shapes.items(): 66 | if input_name not in self.input_shapes: 67 | return False 68 | 69 | profile = self.input_shapes[input_name] 70 | 71 | # Check if this is a dynamic profile (list/tuple of 3 tuples) 72 | is_dynamic_profile = ( 73 | isinstance(profile, (list, tuple)) 74 | and len(profile) == 3 75 | and all(isinstance(shape, (list, tuple)) for shape in profile) 76 | ) 77 | 78 | if is_dynamic_profile: 79 | min_shape, opt_shape, max_shape = profile 80 | 81 | # Check if it's effectively static (all shapes are the same) 82 | if min_shape == opt_shape == max_shape: 83 | if tuple(new_shape) != tuple(min_shape): 84 | return False 85 | else: 86 | # True dynamic profile - check if new shape fits within range 87 | if len(new_shape) != len(min_shape) or len(new_shape) != len(max_shape): 88 | return False 89 | 90 | for new_dim, min_dim, max_dim in zip(new_shape, min_shape, max_shape): 91 | if new_dim < min_dim or new_dim > max_dim: 92 | return False 93 | else: 94 | # Static profile: profile is a single tuple 95 | if tuple(new_shape) != tuple(profile): 96 | return False 97 | 98 | return True 99 | 100 | 101 | class EngineMetadataManager: 102 | """Manages engine metadata files""" 103 | 104 | def _get_metadata_path(self, engine_path: Path) -> Path: 105 | """Get metadata file path for an engine""" 106 | return engine_path.with_suffix(".metadata.json") 107 | 108 | def _compute_onnx_hash(self, onnx_path: str) -> str: 109 | """Compute hash of ONNX file""" 110 | try: 111 | with open(onnx_path, "rb") as f: 112 | return hashlib.sha256(f.read()).hexdigest()[:16] 113 | except Exception: 114 | return "unknown" 115 | 116 | def save_metadata( 117 | self, 118 | engine_path: Path, 119 | model_name: str, 120 | precision: str, 121 | onnx_path: str, 122 | input_shapes: dict, 123 | extra_args: Optional[set[str]] = None, 124 | ) -> None: 125 | """Save engine metadata to file.""" 126 | metadata = { 127 | "model_name": model_name, 128 | "precision": precision, 129 | "onnx_path": str(onnx_path), 130 | "input_shapes": input_shapes, 131 | "extra_args": list(extra_args) if extra_args else [], 132 | "build_timestamp": time.time(), 133 | "tensorrt_version": self._get_tensorrt_version(), 134 | } 135 | 136 | metadata_path = engine_path.with_suffix(".metadata.json") 137 | 138 | try: 139 | with open(metadata_path, "w") as f: 140 | json.dump(metadata, f, indent=2) 141 | 142 | logger.debug(f"Saved engine metadata: {metadata_path}") 143 | 144 | except Exception as e: 145 | logger.warning(f"Failed to save metadata for {engine_path}: {e}, unnecessary recompilations may occur.") 146 | 147 | def load_metadata(self, metadata_path: Path) -> Optional[dict]: 148 | """Load engine metadata from file.""" 149 | try: 150 | with open(metadata_path) as f: 151 | return json.load(f) 152 | except Exception as e: 153 | logger.warning(f"Failed to load metadata from {metadata_path}: {e}") 154 | return None 155 | 156 | def check_engine_compatibility( 157 | self, 158 | engine_path: Path, 159 | target_shapes: dict, 160 | static_shape: bool, 161 | extra_args: Optional[set[str]] = None, 162 | ) -> tuple[bool, str]: 163 | """Check if an existing engine is compatible with current requirements.""" 164 | if not engine_path.exists(): 165 | return False, "No cached engine found" 166 | 167 | metadata_path = engine_path.with_suffix(".metadata.json") 168 | if not metadata_path.exists(): 169 | return False, "No cached engine metadata found" 170 | 171 | try: 172 | metadata = self.load_metadata(metadata_path) 173 | if not metadata: 174 | return False, "Engine metadata corrupted or unreadable" 175 | 176 | # Check TensorRT version compatibility 177 | saved_trt_version = metadata.get("tensorrt_version", "unknown") 178 | current_trt_version = self._get_tensorrt_version() 179 | if saved_trt_version != current_trt_version: 180 | return ( 181 | False, 182 | f"TensorRT version changed: engine built with {saved_trt_version}, current system has {current_trt_version}", 183 | ) 184 | 185 | # Determine shape mode from engine filename 186 | is_static_engine = "_static." in str(engine_path) 187 | if is_static_engine != static_shape: 188 | return ( 189 | False, 190 | f"Shape mode mismatch: cached engine is static:{is_static_engine}, user requested static:{static_shape}", 191 | ) 192 | 193 | # Check shape compatibility based on engine type 194 | saved_shapes = metadata.get("input_shapes", {}) 195 | 196 | if static_shape: 197 | # For static shapes, shapes must match exactly 198 | # Normalize both saved and target shapes to handle list/tuple format differences 199 | normalized_saved = self._normalize_shapes_for_comparison(saved_shapes) 200 | normalized_target = self._normalize_shapes_for_comparison(target_shapes) 201 | 202 | if normalized_saved != normalized_target: 203 | return ( 204 | False, 205 | f"Shape profile incompatible: cached engine expects {saved_shapes}, user requested {target_shapes}", 206 | ) 207 | else: 208 | # For dynamic shapes, check that target shapes fall within saved ranges 209 | for input_name, target_shape in target_shapes.items(): 210 | if input_name not in saved_shapes: 211 | return False, f"Missing input '{input_name}' in saved shapes" 212 | 213 | saved_shape = saved_shapes[input_name] 214 | 215 | # Saved shape should be (min_shape, opt_shape, max_shape) for dynamic 216 | if not isinstance(saved_shape, (list, tuple)) or len(saved_shape) != 3: 217 | return ( 218 | False, 219 | f"Invalid saved dynamic shape format for '{input_name}': {saved_shape}", 220 | ) 221 | 222 | min_shape, opt_shape, max_shape = saved_shape 223 | 224 | # Target shape should be a single shape that fits within min/max bounds 225 | if ( 226 | isinstance(target_shape, (list, tuple)) 227 | and isinstance(target_shape[0], (list, tuple)) 228 | and len(target_shape) == 3 229 | ): 230 | # Target is also min/opt/max format - check compatibility 231 | target_min, target_opt, target_max = target_shape 232 | if target_min != min_shape or target_opt != opt_shape or target_max != max_shape: 233 | return False, f"Dynamic shape range mismatch for '{input_name}'" 234 | else: 235 | # Target is single shape - check if it fits within bounds 236 | target_shape = tuple(target_shape) if isinstance(target_shape, list) else target_shape 237 | min_shape = tuple(min_shape) if isinstance(min_shape, list) else min_shape 238 | max_shape = tuple(max_shape) if isinstance(max_shape, list) else max_shape 239 | 240 | if len(target_shape) != len(min_shape): 241 | return False, f"Shape dimension mismatch for '{input_name}'" 242 | 243 | for i, (target_dim, min_dim, max_dim) in enumerate(zip(target_shape, min_shape, max_shape)): 244 | if target_dim < min_dim or target_dim > max_dim: 245 | return ( 246 | False, 247 | f"Target shape dimension {i} ({target_dim}) outside bounds [{min_dim}, {max_dim}] for '{input_name}'", 248 | ) 249 | 250 | # Check extra args compatibility 251 | saved_args = set(metadata.get("extra_args", [])) 252 | target_args = extra_args or set() 253 | if saved_args != target_args: 254 | return ( 255 | False, 256 | f"Build configuration changed: cached engine used {saved_args or 'default settings'}, user requested {target_args or 'default settings'}", 257 | ) 258 | 259 | return True, "Compatible" 260 | 261 | except Exception as e: 262 | return False, f"Error checking compatibility: {e}" 263 | 264 | def cleanup_metadata(self, engine_path: Path) -> None: 265 | """Remove metadata file""" 266 | metadata_path = self._get_metadata_path(engine_path) 267 | if metadata_path.exists(): 268 | metadata_path.unlink() 269 | logger.debug(f"Removed metadata: {metadata_path}") 270 | 271 | def _get_tensorrt_version(self) -> str: 272 | """Get TensorRT version string""" 273 | try: 274 | import tensorrt_rtx as trt 275 | 276 | return trt.__version__ 277 | except Exception: 278 | return "unknown" 279 | 280 | def _normalize_shapes_for_comparison(self, shapes: dict) -> dict: 281 | """Normalize shapes for comparison by converting lists to tuples recursively""" 282 | normalized = {} 283 | for input_name, shape in shapes.items(): 284 | if isinstance(shape, list): 285 | # Convert list to tuple, handling nested lists as well 286 | normalized[input_name] = ( 287 | tuple(tuple(s) if isinstance(s, list) else s for s in shape) 288 | if isinstance(shape[0], (list, tuple)) 289 | else tuple(shape) 290 | ) 291 | elif isinstance(shape, tuple): 292 | # Ensure nested tuples are also normalized 293 | normalized[input_name] = ( 294 | tuple(tuple(s) if isinstance(s, list) else s for s in shape) 295 | if len(shape) > 0 and isinstance(shape[0], (list, tuple)) 296 | else shape 297 | ) 298 | else: 299 | normalized[input_name] = shape 300 | return normalized 301 | 302 | 303 | # Global metadata manager instance 304 | metadata_manager = EngineMetadataManager() 305 | -------------------------------------------------------------------------------- /demo/tests/test_onnx_acquisition.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """ 17 | Test ONNX acquisition functionality using pytest. 18 | 19 | This test demonstrates how the PathManager can acquire ONNX files from: 20 | 1. Local file paths (for development) 21 | 2. Remote URLs (for production) with proper temporary directory handling 22 | """ 23 | 24 | import os 25 | import tempfile 26 | from pathlib import Path 27 | from unittest.mock import patch 28 | 29 | import pytest 30 | from utils.path_manager import PathManager 31 | 32 | 33 | @pytest.mark.integration 34 | @pytest.mark.paths 35 | class TestONNXAcquisition: 36 | """Test ONNX file acquisition functionality.""" 37 | 38 | def test_local_file_acquisition_copies_all_files(self, path_manager: PathManager, temp_source_dir: Path): 39 | """Test that local file acquisition copies ALL files from the directory.""" 40 | # Create main ONNX file 41 | test_onnx = temp_source_dir / "test_model_fp16.onnx" 42 | test_onnx.write_text("fake onnx content") 43 | 44 | # Create multiple additional files to test that ALL files are copied 45 | additional_files = [ 46 | "test_model_fp16.onnx.data", # Standard external data 47 | "another_model.onnx", # Different model 48 | "config.json", # Configuration file 49 | "weights.bin", # Weights file 50 | "metadata.txt", # Metadata file 51 | "readme.md", # Documentation 52 | ] 53 | 54 | for additional_file in additional_files: 55 | file_path = temp_source_dir / additional_file 56 | file_path.write_text(f"fake content for {additional_file}") 57 | 58 | # Test local file acquisition 59 | success = path_manager.acquire_onnx_file("test_model", "fp16", str(test_onnx), "", "", None) 60 | 61 | assert success, "Local file acquisition should succeed" 62 | 63 | canonical_onnx = path_manager.get_onnx_path("test_model", "fp16") 64 | assert canonical_onnx.exists(), "Main ONNX file should exist" 65 | 66 | # Check ALL files from the directory were copied 67 | shared_files = list(canonical_onnx.parent.iterdir()) 68 | shared_file_names = [f.name for f in shared_files] 69 | 70 | # Should have the main ONNX file with canonical name plus all additional files 71 | expected_files = ["test_model.onnx"] + additional_files # Canonical name, not original 72 | for expected_file in expected_files: 73 | assert expected_file in shared_file_names, f"Expected file {expected_file} should be copied" 74 | 75 | assert len(shared_files) == len(expected_files), ( 76 | f"Expected {len(expected_files)} files, got {len(shared_files)}" 77 | ) 78 | 79 | @patch("utils.path_manager.snapshot_download") 80 | def test_remote_download_with_nested_structure(self, mock_snapshot_download, path_manager: PathManager): 81 | """Test remote download handles nested directory structure correctly.""" 82 | 83 | def mock_download_side_effect(repo_id, allow_patterns, local_dir, token=None): 84 | """Mock snapshot_download to create nested structure like HuggingFace does.""" 85 | temp_dir = Path(local_dir) 86 | 87 | # Create the nested structure that HuggingFace would create 88 | nested_dir = temp_dir / "models" / "onnx" / "test_model" / "fp16" 89 | nested_dir.mkdir(parents=True, exist_ok=True) 90 | 91 | # Create mock ONNX files in the nested directory 92 | (nested_dir / "test_model.onnx").write_text("fake onnx content from remote") 93 | (nested_dir / "test_model.onnx.data").write_text("fake onnx data from remote") 94 | (nested_dir / "config.json").write_text('{"model": "test"}') 95 | 96 | mock_snapshot_download.side_effect = mock_download_side_effect 97 | 98 | # Test remote acquisition with non-existent local path (triggers remote download) 99 | fake_remote_path = "nonexistent/path/to/model.onnx" 100 | 101 | success = path_manager.acquire_onnx_file( 102 | "test_model", 103 | "fp16", 104 | fake_remote_path, 105 | "test_pipeline", 106 | "models/onnx/test_model/fp16", 107 | None, 108 | ) 109 | 110 | assert success, "Remote download should succeed" 111 | 112 | # Verify snapshot_download was called with correct parameters 113 | mock_snapshot_download.assert_called_once_with( 114 | repo_id="test_pipeline", 115 | allow_patterns=os.path.join("models/onnx/test_model/fp16", "*"), 116 | local_dir=mock_snapshot_download.call_args[1]["local_dir"], # temp directory 117 | token=None, 118 | ) 119 | 120 | # Check that files were moved to the correct final location 121 | canonical_onnx = path_manager.get_onnx_path("test_model", "fp16") 122 | assert canonical_onnx.exists(), "Main ONNX file should exist after download" 123 | 124 | # Check all files were moved correctly 125 | target_dir = canonical_onnx.parent 126 | expected_files = ["test_model.onnx", "test_model.onnx.data", "config.json"] 127 | 128 | for expected_file in expected_files: 129 | file_path = target_dir / expected_file 130 | assert file_path.exists(), f"File {expected_file} should exist after download" 131 | 132 | # Verify content was preserved 133 | assert canonical_onnx.read_text() == "fake onnx content from remote" 134 | 135 | @patch("utils.path_manager.snapshot_download") 136 | def test_remote_download_with_token(self, mock_snapshot_download, path_manager: PathManager): 137 | """Test remote download passes HuggingFace token correctly.""" 138 | 139 | def mock_download_side_effect(repo_id, allow_patterns, local_dir, token=None): 140 | temp_dir = Path(local_dir) 141 | nested_dir = temp_dir / "models" / "onnx" / "private_model" / "fp16" 142 | nested_dir.mkdir(parents=True, exist_ok=True) 143 | (nested_dir / "private_model.onnx").write_text("private model content") 144 | 145 | mock_snapshot_download.side_effect = mock_download_side_effect 146 | 147 | # Test with HF token 148 | test_token = "hf_test_token_12345" 149 | success = path_manager.acquire_onnx_file( 150 | "private_model", 151 | "fp16", 152 | "nonexistent/path", 153 | "private_pipeline", 154 | "models/onnx/private_model/fp16", 155 | test_token, 156 | ) 157 | 158 | assert success, "Remote download with token should succeed" 159 | 160 | # Verify token was passed correctly 161 | mock_snapshot_download.assert_called_once() 162 | call_args = mock_snapshot_download.call_args 163 | assert call_args[1]["token"] == test_token, "HF token should be passed to snapshot_download" 164 | 165 | @patch("utils.path_manager.snapshot_download") 166 | def test_remote_download_error_handling(self, mock_snapshot_download, path_manager: PathManager): 167 | """Test remote download handles errors gracefully.""" 168 | 169 | # Mock snapshot_download to raise an exception 170 | mock_snapshot_download.side_effect = Exception("Network error") 171 | 172 | success = path_manager.acquire_onnx_file( 173 | "error_model", 174 | "fp16", 175 | "nonexistent/path", 176 | "error_pipeline", 177 | "models/onnx/error_model/fp16", 178 | None, 179 | ) 180 | 181 | assert not success, "Remote download should fail gracefully on error" 182 | 183 | # Verify the target ONNX file doesn't exist after error 184 | canonical_onnx = path_manager.get_onnx_path("error_model", "fp16") 185 | assert not canonical_onnx.exists(), "ONNX file should not exist after failed download" 186 | 187 | @patch("utils.path_manager.snapshot_download") 188 | def test_remote_download_temporary_directory_cleanup(self, mock_snapshot_download, path_manager: PathManager): 189 | """Test that temporary directories are properly cleaned up.""" 190 | created_temp_dirs = [] 191 | original_tempdir = tempfile.TemporaryDirectory 192 | 193 | def track_temp_dirs(*args, **kwargs): 194 | temp_dir_obj = original_tempdir(*args, **kwargs) 195 | created_temp_dirs.append(Path(temp_dir_obj.name)) 196 | return temp_dir_obj 197 | 198 | with patch("tempfile.TemporaryDirectory", side_effect=track_temp_dirs): 199 | # Mock download that succeeds 200 | def mock_download_side_effect(repo_id, allow_patterns, local_dir, token=None): 201 | temp_dir = Path(local_dir) 202 | nested_dir = temp_dir / "models" / "onnx" / "cleanup_model" / "fp16" 203 | nested_dir.mkdir(parents=True, exist_ok=True) 204 | (nested_dir / "cleanup_model.onnx").write_text("test content") 205 | 206 | mock_snapshot_download.side_effect = mock_download_side_effect 207 | 208 | # Test remote download with mocked response 209 | success = path_manager.acquire_onnx_file( 210 | "cleanup_model", 211 | "fp16", 212 | "nonexistent/path", 213 | "cleanup_pipeline", 214 | "models/onnx/cleanup_model/fp16", 215 | None, 216 | ) 217 | 218 | assert success, "Download should succeed" 219 | 220 | # After successful completion, temp directories should be cleaned up 221 | for temp_dir in created_temp_dirs: 222 | assert not temp_dir.exists(), f"Temporary directory {temp_dir} should be cleaned up" 223 | 224 | def test_local_path_with_nonexistent_source_fails_gracefully(self, path_manager: PathManager): 225 | """Test that local path acquisition fails gracefully when source doesn't exist.""" 226 | nonexistent_path = "/nonexistent/path/to/model.onnx" 227 | 228 | success = path_manager.acquire_onnx_file("nonexistent_model", "fp16", nonexistent_path, "", "", None) 229 | 230 | assert not success, "Acquisition should fail when source doesn't exist" 231 | 232 | # Verify no files were created 233 | canonical_onnx = path_manager.get_onnx_path("nonexistent_model", "fp16") 234 | assert not canonical_onnx.exists(), "ONNX file should not exist after failed acquisition" 235 | 236 | def test_local_path_with_subdirectories_copies_all_files(self, path_manager: PathManager, temp_source_dir: Path): 237 | """Test that local path acquisition copies files even with complex directory structure.""" 238 | # Create a complex source directory structure 239 | model_dir = temp_source_dir / "complex_model" 240 | model_dir.mkdir() 241 | 242 | # Create subdirectory with additional files 243 | sub_dir = model_dir / "weights" 244 | sub_dir.mkdir() 245 | 246 | # Create main ONNX file 247 | main_onnx = model_dir / "complex_model_fp16.onnx" 248 | main_onnx.write_text("main onnx content") 249 | 250 | # Create files in main directory 251 | (model_dir / "complex_model_fp16.onnx.data").write_text("onnx data") 252 | (model_dir / "config.json").write_text('{"model": "complex"}') 253 | (model_dir / "tokenizer.json").write_text('{"vocab_size": 1000}') 254 | 255 | # Create files in subdirectory (should NOT be copied since we only copy from parent dir) 256 | (sub_dir / "weights.bin").write_text("weights content") 257 | (sub_dir / "optimizer.bin").write_text("optimizer content") 258 | 259 | # Test acquisition 260 | success = path_manager.acquire_onnx_file("complex_model", "fp16", str(main_onnx), "", "", None) 261 | 262 | assert success, "Complex directory acquisition should succeed" 263 | 264 | canonical_onnx = path_manager.get_onnx_path("complex_model", "fp16") 265 | assert canonical_onnx.exists(), "Main ONNX file should exist" 266 | 267 | # Check that files from the main directory were copied 268 | target_dir = canonical_onnx.parent 269 | copied_files = [f.name for f in target_dir.iterdir()] 270 | 271 | # Should have main ONNX file (renamed) and other files from same directory 272 | expected_files = [ 273 | "complex_model.onnx", # Renamed from complex_model_fp16.onnx 274 | "complex_model_fp16.onnx.data", 275 | "config.json", 276 | "tokenizer.json", 277 | ] 278 | 279 | for expected_file in expected_files: 280 | assert expected_file in copied_files, f"Expected file {expected_file} should be copied" 281 | 282 | # Should NOT have subdirectory files (only copies from parent directory) 283 | assert "weights.bin" not in copied_files, "Subdirectory files should not be copied" 284 | assert "optimizer.bin" not in copied_files, "Subdirectory files should not be copied" 285 | 286 | # Verify content preservation 287 | assert canonical_onnx.read_text() == "main onnx content" 288 | -------------------------------------------------------------------------------- /demo/flux1.dev/models/flux_model.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | import logging 18 | from typing import Any, Optional 19 | 20 | from diffusers import AutoencoderKL, FluxTransformer2DModel 21 | from diffusers.configuration_utils import FrozenDict 22 | from models.flux_params import FluxParams 23 | from transformers import AutoConfig 24 | from utils.base_model import BaseModel 25 | 26 | # Initialize logger for this module 27 | logger = logging.getLogger("rtx_demo.flux1.dev.models.flux_model") 28 | 29 | 30 | class FluxTransformerModel(BaseModel): 31 | """Flux Transformer model for text-to-image generation""" 32 | 33 | def __init__( 34 | self, 35 | name: str, 36 | device: str = "cuda", 37 | model_params: Optional[FluxParams] = None, 38 | model_id: str = "black-forest-labs/FLUX.1-dev", 39 | hf_token: Optional[str] = None, 40 | ): 41 | super().__init__(name, device, model_params, hf_token) 42 | self.model_id = model_id 43 | 44 | # Load configuration from HuggingFace 45 | logger.debug(f"Loading Flux Transformer config from {model_id}/transformer") 46 | self.config = FrozenDict( 47 | FluxTransformer2DModel.load_config( 48 | model_id, 49 | subfolder="transformer", 50 | token=self.hf_token, 51 | ) 52 | ) 53 | 54 | def get_input_profile( 55 | self, use_static_shape: bool, batch_size: int = 1, height: int = 512, width: int = 512 56 | ) -> dict[str, Any]: 57 | """Return TensorRT input profile for dynamic shapes""" 58 | latent_height = height // self.model_params.VAE_SPATIAL_COMPRESSION_RATIO 59 | latent_width = width // self.model_params.VAE_SPATIAL_COMPRESSION_RATIO 60 | opt_latent_dim = (latent_height // 2) * (latent_width // 2) 61 | 62 | # Build static input profile 63 | if use_static_shape: 64 | input_profile = { 65 | "hidden_states": (batch_size, opt_latent_dim, self.config.in_channels), 66 | "encoder_hidden_states": ( 67 | batch_size, 68 | self.model_params.T5_SEQUENCE_LENGTH, 69 | self.config.joint_attention_dim, 70 | ), 71 | "pooled_projections": (batch_size, self.config.pooled_projection_dim), 72 | "timestep": (batch_size,), 73 | "img_ids": (opt_latent_dim, 3), 74 | "txt_ids": (self.model_params.T5_SEQUENCE_LENGTH, 3), 75 | "guidance": (batch_size,), 76 | } 77 | else: 78 | min_latent_height = self.model_params.MIN_HEIGHT // self.model_params.VAE_SPATIAL_COMPRESSION_RATIO 79 | min_latent_width = self.model_params.MIN_WIDTH // self.model_params.VAE_SPATIAL_COMPRESSION_RATIO 80 | max_latent_height = self.model_params.MAX_HEIGHT // self.model_params.VAE_SPATIAL_COMPRESSION_RATIO 81 | max_latent_width = self.model_params.MAX_WIDTH // self.model_params.VAE_SPATIAL_COMPRESSION_RATIO 82 | min_latent_dim = (min_latent_height // 2) * (min_latent_width // 2) 83 | max_latent_dim = (max_latent_height // 2) * (max_latent_width // 2) 84 | 85 | input_profile = { 86 | "hidden_states": [ 87 | ( 88 | self.model_params.MIN_BATCH_SIZE, 89 | min_latent_dim, 90 | self.config.in_channels, 91 | ), # min 92 | (batch_size, opt_latent_dim, self.config.in_channels), # opt 93 | ( 94 | self.model_params.MAX_BATCH_SIZE, 95 | max_latent_dim, 96 | self.config.in_channels, 97 | ), # max 98 | ], 99 | "encoder_hidden_states": [ 100 | ( 101 | self.model_params.MIN_BATCH_SIZE, 102 | self.model_params.T5_SEQUENCE_LENGTH, 103 | self.config.joint_attention_dim, 104 | ), 105 | ( 106 | batch_size, 107 | self.model_params.T5_SEQUENCE_LENGTH, 108 | self.config.joint_attention_dim, 109 | ), 110 | ( 111 | self.model_params.MAX_BATCH_SIZE, 112 | self.model_params.T5_SEQUENCE_LENGTH, 113 | self.config.joint_attention_dim, 114 | ), 115 | ], 116 | "pooled_projections": [ 117 | (self.model_params.MIN_BATCH_SIZE, self.config.pooled_projection_dim), 118 | (batch_size, self.config.pooled_projection_dim), 119 | (self.model_params.MAX_BATCH_SIZE, self.config.pooled_projection_dim), 120 | ], 121 | "timestep": [ 122 | (self.model_params.MIN_BATCH_SIZE,), 123 | (batch_size,), 124 | (self.model_params.MAX_BATCH_SIZE,), 125 | ], 126 | "img_ids": [ 127 | (min_latent_dim, 3), 128 | (opt_latent_dim, 3), 129 | (max_latent_dim, 3), 130 | ], 131 | "txt_ids": [ 132 | (self.model_params.T5_SEQUENCE_LENGTH, 3), 133 | (self.model_params.T5_SEQUENCE_LENGTH, 3), 134 | (self.model_params.T5_SEQUENCE_LENGTH, 3), 135 | ], 136 | "guidance": [ 137 | (self.model_params.MIN_BATCH_SIZE,), 138 | (batch_size,), 139 | (self.model_params.MAX_BATCH_SIZE,), 140 | ], 141 | } 142 | 143 | return input_profile 144 | 145 | def get_shape_dict(self, batch_size: int = 1, height: int = 512, width: int = 512) -> dict[str, Any]: 146 | """Return shape dictionary for tensor allocation""" 147 | latent_height = height // self.model_params.VAE_SPATIAL_COMPRESSION_RATIO 148 | latent_width = width // self.model_params.VAE_SPATIAL_COMPRESSION_RATIO 149 | latent_dim = (latent_height // 2) * (latent_width // 2) 150 | 151 | return { 152 | "hidden_states": (batch_size, latent_dim, self.config.in_channels), 153 | "encoder_hidden_states": ( 154 | batch_size, 155 | self.model_params.T5_SEQUENCE_LENGTH, 156 | self.config.joint_attention_dim, 157 | ), 158 | "pooled_projections": (batch_size, self.config.pooled_projection_dim), 159 | "timestep": (batch_size,), 160 | "img_ids": (latent_dim, 3), 161 | "txt_ids": (self.model_params.T5_SEQUENCE_LENGTH, 3), 162 | "guidance": (batch_size,), 163 | "latent": ( 164 | batch_size, 165 | latent_dim, 166 | self.config.in_channels, 167 | ), # Use in_channels for output too 168 | } 169 | 170 | 171 | class FluxTextEncoderModel(BaseModel): 172 | """Flux CLIP Text Encoder model""" 173 | 174 | def __init__( 175 | self, 176 | name: str, 177 | device: str = "cuda", 178 | model_params: Optional[FluxParams] = None, 179 | model_id: str = "black-forest-labs/FLUX.1-dev", 180 | hf_token: Optional[str] = None, 181 | ): 182 | super().__init__(name, device, model_params, hf_token) 183 | self.model_id = model_id 184 | 185 | # Load configuration from HuggingFace 186 | logger.debug(f"Loading CLIP Text Encoder config from {model_id}") 187 | self.config = AutoConfig.from_pretrained(model_id, subfolder="text_encoder", token=self.hf_token) 188 | 189 | def get_input_profile(self, use_static_shape: bool, batch_size: int = 1, **kwargs) -> dict[str, Any]: 190 | """Return TensorRT input profile for dynamic shapes""" 191 | if use_static_shape: 192 | return { 193 | "input_ids": (batch_size, self.model_params.CLIP_SEQUENCE_LENGTH), 194 | } 195 | return { 196 | "input_ids": [ 197 | (self.model_params.MIN_BATCH_SIZE, self.model_params.CLIP_SEQUENCE_LENGTH), 198 | (batch_size, self.model_params.CLIP_SEQUENCE_LENGTH), 199 | (self.model_params.MAX_BATCH_SIZE, self.model_params.CLIP_SEQUENCE_LENGTH), 200 | ] 201 | } 202 | 203 | def get_shape_dict(self, batch_size: int = 1, **kwargs) -> dict[str, Any]: 204 | """Return shape dictionary for tensor allocation""" 205 | # Handle both direct config and nested text_config structures 206 | hidden_size = getattr(self.config, "hidden_size", None) 207 | if hidden_size is None and hasattr(self.config, "text_config"): 208 | hidden_size = self.config.text_config.hidden_size 209 | projection_dim = getattr(self.config, "projection_dim", hidden_size) 210 | 211 | return { 212 | "input_ids": (batch_size, self.model_params.CLIP_SEQUENCE_LENGTH), 213 | "text_embeddings": (batch_size, self.model_params.CLIP_SEQUENCE_LENGTH, hidden_size), 214 | "pooled_embeddings": (batch_size, projection_dim), 215 | } 216 | 217 | 218 | class FluxT5EncoderModel(BaseModel): 219 | """Flux T5 Text Encoder model""" 220 | 221 | def __init__( 222 | self, 223 | name: str, 224 | device: str = "cuda", 225 | model_params: Optional[FluxParams] = None, 226 | model_id: str = "black-forest-labs/FLUX.1-dev", 227 | hf_token: Optional[str] = None, 228 | ): 229 | super().__init__(name, device, model_params, hf_token) 230 | self.model_id = model_id 231 | 232 | # Load configuration from HuggingFace 233 | logger.debug(f"Loading T5 Text Encoder config from {model_id}") 234 | self.config = AutoConfig.from_pretrained(model_id, subfolder="text_encoder_2", token=self.hf_token) 235 | 236 | def get_input_profile(self, use_static_shape: bool, batch_size: int = 1) -> dict[str, Any]: 237 | """Return TensorRT input profile for dynamic shapes""" 238 | if use_static_shape: 239 | return { 240 | "input_ids": (batch_size, self.model_params.T5_SEQUENCE_LENGTH), 241 | } 242 | return { 243 | "input_ids": [ 244 | (self.model_params.MIN_BATCH_SIZE, self.model_params.T5_SEQUENCE_LENGTH), 245 | (batch_size, self.model_params.T5_SEQUENCE_LENGTH), 246 | (self.model_params.MAX_BATCH_SIZE, self.model_params.T5_SEQUENCE_LENGTH), 247 | ] 248 | } 249 | 250 | def get_shape_dict(self, batch_size: int = 1) -> dict[str, Any]: 251 | """Return shape dictionary for tensor allocation""" 252 | # T5 uses 'd_model' for hidden size 253 | hidden_size = getattr(self.config, "d_model", 4096) 254 | 255 | return { 256 | "input_ids": (batch_size, self.model_params.T5_SEQUENCE_LENGTH), 257 | "text_embeddings": (batch_size, self.model_params.T5_SEQUENCE_LENGTH, hidden_size), 258 | } 259 | 260 | 261 | class FluxVAEModel(BaseModel): 262 | """Flux VAE Decoder model""" 263 | 264 | def __init__( 265 | self, 266 | name: str, 267 | device: str = "cuda", 268 | model_params: Optional[FluxParams] = None, 269 | model_id: str = "black-forest-labs/FLUX.1-dev", 270 | hf_token: Optional[str] = None, 271 | ): 272 | super().__init__(name, device, model_params, hf_token) 273 | self.model_id = model_id 274 | 275 | # Load configuration from HuggingFace 276 | logger.debug(f"Loading VAE config from {model_id}/vae") 277 | self.config = FrozenDict(AutoencoderKL.load_config(model_id, subfolder="vae", token=self.hf_token)) 278 | 279 | def get_input_profile( 280 | self, use_static_shape: bool, batch_size: int = 1, height: int = 512, width: int = 512 281 | ) -> dict[str, Any]: 282 | """Return TensorRT input profile for dynamic shapes""" 283 | latent_height = height // self.model_params.VAE_SPATIAL_COMPRESSION_RATIO 284 | latent_width = width // self.model_params.VAE_SPATIAL_COMPRESSION_RATIO 285 | 286 | min_latent_height = self.model_params.MIN_HEIGHT // self.model_params.VAE_SPATIAL_COMPRESSION_RATIO 287 | min_latent_width = self.model_params.MIN_WIDTH // self.model_params.VAE_SPATIAL_COMPRESSION_RATIO 288 | max_latent_height = self.model_params.MAX_HEIGHT // self.model_params.VAE_SPATIAL_COMPRESSION_RATIO 289 | max_latent_width = self.model_params.MAX_WIDTH // self.model_params.VAE_SPATIAL_COMPRESSION_RATIO 290 | 291 | if use_static_shape: 292 | return { 293 | "latent": (batch_size, self.config.latent_channels, latent_height, latent_width), 294 | } 295 | 296 | return { 297 | "latent": [ 298 | ( 299 | self.model_params.MIN_BATCH_SIZE, 300 | self.config.latent_channels, 301 | min_latent_height, 302 | min_latent_width, 303 | ), 304 | (batch_size, self.config.latent_channels, latent_height, latent_width), 305 | ( 306 | self.model_params.MAX_BATCH_SIZE, 307 | self.config.latent_channels, 308 | max_latent_height, 309 | max_latent_width, 310 | ), 311 | ] 312 | } 313 | 314 | def get_shape_dict(self, batch_size: int = 1, height: int = 512, width: int = 512) -> dict[str, Any]: 315 | """Return shape dictionary for tensor allocation""" 316 | latent_height = height // self.model_params.VAE_SPATIAL_COMPRESSION_RATIO 317 | latent_width = width // self.model_params.VAE_SPATIAL_COMPRESSION_RATIO 318 | 319 | return { 320 | "latent": (batch_size, self.config.latent_channels, latent_height, latent_width), 321 | "images": (batch_size, self.config.out_channels, height, width), 322 | } 323 | -------------------------------------------------------------------------------- /demo/utils/model_registry.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | """ 18 | Model Registry with shared model definitions. 19 | 20 | Contains ONLY model definitions, configurations, and metadata. 21 | Path management is handled by PathManager. 22 | """ 23 | 24 | import logging 25 | from typing import Any, Optional 26 | 27 | import tensorrt_rtx as trt 28 | import torch 29 | 30 | # Initialize logger for this module 31 | logger = logging.getLogger("rtx_demo.utils.model_registry") 32 | 33 | # Shared model definitions - each model defined once with all its variants 34 | MODELS: dict[str, dict[str, dict[str, Any]]] = { 35 | "flux_clip_text_encoder": { 36 | "bf16": { 37 | "onnx_repository": "black-forest-labs/FLUX.1-dev-onnx", 38 | "onnx_subfolder": "clip.opt", 39 | "input_shapes": {"input_ids": ("B", 77)}, 40 | "input_dtypes": {"input_ids": trt.DataType.INT32}, 41 | "output_shapes": { 42 | "text_embeddings": ("B", 77, 768), 43 | "pooled_embeddings": ("B", 768), 44 | }, 45 | }, 46 | }, 47 | "flux_t5_text_encoder": { 48 | "bf16": { 49 | "onnx_repository": "black-forest-labs/FLUX.1-dev-onnx", 50 | "onnx_subfolder": "t5.opt", 51 | "input_shapes": {"input_ids": ("B", 512)}, 52 | "input_dtypes": {"input_ids": trt.DataType.INT32}, 53 | "output_shapes": { 54 | "text_embeddings": ("B", 512, 4096), 55 | }, 56 | }, 57 | "fp8": { 58 | "onnx_repository": "black-forest-labs/FLUX.1-dev-onnx", 59 | "onnx_subfolder": "t5-fp8.opt", 60 | "input_shapes": {"input_ids": ("B", 512)}, 61 | "input_dtypes": {"input_ids": trt.DataType.INT32}, 62 | "output_shapes": { 63 | "text_embeddings": ("B", 512, 4096), 64 | }, 65 | }, 66 | }, 67 | "flux_transformer": { 68 | "bf16": { 69 | "onnx_repository": "black-forest-labs/FLUX.1-dev-onnx", 70 | "onnx_subfolder": "transformer.opt/bf16", 71 | "input_shapes": { 72 | "hidden_states": ("B", "latent_dim", 64), 73 | "encoder_hidden_states": ("B", 512, 4096), 74 | "pooled_projections": ("B", 768), 75 | "timestep": ("B",), 76 | "img_ids": ("latent_dim", 3), 77 | "txt_ids": (512, 3), 78 | "guidance": ("B",), 79 | }, 80 | "input_dtypes": { 81 | "hidden_states": trt.DataType.BF16, 82 | "encoder_hidden_states": trt.DataType.BF16, 83 | "pooled_projections": trt.DataType.BF16, 84 | "timestep": trt.DataType.BF16, 85 | "img_ids": trt.DataType.FLOAT, 86 | "txt_ids": trt.DataType.FLOAT, 87 | "guidance": trt.DataType.FLOAT, 88 | }, 89 | "output_shapes": { 90 | "latent": ("B", "latent_dim", 64), 91 | }, 92 | }, 93 | "fp8": { 94 | "onnx_repository": "black-forest-labs/FLUX.1-dev-onnx", 95 | "onnx_subfolder": "transformer.opt/fp8", 96 | "input_shapes": { 97 | "hidden_states": ("B", "latent_dim", 64), 98 | "encoder_hidden_states": ("B", 512, 4096), 99 | "pooled_projections": ("B", 768), 100 | "timestep": ("B",), 101 | "img_ids": ("latent_dim", 3), 102 | "txt_ids": (512, 3), 103 | "guidance": ("B",), 104 | }, 105 | "input_dtypes": { 106 | "hidden_states": trt.DataType.BF16, 107 | "encoder_hidden_states": trt.DataType.BF16, 108 | "pooled_projections": trt.DataType.BF16, 109 | "timestep": trt.DataType.BF16, 110 | "img_ids": trt.DataType.FLOAT, 111 | "txt_ids": trt.DataType.FLOAT, 112 | "guidance": trt.DataType.FLOAT, 113 | }, 114 | "output_shapes": { 115 | "latent": ("B", "latent_dim", 64), 116 | }, 117 | }, 118 | "fp4": { 119 | "onnx_repository": "black-forest-labs/FLUX.1-dev-onnx", 120 | "onnx_subfolder": "transformer.opt/fp4", 121 | "input_shapes": { 122 | "hidden_states": ("B", "latent_dim", 64), 123 | "encoder_hidden_states": ("B", 512, 4096), 124 | "pooled_projections": ("B", 768), 125 | "timestep": ("B",), 126 | "img_ids": ("latent_dim", 3), 127 | "txt_ids": (512, 3), 128 | "guidance": ("B",), 129 | }, 130 | "input_dtypes": { 131 | "hidden_states": trt.DataType.BF16, 132 | "encoder_hidden_states": trt.DataType.BF16, 133 | "pooled_projections": trt.DataType.BF16, 134 | "timestep": trt.DataType.BF16, 135 | "img_ids": trt.DataType.FLOAT, 136 | "txt_ids": trt.DataType.FLOAT, 137 | "guidance": trt.DataType.FLOAT, 138 | }, 139 | "output_shapes": { 140 | "latent": ("B", "latent_dim", 64), 141 | }, 142 | }, 143 | }, 144 | "flux_vae_decoder": { 145 | "bf16": { 146 | "onnx_repository": "black-forest-labs/FLUX.1-dev-onnx", 147 | "onnx_subfolder": "vae.opt", 148 | "input_shapes": {"latent": ("B", 16, "latent_height", "latent_width")}, 149 | "input_dtypes": {"latent": trt.DataType.BF16}, 150 | "output_shapes": { 151 | "images": ("B", 3, "height", "width"), 152 | }, 153 | }, 154 | }, 155 | } 156 | 157 | # Pipeline compositions - map pipeline roles to actual models 158 | PIPELINES: dict[str, dict[str, str]] = { 159 | "flux_1_dev": { 160 | "clip_text_encoder": "flux_clip_text_encoder", 161 | "t5_text_encoder": "flux_t5_text_encoder", 162 | "transformer": "flux_transformer", 163 | "vae_decoder": "flux_vae_decoder", 164 | }, 165 | } 166 | 167 | # Default precisions per pipeline 168 | DEFAULT_PRECISIONS: dict[str, dict[str, str]] = { 169 | # Flux default precisions 170 | "flux_1_dev": { 171 | "clip_text_encoder": "bf16", 172 | "t5_text_encoder": "bf16", 173 | "transformer": "fp8", 174 | "vae_decoder": "bf16", 175 | }, 176 | } 177 | 178 | 179 | # Short form precision mapping 180 | SHORT_FORM_PRECISIONS: dict[str, torch.dtype] = { 181 | "fp16": torch.float16, 182 | "fp8": torch.float8_e4m3fn, 183 | "bf16": torch.bfloat16, 184 | "fp32": torch.float32, 185 | "int32": torch.int32, 186 | "int64": torch.int64, 187 | "bool": torch.bool, 188 | } 189 | 190 | 191 | class ModelRegistry: 192 | """Registry with model definitions and configurations""" 193 | 194 | def __init__(self): 195 | self.models = MODELS 196 | self.pipelines = PIPELINES 197 | self.defaults = DEFAULT_PRECISIONS 198 | self.short_form_precisions = SHORT_FORM_PRECISIONS 199 | self._validate() 200 | 201 | def _validate(self): 202 | """Validate registry consistency""" 203 | # Validate input shapes/dtypes keys match 204 | for model_id, precisions in self.models.items(): 205 | for precision, config in precisions.items(): 206 | shape_keys = set(config.get("input_shapes", {}).keys()) 207 | dtype_keys = set(config.get("input_dtypes", {}).keys()) 208 | if shape_keys != dtype_keys: 209 | raise ValueError(f"{model_id}[{precision}]: shape/dtype key mismatch") 210 | 211 | # Validate pipeline references and precision consistency 212 | for pipeline, roles in self.pipelines.items(): 213 | if pipeline not in self.defaults: 214 | raise ValueError(f"Pipeline {pipeline} missing in DEFAULT_PRECISIONS") 215 | if set(roles.keys()) != set(self.defaults[pipeline].keys()): 216 | raise ValueError(f"Pipeline {pipeline}: role mismatch with defaults") 217 | for role, model_id in roles.items(): 218 | if model_id not in self.models: 219 | raise ValueError(f"Model {model_id} not found in MODELS") 220 | precision = self.defaults[pipeline][role] 221 | if precision not in self.short_form_precisions: 222 | raise ValueError(f"Invalid precision {precision}") 223 | if precision not in self.models[model_id]: 224 | raise ValueError(f"Precision {precision} not available for {model_id}") 225 | 226 | def get_torch_dtype(self, precision: str) -> torch.dtype: 227 | """Get the torch.dtype for a precision""" 228 | return self.short_form_precisions.get(precision) 229 | 230 | def get_model_id(self, pipeline_name: str, role: str) -> Optional[str]: 231 | """Get the actual model ID for a pipeline role""" 232 | return self.pipelines.get(pipeline_name, {}).get(role) 233 | 234 | def is_model_id(self, model_id: str) -> bool: 235 | """Check if a model ID is valid""" 236 | return model_id in self.models 237 | 238 | def get_model_config(self, pipeline_name: str, role: str, precision: str) -> Optional[dict[str, Any]]: 239 | """Get configuration for a model in a pipeline role""" 240 | model_id = self.get_model_id(pipeline_name, role) 241 | if not model_id: 242 | return None 243 | 244 | return self.models.get(model_id, {}).get(precision) 245 | 246 | def get_available_precisions(self, pipeline_name: str, role: str) -> list: 247 | """Get available precisions for a model role""" 248 | model_id = self.get_model_id(pipeline_name, role) 249 | if not model_id: 250 | return [] 251 | 252 | return list(self.models.get(model_id, {}).keys()) 253 | 254 | def get_pipeline_roles(self, pipeline_name: str) -> list: 255 | """Get all roles in a pipeline""" 256 | return list(self.pipelines.get(pipeline_name, {}).keys()) 257 | 258 | def get_pipeline_config(self, pipeline_name: str) -> dict[str, str]: 259 | """Get all configs in a pipeline""" 260 | return self.pipelines.get(pipeline_name, {}) 261 | 262 | def get_default_precision(self, pipeline_name: str, role: str) -> str: 263 | """Get default precision for a role""" 264 | return self.defaults.get(pipeline_name, {}).get(role, "fp16") 265 | 266 | def get_default_precisions(self, pipeline_name: str) -> dict[str, str]: 267 | """Get default precisions for a pipeline""" 268 | return self.defaults.get(pipeline_name, {}) 269 | 270 | def get_onnx_path(self, pipeline_name: str, role: str, precision: str) -> Optional[str]: 271 | """ 272 | Get ONNX source (local path) for a model role, if present. 273 | 274 | Note: When copying from local paths, ALL files in the source directory are copied. 275 | """ 276 | config = self.get_model_config(pipeline_name, role, precision) 277 | 278 | if not config: 279 | return None 280 | 281 | return config.get("onnx_path") 282 | 283 | def get_onnx_repository(self, pipeline_name: str, role: str, precision: str) -> Optional[str]: 284 | """Get ONNX repository for a model role.""" 285 | config = self.get_model_config(pipeline_name, role, precision) 286 | if not config: 287 | return None 288 | return config.get("onnx_repository") 289 | 290 | def get_onnx_subfolder(self, pipeline_name: str, role: str, precision: str) -> Optional[str]: 291 | """Get ONNX subfolder for a model role.""" 292 | config = self.get_model_config(pipeline_name, role, precision) 293 | if not config: 294 | return None 295 | return config.get("onnx_subfolder") 296 | 297 | def validate_precision_config(self, pipeline_name: str, precision_config: dict[str, str]) -> dict[str, str]: 298 | """Validate and fill in missing precisions with defaults""" 299 | validated_config = {} 300 | pipeline_roles = self.get_pipeline_roles(pipeline_name) 301 | 302 | for role in pipeline_roles: 303 | if role in precision_config: 304 | available_precisions = self.get_available_precisions(pipeline_name, role) 305 | if precision_config[role] in available_precisions: 306 | validated_config[role] = precision_config[role] 307 | else: 308 | logger.warning(f"Precision '{precision_config[role]}' not available for {role}, using default") 309 | validated_config[role] = self.get_default_precision(pipeline_name, role) 310 | else: 311 | validated_config[role] = self.get_default_precision(pipeline_name, role) 312 | 313 | return validated_config 314 | 315 | def print_available_models(self, pipeline_name: str): 316 | """Print available models and precisions for a pipeline""" 317 | logger.info(f"\nAvailable models for {pipeline_name}:") 318 | roles = self.get_pipeline_roles(pipeline_name) 319 | 320 | for role in roles: 321 | model_id = self.get_model_id(pipeline_name, role) 322 | precisions = self.get_available_precisions(pipeline_name, role) 323 | default = self.get_default_precision(pipeline_name, role) 324 | logger.info(f" {role} ({model_id}): {precisions} (default: {default})") 325 | 326 | def print_sharing_info(self): 327 | """Print model sharing information""" 328 | logger.info("\nModel Sharing Analysis:") 329 | 330 | # Find which models are shared across pipelines 331 | model_usage = {} 332 | for pipeline_name, roles in self.pipelines.items(): 333 | for role, model_id in roles.items(): 334 | if model_id not in model_usage: 335 | model_usage[model_id] = [] 336 | model_usage[model_id].append(f"{pipeline_name}:{role}") 337 | 338 | shared_models = {k: v for k, v in model_usage.items() if len(v) > 1} 339 | unique_models = {k: v for k, v in model_usage.items() if len(v) == 1} 340 | 341 | if shared_models: 342 | logger.info("Shared models (cached once, linked across pipelines):") 343 | for model_id, usage in shared_models.items(): 344 | logger.info(f" {model_id}: {', '.join(usage)}") 345 | 346 | if unique_models: 347 | logger.info("Pipeline-specific models:") 348 | for model_id, usage in unique_models.items(): 349 | logger.info(f" {model_id}: {usage[0]}") 350 | 351 | def get_io_names(self, model_id: str, precision: str) -> tuple[list, list]: 352 | """Get input and output names for a model and precision""" 353 | if model_id not in self.models or precision not in self.models[model_id]: 354 | return [], [] 355 | 356 | config = self.models[model_id][precision] 357 | input_names = list(config.get("input_shapes", {}).keys()) 358 | output_names = list(config.get("output_shapes", {}).keys()) 359 | return input_names, output_names 360 | 361 | 362 | # Global registry instance 363 | registry = ModelRegistry() 364 | --------------------------------------------------------------------------------