├── .circleci └── config.yml ├── .gitattributes ├── .github └── ISSUE_TEMPLATE │ ├── bug_report.md │ └── feature_request.md ├── .gitignore ├── CODE_OF_CONDUCT.md ├── Cargo.toml ├── Dockerfile ├── README.md ├── assets ├── efficientnet.onnx └── images │ └── meme.jpg ├── tensorrt-sys ├── Cargo.toml ├── README.md ├── build.rs ├── src │ └── lib.rs └── trt-sys │ ├── CMakeLists.txt │ ├── TRTBuilder │ ├── TRTBuilder.cpp │ └── TRTBuilder.h │ ├── TRTContext │ ├── TRTContext.cpp │ └── TRTContext.h │ ├── TRTCudaEngine │ ├── TRTCudaEngine.cpp │ └── TRTCudaEngine.h │ ├── TRTDims │ ├── TRTDims.cpp │ └── TRTDims.h │ ├── TRTEnums.h │ ├── TRTHostMemory │ ├── TRTHostMemory.cpp │ └── TRTHostMemory.h │ ├── TRTLayer │ ├── TRTActivationLayer.cpp │ ├── TRTActivationLayer.h │ ├── TRTElementWiseLayer.cpp │ ├── TRTElementWiseLayer.h │ ├── TRTGatherLayer.cpp │ ├── TRTGatherLayer.h │ ├── TRTLayer.cpp │ ├── TRTLayer.h │ ├── TRTPoolingLayer.cpp │ └── TRTPoolingLayer.h │ ├── TRTLogger │ ├── TRTLogger.cpp │ ├── TRTLogger.h │ └── TRTLoggerInternal.hpp │ ├── TRTNetworkDefinition │ ├── TRTNetworkDefinition.cpp │ └── TRTNetworkDefinition.h │ ├── TRTOnnxParser │ ├── TRTOnnxParser.cpp │ └── TRTOnnxParser.h │ ├── TRTProfiler │ ├── TRTProfiler.cpp │ ├── TRTProfiler.h │ └── TRTProfilerInternal.hpp │ ├── TRTRuntime │ ├── TRTRuntime.cpp │ └── TRTRuntime.h │ ├── TRTTensor │ ├── TRTTensor.cpp │ └── TRTTensor.h │ ├── TRTUffParser │ ├── TRTUffParser.cpp │ └── TRTUffParser.h │ ├── TRTUtils.hpp │ └── tensorrt_api.h ├── tensorrt ├── Cargo.toml ├── README.md ├── examples │ ├── README.md │ ├── basic │ │ ├── README.md │ │ └── main.rs │ ├── mnist_uff │ │ └── main.rs │ ├── onnx │ │ └── main.rs │ └── ssd_uff │ │ └── main.rs └── src │ ├── builder │ ├── mod.rs │ └── tests.rs │ ├── context.rs │ ├── data_size.rs │ ├── dims.rs │ ├── engine.rs │ ├── lib.rs │ ├── network │ ├── layer │ │ ├── activation_layer.rs │ │ ├── element_wise_layer.rs │ │ ├── gather_layer.rs │ │ ├── identity_layer.rs │ │ ├── mod.rs │ │ └── pooling_layer.rs │ └── mod.rs │ ├── onnx.rs │ ├── profiler.rs │ ├── runtime.rs │ ├── uff.rs │ └── utils.rs └── tensorrt_rs_derive ├── Cargo.toml └── src └── lib.rs /.circleci/config.yml: -------------------------------------------------------------------------------- 1 | # Use the latest 2.1 version of CircleCI pipeline process engine. See: https://circleci.com/docs/2.0/configuration-reference 2 | version: 2.1 3 | jobs: 4 | build-trt5: 5 | docker: 6 | - image: mstallmo/tensorrt-rs:0.4 7 | steps: 8 | - checkout 9 | - run: cd tensorrt && cargo build --release --no-default-features --features "trt-5" 10 | - persist_to_workspace: 11 | root: /root 12 | paths: 13 | - project 14 | build-trt7: 15 | docker: 16 | - image: mstallmo/tensorrt-rs:0.5 17 | steps: 18 | - checkout 19 | - run: cd tensorrt && cargo build --release 20 | - persist_to_workspace: 21 | root: /root 22 | paths: 23 | - project 24 | test-tensorrt-sys: 25 | docker: 26 | - image: mstallmo/tensorrt-rs:0.4 27 | steps: 28 | - attach_workspace: 29 | at: /root 30 | - run: cd tensorrt-sys && cargo test 31 | # Orchestrate or schedule a set of jobs 32 | workflows: 33 | build-docker: 34 | jobs: 35 | - build-trt5 36 | - build-trt7 37 | - test-tensorrt-sys: 38 | requires: 39 | - build-trt7 -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mstallmo/tensorrt-rs/1b9b0d80e2bcde72365ca2420547116daecd5f17/.gitattributes -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | Steps to reproduce the behavior: 15 | 1. Go to '...' 16 | 2. Click on '....' 17 | 3. Scroll down to '....' 18 | 4. See error 19 | 20 | **Expected behavior** 21 | A clear and concise description of what you expected to happen. 22 | 23 | **Screenshots** 24 | If applicable, add screenshots to help explain your problem. 25 | 26 | **Desktop (please complete the following information):** 27 | - OS: [e.g. Ubuntu 18.04] 28 | 29 | **GPU (please complete the following information):** 30 | - Model: [e.g. RTX 2080] 31 | - OS: [e.g. Ubuntu 18.04] 32 | - CUDA Version [e.g. 10.1] 33 | - CUDNN Version [e.g. 7.5] 34 | - TensorRT Version [e.g. 5.0.1] 35 | 36 | **Additional context** 37 | Add any other context about the problem here. 38 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | **Additional context** 20 | Add any other context or screenshots about the feature request here. 21 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | **/*.rs.bk 3 | Cargo.lock 4 | **/cmake-build-debug 5 | .idea 6 | *.engine 7 | assets/ 8 | tensorrt-sys/src/bindings.rs 9 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | We as members, contributors, and leaders pledge to make participation in our 6 | community a harassment-free experience for everyone, regardless of age, body 7 | size, visible or invisible disability, ethnicity, sex characteristics, gender 8 | identity and expression, level of experience, education, socio-economic status, 9 | nationality, personal appearance, race, religion, or sexual identity 10 | and orientation. 11 | 12 | We pledge to act and interact in ways that contribute to an open, welcoming, 13 | diverse, inclusive, and healthy community. 14 | 15 | ## Our Standards 16 | 17 | Examples of behavior that contributes to a positive environment for our 18 | community include: 19 | 20 | * Demonstrating empathy and kindness toward other people 21 | * Being respectful of differing opinions, viewpoints, and experiences 22 | * Giving and gracefully accepting constructive feedback 23 | * Accepting responsibility and apologizing to those affected by our mistakes, 24 | and learning from the experience 25 | * Focusing on what is best not just for us as individuals, but for the 26 | overall community 27 | 28 | Examples of unacceptable behavior include: 29 | 30 | * The use of sexualized language or imagery, and sexual attention or 31 | advances of any kind 32 | * Trolling, insulting or derogatory comments, and personal or political attacks 33 | * Public or private harassment 34 | * Publishing others' private information, such as a physical or email 35 | address, without their explicit permission 36 | * Other conduct which could reasonably be considered inappropriate in a 37 | professional setting 38 | 39 | ## Enforcement Responsibilities 40 | 41 | Community leaders are responsible for clarifying and enforcing our standards of 42 | acceptable behavior and will take appropriate and fair corrective action in 43 | response to any behavior that they deem inappropriate, threatening, offensive, 44 | or harmful. 45 | 46 | Community leaders have the right and responsibility to remove, edit, or reject 47 | comments, commits, code, wiki edits, issues, and other contributions that are 48 | not aligned to this Code of Conduct, and will communicate reasons for moderation 49 | decisions when appropriate. 50 | 51 | ## Scope 52 | 53 | This Code of Conduct applies within all community spaces, and also applies when 54 | an individual is officially representing the community in public spaces. 55 | Examples of representing our community include using an official e-mail address, 56 | posting via an official social media account, or acting as an appointed 57 | representative at an online or offline event. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported to the community leaders responsible for enforcement at: 63 | 64 | * Mason Stallmo or themoose5#6218 on Discord 65 | 66 | All complaints will be reviewed and investigated promptly and fairly. 67 | 68 | All community leaders are obligated to respect the privacy and security of the 69 | reporter of any incident. 70 | 71 | ## Enforcement Guidelines 72 | 73 | Community leaders will follow these Community Impact Guidelines in determining 74 | the consequences for any action they deem in violation of this Code of Conduct: 75 | 76 | ### 1. Correction 77 | 78 | **Community Impact**: Use of inappropriate language or other behavior deemed 79 | unprofessional or unwelcome in the community. 80 | 81 | **Consequence**: A private, written warning from community leaders, providing 82 | clarity around the nature of the violation and an explanation of why the 83 | behavior was inappropriate. A public apology may be requested. 84 | 85 | ### 2. Warning 86 | 87 | **Community Impact**: A violation through a single incident or series 88 | of actions. 89 | 90 | **Consequence**: A warning with consequences for continued behavior. No 91 | interaction with the people involved, including unsolicited interaction with 92 | those enforcing the Code of Conduct, for a specified period of time. This 93 | includes avoiding interactions in community spaces as well as external channels 94 | like social media. Violating these terms may lead to a temporary or 95 | permanent ban. 96 | 97 | ### 3. Temporary Ban 98 | 99 | **Community Impact**: A serious violation of community standards, including 100 | sustained inappropriate behavior. 101 | 102 | **Consequence**: A temporary ban from any sort of interaction or public 103 | communication with the community for a specified period of time. No public or 104 | private interaction with the people involved, including unsolicited interaction 105 | with those enforcing the Code of Conduct, is allowed during this period. 106 | Violating these terms may lead to a permanent ban. 107 | 108 | ### 4. Permanent Ban 109 | 110 | **Community Impact**: Demonstrating a pattern of violation of community 111 | standards, including sustained inappropriate behavior, harassment of an 112 | individual, or aggression toward or disparagement of classes of individuals. 113 | 114 | **Consequence**: A permanent ban from any sort of public interaction within 115 | the community. 116 | 117 | ## Attribution 118 | 119 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], 120 | version 2.0, available at 121 | https://www.contributor-covenant.org/version/2/0/code_of_conduct.html. 122 | 123 | Community Impact Guidelines were inspired by [Mozilla's code of conduct 124 | enforcement ladder](https://github.com/mozilla/diversity). 125 | 126 | [homepage]: https://www.contributor-covenant.org 127 | 128 | For answers to common questions about this code of conduct, see the FAQ at 129 | https://www.contributor-covenant.org/faq. Translations are available at 130 | https://www.contributor-covenant.org/translations. 131 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [workspace] 2 | members = [ 3 | "tensorrt", 4 | "tensorrt-sys", 5 | "tensorrt_rs_derive", 6 | ] -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvcr.io/nvidia/tensorrt:20.11-py3 2 | 3 | RUN apt-get update 4 | RUN apt install -y software-properties-common 5 | RUN add-apt-repository ppa:ubuntu-toolchain-r/test 6 | RUN apt-get update 7 | RUN apt-get install g++-7 -y 8 | RUN update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-7 60 \ 9 | --slave /usr/bin/g++ g++ /usr/bin/g++-7 10 | RUN update-alternatives --config gcc 11 | RUN gcc --version 12 | RUN g++ --version 13 | RUN apt-get update 14 | RUN apt-get install clang-6.0 -y 15 | 16 | # Download and install Rust 17 | RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y 18 | 19 | #Add Cargo executables to path 20 | ENV PATH="/root/.cargo/bin:${PATH}" 21 | 22 | # Check version 23 | RUN cargo --version 24 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TensorRT-RS 2 | Rust Bindings For Nvidia's TensorRT Deep Learning Library. 3 | 4 | See [tensorrt/README.md](./tensorrt/README.md) for information on the Rust library 5 | See [tensorrt-sys/README.md](./tensorrt-sys/README.md) for information on the wrapper library for TensorRT 6 | -------------------------------------------------------------------------------- /assets/efficientnet.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mstallmo/tensorrt-rs/1b9b0d80e2bcde72365ca2420547116daecd5f17/assets/efficientnet.onnx -------------------------------------------------------------------------------- /assets/images/meme.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mstallmo/tensorrt-rs/1b9b0d80e2bcde72365ca2420547116daecd5f17/assets/images/meme.jpg -------------------------------------------------------------------------------- /tensorrt-sys/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "tensorrt-sys" 3 | version = "0.4.0" 4 | authors = ["Mason Stallmo "] 5 | license = "MIT" 6 | edition = "2018" 7 | build = "build.rs" 8 | repository = "https://github.com/mstallmo/tensorrt-rs" 9 | description = "Low level wrapper around Nvidia's TensorRT library" 10 | 11 | [features] 12 | default = [] 13 | 14 | trt-5 = [] 15 | 16 | trt-6 = [] 17 | 18 | trt-7 = [] 19 | 20 | [dependencies] 21 | libc = "0.*" 22 | 23 | [build-dependencies] 24 | cmake = "0.1" 25 | bindgen = "0.55.1" 26 | -------------------------------------------------------------------------------- /tensorrt-sys/README.md: -------------------------------------------------------------------------------- 1 | # TensorRT-sys 2 | ![Crates.io](https://img.shields.io/crates/v/tensorrt-sys) 3 | 4 | :warning: __This crate currently only supports Linux__ :warning: 5 | 6 | C++ wrapper and Rust bindings to the TensorRT C++ library. Check 7 | [here](https://docs.nvidia.com/deeplearning/tensorrt/archives/tensorrt-515/tensorrt-api/c_api/classnvinfer1_1_1_i_builder.html) 8 | documentation on the C++ TensorRT library 9 | 10 | ### Prerequisites 11 | CUDA 10.1 12 | 13 | TensorRT 5.1.5 14 | 15 | CMake > 3.10 16 | 17 | 18 | 19 | TensorRT-sys' bindings depends on TensorRT 5.1.5 for the bindings to work correctly. While other versions of 20 | TensorRT *may* work with the bindings there are no guarantees as functions that are bound to may have been depricated, 21 | removed, or changed in future versions of TensorRT. 22 | 23 | The prerequisites enumerated above are expected to be installed in their default location on Linux. See the [nvidia 24 | documentation](https://docs.nvidia.com/deeplearning/tensorrt/install-guide/index.html#installing) around TensorRT for 25 | further install information. 26 | 27 | __Note:__ The tarball installation method described in the TesnorRT documentation is likely to cause major headaches with 28 | getting everything to link correctly. It is highly recommended to use the package manager method if possible. 29 | 30 | If there are issues with loading default TensoRT plugins it seems that there are some missing plugins that were not 31 | provided with the 5.1.5 binary installation at one point in time. To get these plugins follow the instructions for building 32 | and installing the TensorRT OSS components from [here](https://github.com/NVIDIA/TensorRT). 33 | 34 | Windows support is not currently supported but should be coming soon! 35 | 36 | ### Support Matrix for TensorRT Classes 37 | Anything not listed below currently does not have any support. 38 | 39 | | Class Name| Status| 40 | |------------------| ---| 41 | | nvinfer1::ILogger| Complete| 42 | |nvinfer1::IBuilder| Complete | 43 | |nvinfer1::IExecutionContext| Complete | 44 | |nvinfer1::IRuntime| Partial| 45 | |nvinfer1::ICudaEngine| Partial| 46 | |nvinfer1::INetworkDefinition| Partial| 47 | |nvinfer1::IHostMemory| Partial| 48 | |nvinfer1::IDims (and all sub-dims)| Complete| 49 | |nvinfer1::ILayer | Partial | 50 | |nvinfer1::IProfiler | Complete | 51 | |nvuffparser::IUffParser| Partial| 52 | 53 | 54 | 55 | 56 | 57 | ### Structure 58 | All of the C++ code that is used to communicate between Rust and TensorRT itself is contained in the `trt-sys` sub-folder 59 | This code exposes a C interface that can be consumed by Rust and translates between said interface and the API exposed by 60 | the TensorRT C++ library. 61 | 62 | Bindings to the C++ wrapper library are generated using the bindgen command 63 | `bindgen --size_t-is-usize trt-sys/tensorrt_api.h -o src/bindings.rs`. All headers that make up the C api are included in 64 | the file `tensorrt_api.h` that bindgen then consumes to crate the Rust bindings. These bindings are saved in the `src/` 65 | folder and imported by `lib.rs` to create the crate that is used by tensorrt-rs. 66 | -------------------------------------------------------------------------------- /tensorrt-sys/build.rs: -------------------------------------------------------------------------------- 1 | use bindgen::builder; 2 | use cmake::Config; 3 | 4 | fn cuda_configuration() { 5 | let cudadir = match option_env!("CUDA_INSTALL_DIR") { 6 | Some(cuda_dir) => cuda_dir, 7 | None => "/usr/local/cuda", 8 | }; 9 | 10 | println!("cargo:rustc-link-search={}/lib64", cudadir); 11 | println!("cargo:rustc-link-lib=dylib=cudart"); 12 | } 13 | 14 | fn tensorrt_configuration() { 15 | match option_env!("TRT_INSTALL_DIR") { 16 | Some(trt_lib_dir) => { 17 | println!("cargo:rustc-link-search={}/lib", trt_lib_dir); 18 | } 19 | None => (), 20 | } 21 | println!("cargo:rustc-link-lib=dylib=nvinfer"); 22 | println!("cargo:rustc-link-lib=dylib=nvonnxparser"); 23 | println!("cargo:rustc-link-lib=dylib=nvparsers"); 24 | println!("cargo:rustc-link-lib=dylib=nvinfer_plugin"); 25 | } 26 | 27 | // Not sure if I love this solution but I think it's relatively robust enough for now on Unix systems. 28 | // Still have to thoroughly test what happens with a TRT library installed that's not done by the 29 | // dpkg. It's possible that we'll just have to fall back to only supporting one system library and assuming that 30 | // the user has the correct library installed and is viewable via ldconfig. 31 | // 32 | // Hopefully something like this will work for Windows installs as well, not having a default library 33 | // install location will make that significantly harder. 34 | // 35 | fn main() -> Result<(), ()> { 36 | let mut cfg = Config::new("trt-sys"); 37 | 38 | #[cfg(feature = "trt-5")] 39 | { 40 | println!("Setting Config to TRT5"); 41 | cfg.define("TRT5", ""); 42 | let bindings = builder() 43 | .clang_args(&["-x", "c++"]) 44 | .header("trt-sys/tensorrt_api.h") 45 | .size_t_is_usize(true) 46 | .generate()?; 47 | 48 | bindings.write_to_file("src/bindings.rs").unwrap(); 49 | } 50 | 51 | #[cfg(feature = "trt-6")] 52 | { 53 | println!("Setting Config to TRT6"); 54 | cfg.define("TRT6", ""); 55 | let bindings = builder() 56 | .clang_arg("-DTRT6") 57 | .clang_args(&["-x", "c++"]) 58 | .header("trt-sys/tensorrt_api.h") 59 | .size_t_is_usize(true) 60 | .generate()?; 61 | 62 | bindings.write_to_file("src/bindings.rs").unwrap(); 63 | } 64 | 65 | #[cfg(feature = "trt-7")] 66 | { 67 | println!("Setting Config to TRT7"); 68 | cfg.define("TRT7", ""); 69 | let bindings = builder() 70 | .clang_arg("-DTRT7") 71 | .clang_args(&["-x", "c++"]) 72 | .header("trt-sys/tensorrt_api.h") 73 | .size_t_is_usize(true) 74 | .generate()?; 75 | 76 | bindings.write_to_file("src/bindings.rs").unwrap(); 77 | } 78 | 79 | let dst = cfg.build(); 80 | println!("cargo:rustc-link-search=native={}", dst.display()); 81 | println!("cargo:rustc-link-lib=static=trt-sys"); 82 | println!("cargo:rustc-link-lib=dylib=stdc++"); 83 | 84 | tensorrt_configuration(); 85 | cuda_configuration(); 86 | 87 | Ok(()) 88 | } 89 | -------------------------------------------------------------------------------- /tensorrt-sys/src/lib.rs: -------------------------------------------------------------------------------- 1 | #![allow(non_upper_case_globals)] 2 | #![allow(non_camel_case_types)] 3 | #![allow(non_snake_case)] 4 | 5 | include!("bindings.rs"); 6 | -------------------------------------------------------------------------------- /tensorrt-sys/trt-sys/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.10) 2 | project(LibTRT LANGUAGES CXX CUDA) 3 | 4 | if(DEFINED TRT7) 5 | message(STATUS "TRT7 is defined") 6 | add_compile_definitions(TRT7) 7 | elseif(DEFINED TRT6) 8 | message(STATUS "TRT6 is defined") 9 | add_compile_definitions(TRT6) 10 | elseif(DEFINED TRT5) 11 | message(STATUS "Falling back to TRT5") 12 | add_compile_definitions(TRT5) 13 | endif() 14 | 15 | set(CMAKE_CXX_STANDARD 17) 16 | 17 | set(CMAKE_CXX_FLAGS "-fPIC -O3 -Wall -Wextra -Werror -Wno-unknown-pragmas -Wno-deprecated -Wno-deprecated-declarations") 18 | 19 | file(GLOB source_files 20 | "TRTLogger/*.cpp" 21 | "TRTRuntime/*cpp" 22 | "TRTCudaEngine/*.cpp" 23 | "TRTContext/*.cpp" 24 | "TRTUffParser/*.cpp" 25 | "TRTOnnxParser/*.cpp" 26 | "TRTDims/*.cpp" 27 | "TRTBuilder/*.cpp" 28 | "TRTNetworkDefinition/*.cpp" 29 | "TRTHostMemory/*.cpp" 30 | "TRTLayer/*.cpp" 31 | "TRTTensor/*.cpp" 32 | "TRTProfiler/*.cpp" 33 | ) 34 | 35 | add_library(trt-sys STATIC ${source_files}) 36 | include_directories(${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) 37 | 38 | if(DEFINED ENV{TRT_INSTALL_DIR}) 39 | target_include_directories(trt-sys PRIVATE $ENV{TRT_INSTALL_DIR}/include) 40 | endif() 41 | 42 | install(TARGETS trt-sys DESTINATION .) 43 | -------------------------------------------------------------------------------- /tensorrt-sys/trt-sys/TRTBuilder/TRTBuilder.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Created by mason on 11/27/19. 3 | // 4 | #include 5 | #include 6 | #include 7 | 8 | #include "TRTBuilder.h" 9 | #include "../TRTLogger/TRTLoggerInternal.hpp" 10 | 11 | void builder_set_max_batch_size(nvinfer1::IBuilder* builder, int32_t batch_size) { 12 | builder->setMaxBatchSize(batch_size); 13 | } 14 | 15 | int32_t builder_get_max_batch_size(nvinfer1::IBuilder* builder) { 16 | return builder->getMaxBatchSize(); 17 | } 18 | 19 | void builder_set_max_workspace_size(nvinfer1::IBuilder* builder, size_t workspace_size) { 20 | builder->setMaxWorkspaceSize(workspace_size); 21 | } 22 | 23 | size_t builder_get_max_workspace_size(nvinfer1::IBuilder* builder) { 24 | return builder->getMaxWorkspaceSize(); 25 | } 26 | 27 | void builder_set_half2_mode(nvinfer1::IBuilder* builder, bool mode) { 28 | builder->setHalf2Mode(mode); 29 | } 30 | 31 | bool builder_get_half2_mode(nvinfer1::IBuilder* builder) { 32 | return builder->getHalf2Mode(); 33 | } 34 | 35 | void builder_set_debug_sync(nvinfer1::IBuilder* builder, bool sync) { 36 | builder->setDebugSync(sync); 37 | } 38 | 39 | bool builder_get_debug_sync(nvinfer1::IBuilder* builder) { 40 | return builder->getDebugSync(); 41 | } 42 | 43 | void builder_set_min_find_iterations(nvinfer1::IBuilder* builder, int min_find) { 44 | builder->setMinFindIterations(min_find); 45 | } 46 | 47 | int builder_get_min_find_iterations(nvinfer1::IBuilder* builder) { 48 | return builder->getMinFindIterations(); 49 | } 50 | 51 | void builder_set_average_find_iterations(nvinfer1::IBuilder* builder, int avg_find) { 52 | builder->setAverageFindIterations(avg_find); 53 | } 54 | 55 | int builder_get_average_find_iterations(nvinfer1::IBuilder* builder) { 56 | return builder->getAverageFindIterations(); 57 | } 58 | 59 | bool builder_platform_has_fast_fp16(nvinfer1::IBuilder* builder){ 60 | return builder->platformHasFastFp16(); 61 | } 62 | 63 | bool builder_platform_has_fast_int8(nvinfer1::IBuilder* builder) { 64 | return builder->platformHasFastInt8(); 65 | } 66 | 67 | void builder_set_int8_mode(nvinfer1::IBuilder* builder, bool mode) { 68 | builder->setInt8Mode(mode); 69 | } 70 | 71 | bool builder_get_int8_mode(nvinfer1::IBuilder* builder) { 72 | return builder->getInt8Mode(); 73 | } 74 | 75 | void builder_set_fp16_mode(nvinfer1::IBuilder* builder, bool mode) { 76 | builder->setFp16Mode(mode); 77 | } 78 | 79 | bool builder_get_fp16_mode(nvinfer1::IBuilder* builder) { 80 | return builder->getFp16Mode(); 81 | } 82 | 83 | void builder_set_device_type(nvinfer1::IBuilder* builder, nvinfer1::ILayer* layer, DeviceType_t deviceType) { 84 | builder->setDeviceType(layer, static_cast(deviceType)); 85 | } 86 | 87 | DeviceType_t builder_get_device_type(nvinfer1::IBuilder* builder, nvinfer1::ILayer* layer) { 88 | return static_cast(builder->getDeviceType(layer)); 89 | } 90 | 91 | bool builder_is_device_type_set(nvinfer1::IBuilder* builder, nvinfer1::ILayer* layer) { 92 | return builder->isDeviceTypeSet(layer); 93 | } 94 | 95 | void builder_set_default_device_type(nvinfer1::IBuilder* builder, DeviceType_t deviceType) { 96 | builder->setDefaultDeviceType(static_cast(deviceType)); 97 | } 98 | 99 | DeviceType_t builder_get_default_device_type(nvinfer1::IBuilder *builder) { 100 | return static_cast(builder->getDefaultDeviceType()); 101 | } 102 | 103 | void builder_reset_device_type(nvinfer1::IBuilder* builder, nvinfer1::ILayer* layer) { 104 | builder->resetDeviceType(layer); 105 | } 106 | 107 | bool builder_can_run_on_dla(nvinfer1::IBuilder* builder, nvinfer1::ILayer* layer) { 108 | return builder->canRunOnDLA(layer); 109 | } 110 | 111 | int builder_get_max_dla_batch_size(nvinfer1::IBuilder* builder) { 112 | return builder->getMaxBatchSize(); 113 | } 114 | 115 | void builder_allow_gpu_fallback(nvinfer1::IBuilder* builder, bool set_fallback_mode) { 116 | builder->allowGPUFallback(set_fallback_mode); 117 | } 118 | 119 | int builder_get_nb_dla_cores(nvinfer1::IBuilder* builder) { 120 | return builder->getNbDLACores(); 121 | } 122 | 123 | void builder_set_dla_core(nvinfer1::IBuilder* builder, int dla_core) { 124 | builder->setDLACore(dla_core); 125 | } 126 | 127 | int builder_get_dla_core(nvinfer1::IBuilder* builder) { 128 | return builder->getDLACore(); 129 | } 130 | 131 | void builder_set_strict_type_constraints(nvinfer1::IBuilder* builder, bool mode) { 132 | builder->setStrictTypeConstraints(mode); 133 | } 134 | 135 | bool builder_get_strict_type_constraints(nvinfer1::IBuilder* builder) { 136 | return builder->getStrictTypeConstraints(); 137 | } 138 | 139 | void builder_set_refittable(nvinfer1::IBuilder* builder, bool can_refit) { 140 | builder->setRefittable(can_refit); 141 | } 142 | 143 | bool builder_get_refittable(nvinfer1::IBuilder* builder) { 144 | return builder->getRefittable(); 145 | } 146 | 147 | void builder_set_engine_capability(nvinfer1::IBuilder* builder, EngineCapabiliy_t engine_capability) { 148 | builder->setEngineCapability(static_cast(engine_capability)); 149 | } 150 | 151 | EngineCapabiliy_t builder_get_engine_capability(nvinfer1::IBuilder* builder) { 152 | return static_cast(builder->getEngineCapability()); 153 | } 154 | 155 | nvinfer1::IBuilder *create_infer_builder(Logger_t *logger) { 156 | initLibNvInferPlugins(&logger->getLogger(), ""); 157 | return nvinfer1::createInferBuilder(logger->getLogger()); 158 | } 159 | 160 | 161 | void destroy_builder(nvinfer1::IBuilder* builder) { 162 | builder->destroy(); 163 | } 164 | 165 | #if defined(TRT6) || defined(TRT7) 166 | nvinfer1::INetworkDefinition *create_network_v2(nvinfer1::IBuilder *builder, uint32_t flags) { 167 | return builder->createNetworkV2(flags); 168 | } 169 | #else 170 | nvinfer1::INetworkDefinition *create_network(nvinfer1::IBuilder *builder) { 171 | return builder->createNetwork(); 172 | } 173 | #endif 174 | 175 | nvinfer1::ICudaEngine *build_cuda_engine(nvinfer1::IBuilder *builder, nvinfer1::INetworkDefinition *network) { 176 | return builder->buildCudaEngine(*network); 177 | } 178 | 179 | void builder_reset(nvinfer1::IBuilder* builder, nvinfer1::INetworkDefinition* network) { 180 | builder->reset(*network); 181 | } 182 | -------------------------------------------------------------------------------- /tensorrt-sys/trt-sys/TRTBuilder/TRTBuilder.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by mason on 11/27/19. 3 | // 4 | 5 | #ifndef LIBTRT_TRTBUILDER_H 6 | #define LIBTRT_TRTBUILDER_H 7 | 8 | #include 9 | #include "../TRTLogger/TRTLogger.h" 10 | #include "../TRTEnums.h" 11 | 12 | #include 13 | #include 14 | 15 | nvinfer1::ICudaEngine *build_cuda_engine(nvinfer1::IBuilder *builder, nvinfer1::INetworkDefinition *network); 16 | nvinfer1::IBuilder *create_infer_builder(Logger_t *logger); 17 | void destroy_builder(nvinfer1::IBuilder* builder); 18 | void builder_set_max_batch_size(nvinfer1::IBuilder* builder, int32_t batch_size); 19 | int32_t builder_get_max_batch_size(nvinfer1::IBuilder* builder); 20 | void builder_set_max_workspace_size(nvinfer1::IBuilder* builder, size_t batch_size); 21 | size_t builder_get_max_workspace_size(nvinfer1::IBuilder* builder); 22 | void builder_set_half2_mode(nvinfer1::IBuilder* builder, bool mode); 23 | bool builder_get_half2_mode(nvinfer1::IBuilder* builder); 24 | void builder_set_debug_sync(nvinfer1::IBuilder* builder, bool sync); 25 | bool builder_get_debug_sync(nvinfer1::IBuilder* builder); 26 | void builder_set_min_find_iterations(nvinfer1::IBuilder* builder, int min_find); 27 | int builder_get_min_find_iterations(nvinfer1::IBuilder* builder); 28 | void builder_set_average_find_iterations(nvinfer1::IBuilder* builder, int avg_find); 29 | int builder_get_average_find_iterations(nvinfer1::IBuilder* builder); 30 | bool builder_platform_has_fast_fp16(nvinfer1::IBuilder* builder); 31 | bool builder_platform_has_fast_int8(nvinfer1::IBuilder* builder); 32 | void builder_set_int8_mode(nvinfer1::IBuilder* builder, bool mode); 33 | bool builder_get_int8_mode(nvinfer1::IBuilder* builder); 34 | void builder_set_fp16_mode(nvinfer1::IBuilder* builder, bool mode); 35 | bool builder_get_fp16_mode(nvinfer1::IBuilder* builder); 36 | void builder_set_device_type(nvinfer1::IBuilder* builder, nvinfer1::ILayer* layer, DeviceType_t deviceType); 37 | DeviceType_t builder_get_device_type(nvinfer1::IBuilder* builder, nvinfer1::ILayer* layer); 38 | bool builder_is_device_type_set(nvinfer1::IBuilder* builder, nvinfer1::ILayer* layer); 39 | void builder_set_default_device_type(nvinfer1::IBuilder* builder, DeviceType_t deviceType); 40 | DeviceType_t builder_get_default_device_type(nvinfer1::IBuilder* builder); 41 | void builder_reset_device_type(nvinfer1::IBuilder* builder, nvinfer1::ILayer* layer); 42 | bool builder_can_run_on_dla(nvinfer1::IBuilder* builder, nvinfer1::ILayer* layer); 43 | int builder_get_max_dla_batch_size(nvinfer1::IBuilder* builder); 44 | void builder_allow_gpu_fallback(nvinfer1::IBuilder* builder, bool set_fallback_mode); 45 | int builder_get_nb_dla_cores(nvinfer1::IBuilder* builder); 46 | void builder_set_dla_core(nvinfer1::IBuilder* builder, int dla_core); 47 | int builder_get_dla_core(nvinfer1::IBuilder* builder); 48 | void builder_set_strict_type_constraints(nvinfer1::IBuilder* builder, bool mode); 49 | bool builder_get_strict_type_constraints(nvinfer1::IBuilder* builder); 50 | void builder_set_refittable(nvinfer1::IBuilder* builder, bool can_refit); 51 | bool builder_get_refittable(nvinfer1::IBuilder* builder); 52 | void builder_set_engine_capability(nvinfer1::IBuilder* builder, EngineCapabiliy_t engine_capability); 53 | EngineCapabiliy_t builder_get_engine_capability(nvinfer1::IBuilder* builder); 54 | #if defined(TRT6) || defined(TRT7) 55 | nvinfer1::INetworkDefinition *create_network_v2(nvinfer1::IBuilder* builder, uint32_t flags); 56 | #else 57 | nvinfer1::INetworkDefinition *create_network(nvinfer1::IBuilder* builder); 58 | #endif 59 | 60 | void builder_reset(nvinfer1::IBuilder* builder, nvinfer1::INetworkDefinition* network); 61 | #endif //LIBTRT_TRTBUILDER_H 62 | -------------------------------------------------------------------------------- /tensorrt-sys/trt-sys/TRTContext/TRTContext.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Created by mason on 9/17/19. 3 | // 4 | #include 5 | #include "NvInfer.h" 6 | #include "TRTContext.h" 7 | #include "../TRTProfiler/TRTProfilerInternal.hpp" 8 | //#include "../TRTUtils.hpp" 9 | 10 | //struct Context { 11 | // using IExecutionContextPtr = std::unique_ptr>; 12 | // IExecutionContextPtr internal_context; 13 | // 14 | // explicit Context(nvinfer1::IExecutionContext *executionContext) { 15 | // internal_context = IExecutionContextPtr(executionContext); 16 | // } 17 | // 18 | // ~Context() { 19 | // if (_concreteProfiler) { 20 | // _concreteProfiler->destroy(); 21 | // } 22 | // } 23 | // 24 | // ConcreteProfiler* _concreteProfiler = nullptr; 25 | //}; 26 | // 27 | //Context_t *create_execution_context(nvinfer1::IExecutionContext *execution_context) { 28 | // return new Context(execution_context); 29 | //} 30 | 31 | void destroy_excecution_context(nvinfer1::IExecutionContext *execution_context) { 32 | execution_context->destroy(); 33 | } 34 | 35 | void context_set_debug_sync(nvinfer1::IExecutionContext *execution_context, bool sync) { 36 | execution_context->setDebugSync(sync); 37 | } 38 | 39 | bool context_get_debug_sync(nvinfer1::IExecutionContext *execution_context) { 40 | return execution_context->getDebugSync(); 41 | } 42 | 43 | void context_set_name(nvinfer1::IExecutionContext *execution_context, const char *name) { 44 | execution_context->setName(name); 45 | } 46 | 47 | const char *context_get_name(nvinfer1::IExecutionContext *execution_context) { 48 | return execution_context->getName(); 49 | } 50 | 51 | void context_set_profiler(nvinfer1::IExecutionContext *context, CppProfiler* profiler) { 52 | context->setProfiler(profiler); 53 | } 54 | 55 | //Profiler_t* context_get_profiler(Context_t *context) { 56 | // auto concreteProfiler = dynamic_cast(context->internal_context->getProfiler()); 57 | // return concreteProfiler->getInternalProfiler(); 58 | //} 59 | 60 | void execute(nvinfer1::IExecutionContext *execution_context, void **buffers, int batch_size) { 61 | execution_context->execute(batch_size, &buffers[0]); 62 | } -------------------------------------------------------------------------------- /tensorrt-sys/trt-sys/TRTContext/TRTContext.h: -------------------------------------------------------------------------------- 1 | #pragma clang diagnostic push 2 | #pragma ide diagnostic ignored "modernize-deprecated-headers" 3 | // 4 | // Created by mason on 9/17/19. 5 | // 6 | 7 | #ifndef LIBTRT_TRTCONTEXT_H 8 | #define LIBTRT_TRTCONTEXT_H 9 | 10 | #include 11 | #include 12 | 13 | #include 14 | #include "../TRTProfiler/TRTProfiler.h" 15 | 16 | 17 | void destroy_excecution_context(nvinfer1::IExecutionContext* execution_context); 18 | 19 | void context_set_debug_sync(nvinfer1::IExecutionContext* execution_context, bool sync); 20 | bool context_get_debug_sync(nvinfer1::IExecutionContext* execution_context); 21 | 22 | void context_set_name(nvinfer1::IExecutionContext* execution_context, const char *name); 23 | const char* context_get_name(nvinfer1::IExecutionContext *execution_context); 24 | 25 | void context_set_profiler(nvinfer1::IExecutionContext* execution_context, CppProfiler* profiler); 26 | //Profiler_t* context_get_profiler(Context_t *execution_context); 27 | 28 | void execute(nvinfer1::IExecutionContext* execution_context, void** buffers, int batch_size); 29 | 30 | 31 | #endif //LIBTRT_TRTCONTEXT_H 32 | 33 | #pragma clang diagnostic pop -------------------------------------------------------------------------------- /tensorrt-sys/trt-sys/TRTCudaEngine/TRTCudaEngine.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Created by mason on 8/26/19. 3 | // 4 | #include 5 | 6 | #include "TRTCudaEngine.h" 7 | 8 | void engine_destroy(nvinfer1::ICudaEngine* engine) { 9 | engine->destroy(); 10 | } 11 | 12 | int engine_get_nb_bindings(nvinfer1::ICudaEngine* engine) { 13 | return engine->getNbBindings(); 14 | } 15 | 16 | int engine_get_binding_index(nvinfer1::ICudaEngine* engine, const char* op_name) { 17 | return engine->getBindingIndex(op_name); 18 | } 19 | 20 | const char* engine_get_binding_name(nvinfer1::ICudaEngine* engine, int binding_index) { 21 | return engine->getBindingName(binding_index); 22 | } 23 | 24 | bool engine_binding_is_input(nvinfer1::ICudaEngine *engine, int binding_index) { 25 | return engine->bindingIsInput(binding_index); 26 | } 27 | 28 | nvinfer1::Dims engine_get_binding_dimensions(nvinfer1::ICudaEngine *engine, int binding_index) { 29 | return engine->getBindingDimensions(binding_index); 30 | } 31 | 32 | DataType_t engine_get_binding_data_type(nvinfer1::ICudaEngine *engine, int binding_index) { 33 | return static_cast(engine->getBindingDataType(binding_index)); 34 | } 35 | 36 | int engine_get_max_batch_size(nvinfer1::ICudaEngine *engine) { 37 | return engine->getMaxBatchSize(); 38 | } 39 | 40 | int engine_get_nb_layers(nvinfer1::ICudaEngine *engine) { 41 | return engine->getNbLayers(); 42 | } 43 | 44 | size_t engine_get_workspace_size(nvinfer1::ICudaEngine *engine) { 45 | return engine->getWorkspaceSize(); 46 | } 47 | 48 | nvinfer1::IExecutionContext* engine_create_execution_context(nvinfer1::ICudaEngine *engine) { 49 | return engine->createExecutionContext(); 50 | } 51 | 52 | nvinfer1::IExecutionContext* engine_create_execution_context_without_device_memory(nvinfer1::ICudaEngine *engine) { 53 | return engine->createExecutionContextWithoutDeviceMemory(); 54 | } 55 | 56 | nvinfer1::IHostMemory* engine_serialize(nvinfer1::ICudaEngine* engine) { 57 | return engine->serialize(); 58 | } 59 | 60 | TensorLocation_t engine_get_location(nvinfer1::ICudaEngine *engine, int binding_index) { 61 | return static_cast(engine->getLocation(binding_index)); 62 | } 63 | 64 | size_t engine_get_device_memory_size(nvinfer1::ICudaEngine *engine) { 65 | return engine->getDeviceMemorySize(); 66 | } 67 | 68 | bool engine_is_refittable(nvinfer1::ICudaEngine *engine) { 69 | return engine->isRefittable(); 70 | } 71 | -------------------------------------------------------------------------------- /tensorrt-sys/trt-sys/TRTCudaEngine/TRTCudaEngine.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by mason on 8/26/19. 3 | // 4 | 5 | #ifndef LIBTRT_TRTCUDAENGINE_H 6 | #define LIBTRT_TRTCUDAENGINE_H 7 | 8 | #include 9 | 10 | #include "../TRTContext/TRTContext.h" 11 | #include "../TRTHostMemory/TRTHostMemory.h" 12 | #include "../TRTDims/TRTDims.h" 13 | #include "../TRTEnums.h" 14 | 15 | void engine_destroy(nvinfer1::ICudaEngine* engine); 16 | nvinfer1::IExecutionContext* engine_create_execution_context(nvinfer1::ICudaEngine *engine); 17 | nvinfer1::IExecutionContext* engine_create_execution_context_without_device_memory(nvinfer1::ICudaEngine *engine); 18 | int engine_get_nb_bindings(nvinfer1::ICudaEngine* engine); 19 | int engine_get_binding_index(nvinfer1::ICudaEngine *engine, const char* op_name); 20 | const char* engine_get_binding_name(nvinfer1::ICudaEngine* engine, int binding_index); 21 | bool engine_binding_is_input(nvinfer1::ICudaEngine *engine, int binding_index); 22 | nvinfer1::Dims engine_get_binding_dimensions(nvinfer1::ICudaEngine *engine, int binding_index); 23 | DataType_t engine_get_binding_data_type(nvinfer1::ICudaEngine *engine, int binding_index); 24 | int engine_get_max_batch_size(nvinfer1::ICudaEngine *engine); 25 | int engine_get_nb_layers(nvinfer1::ICudaEngine *engine); 26 | size_t engine_get_workspace_size(nvinfer1::ICudaEngine *engine); 27 | nvinfer1::IHostMemory* engine_serialize(nvinfer1::ICudaEngine* engine); 28 | TensorLocation_t engine_get_location(nvinfer1::ICudaEngine *engine, int binding_index); 29 | size_t engine_get_device_memory_size(nvinfer1::ICudaEngine *engine); 30 | bool engine_is_refittable(nvinfer1::ICudaEngine *engine); 31 | 32 | #endif //LIBTRT_TRTCUDAENGINE_H 33 | -------------------------------------------------------------------------------- /tensorrt-sys/trt-sys/TRTDims/TRTDims.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Created by mason on 4/30/20. 3 | // 4 | #include 5 | #include 6 | #include "TRTDims.h" 7 | 8 | nvinfer1::Dims create_dims(int nb_dims, const int* d, const DimensionType_t *dimension_types) { 9 | nvinfer1::Dims dims{}; 10 | dims.nbDims = nb_dims; 11 | memcpy(dims.d, d, nvinfer1::Dims::MAX_DIMS * sizeof(int)); 12 | memcpy(dims.type, dimension_types, nvinfer1::Dims::MAX_DIMS * sizeof(DimensionType)); 13 | 14 | return dims; 15 | } 16 | 17 | nvinfer1::Dims2 create_dims2(int dim1, int dim2) { 18 | return nvinfer1::Dims2(dim1, dim2); 19 | } 20 | 21 | nvinfer1::DimsHW create_dimsHW(int height , int width) { 22 | return nvinfer1::DimsHW(height, width); 23 | } 24 | 25 | nvinfer1::Dims3 create_dims3(int dim1, int dim2, int dim3) { 26 | return nvinfer1::Dims3(dim1, dim2, dim3); 27 | } 28 | 29 | nvinfer1::DimsCHW create_dimsCHW(int channels, int height, int width) { 30 | return nvinfer1::DimsCHW(channels, height, width); 31 | } 32 | 33 | nvinfer1::Dims4 create_dims4(int dim1, int dim2, int dim3, int dim4) { 34 | return nvinfer1::Dims4(dim1, dim2, dim3, dim4); 35 | } 36 | 37 | nvinfer1::DimsNCHW create_dimsNCHW(int batchSize, int channel, int height, int width) { 38 | return nvinfer1::DimsNCHW(batchSize, channel, height, width); 39 | } 40 | -------------------------------------------------------------------------------- /tensorrt-sys/trt-sys/TRTDims/TRTDims.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by mason on 11/22/19. 3 | // 4 | 5 | #ifndef LIBTRT_TRTDIMS_H 6 | #define LIBTRT_TRTDIMS_H 7 | 8 | #include 9 | #include "../TRTEnums.h" 10 | 11 | nvinfer1::Dims create_dims(int nb_dims, const int* d, const DimensionType_t *dimension_types); 12 | nvinfer1::Dims2 create_dims2(int dim1, int dim2); 13 | nvinfer1::DimsHW create_dimsHW(int height, int width); 14 | nvinfer1::Dims3 create_dims3(int dim1, int dim2, int dim3); 15 | nvinfer1::DimsCHW create_dimsCHW(int channel, int height, int width); 16 | nvinfer1::Dims4 create_dims4(int dim1, int dim2, int dim3, int dim4); 17 | nvinfer1::DimsNCHW create_dimsNCHW(int index, int channel, int height, int width); 18 | 19 | 20 | #endif //LIBTRT_TRTDIMS_H 21 | -------------------------------------------------------------------------------- /tensorrt-sys/trt-sys/TRTEnums.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by mason on 10/10/20. 3 | // 4 | 5 | #ifndef LIBTRT_TRTENUMS_H 6 | #define LIBTRT_TRTENUMS_H 7 | 8 | #ifdef __cplusplus 9 | extern "C" { 10 | #endif 11 | 12 | enum class ActivationType { 13 | kRELU = 0, 14 | kSIGMOID = 1, 15 | kTANH = 2, 16 | kLEAKY_RELU = 3, 17 | kELU = 4, 18 | kSELU = 5, 19 | kSOFTSIGN = 6, 20 | kSOFTPLUS = 7, 21 | kCLIP = 8, 22 | kHARD_SIGMOID = 9, 23 | kSCALED_TANH = 10, 24 | kTHRESHOLDED_RELU = 11, 25 | }; 26 | typedef enum ActivationType ActivationType_t; 27 | 28 | enum class DataType { 29 | kFLOAT = 0, 30 | kHALF = 1, 31 | kINT8 = 2, 32 | kINT32 = 3, 33 | }; 34 | typedef enum DataType DataType_t; 35 | 36 | enum class DeviceType { 37 | kGPU = 0, 38 | kDLA = 1, 39 | }; 40 | typedef enum DeviceType DeviceType_t; 41 | 42 | enum class DimensionType { 43 | kSPATIAL = 0, 44 | kCHANNEL = 1, 45 | kINDEX = 2, 46 | kSEQUENCE = 3 47 | }; 48 | typedef enum DimensionType DimensionType_t; 49 | 50 | enum class ElementWiseOperation { 51 | kSUM = 0, 52 | kPROD = 1, 53 | kMAX = 2, 54 | kMIN = 3, 55 | kSUB = 4, 56 | kDIV = 5, 57 | kPOW = 6, 58 | }; 59 | typedef enum ElementWiseOperation ElementWiseOperation_t; 60 | 61 | enum class EngineCapabiliy { 62 | kDEFAULT = 0, 63 | kSAFE_GPU = 1, 64 | kSAFE_DLA = 2, 65 | }; 66 | typedef enum EngineCapabiliy EngineCapabiliy_t; 67 | 68 | enum class LayerType { 69 | kCONVOLUTION = 0, 70 | kFULLY_CONNECTED = 1, 71 | kACTIVATION = 2, 72 | kPOOLING = 3, 73 | kLRN = 4, 74 | kSCALE = 5, 75 | kSOFTMAX = 6, 76 | kDECONVOLUTION = 7, 77 | kCONCATENATION = 8, 78 | kELEMENTWISE = 9, 79 | kPLUGIN = 10, 80 | kRNN = 11, 81 | kUNARY = 12, 82 | kPADDING = 13, 83 | kSHUFFLE = 14, 84 | kREDUCE = 15, 85 | kTOPK = 16, 86 | kGATHER = 17, 87 | kMATRIX_MULTIPLY = 18, 88 | kRAGGED_SOFTMAX = 19, 89 | kCONSTANT = 20, 90 | kRNN_V2 = 21, 91 | kIDENTITY = 22, 92 | kPLUGIN_V2 = 23, 93 | kSLICE = 24, 94 | }; 95 | typedef enum LayerType LayerType_t; 96 | 97 | enum class PaddingMode { 98 | kEXPLICIT_ROUND_DOWN = 0, 99 | kEXPLICIT_ROUND_UP = 1, 100 | kSAME_UPPER = 2, 101 | kSAME_LOWER = 3, 102 | kCAFFE_ROUND_DOWN = 4, 103 | kCAFFE_ROUND_UP = 5, 104 | }; 105 | 106 | enum class PoolingType { 107 | kMAX = 0, 108 | kAVERAGE = 1, 109 | kMAX_AVERAGE_BLEND = 2, 110 | }; 111 | 112 | enum class TensorLocation { 113 | kDEVICE = 0, 114 | kHOST = 1, 115 | }; 116 | typedef enum TensorLocation TensorLocation_t; 117 | 118 | #ifdef __cplusplus 119 | }; 120 | #endif 121 | 122 | #endif //LIBTRT_TRTENUMS_H 123 | -------------------------------------------------------------------------------- /tensorrt-sys/trt-sys/TRTHostMemory/TRTHostMemory.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Created by mason on 1/19/20. 3 | // 4 | #include 5 | #include "TRTHostMemory.h" 6 | 7 | void destroy_host_memory(nvinfer1::IHostMemory* host_memory) { 8 | host_memory->destroy(); 9 | } 10 | 11 | void* host_memory_get_data(nvinfer1::IHostMemory* host_memory) { 12 | return host_memory->data(); 13 | } 14 | 15 | size_t host_memory_get_size(nvinfer1::IHostMemory* host_memory) { 16 | return host_memory->size(); 17 | } 18 | -------------------------------------------------------------------------------- /tensorrt-sys/trt-sys/TRTHostMemory/TRTHostMemory.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by mason on 1/19/20. 3 | // 4 | 5 | #ifndef LIBTRT_TRTHOSTMEMORY_H 6 | #define LIBTRT_TRTHOSTMEMORY_H 7 | 8 | #include 9 | 10 | void destroy_host_memory(nvinfer1::IHostMemory* host_memory); 11 | 12 | void* host_memory_get_data(nvinfer1::IHostMemory* host_memory); 13 | size_t host_memory_get_size(nvinfer1::IHostMemory* host_memory); 14 | 15 | 16 | #endif //LIBTRT_TRTHOSTMEMORY_H 17 | -------------------------------------------------------------------------------- /tensorrt-sys/trt-sys/TRTLayer/TRTActivationLayer.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Created by mason on 10/13/20. 3 | // 4 | 5 | #include "TRTActivationLayer.h" 6 | 7 | void activation_set_activation_type(nvinfer1::IActivationLayer *layer, ActivationType_t activationType) { 8 | layer->setActivationType(static_cast(activationType)); 9 | } 10 | 11 | ActivationType_t activation_get_activation_type(nvinfer1::IActivationLayer *layer) { 12 | return static_cast(layer->getActivationType()); 13 | } 14 | 15 | void activation_set_alpha(nvinfer1::IActivationLayer *layer, float alpha) { 16 | layer->setAlpha(alpha); 17 | } 18 | 19 | float activation_get_alpha(nvinfer1::IActivationLayer *layer) { 20 | return layer->getAlpha(); 21 | } 22 | 23 | void activation_set_beta(nvinfer1::IActivationLayer *layer, float beta) { 24 | layer->setBeta(beta); 25 | } 26 | 27 | float activation_get_beta(nvinfer1::IActivationLayer *layer) { 28 | return layer->getBeta(); 29 | } 30 | -------------------------------------------------------------------------------- /tensorrt-sys/trt-sys/TRTLayer/TRTActivationLayer.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by mason on 10/13/20. 3 | // 4 | 5 | #ifndef LIBTRT_TRTACTIVATIONLAYER_H 6 | #define LIBTRT_TRTACTIVATIONLAYER_H 7 | 8 | #include "../TRTEnums.h" 9 | #include "TRTLayer.h" 10 | 11 | void activation_set_activation_type(nvinfer1::IActivationLayer *layer, ActivationType_t activationType); 12 | ActivationType_t activation_get_activation_type(nvinfer1::IActivationLayer *layer); 13 | void activation_set_alpha(nvinfer1::IActivationLayer *layer, float alpha); 14 | float activation_get_alpha(nvinfer1::IActivationLayer *layer); 15 | void activation_set_beta(nvinfer1::IActivationLayer *layer, float beta); 16 | float activation_get_beta(nvinfer1::IActivationLayer *layer); 17 | 18 | #endif //LIBTRT_TRTACTIVATIONLAYER_H 19 | -------------------------------------------------------------------------------- /tensorrt-sys/trt-sys/TRTLayer/TRTElementWiseLayer.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Created by mason on 10/12/20. 3 | // 4 | 5 | #include 6 | 7 | #include "TRTElementWiseLayer.h" 8 | 9 | void elementwise_set_operation(nvinfer1::IElementWiseLayer *element_wise_layer, ElementWiseOperation_t operation) { 10 | element_wise_layer->setOperation(static_cast(operation)); 11 | } 12 | 13 | ElementWiseOperation_t elementwise_get_operation(nvinfer1::IElementWiseLayer *element_wise_layer) { 14 | return static_cast(element_wise_layer->getOperation()); 15 | } 16 | -------------------------------------------------------------------------------- /tensorrt-sys/trt-sys/TRTLayer/TRTElementWiseLayer.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by mason on 10/12/20. 3 | // 4 | 5 | #ifndef LIBTRT_TRTELEMENTWISELAYER_H 6 | #define LIBTRT_TRTELEMENTWISELAYER_H 7 | 8 | #include "TRTLayer.h" 9 | #include "../TRTEnums.h" 10 | 11 | void elementwise_set_operation(nvinfer1::IElementWiseLayer *element_wise_layer, ElementWiseOperation_t type); 12 | ElementWiseOperation_t elementwise_get_operation(nvinfer1::IElementWiseLayer *element_wise_layer); 13 | 14 | #endif //LIBTRT_TRTELEMENTWISELAYER_H 15 | -------------------------------------------------------------------------------- /tensorrt-sys/trt-sys/TRTLayer/TRTGatherLayer.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Created by mason on 10/12/20. 3 | // 4 | 5 | #include "TRTGatherLayer.h" 6 | 7 | int32_t gather_layer_get_gather_axis(nvinfer1::IGatherLayer *layer) { 8 | return layer->getGatherAxis(); 9 | } 10 | 11 | void gather_layer_set_gather_axis(nvinfer1::IGatherLayer *layer, int32_t axis) { 12 | layer->setGatherAxis(axis); 13 | } 14 | -------------------------------------------------------------------------------- /tensorrt-sys/trt-sys/TRTLayer/TRTGatherLayer.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by mason on 10/12/20. 3 | // 4 | 5 | #ifndef LIBTRT_TRTGATHERLAYER_H 6 | #define LIBTRT_TRTGATHERLAYER_H 7 | 8 | #include 9 | #include "TRTLayer.h" 10 | 11 | int32_t gather_layer_get_gather_axis(nvinfer1::IGatherLayer *layer); 12 | void gather_layer_set_gather_axis(nvinfer1::IGatherLayer *layer, int32_t axis); 13 | 14 | #endif //LIBTRT_TRTGATHERLAYER_H 15 | -------------------------------------------------------------------------------- /tensorrt-sys/trt-sys/TRTLayer/TRTLayer.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Created by mason on 10/7/20. 3 | // 4 | #include 5 | #include "TRTLayer.h" 6 | 7 | LayerType_t layer_get_type(nvinfer1::ILayer *layer) { 8 | return static_cast(layer->getType()); 9 | } 10 | 11 | void layer_set_name(nvinfer1::ILayer* layer, const char* name) { 12 | layer->setName(name); 13 | } 14 | 15 | const char* layer_get_name(nvinfer1::ILayer *layer) { 16 | if(layer == nullptr) { 17 | return nullptr; 18 | } 19 | 20 | return layer->getName(); 21 | } 22 | 23 | int32_t layer_get_nb_inputs(nvinfer1::ILayer *layer) { 24 | return layer->getNbInputs(); 25 | } 26 | 27 | nvinfer1::ITensor* layer_get_input(nvinfer1::ILayer *layer, int32_t index) { 28 | return layer->getInput(index); 29 | } 30 | 31 | int32_t layer_get_nb_outputs(nvinfer1::ILayer *layer) { 32 | return layer->getNbOutputs(); 33 | } 34 | 35 | nvinfer1::ITensor* layer_get_output(nvinfer1::ILayer *layer, int32_t index) { 36 | return layer->getOutput(index); 37 | } 38 | 39 | void layer_set_input(nvinfer1::ILayer *layer, int32_t index, nvinfer1::ITensor *tensor) { 40 | layer->setInput(index, *tensor); 41 | } 42 | 43 | void layer_set_precision(nvinfer1::ILayer *layer, DataType_t precision) { 44 | layer->setPrecision(static_cast(precision)); 45 | } 46 | 47 | DataType_t layer_get_precision(nvinfer1::ILayer *layer) { 48 | return static_cast(layer->getPrecision()); 49 | } 50 | 51 | bool layer_precision_is_set(nvinfer1::ILayer *layer) { 52 | return layer->precisionIsSet(); 53 | } 54 | 55 | void layer_reset_precision(nvinfer1::ILayer *layer) { 56 | layer->resetPrecision(); 57 | } 58 | 59 | void layer_set_output_type(nvinfer1::ILayer *layer, int32_t index, DataType_t dataType) { 60 | layer->setOutputType(index, static_cast(dataType)); 61 | } 62 | 63 | DataType_t layer_get_output_type(nvinfer1::ILayer *layer, int32_t index) { 64 | return static_cast(layer->getOutputType(index)); 65 | } 66 | 67 | bool layer_output_type_is_set(nvinfer1::ILayer *layer, int32_t index) { 68 | return layer->outputTypeIsSet(index); 69 | } 70 | 71 | void layer_reset_output_type(nvinfer1::ILayer *layer, int32_t index) { 72 | layer->resetOutputType(index); 73 | } 74 | -------------------------------------------------------------------------------- /tensorrt-sys/trt-sys/TRTLayer/TRTLayer.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by mason on 10/7/20. 3 | // 4 | 5 | #ifndef LIBTRT_TRTLAYER_H 6 | #define LIBTRT_TRTLAYER_H 7 | 8 | #include 9 | #include 10 | #include "../TRTEnums.h" 11 | 12 | LayerType_t layer_get_type(nvinfer1::ILayer *layer); 13 | void layer_set_name(nvinfer1::ILayer *layer, const char* name); 14 | const char* layer_get_name(nvinfer1::ILayer *layer); 15 | int32_t layer_get_nb_inputs(nvinfer1::ILayer *layer); 16 | nvinfer1::ITensor* layer_get_input(nvinfer1::ILayer *layer, int32_t index); 17 | int32_t layer_get_nb_outputs(nvinfer1::ILayer *layer); 18 | nvinfer1::ITensor * layer_get_output(nvinfer1::ILayer *layer, int32_t index); 19 | void layer_set_input(nvinfer1::ILayer *layer, int32_t index, nvinfer1::ITensor *tensor); 20 | void layer_set_precision(nvinfer1::ILayer *layer, DataType_t precision); 21 | DataType_t layer_get_precision(nvinfer1::ILayer *layer); 22 | bool layer_precision_is_set(nvinfer1::ILayer *layer); 23 | void layer_reset_precision(nvinfer1::ILayer *layer); 24 | void layer_set_output_type(nvinfer1::ILayer *layer, int32_t index, DataType_t dataType); 25 | DataType_t layer_get_output_type(nvinfer1::ILayer *layer, int32_t index); 26 | bool layer_output_type_is_set(nvinfer1::ILayer *layer, int32_t index); 27 | void layer_reset_output_type(nvinfer1::ILayer *layer, int32_t index); 28 | 29 | #endif //LIBTRT_TRTLAYER_H 30 | -------------------------------------------------------------------------------- /tensorrt-sys/trt-sys/TRTLayer/TRTPoolingLayer.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Created by mason on 10/13/20. 3 | // 4 | 5 | #include "TRTPoolingLayer.h" 6 | 7 | 8 | void pooling_set_pooling_type(nvinfer1::IPoolingLayer *layer, PoolingType poolingType) { 9 | layer->setPoolingType(static_cast(poolingType)); 10 | } 11 | 12 | PoolingType pooling_get_pooling_type(nvinfer1::IPoolingLayer *layer) { 13 | return static_cast(layer->getPoolingType()); 14 | } 15 | 16 | void pooling_set_window_size(nvinfer1::IPoolingLayer *layer, nvinfer1::DimsHW dims) { 17 | layer->setWindowSize(dims); 18 | } 19 | 20 | nvinfer1::DimsHW pooling_get_window_size(nvinfer1::IPoolingLayer *layer) { 21 | return layer->getWindowSize(); 22 | } 23 | 24 | void pooling_set_stride(nvinfer1::IPoolingLayer *layer, nvinfer1::DimsHW stride) { 25 | layer->setStride(stride); 26 | } 27 | 28 | nvinfer1::DimsHW pooling_get_stride(nvinfer1::IPoolingLayer *layer) { 29 | return layer->getStride(); 30 | } 31 | 32 | void pooling_set_padding(nvinfer1::IPoolingLayer *layer, nvinfer1::DimsHW padding) { 33 | layer->setPadding(padding); 34 | } 35 | 36 | nvinfer1::DimsHW pooling_get_padding(nvinfer1::IPoolingLayer *layer) { 37 | return layer->getPadding(); 38 | } 39 | 40 | void pooling_set_blend_factor(nvinfer1::IPoolingLayer *layer, float factor) { 41 | layer->setBlendFactor(factor); 42 | } 43 | 44 | float pooling_get_blend_factor(nvinfer1::IPoolingLayer *layer) { 45 | return layer->getBlendFactor(); 46 | } 47 | 48 | void pooling_set_average_count_excludes_padding(nvinfer1::IPoolingLayer *layer, bool exclusive) { 49 | layer->setAverageCountExcludesPadding(exclusive); 50 | } 51 | 52 | bool pooling_get_average_count_excludes_padding(nvinfer1::IPoolingLayer *layer) { 53 | return layer->getAverageCountExcludesPadding(); 54 | } 55 | 56 | void pooling_set_pre_padding(nvinfer1::IPoolingLayer *layer, nvinfer1::Dims pre_padding) { 57 | layer->setPrePadding(pre_padding); 58 | } 59 | 60 | nvinfer1::Dims pooling_get_pre_padding(nvinfer1::IPoolingLayer *layer) { 61 | return layer->getPrePadding(); 62 | } 63 | 64 | void pooling_set_post_padding(nvinfer1::IPoolingLayer *layer, nvinfer1::Dims post_padding) { 65 | layer->setPrePadding(post_padding); 66 | } 67 | 68 | nvinfer1::Dims pooling_get_post_padding(nvinfer1::IPoolingLayer *layer) { 69 | return layer->getPrePadding(); 70 | } 71 | 72 | void pooling_set_padding_mode(nvinfer1::IPoolingLayer *layer, PaddingMode mode) { 73 | layer->setPaddingMode(static_cast(mode)); 74 | } 75 | 76 | PaddingMode pooling_get_padding_mode(nvinfer1::IPoolingLayer *layer) { 77 | return static_cast(layer->getPaddingMode()); 78 | } 79 | -------------------------------------------------------------------------------- /tensorrt-sys/trt-sys/TRTLayer/TRTPoolingLayer.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by mason on 10/13/20. 3 | // 4 | 5 | #ifndef LIBTRT_TRTPOOLINGLAYER_H 6 | #define LIBTRT_TRTPOOLINGLAYER_H 7 | 8 | #include "../TRTEnums.h" 9 | #include "../TRTDims/TRTDims.h" 10 | #include "TRTLayer.h" 11 | 12 | void pooling_set_pooling_type(nvinfer1::IPoolingLayer *layer, PoolingType poolingType); 13 | PoolingType pooling_get_pooling_type(nvinfer1::IPoolingLayer *layer); 14 | 15 | void pooling_set_window_size(nvinfer1::IPoolingLayer *layer, nvinfer1::DimsHW dims); 16 | nvinfer1::DimsHW pooling_get_window_size(nvinfer1::IPoolingLayer *layer); 17 | 18 | void pooling_set_stride(nvinfer1::IPoolingLayer *layer, nvinfer1::DimsHW stride); 19 | nvinfer1::DimsHW pooling_get_stride(nvinfer1::IPoolingLayer *layer); 20 | 21 | void pooling_set_padding(nvinfer1::IPoolingLayer *layer, nvinfer1::DimsHW padding); 22 | nvinfer1::DimsHW pooling_get_padding(nvinfer1::IPoolingLayer *layer); 23 | 24 | void pooling_set_blend_factor(nvinfer1::IPoolingLayer *layer, float factor); 25 | float pooling_get_blend_factor(nvinfer1::IPoolingLayer *layer); 26 | 27 | void pooling_set_average_count_excludes_padding(nvinfer1::IPoolingLayer *layer, bool exclusive); 28 | bool pooling_get_average_count_excludes_padding(nvinfer1::IPoolingLayer *layer); 29 | 30 | void pooling_set_pre_padding(nvinfer1::IPoolingLayer *layer, nvinfer1::Dims pre_padding); 31 | nvinfer1::Dims pooling_get_pre_padding(nvinfer1::IPoolingLayer *layer); 32 | 33 | void pooling_set_post_padding(nvinfer1::IPoolingLayer *layer, nvinfer1::Dims post_padding); 34 | nvinfer1::Dims pooling_get_post_padding(nvinfer1::IPoolingLayer *layer); 35 | 36 | void pooling_set_padding_mode(nvinfer1::IPoolingLayer *layer, PaddingMode mode); 37 | PaddingMode pooling_get_padding_mode(nvinfer1::IPoolingLayer *layer); 38 | 39 | #endif //LIBTRT_TRTPOOLINGLAYER_H 40 | -------------------------------------------------------------------------------- /tensorrt-sys/trt-sys/TRTLogger/TRTLogger.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Created by mason on 8/25/19. 3 | // 4 | #include 5 | #include "NvInfer.h" 6 | #include "TRTLoggerInternal.hpp" 7 | 8 | void get_tensorrt_version(char *string) { 9 | sprintf(string, "%d.%d.%d", NV_TENSORRT_MAJOR, NV_TENSORRT_MINOR, NV_TENSORRT_PATCH); 10 | } 11 | 12 | Logger_t *create_logger(const int severity) { 13 | auto nvSeverity = static_cast(severity); 14 | 15 | return new Logger(nvSeverity); 16 | } 17 | 18 | void delete_logger(Logger_t *logger) { 19 | if (logger == nullptr) 20 | return; 21 | 22 | delete logger; 23 | } 24 | 25 | void set_logger_severity(const Logger_t* logger, const int severity) { 26 | auto nvSeverity = static_cast(severity); 27 | 28 | logger->internal_logger->severity(nvSeverity); 29 | } 30 | 31 | void log_error(Logger_t *logger, char *err) { 32 | if (logger == nullptr) 33 | return; 34 | auto &l = logger->internal_logger; 35 | l->log(nvinfer1::ILogger::Severity::kWARNING, err); 36 | } 37 | 38 | nvinfer1::ILogger &Logger::getLogger() const { 39 | return *this->internal_logger; 40 | } -------------------------------------------------------------------------------- /tensorrt-sys/trt-sys/TRTLogger/TRTLogger.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by mason on 8/25/19. 3 | // 4 | 5 | #ifndef TENSRORT_SYS_TRTLOGGER_H 6 | #define TENSRORT_SYS_TRTLOGGER_H 7 | 8 | #ifdef __cplusplus 9 | extern "C" { 10 | #endif 11 | 12 | struct Logger; 13 | typedef struct Logger Logger_t; 14 | 15 | void get_tensorrt_version(char* string); 16 | 17 | Logger_t* create_logger(const int severity); 18 | void set_logger_severity(const Logger_t* logger, const int severity); 19 | void log_error(Logger_t* logger, char* err); 20 | void delete_logger(Logger_t* logger); 21 | 22 | 23 | #ifdef __cplusplus 24 | }; 25 | #endif 26 | 27 | #endif //TENSRORT_SYS_TRTLOGGER_H 28 | -------------------------------------------------------------------------------- /tensorrt-sys/trt-sys/TRTLogger/TRTLoggerInternal.hpp: -------------------------------------------------------------------------------- 1 | // 2 | // Created by mason on 4/11/20. 3 | // 4 | 5 | #ifndef LIBTRT_TRTLOGGERINTERNAL_HPP 6 | #define LIBTRT_TRTLOGGERINTERNAL_HPP 7 | 8 | #include 9 | #include "TRTLogger.h" 10 | #include 11 | 12 | class TRTLogger : public nvinfer1::ILogger { 13 | public: 14 | explicit TRTLogger(nvinfer1::ILogger::Severity severity) 15 | : mReportableSeverity(severity) { 16 | } 17 | 18 | void log(nvinfer1::ILogger::Severity severity, const char *msg) final { 19 | if (severity <= mReportableSeverity) 20 | printf("%s\n", msg); 21 | } 22 | 23 | void severity(nvinfer1::ILogger::Severity severity) { 24 | mReportableSeverity = severity; 25 | } 26 | 27 | private: 28 | nvinfer1::ILogger::Severity mReportableSeverity; 29 | }; 30 | 31 | struct Logger { 32 | std::unique_ptr internal_logger; 33 | 34 | explicit Logger(nvinfer1::ILogger::Severity severity) { 35 | internal_logger = std::make_unique(severity); 36 | }; 37 | 38 | [[nodiscard]] nvinfer1::ILogger& getLogger() const; 39 | }; 40 | 41 | 42 | #endif //LIBTRT_TRTLOGGERINTERNAL_HPP 43 | -------------------------------------------------------------------------------- /tensorrt-sys/trt-sys/TRTNetworkDefinition/TRTNetworkDefinition.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Created by mason on 11/27/19. 3 | // 4 | #include "TRTNetworkDefinition.h" 5 | 6 | void destroy_network(nvinfer1::INetworkDefinition *network) { 7 | network->destroy(); 8 | } 9 | 10 | nvinfer1::ITensor * 11 | network_add_input(nvinfer1::INetworkDefinition *network, const char *name, nvinfer1::DataType type, 12 | nvinfer1::Dims dims) { 13 | return network->addInput(name, type, dims); 14 | } 15 | 16 | nvinfer1::ITensor *network_get_input(nvinfer1::INetworkDefinition *network, int32_t idx) { 17 | return network->getInput(idx); 18 | } 19 | 20 | int network_get_nb_layers(nvinfer1::INetworkDefinition *network) { 21 | return network->getNbLayers(); 22 | } 23 | 24 | nvinfer1::ILayer *network_get_layer(nvinfer1::INetworkDefinition *network, int index) { 25 | return network->getLayer(index); 26 | } 27 | 28 | nvinfer1::IIdentityLayer * 29 | network_add_identity_layer(nvinfer1::INetworkDefinition *network, nvinfer1::ITensor *inputTensor) { 30 | return network->addIdentity(*inputTensor); 31 | } 32 | 33 | int network_get_nb_inputs(nvinfer1::INetworkDefinition *network) { 34 | return network->getNbInputs(); 35 | } 36 | 37 | int network_get_nb_outputs(nvinfer1::INetworkDefinition *network) { 38 | return network->getNbOutputs(); 39 | } 40 | 41 | nvinfer1::ITensor *network_get_output(nvinfer1::INetworkDefinition *network, int32_t index) { 42 | return network->getOutput(index); 43 | } 44 | 45 | void network_remove_tensor(nvinfer1::INetworkDefinition *network, nvinfer1::ITensor *tensor) { 46 | network->removeTensor(*tensor); 47 | } 48 | 49 | void network_mark_output(nvinfer1::INetworkDefinition *network, nvinfer1::ITensor *tensor) { 50 | network->markOutput(*tensor); 51 | } 52 | 53 | void network_unmark_output(nvinfer1::INetworkDefinition *network, nvinfer1::ITensor *tensor) { 54 | network->unmarkOutput(*tensor); 55 | } 56 | 57 | nvinfer1::IElementWiseLayer * 58 | network_add_element_wise(nvinfer1::INetworkDefinition *network, nvinfer1::ITensor *input1, nvinfer1::ITensor *input2, 59 | nvinfer1::ElementWiseOperation op) { 60 | return network->addElementWise(*input1, *input2, op); 61 | } 62 | 63 | nvinfer1::IGatherLayer * 64 | network_add_gather(nvinfer1::INetworkDefinition *network, nvinfer1::ITensor *data, nvinfer1::ITensor *indices, 65 | int32_t axis) { 66 | return network->addGather(*data, *indices, axis); 67 | } 68 | 69 | nvinfer1::IActivationLayer * 70 | network_add_activation(nvinfer1::INetworkDefinition *network, nvinfer1::ITensor *input, nvinfer1::ActivationType type) { 71 | return network->addActivation(*input, type); 72 | } 73 | 74 | nvinfer1::IPoolingLayer * 75 | network_add_pooling(nvinfer1::INetworkDefinition *network, nvinfer1::ITensor *input, nvinfer1::PoolingType poolingType, 76 | nvinfer1::DimsHW dims) { 77 | return network->addPooling(*input, poolingType, dims); 78 | } -------------------------------------------------------------------------------- /tensorrt-sys/trt-sys/TRTNetworkDefinition/TRTNetworkDefinition.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by mason on 11/27/19. 3 | // 4 | 5 | #ifndef LIBTRT_TRTNETWORKDEFINITION_H 6 | #define LIBTRT_TRTNETWORKDEFINITION_H 7 | 8 | #include 9 | #include 10 | 11 | void destroy_network(nvinfer1::INetworkDefinition *network); 12 | nvinfer1::ITensor *network_add_input(nvinfer1::INetworkDefinition *network, const char *name, nvinfer1::DataType dataType, nvinfer1::Dims dims); 13 | nvinfer1::ITensor *network_get_input(nvinfer1::INetworkDefinition *network, int32_t idx); 14 | int network_get_nb_layers(nvinfer1::INetworkDefinition *network); 15 | nvinfer1::ILayer *network_get_layer(nvinfer1::INetworkDefinition *network, int index); 16 | int network_get_nb_inputs(nvinfer1::INetworkDefinition *network); 17 | int network_get_nb_outputs(nvinfer1::INetworkDefinition *network); 18 | nvinfer1::ITensor *network_get_output(nvinfer1::INetworkDefinition *network, int32_t index); 19 | void network_remove_tensor(nvinfer1::INetworkDefinition *network, nvinfer1::ITensor *tensor); 20 | void network_mark_output(nvinfer1::INetworkDefinition *network, nvinfer1::ITensor *tensor); 21 | void network_unmark_output(nvinfer1::INetworkDefinition *network, nvinfer1::ITensor *tensor); 22 | nvinfer1::IIdentityLayer *network_add_identity_layer(nvinfer1::INetworkDefinition *network, nvinfer1::ITensor *inputTensor); 23 | nvinfer1::IElementWiseLayer *network_add_element_wise(nvinfer1::INetworkDefinition *network, nvinfer1::ITensor *input1, nvinfer1::ITensor *input2, nvinfer1::ElementWiseOperation op); 24 | nvinfer1::IGatherLayer *network_add_gather(nvinfer1::INetworkDefinition *network, nvinfer1::ITensor *data, nvinfer1::ITensor *indices, int32_t axis); 25 | nvinfer1::IActivationLayer *network_add_activation(nvinfer1::INetworkDefinition *network, nvinfer1::ITensor *input, nvinfer1::ActivationType type); 26 | nvinfer1::IPoolingLayer *network_add_pooling(nvinfer1::INetworkDefinition *network, nvinfer1::ITensor *input, nvinfer1::PoolingType poolingType, nvinfer1::DimsHW dims); 27 | 28 | #endif //LIBTRT_TRTNETWORKDEFINITION_H 29 | -------------------------------------------------------------------------------- /tensorrt-sys/trt-sys/TRTOnnxParser/TRTOnnxParser.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "NvOnnxParser.h" 5 | 6 | #include "TRTOnnxParser.h" 7 | #include "../TRTLogger/TRTLoggerInternal.hpp" 8 | #include "../TRTUtils.hpp" 9 | 10 | struct OnnxParser { 11 | using IOnnxParserPtr = std::unique_ptr>; 12 | IOnnxParserPtr internal_onnxParser; 13 | 14 | explicit OnnxParser(nvonnxparser::IParser *onnxParser) : internal_onnxParser(onnxParser) {}; 15 | }; 16 | 17 | OnnxParser_t *onnxparser_create_parser(nvinfer1::INetworkDefinition *network, Logger_t *logger) { 18 | return new OnnxParser(nvonnxparser::createParser(*network, logger->getLogger())); 19 | } 20 | 21 | void onnxparser_destroy_parser(OnnxParser_t *onnx_parser) { 22 | if (onnx_parser == nullptr) 23 | return; 24 | 25 | delete onnx_parser; 26 | } 27 | 28 | bool onnxparser_parse_from_file(const OnnxParser_t *onnx_parser, const char *file, int verbosity) { 29 | if (onnx_parser == nullptr || file == nullptr) 30 | return false; 31 | 32 | return onnx_parser->internal_onnxParser->parseFromFile(file, verbosity); 33 | } 34 | -------------------------------------------------------------------------------- /tensorrt-sys/trt-sys/TRTOnnxParser/TRTOnnxParser.h: -------------------------------------------------------------------------------- 1 | #ifndef LIBTRT_TRTONNXPARSER_H 2 | #define LIBTRT_TRTONNXPARSER_H 3 | 4 | #include 5 | #include "../TRTNetworkDefinition/TRTNetworkDefinition.h" 6 | #include "../TRTLogger/TRTLogger.h" 7 | 8 | #ifdef __cplusplus 9 | extern "C" { 10 | #endif 11 | 12 | struct OnnxParser; 13 | typedef struct OnnxParser OnnxParser_t; 14 | 15 | OnnxParser_t* onnxparser_create_parser(nvinfer1::INetworkDefinition *network, Logger_t *logger); 16 | void onnxparser_destroy_parser(OnnxParser_t* onnx_parser); 17 | 18 | bool onnxparser_parse_from_file(const OnnxParser_t* onnx_parser, const char* file, int verbosity); 19 | 20 | #ifdef __cplusplus 21 | }; 22 | #endif 23 | 24 | #endif //LIBTRT_TRTONNXPARSER_H 25 | -------------------------------------------------------------------------------- /tensorrt-sys/trt-sys/TRTProfiler/TRTProfiler.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Created by mason on 12/16/20. 3 | // 4 | 5 | #include "TRTProfilerInternal.hpp" 6 | #include 7 | 8 | CppProfiler* create_profiler(Profiler_t * rust_profiler) { 9 | return new CppProfiler(rust_profiler); 10 | } 11 | 12 | void destroy_profiler(CppProfiler* profiler) { 13 | delete profiler; 14 | } 15 | 16 | CppProfiler::~CppProfiler() { 17 | (*profiler->destroy)(profiler, profiler->context); 18 | } 19 | 20 | void CppProfiler::reportLayerTime(const char *layerName, float ms) { 21 | (*profiler->reportLayerTime)(profiler->context, layerName, ms); 22 | } -------------------------------------------------------------------------------- /tensorrt-sys/trt-sys/TRTProfiler/TRTProfiler.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by mason on 10/7/20. 3 | // 4 | 5 | #ifndef LIBTRT_TRTPROFILER_H 6 | #define LIBTRT_TRTPROFILER_H 7 | 8 | 9 | struct Profiler; 10 | typedef struct Profiler Profiler_t; 11 | class CppProfiler; 12 | 13 | CppProfiler* create_profiler(Profiler_t * rust_profiler); 14 | void destroy_profiler(CppProfiler* profiler); 15 | 16 | 17 | #endif //LIBTRT_TRTPROFILER_H 18 | -------------------------------------------------------------------------------- /tensorrt-sys/trt-sys/TRTProfiler/TRTProfilerInternal.hpp: -------------------------------------------------------------------------------- 1 | // 2 | // Created by mason on 10/7/20. 3 | // 4 | 5 | #ifndef LIBTRT_TRTPROFILERINTERNAL_HPP 6 | #define LIBTRT_TRTPROFILERINTERNAL_HPP 7 | 8 | #include "TRTProfiler.h" 9 | #include 10 | 11 | struct Profiler { 12 | void (*reportLayerTime)(void *context, const char *layerName, float ms); 13 | 14 | void (*destroy)(void *self, void *context); 15 | 16 | void *context; 17 | }; 18 | 19 | class CppProfiler : public nvinfer1::IProfiler { 20 | public: 21 | explicit CppProfiler(Profiler_t *_profiler) : profiler(_profiler) {} 22 | ~CppProfiler(); 23 | 24 | void reportLayerTime(const char *layerName, float ms) override; 25 | 26 | Profiler_t *getInternalProfiler() { 27 | return profiler; 28 | } 29 | 30 | private: 31 | Profiler_t *profiler; 32 | }; 33 | 34 | #endif //LIBTRT_TRTPROFILERINTERNAL_HPP 35 | -------------------------------------------------------------------------------- /tensorrt-sys/trt-sys/TRTRuntime/TRTRuntime.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Created by mason on 8/25/19. 3 | // 4 | #include "TRTRuntime.h" 5 | #include "../TRTLogger/TRTLoggerInternal.hpp" 6 | 7 | nvinfer1::IRuntime *create_infer_runtime(Logger_t *logger) { 8 | return nvinfer1::createInferRuntime(logger->getLogger()); 9 | } 10 | 11 | void destroy_infer_runtime(nvinfer1::IRuntime *runtime) { 12 | runtime->destroy(); 13 | } 14 | 15 | nvinfer1::ICudaEngine *deserialize_cuda_engine(nvinfer1::IRuntime *runtime, const void *blob, unsigned long long size) { 16 | return runtime->deserializeCudaEngine(blob, size, nullptr); 17 | } 18 | 19 | int runtime_get_nb_dla_cores(nvinfer1::IRuntime *runtime) { 20 | return runtime->getNbDLACores(); 21 | } 22 | 23 | int runtime_get_dla_core(nvinfer1::IRuntime *runtime) { 24 | return runtime->getDLACore(); 25 | } 26 | 27 | void runtime_set_dla_core(nvinfer1::IRuntime *runtime, int dla_core) { 28 | runtime->setDLACore(dla_core); 29 | } 30 | -------------------------------------------------------------------------------- /tensorrt-sys/trt-sys/TRTRuntime/TRTRuntime.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by mason on 8/25/19. 3 | // 4 | 5 | #ifndef TENSRORT_SYS_TRTRUNTIME_H 6 | #define TENSRORT_SYS_TRTRUNTIME_H 7 | 8 | #include 9 | #include "../TRTLogger/TRTLogger.h" 10 | #include "../TRTCudaEngine/TRTCudaEngine.h" 11 | 12 | nvinfer1::IRuntime *create_infer_runtime(Logger_t *logger); 13 | void destroy_infer_runtime(nvinfer1::IRuntime *runtime); 14 | nvinfer1::ICudaEngine *deserialize_cuda_engine(nvinfer1::IRuntime *runtime, const void *blob, unsigned long long size); 15 | int runtime_get_nb_dla_cores(nvinfer1::IRuntime *runtime); 16 | int runtime_get_dla_core(nvinfer1::IRuntime *runtime); 17 | void runtime_set_dla_core(nvinfer1::IRuntime *runtime, int dla_core); 18 | 19 | #endif //TENSRORT_SYS_TRTRUNTIME_H 20 | -------------------------------------------------------------------------------- /tensorrt-sys/trt-sys/TRTTensor/TRTTensor.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Created by mason on 10/10/20. 3 | // 4 | 5 | #include "TRTTensor.h" 6 | 7 | const char* tensor_get_name(nvinfer1::ITensor *tensor) { 8 | return tensor->getName(); 9 | } 10 | 11 | void tensor_set_dimensions(nvinfer1::ITensor *tensor, nvinfer1::Dims dimensions) { 12 | tensor->setDimensions(dimensions); 13 | } 14 | 15 | -------------------------------------------------------------------------------- /tensorrt-sys/trt-sys/TRTTensor/TRTTensor.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by mason on 10/10/20. 3 | // 4 | 5 | #ifndef LIBTRT_TRTTENSOR_H 6 | #define LIBTRT_TRTTENSOR_H 7 | 8 | #include 9 | #include "../TRTDims/TRTDims.h" 10 | 11 | const char* tensor_get_name(nvinfer1::ITensor *tensor); 12 | void tensor_set_dimensions(nvinfer1::ITensor *tensor, nvinfer1::Dims dimensions); // only valid for input tensors 13 | 14 | #endif //LIBTRT_TRTTENSOR_H 15 | -------------------------------------------------------------------------------- /tensorrt-sys/trt-sys/TRTUffParser/TRTUffParser.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Created by mason on 11/22/19. 3 | // 4 | #include 5 | #include 6 | 7 | #include "NvUffParser.h" 8 | 9 | #include "TRTUffParser.h" 10 | #include "../TRTUtils.hpp" 11 | 12 | struct UffParser { 13 | using IUffParserPtr = std::unique_ptr>; 14 | IUffParserPtr internal_uffParser; 15 | 16 | explicit UffParser(nvuffparser::IUffParser *uffParser) : internal_uffParser(uffParser) {}; 17 | }; 18 | 19 | UffParser_t *uffparser_create_uff_parser() { 20 | return new UffParser(nvuffparser::createUffParser()); 21 | } 22 | 23 | void uffparser_destroy_uff_parser(UffParser_t *uff_parser) { 24 | if (uff_parser == nullptr) 25 | return; 26 | 27 | delete uff_parser; 28 | } 29 | 30 | bool uffparser_register_input(const UffParser_t *uff_parser, const char *input_name, const nvinfer1::Dims dims, int input_order) { 31 | if (uff_parser == nullptr || input_name == nullptr) 32 | return false; 33 | 34 | auto inputOrder = static_cast(input_order); 35 | return uff_parser->internal_uffParser->registerInput(input_name, dims, inputOrder); 36 | } 37 | 38 | bool uffparser_register_output(const UffParser_t *uff_parser, const char *output_name) { 39 | if (uff_parser == nullptr || output_name == nullptr) 40 | return false; 41 | 42 | return uff_parser->internal_uffParser->registerOutput(output_name); 43 | } 44 | 45 | bool uffparser_parse(const UffParser_t *uff_parser, const char *file, nvinfer1::INetworkDefinition *network) { 46 | if (uff_parser == nullptr || file == nullptr || network == nullptr) 47 | return false; 48 | 49 | return uff_parser->internal_uffParser->parse(file, *network, nvinfer1::DataType::kFLOAT); 50 | } -------------------------------------------------------------------------------- /tensorrt-sys/trt-sys/TRTUffParser/TRTUffParser.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by mason on 11/22/19. 3 | // 4 | 5 | #ifndef LIBTRT_TRTUFFPARSER_H 6 | #define LIBTRT_TRTUFFPARSER_H 7 | 8 | #include 9 | #include 10 | 11 | #include "../TRTNetworkDefinition/TRTNetworkDefinition.h" 12 | 13 | #ifdef __cplusplus 14 | extern "C" { 15 | #endif 16 | 17 | struct UffParser; 18 | typedef struct UffParser UffParser_t; 19 | 20 | UffParser_t* uffparser_create_uff_parser(); 21 | void uffparser_destroy_uff_parser(UffParser_t* uff_parser); 22 | 23 | bool uffparser_register_input(const UffParser_t* uff_parser, const char* input_name, nvinfer1::Dims dims, int input_order); 24 | bool uffparser_register_output(const UffParser_t* uff_parser, const char* output_name); 25 | 26 | bool uffparser_parse(const UffParser_t* uff_parser, const char* file, nvinfer1::INetworkDefinition *network); 27 | 28 | #ifdef __cplusplus 29 | }; 30 | #endif 31 | 32 | #endif //LIBTRT_TRTUFFPARSER_H 33 | -------------------------------------------------------------------------------- /tensorrt-sys/trt-sys/TRTUtils.hpp: -------------------------------------------------------------------------------- 1 | // 2 | // Created by mason on 4/11/20. 3 | // 4 | 5 | #ifndef LIBTRT_TRTUTILS_HPP 6 | #define LIBTRT_TRTUTILS_HPP 7 | 8 | template 9 | struct TRTDeleter { 10 | void operator()(T* ptr) { 11 | ptr->destroy(); 12 | } 13 | }; 14 | 15 | #endif //LIBTRT_TRTUTILS_HPP 16 | -------------------------------------------------------------------------------- /tensorrt-sys/trt-sys/tensorrt_api.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by mason on 8/25/19. 3 | // 4 | 5 | #ifndef TENSRORT_SYS_TENSORRT_API_H 6 | #define TENSRORT_SYS_TENSORRT_API_H 7 | #include 8 | 9 | #include "TRTEnums.h" 10 | #include "TRTLogger/TRTLogger.h" 11 | #include "TRTRuntime/TRTRuntime.h" 12 | #include "TRTCudaEngine/TRTCudaEngine.h" 13 | #include "TRTContext/TRTContext.h" 14 | #include "TRTUffParser/TRTUffParser.h" 15 | #include "TRTOnnxParser/TRTOnnxParser.h" 16 | #include "TRTDims/TRTDims.h" 17 | #include "TRTBuilder/TRTBuilder.h" 18 | #include "TRTNetworkDefinition/TRTNetworkDefinition.h" 19 | #include "TRTHostMemory/TRTHostMemory.h" 20 | #include "TRTProfiler/TRTProfiler.h" 21 | #include "TRTTensor/TRTTensor.h" 22 | #include "TRTLayer/TRTLayer.h" 23 | #include "TRTLayer/TRTElementWiseLayer.h" 24 | #include "TRTLayer/TRTGatherLayer.h" 25 | #include "TRTLayer/TRTActivationLayer.h" 26 | #include "TRTLayer/TRTPoolingLayer.h" 27 | 28 | #endif //TENSRORT_SYS_TENSORRT_API_H 29 | -------------------------------------------------------------------------------- /tensorrt/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "tensorrt-rs" 3 | version = "0.4.0" 4 | authors = ["Mason Stallmo "] 5 | license = "MIT" 6 | repository = "https://github.com/mstallmo/tensorrt-rs" 7 | edition = "2018" 8 | description = "Rust library for using Nvidia's TensorRT deep learning acceleration library" 9 | 10 | [features] 11 | default = ["trt-7"] 12 | 13 | trt-5 = ["tensorrt-sys/trt-5"] 14 | 15 | trt-6 = ["tensorrt-sys/trt-6"] 16 | 17 | trt-7 = ["tensorrt-sys/trt-7"] 18 | 19 | 20 | [dependencies] 21 | # Uncomment when working locally 22 | #tensorrt-sys = { path = "../tensorrt-sys" } 23 | tensorrt-sys = { git = "https://github.com/mstallmo/tensorrt-rs", branch = "develop" } 24 | #tensorrt_rs_derive = { path = "../tensorrt_rs_derive" } 25 | tensorrt_rs_derive = { git = "https://github.com/mstallmo/tensorrt-rs", branch = "develop" } 26 | ndarray = "0.13" 27 | ndarray-image = "0.2" 28 | image = "0.23" 29 | imageproc = "0.21.0" 30 | bitflags = "1.2" 31 | num-traits = "0.2.12" 32 | num-derive = "0.3.2" 33 | cuda-runtime-sys = "0.3.0-alpha.1" 34 | anyhow = "1.0.35" 35 | 36 | [dev-dependencies] 37 | lazy_static = "1.4" 38 | 39 | 40 | [[example]] 41 | name = "onnx" 42 | required-features = ["trt-7"] 43 | -------------------------------------------------------------------------------- /tensorrt/README.md: -------------------------------------------------------------------------------- 1 | # TensorRT-RS 2 | ![Crates.io](https://img.shields.io/crates/v/tensorrt-rs) 3 | 4 | :warning: __This crate currently only supports Linux__ :warning: 5 | 6 | Rust library for creating and executing TensorRT engines. 7 | 8 | This library depends on tensorrt-sys for Rust bindings to the underlying C++ TensorRT library. See the tensorrt-sys 9 | README for information on prerequisite dependencies. 10 | 11 | ### Status 12 | This crate is still very much in early stage devleopment. Support for TensorRT functionality is only the basic of what 13 | is needed to read a model file in .uff format, parse that file into a TensorRT engine and execute that engine, and seralize 14 | a binary version of that engine to disk. Currently TensorRT plugins are not supported so there may be issues when trying 15 | to use a model that requires plugin support to operate. 16 | 17 | 18 | ### Upcoming Improvements to the Project 19 | - ~~Feature configuration for selecting the TensorRT library version~~ - added in 0.4.0 20 | - Support for CUDA streams for async execution 21 | - ~~Support for the TensorRT Onnx Parser~~ - added in 0.4.0 22 | - Support for all functionality in the TensorRT C++ library 23 | - Complete support for the classes that already have bindings. 24 | - Add support for loading custom plugins. 25 | - Add support for custom layer implementations 26 | 27 | Upcoming features are not constrained to those listed above. Any feature requests are welcome and appreciated! 28 | 29 | ### Usage 30 | 31 | #### TensorRT version 32 | TensorRT-RS supports multiple versions of bindings to the underlying TensorRT library. The current supported versions 33 | are TensorRT 5, TensorRT 6, and TensorRT 7. The various bindings are enabled via cargo features. By default bindings 34 | are generated for TensorRT 5. If you need to bind to one of the other supported versions of TensorRT set the feature 35 | in your `Cargo.toml`. 36 | ``` 37 | [dependencies] 38 | tensorrt-rs = { version = 0.4.0, default-features = false, features = ["trt-7"] } 39 | ``` 40 | 41 | #### TensorRT Library in Non default Location 42 | By default TensorRT will link to the TensorRT dylibs that are installed in the system library directory on Linux 43 | (there isn't a Windows default install location). If you would like to use a different version of TensorRT that is not 44 | in the default location you can set the path to the libraries via the `TRT_INSTALL_DIR` environment variable. The env 45 | variable should point to the root of the TensorRT folder. The build process will append `/lib` and `/include` to the ` 46 | TRT_INSTALL_DIR` where appropriate. 47 | 48 | ex. 49 | ```shell script 50 | export TRT_INSTALL_DIR=~/TensorRT-7.1.3.4 51 | ``` 52 | 53 | #### CUDA Library in Non-default Location 54 | Related to linking to a [library in non-default location](#TensorRT-Library-in-Non-default-Location) we also support 55 | linking to a non default CUDA install location. When using a different TensorRT library than is installed in the default 56 | location on your system it's likely that the CUDA version that is installed will not be correct for that TensorRT 57 | version. This is also done via an environment variable `CUDA_INSTALL_DIR`. This variable should be set to the root of 58 | the directory as well such that appending `lib64` will result in finding the appropriate CUDA libraries. 59 | 60 | ex. 61 | ```shell script 62 | export CUDA_INSTALL_DIR=~/cuda-10.1 63 | ``` 64 | 65 | ### Examples 66 | See the examples directory for a basic example on loading a UFF model, creating an engine, and performing inference on 67 | an image using the engine. The sample uses the UFF MNIST model provided in the samples directory when installing TensorRT 68 | on Linux. 69 | 70 | Contributions are always welcome! Feel free to reach out with any questions or improvements and I will try to get back 71 | to you as soon as possible! -------------------------------------------------------------------------------- /tensorrt/examples/README.md: -------------------------------------------------------------------------------- 1 | # Examples 2 | 3 | This directory contains example applications for using the TensorRT-rs library. These are small examples that demonstrate 4 | basic usage of the library and associated data types to run model inference using TensorRT. Examples are named for the 5 | model used as well as the file format output from the training framework (uff or onnx). All sample data and models ship 6 | with the TensorRT library install from nvidia. 7 | 8 | These examples are adapted from the C++ examples that are provided with TensorRT. The logic is not a 1 for 1 port but the 9 | basic model inference idea is the same. 10 | 11 | More examples are always welcome. Please feel free to open a PR with a new example if you feel that one is missing from 12 | the list here. 13 | 14 | ## Getting Started 15 | 16 | All examples expect models and input data to be available in the `assets` subfolder that resides in the top level 17 | directory. This folder is not tracked by git and will need to be crated and populated on your machine. 18 | 19 | Models and input data are shipped with TensorRT and should be copied into the assets folder. Models live at the top level 20 | and images live in the images sub-folder. 21 | 22 | If you installed TensorRT via the .deb archive on Linux these assets can be found in `/usr/src/tensorrt/data`. In this 23 | folder each of the assets will be split by specific example sub-folder (mnist, ssd, etc). 24 | 25 | ## mnist_uff 26 | This example uses the MNIST digit dataset and classifies images of hand drawn digits. The output is 10 entries representing 27 | the numbers 0..9 and will have an associated floating point value that indicates what digit the model identified. 28 | 29 | For more information about this example see: 30 | [TensorRT MNIST Sample](https://github.com/mstallmo/TensorRT/tree/master/samples/opensource/sampleUffMNIST) 31 | 32 | ### Run 33 | ```shell script 34 | $ cargo run --example mnist_uff 35 | ``` 36 | 37 | ## ssd_uff 38 | This example uses the Single Shot MultiBox Detector to perform object detection in images. The output of the model will 39 | be the bounding box data and confidence scores. This plugin takes advantage of the TensorRT plugins provided by nvidia, 40 | specifically the NMS (non-max suppression) plugin. 41 | 42 | To properly map the operations in the UFF graph to the appropriate TensorRT plugin there is a small preprocessing step 43 | that needs to be done to the model before it can be loaded by TensorRT in the example. The details can be found in the 44 | [TensorRT SSD Sample](https://github.com/mstallmo/TensorRT/tree/master/samples/opensource/sampleUffSSD#sampleuffssd-plugins) 45 | 46 | The result of this pre-processing should be placed into the `assets` in the top level directory. 47 | 48 | ### Run 49 | ```shell script 50 | $ cargo run --example ssd_uff 51 | ``` -------------------------------------------------------------------------------- /tensorrt/examples/basic/README.md: -------------------------------------------------------------------------------- 1 | # Basic Example 2 | 3 | This example demonstrates the basic creating of a TensorRT engine from 4 | a .uff file produced by any of the major deep learning frameworks. Specifically 5 | this example uses the MNIST .uff file provided in the samples of the TensorRT 6 | install. 7 | 8 | In this example we create a TensorRT engine from the path to the .uff file's location 9 | and print out some basic information about the created engine. 10 | 11 | To run this example: 12 | ``` 13 | cargo run --example basic 14 | ``` -------------------------------------------------------------------------------- /tensorrt/examples/basic/main.rs: -------------------------------------------------------------------------------- 1 | use std::path::Path; 2 | use tensorrt_rs::builder::{Builder, NetworkBuildFlags}; 3 | use tensorrt_rs::dims::DimsCHW; 4 | use tensorrt_rs::engine::Engine; 5 | use tensorrt_rs::runtime::Logger; 6 | use tensorrt_rs::uff::{UffFile, UffInputOrder, UffParser}; 7 | 8 | fn create_engine(logger: &Logger, uff_file: UffFile) -> Engine { 9 | let builder = Builder::new(&logger); 10 | let network = builder.create_network_v2(NetworkBuildFlags::DEFAULT); 11 | 12 | let uff_parser = UffParser::new(); 13 | let dim = DimsCHW::new(1, 28, 28); 14 | uff_parser 15 | .register_input("in", dim, UffInputOrder::Nchw) 16 | .unwrap(); 17 | uff_parser.register_output("out").unwrap(); 18 | uff_parser.parse(&uff_file, &network).unwrap(); 19 | 20 | builder.build_cuda_engine(&network) 21 | } 22 | 23 | fn main() { 24 | let logger = Logger::new(); 25 | let uff_file = UffFile::new(Path::new("../assets/lenet5.uff")).unwrap(); 26 | let engine = create_engine(&logger, uff_file); 27 | 28 | println!("Engine number of bindings: {}", engine.get_nb_bindings()); 29 | 30 | for binding_index in 0..engine.get_nb_bindings() { 31 | println!( 32 | "Binding name at {}: {}", 33 | binding_index, 34 | engine.get_binding_name(binding_index).unwrap() 35 | ); 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /tensorrt/examples/mnist_uff/main.rs: -------------------------------------------------------------------------------- 1 | use ndarray::Array; 2 | use ndarray_image; 3 | use std::iter::FromIterator; 4 | use std::path::Path; 5 | use tensorrt_rs::builder::{Builder, NetworkBuildFlags}; 6 | use tensorrt_rs::context::ExecuteInput; 7 | use tensorrt_rs::data_size::GB; 8 | use tensorrt_rs::dims::DimsCHW; 9 | use tensorrt_rs::engine::Engine; 10 | use tensorrt_rs::profiler::{DefaultProfiler, Profiler}; 11 | use tensorrt_rs::runtime::Logger; 12 | use tensorrt_rs::uff::{UffFile, UffInputOrder, UffParser}; 13 | 14 | fn create_engine(logger: &Logger, uff_file: UffFile) -> Engine { 15 | let builder = Builder::new(&logger); 16 | builder.set_max_workspace_size(1 * GB); 17 | let network = builder.create_network_v2(NetworkBuildFlags::DEFAULT); 18 | 19 | let uff_parser = UffParser::new(); 20 | let dim = DimsCHW::new(1, 28, 28); 21 | uff_parser 22 | .register_input("in", dim, UffInputOrder::Nchw) 23 | .unwrap(); 24 | uff_parser.register_output("out").unwrap(); 25 | uff_parser.parse(&uff_file, &network).unwrap(); 26 | 27 | builder.build_cuda_engine(&network) 28 | } 29 | 30 | fn main() { 31 | // Create TensorRT engine from .uff file 32 | let logger = Logger::new(); 33 | let uff_file = UffFile::new(Path::new("../assets/lenet5.uff")).unwrap(); 34 | let engine = create_engine(&logger, uff_file); 35 | 36 | let profiler = Profiler::new(DefaultProfiler::new()); 37 | 38 | // Create execution context 39 | let context = engine.create_execution_context(); 40 | context.set_profiler(&profiler); 41 | 42 | // Load image from disk 43 | let input_image = image::open("../assets/images/0.pgm").unwrap().into_luma(); 44 | println!("Image dimensions: {:?}", input_image.dimensions()); 45 | 46 | // Convert image to ndarray 47 | let array: ndarray_image::NdGray = ndarray_image::NdImage(&input_image).into(); 48 | println!("NdArray len: {}", array.len()); 49 | let mut pre_processed = Array::from_iter(array.iter().map(|&x| 1.0 - (x as f32) / 255.0)); 50 | 51 | // Run inference 52 | let mut output = ndarray::Array1::::zeros(10); 53 | let outputs = vec![ExecuteInput::Float(&mut output)]; 54 | context 55 | .execute(ExecuteInput::Float(&mut pre_processed), outputs) 56 | .unwrap(); 57 | println!("output: {}", output); 58 | } 59 | -------------------------------------------------------------------------------- /tensorrt/examples/onnx/main.rs: -------------------------------------------------------------------------------- 1 | use ndarray::Array; 2 | use ndarray_image; 3 | use std::iter::FromIterator; 4 | use std::path::PathBuf; 5 | use tensorrt_rs::builder::{Builder, NetworkBuildFlags}; 6 | use tensorrt_rs::context::ExecuteInput; 7 | use tensorrt_rs::data_size::GB; 8 | use tensorrt_rs::dims::Dims4; 9 | use tensorrt_rs::engine::Engine; 10 | use tensorrt_rs::onnx::{OnnxFile, OnnxParser}; 11 | use tensorrt_rs::runtime::Logger; 12 | 13 | fn create_engine( 14 | logger: &Logger, 15 | file: OnnxFile, 16 | batch_size: i32, 17 | workspace_size: usize, 18 | ) -> Engine { 19 | let builder = Builder::new(&logger); 20 | builder.set_max_workspace_size(1 * GB); 21 | let network = builder.create_network_v2(NetworkBuildFlags::EXPLICIT_BATCH); 22 | let verbosity = 7; 23 | 24 | builder.set_max_batch_size(batch_size); 25 | builder.set_max_workspace_size(workspace_size); 26 | 27 | let parser = OnnxParser::new(&network, &logger); 28 | parser.parse_from_file(&file, verbosity).unwrap(); 29 | 30 | let dim = Dims4::new(batch_size, 224, 224, 3); 31 | network.get_input(0).set_dimensions(dim); 32 | builder.build_cuda_engine(&network) 33 | } 34 | 35 | fn main() { 36 | let logger = Logger::new(); 37 | let file = OnnxFile::new(&PathBuf::from("../assets/efficientnet.onnx")).unwrap(); 38 | let engine = create_engine(&logger, file, 1, 1 * GB); 39 | 40 | let context = engine.create_execution_context(); 41 | 42 | let input_image = image::open("../assets/images/meme.jpg") 43 | .unwrap() 44 | .crop(0, 0, 100, 100) 45 | .into_rgb(); 46 | eprintln!("Image dimensions: {:?}", input_image.dimensions()); 47 | 48 | // Convert image to ndarray 49 | let array: ndarray_image::NdColor = ndarray_image::NdImage(&input_image).into(); 50 | println!("NdArray len: {}", array.len()); 51 | 52 | let mut pre_processed = Array::from_iter(array.iter().map(|&x| 1.0 - (x as f32) / 255.0)); 53 | 54 | // Run inference 55 | let mut output = ndarray::Array1::::zeros(1000); 56 | let outputs = vec![ExecuteInput::Float(&mut output)]; 57 | context 58 | .execute(ExecuteInput::Float(&mut pre_processed), outputs) 59 | .unwrap(); 60 | println!("output: {}", output); 61 | } 62 | -------------------------------------------------------------------------------- /tensorrt/examples/ssd_uff/main.rs: -------------------------------------------------------------------------------- 1 | use image::RgbImage; 2 | use imageproc::rect::Rect; 3 | use ndarray::{Array, Array1, Array3}; 4 | use std::iter::FromIterator; 5 | use std::path::Path; 6 | use tensorrt_rs::builder::{Builder, NetworkBuildFlags}; 7 | use tensorrt_rs::context::ExecuteInput; 8 | use tensorrt_rs::data_size::GB; 9 | use tensorrt_rs::dims::{Dim, DimsCHW}; 10 | use tensorrt_rs::engine::Engine; 11 | use tensorrt_rs::runtime::Logger; 12 | use tensorrt_rs::uff::{UffFile, UffInputOrder, UffParser}; 13 | 14 | fn create_engine(logger: &Logger, uff_file: &UffFile) -> Engine { 15 | let builder = Builder::new(&logger); 16 | builder.set_max_workspace_size(1 * GB); 17 | let network = builder.create_network_v2(NetworkBuildFlags::DEFAULT); 18 | 19 | let uff_parser = UffParser::new(); 20 | let dim = DimsCHW::new(3, 300, 300); 21 | uff_parser 22 | .register_input("Input", dim, UffInputOrder::Nchw) 23 | .unwrap(); 24 | uff_parser.register_output("NMS").unwrap(); 25 | uff_parser.parse(uff_file, &network).unwrap(); 26 | 27 | builder.build_cuda_engine(&network) 28 | } 29 | 30 | //Input formatting logic comes directly from the C++ code in the sampleUffSSD.cpp. 31 | //https://github.com/NVIDIA/TensorRT/blob/release/5.1/samples/opensource/sampleUffSSD/sampleUffSSD.cpp 32 | fn process_input(image: &RgbImage) -> Array1 { 33 | let mut base_array = Array3::::zeros((3, image.height() as usize, image.width() as usize)); 34 | for c in 0..3 { 35 | for j in 0..(image.height() * image.width()) as usize { 36 | base_array.as_slice_mut().unwrap()[c * (300 * 300) + j] = 37 | (2.0 / 255.0) * (image.as_flat_samples().as_slice()[j * 3 + c] as f32) - 1.0; 38 | } 39 | } 40 | Array::from_iter(base_array.iter().cloned()) 41 | } 42 | 43 | fn infer(engine: &Engine, input: &mut Array1) -> (ndarray::Array1, ndarray::Array1) { 44 | let context = engine.create_execution_context(); 45 | 46 | let binding_dim = engine.get_binding_dimensions(1); 47 | let dim_slice = &binding_dim.d()[0..binding_dim.nb_dims() as usize]; 48 | let vol = dim_slice.iter().fold(1, |acc, x| acc * x) as usize; 49 | let mut top_detections = Array1::::zeros(vol); 50 | 51 | let binding_dim = engine.get_binding_dimensions(2); 52 | let dim_slice = &binding_dim.d()[0..binding_dim.nb_dims() as usize]; 53 | let vol = dim_slice.iter().fold(1, |acc, x| acc * x) as usize; 54 | let mut keep_count = Array1::::zeros(vol); 55 | 56 | let outputs = vec![ 57 | ExecuteInput::Float(&mut top_detections), 58 | ExecuteInput::Integer(&mut keep_count), 59 | ]; 60 | let execute_input = ExecuteInput::Float(input); 61 | context.execute(execute_input, outputs).unwrap(); 62 | 63 | (top_detections, keep_count) 64 | } 65 | 66 | fn verify_output(image: &mut RgbImage, top_detections: &Array1, keep_count: &Array1) { 67 | for i in 0..keep_count[0] as usize { 68 | let det_base_index = i * 7; 69 | 70 | if top_detections[det_base_index + 2] > 0.5 { 71 | let min_x = top_detections[det_base_index + 3] * image.width() as f32; 72 | let min_y = top_detections[det_base_index + 4] * image.height() as f32; 73 | let max_x = top_detections[det_base_index + 5] * image.width() as f32; 74 | let max_y = top_detections[det_base_index + 6] * image.height() as f32; 75 | 76 | let rect = Rect::at((min_x) as i32, (min_y) as i32).of_size( 77 | (max_x) as u32 - (min_x) as u32, 78 | (max_y) as u32 - (min_y) as u32, 79 | ); 80 | 81 | imageproc::drawing::draw_hollow_rect_mut(image, rect, image::Rgb([255u8, 0u8, 0u8])); 82 | 83 | let confidence_string = format!("confidence {}", top_detections[2] * 100.0); 84 | let coordinates_string = 85 | format!("coordinates ({}, {}), ({}, {})", min_x, min_y, max_x, max_y); 86 | println!( 87 | "Detected dog in the image with {} and {}", 88 | confidence_string, coordinates_string 89 | ); 90 | } 91 | } 92 | 93 | image.save("test.jpg").unwrap(); 94 | } 95 | 96 | fn main() { 97 | let logger = Logger::new(); 98 | let uff_file = UffFile::new(Path::new("../assets/sample_ssd_relu6.uff")).unwrap(); 99 | let engine = create_engine(&logger, &uff_file); 100 | 101 | let mut input_image = image::open("../assets/images/dog.ppm").unwrap().into_rgb(); 102 | let mut input_buffer = process_input(&input_image); 103 | 104 | let (top_detections, keep_count) = infer(&engine, &mut input_buffer); 105 | verify_output(&mut input_image, &top_detections, &keep_count); 106 | 107 | println!("Done!"); 108 | } 109 | -------------------------------------------------------------------------------- /tensorrt/src/builder/mod.rs: -------------------------------------------------------------------------------- 1 | #[cfg(test)] 2 | mod tests; 3 | 4 | use std::marker::PhantomData; 5 | 6 | use crate::engine::Engine; 7 | use crate::network::layer::Layer; 8 | use crate::network::Network; 9 | use crate::runtime::Logger; 10 | use num_derive::FromPrimitive; 11 | use num_traits::FromPrimitive; 12 | use std::os::raw::c_int; 13 | #[cfg(feature = "trt-5")] 14 | use tensorrt_sys::create_network; 15 | #[cfg(not(feature = "trt-5"))] 16 | use tensorrt_sys::create_network_v2; 17 | 18 | use tensorrt_sys::{ 19 | build_cuda_engine, builder_allow_gpu_fallback, builder_can_run_on_dla, 20 | builder_get_average_find_iterations, builder_get_debug_sync, builder_get_default_device_type, 21 | builder_get_device_type, builder_get_dla_core, builder_get_engine_capability, 22 | builder_get_fp16_mode, builder_get_half2_mode, builder_get_int8_mode, 23 | builder_get_max_batch_size, builder_get_max_dla_batch_size, builder_get_max_workspace_size, 24 | builder_get_min_find_iterations, builder_get_nb_dla_cores, builder_get_refittable, 25 | builder_get_strict_type_constraints, builder_is_device_type_set, 26 | builder_platform_has_fast_fp16, builder_platform_has_fast_int8, builder_reset, 27 | builder_reset_device_type, builder_set_average_find_iterations, builder_set_debug_sync, 28 | builder_set_default_device_type, builder_set_device_type, builder_set_dla_core, 29 | builder_set_engine_capability, builder_set_fp16_mode, builder_set_half2_mode, 30 | builder_set_int8_mode, builder_set_max_batch_size, builder_set_max_workspace_size, 31 | builder_set_min_find_iterations, builder_set_refittable, builder_set_strict_type_constraints, 32 | create_infer_builder, destroy_builder, 33 | }; 34 | 35 | #[repr(C)] 36 | #[derive(Eq, PartialEq, Debug, FromPrimitive)] 37 | pub enum DeviceType { 38 | GPU, 39 | DLA, 40 | } 41 | 42 | #[repr(C)] 43 | #[derive(Eq, PartialEq, Debug, FromPrimitive)] 44 | pub enum EngineCapability { 45 | Default, 46 | SafeGpu, 47 | SafeDla, 48 | } 49 | 50 | pub struct Builder<'a> { 51 | pub(crate) internal_builder: *mut tensorrt_sys::nvinfer1_IBuilder, 52 | pub(crate) logger: PhantomData<&'a Logger>, 53 | } 54 | 55 | bitflags! { 56 | pub struct NetworkBuildFlags: u32 { 57 | const DEFAULT = 0b0; 58 | const EXPLICIT_BATCH = 0b1; 59 | const EXPLICIT_PRECISION = 0b10; 60 | } 61 | } 62 | 63 | impl<'a> Builder<'a> { 64 | pub fn new(logger: &'a Logger) -> Self { 65 | let internal_builder = unsafe { create_infer_builder(logger.internal_logger) }; 66 | let logger = PhantomData; 67 | Self { 68 | internal_builder, 69 | logger, 70 | } 71 | } 72 | 73 | pub fn get_max_workspace_size(&self) -> usize { 74 | unsafe { builder_get_max_workspace_size(self.internal_builder) as usize } 75 | } 76 | 77 | pub fn set_max_workspace_size(&self, ws: usize) { 78 | unsafe { builder_set_max_workspace_size(self.internal_builder, ws as usize) } 79 | } 80 | 81 | pub fn get_max_batch_size(&self) -> i32 { 82 | unsafe { builder_get_max_batch_size(self.internal_builder) as i32 } 83 | } 84 | 85 | pub fn set_max_batch_size(&self, bs: i32) { 86 | unsafe { builder_set_max_batch_size(self.internal_builder, bs as i32) } 87 | } 88 | 89 | pub fn set_half2_mode(&self, mode: bool) { 90 | unsafe { builder_set_half2_mode(self.internal_builder, mode) } 91 | } 92 | 93 | pub fn get_half2_mode(&self) -> bool { 94 | unsafe { builder_get_half2_mode(self.internal_builder) } 95 | } 96 | 97 | pub fn set_debug_sync(&self, sync: bool) { 98 | unsafe { builder_set_debug_sync(self.internal_builder, sync) } 99 | } 100 | 101 | pub fn get_debug_sync(&self) -> bool { 102 | unsafe { builder_get_debug_sync(self.internal_builder) } 103 | } 104 | 105 | pub fn set_min_find_iterations(&self, min_find: i32) { 106 | unsafe { builder_set_min_find_iterations(self.internal_builder, min_find) } 107 | } 108 | 109 | pub fn get_min_find_iterations(&self) -> i32 { 110 | unsafe { builder_get_min_find_iterations(self.internal_builder) } 111 | } 112 | 113 | pub fn set_average_find_iterations(&self, avg_find: i32) { 114 | unsafe { builder_set_average_find_iterations(self.internal_builder, avg_find) } 115 | } 116 | 117 | pub fn get_average_find_iterations(&self) -> i32 { 118 | unsafe { builder_get_average_find_iterations(self.internal_builder) } 119 | } 120 | 121 | pub fn platform_has_fast_fp16(&self) -> bool { 122 | unsafe { builder_platform_has_fast_fp16(self.internal_builder) } 123 | } 124 | 125 | pub fn platform_has_fast_int8(&self) -> bool { 126 | unsafe { builder_platform_has_fast_int8(self.internal_builder) } 127 | } 128 | 129 | pub fn set_int8_mode(&self, mode: bool) { 130 | unsafe { builder_set_int8_mode(self.internal_builder, mode) } 131 | } 132 | 133 | pub fn get_int8_mode(&self) -> bool { 134 | unsafe { builder_get_int8_mode(self.internal_builder) } 135 | } 136 | 137 | pub fn set_fp16_mode(&self, mode: bool) { 138 | unsafe { builder_set_fp16_mode(self.internal_builder, mode) } 139 | } 140 | 141 | pub fn get_fp16_mode(&self) -> bool { 142 | unsafe { builder_get_fp16_mode(self.internal_builder) } 143 | } 144 | 145 | pub fn set_device_type(&self, layer: &T, device_type: DeviceType) { 146 | unsafe { 147 | builder_set_device_type( 148 | self.internal_builder, 149 | layer.get_internal_layer(), 150 | device_type as c_int, 151 | ) 152 | } 153 | } 154 | 155 | pub fn get_device_type(&self, layer: &dyn Layer) -> DeviceType { 156 | let primitive = 157 | unsafe { builder_get_device_type(self.internal_builder, layer.get_internal_layer()) }; 158 | FromPrimitive::from_i32(primitive).unwrap() 159 | } 160 | 161 | pub fn is_device_type_set(&self, layer: &dyn Layer) -> bool { 162 | unsafe { builder_is_device_type_set(self.internal_builder, layer.get_internal_layer()) } 163 | } 164 | 165 | pub fn set_default_device_type(&self, device_type: DeviceType) { 166 | unsafe { builder_set_default_device_type(self.internal_builder, device_type as c_int) } 167 | } 168 | 169 | pub fn get_default_device_type(&self) -> DeviceType { 170 | let primitive = unsafe { builder_get_default_device_type(self.internal_builder) }; 171 | FromPrimitive::from_i32(primitive).unwrap() 172 | } 173 | 174 | pub fn reset_device_type(&self, layer: &dyn Layer) { 175 | unsafe { builder_reset_device_type(self.internal_builder, layer.get_internal_layer()) } 176 | } 177 | 178 | pub fn can_run_on_dla(&self, layer: &dyn Layer) -> bool { 179 | unsafe { builder_can_run_on_dla(self.internal_builder, layer.get_internal_layer()) } 180 | } 181 | 182 | pub fn get_max_dla_batch_size(&self) -> i32 { 183 | unsafe { builder_get_max_dla_batch_size(self.internal_builder) } 184 | } 185 | 186 | pub fn allow_gpu_fallback(&self, set_fallback_mode: bool) { 187 | unsafe { builder_allow_gpu_fallback(self.internal_builder, set_fallback_mode) } 188 | } 189 | 190 | pub fn get_nb_dla_cores(&self) -> i32 { 191 | unsafe { builder_get_nb_dla_cores(self.internal_builder) } 192 | } 193 | 194 | pub fn set_dla_core(&self, dla_core: i32) { 195 | unsafe { builder_set_dla_core(self.internal_builder, dla_core) } 196 | } 197 | 198 | pub fn get_dla_core(&self) -> i32 { 199 | unsafe { builder_get_dla_core(self.internal_builder) } 200 | } 201 | 202 | pub fn set_strict_type_constraints(&self, mode: bool) { 203 | unsafe { builder_set_strict_type_constraints(self.internal_builder, mode) } 204 | } 205 | 206 | pub fn get_strict_type_constraints(&self) -> bool { 207 | unsafe { builder_get_strict_type_constraints(self.internal_builder) } 208 | } 209 | 210 | pub fn set_refittable(&self, can_refit: bool) { 211 | unsafe { builder_set_refittable(self.internal_builder, can_refit) } 212 | } 213 | 214 | pub fn get_refittable(&self) -> bool { 215 | unsafe { builder_get_refittable(self.internal_builder) } 216 | } 217 | 218 | pub fn set_engine_capability(&self, engine_capability: EngineCapability) { 219 | unsafe { builder_set_engine_capability(self.internal_builder, engine_capability as c_int) } 220 | } 221 | 222 | pub fn get_engine_capability(&self) -> EngineCapability { 223 | let primitive = unsafe { builder_get_engine_capability(self.internal_builder) }; 224 | FromPrimitive::from_i32(primitive).unwrap() 225 | } 226 | 227 | #[cfg(feature = "trt-5")] 228 | pub fn create_network(&self) -> Network { 229 | let internal_network = unsafe { create_network(self.internal_builder) }; 230 | Network { internal_network } 231 | } 232 | 233 | #[cfg(not(feature = "trt-5"))] 234 | pub fn create_network_v2(&self, flags: NetworkBuildFlags) -> Network { 235 | let internal_network = unsafe { create_network_v2(self.internal_builder, flags.bits()) }; 236 | Network { internal_network } 237 | } 238 | 239 | pub fn build_cuda_engine(&self, network: &Network) -> Engine { 240 | let internal_engine = 241 | unsafe { build_cuda_engine(self.internal_builder, network.internal_network) }; 242 | Engine { internal_engine } 243 | } 244 | 245 | pub fn reset(&self, network: Network) { 246 | unsafe { builder_reset(self.internal_builder, network.internal_network) } 247 | } 248 | } 249 | 250 | impl<'a> Drop for Builder<'a> { 251 | fn drop(&mut self) { 252 | unsafe { destroy_builder(self.internal_builder) }; 253 | } 254 | } 255 | -------------------------------------------------------------------------------- /tensorrt/src/builder/tests.rs: -------------------------------------------------------------------------------- 1 | use super::*; 2 | use crate::dims::DimsCHW; 3 | use crate::network::Network; 4 | use crate::uff::{UffFile, UffInputOrder, UffParser}; 5 | use lazy_static::lazy_static; 6 | use std::path::Path; 7 | use std::sync::Mutex; 8 | 9 | lazy_static! { 10 | static ref LOGGER: Mutex = Mutex::new(Logger::new()); 11 | } 12 | 13 | fn create_network(logger: &Logger) -> (Network, Builder) { 14 | let builder = Builder::new(&logger); 15 | let network = builder.create_network_v2(NetworkBuildFlags::DEFAULT); 16 | 17 | let uff_parser = UffParser::new(); 18 | let dim = DimsCHW::new(1, 28, 28); 19 | 20 | uff_parser 21 | .register_input("in", dim, UffInputOrder::Nchw) 22 | .unwrap(); 23 | uff_parser.register_output("out").unwrap(); 24 | println!( 25 | "current dir: {}", 26 | std::env::current_dir().unwrap().display() 27 | ); 28 | let uff_file = UffFile::new(Path::new("../assets/lenet5.uff")).unwrap(); 29 | uff_parser.parse(&uff_file, &network).unwrap(); 30 | 31 | (network, builder) 32 | } 33 | 34 | #[test] 35 | fn set_half2_mode_true() { 36 | let logger = match LOGGER.lock() { 37 | Ok(guard) => guard, 38 | Err(poisoned) => poisoned.into_inner(), 39 | }; 40 | let builder = Builder::new(&logger); 41 | 42 | builder.set_half2_mode(true); 43 | assert_eq!(builder.get_half2_mode(), true); 44 | } 45 | 46 | #[test] 47 | fn set_half2_mode_false() { 48 | let logger = match LOGGER.lock() { 49 | Ok(guard) => guard, 50 | Err(poisoned) => poisoned.into_inner(), 51 | }; 52 | let builder = Builder::new(&logger); 53 | 54 | builder.set_half2_mode(false); 55 | assert_eq!(builder.get_half2_mode(), false); 56 | } 57 | 58 | #[test] 59 | fn set_debug_sync_true() { 60 | let logger = match LOGGER.lock() { 61 | Ok(guard) => guard, 62 | Err(poisoned) => poisoned.into_inner(), 63 | }; 64 | let builder = Builder::new(&logger); 65 | 66 | builder.set_debug_sync(true); 67 | assert_eq!(builder.get_debug_sync(), true); 68 | } 69 | 70 | #[test] 71 | fn set_debug_sync_false() { 72 | let logger = match LOGGER.lock() { 73 | Ok(guard) => guard, 74 | Err(poisoned) => poisoned.into_inner(), 75 | }; 76 | let builder = Builder::new(&logger); 77 | 78 | builder.set_debug_sync(false); 79 | assert_eq!(builder.get_debug_sync(), false); 80 | } 81 | 82 | #[test] 83 | fn set_min_find_iterations() { 84 | let logger = match LOGGER.lock() { 85 | Ok(guard) => guard, 86 | Err(poisoned) => poisoned.into_inner(), 87 | }; 88 | let builder = Builder::new(&logger); 89 | 90 | builder.set_min_find_iterations(10); 91 | assert_eq!(builder.get_min_find_iterations(), 10); 92 | } 93 | 94 | #[test] 95 | fn set_average_find_iterations() { 96 | let logger = match LOGGER.lock() { 97 | Ok(guard) => guard, 98 | Err(poisoned) => poisoned.into_inner(), 99 | }; 100 | let builder = Builder::new(&logger); 101 | 102 | builder.set_average_find_iterations(20); 103 | assert_eq!(builder.get_average_find_iterations(), 20); 104 | } 105 | 106 | #[test] 107 | fn platform_has_fast_fp16() { 108 | let logger = match LOGGER.lock() { 109 | Ok(guard) => guard, 110 | Err(poisoned) => poisoned.into_inner(), 111 | }; 112 | let builder = Builder::new(&logger); 113 | 114 | assert_eq!(builder.platform_has_fast_fp16(), true); 115 | } 116 | 117 | #[test] 118 | fn platform_has_fast_int8() { 119 | let logger = match LOGGER.lock() { 120 | Ok(guard) => guard, 121 | Err(poisoned) => poisoned.into_inner(), 122 | }; 123 | let builder = Builder::new(&logger); 124 | 125 | assert_eq!(builder.platform_has_fast_int8(), true); 126 | } 127 | 128 | #[test] 129 | fn set_int8_mode_true() { 130 | let logger = match LOGGER.lock() { 131 | Ok(guard) => guard, 132 | Err(poisoned) => poisoned.into_inner(), 133 | }; 134 | let builder = Builder::new(&logger); 135 | 136 | builder.set_int8_mode(true); 137 | assert_eq!(builder.get_int8_mode(), true); 138 | } 139 | 140 | #[test] 141 | fn set_int8_mode_false() { 142 | let logger = match LOGGER.lock() { 143 | Ok(guard) => guard, 144 | Err(poisoned) => poisoned.into_inner(), 145 | }; 146 | let builder = Builder::new(&logger); 147 | 148 | builder.set_int8_mode(false); 149 | assert_eq!(builder.get_int8_mode(), false); 150 | } 151 | 152 | #[test] 153 | fn set_fp16_mode_true() { 154 | let logger = match LOGGER.lock() { 155 | Ok(guard) => guard, 156 | Err(poisoned) => poisoned.into_inner(), 157 | }; 158 | let builder = Builder::new(&logger); 159 | 160 | builder.set_fp16_mode(true); 161 | assert_eq!(builder.get_fp16_mode(), true); 162 | } 163 | 164 | #[test] 165 | fn set_fp16_mode_false() { 166 | let logger = match LOGGER.lock() { 167 | Ok(guard) => guard, 168 | Err(poisoned) => poisoned.into_inner(), 169 | }; 170 | let builder = Builder::new(&logger); 171 | 172 | builder.set_fp16_mode(false); 173 | assert_eq!(builder.get_fp16_mode(), false); 174 | } 175 | 176 | #[test] 177 | fn set_device_type_gpu() { 178 | let logger = match LOGGER.lock() { 179 | Ok(guard) => guard, 180 | Err(poisoned) => poisoned.into_inner(), 181 | }; 182 | let (network, builder) = create_network(&logger); 183 | 184 | let layer = network.get_layer(0); 185 | builder.set_device_type(&layer, DeviceType::GPU); 186 | 187 | assert_eq!(builder.get_device_type(&layer), DeviceType::GPU); 188 | } 189 | 190 | #[test] 191 | fn is_device_type_set_true() { 192 | let logger = match LOGGER.lock() { 193 | Ok(guard) => guard, 194 | Err(poisoned) => poisoned.into_inner(), 195 | }; 196 | let (network, builder) = create_network(&logger); 197 | 198 | let layer = network.get_layer(0); 199 | builder.set_device_type(&layer, DeviceType::GPU); 200 | 201 | assert_eq!(builder.is_device_type_set(&layer), true); 202 | } 203 | 204 | #[test] 205 | fn is_device_type_set_false() { 206 | let logger = match LOGGER.lock() { 207 | Ok(guard) => guard, 208 | Err(poisoned) => poisoned.into_inner(), 209 | }; 210 | let (network, builder) = create_network(&logger); 211 | 212 | let layer = network.get_layer(0); 213 | 214 | assert_eq!(builder.is_device_type_set(&layer), false); 215 | } 216 | 217 | #[cfg(target_arch = "aarch64")] 218 | #[test] 219 | fn set_device_type_DLA() { 220 | let logger = match LOGGER.lock() { 221 | Ok(guard) => guard, 222 | Err(poisoned) => poisoned.into_inner(), 223 | }; 224 | let (network, builder) = create_network(&logger); 225 | 226 | builder.set_fp16_mode(true); 227 | let layer = network.get_layer(0); 228 | builder.set_device_type(&layer, DeviceType::DLA); 229 | 230 | assert_eq!(builder.get_device_type(&layer), DeviceType::DLA); 231 | } 232 | 233 | #[cfg(target_arch = "aarch64")] 234 | #[test] 235 | fn set_default_device_type_GPU() { 236 | let logger = match LOGGER.lock() { 237 | Ok(guard) => guard, 238 | Err(poisoned) => poisoned.into_inner(), 239 | }; 240 | let builder = Builder::new(&logger); 241 | 242 | builder.set_fp16_mode(true); 243 | builder.set_default_device_type(DeviceType::GPU); 244 | 245 | assert_eq!(builder.get_default_device_type(), DeviceType::GPU); 246 | } 247 | 248 | #[test] 249 | fn reset_device_type() { 250 | let logger = match LOGGER.lock() { 251 | Ok(guard) => guard, 252 | Err(poisoned) => poisoned.into_inner(), 253 | }; 254 | let (network, builder) = create_network(&logger); 255 | 256 | let layer = network.get_layer(0); 257 | builder.set_device_type(&layer, DeviceType::GPU); 258 | builder.reset_device_type(&layer); 259 | 260 | assert_eq!(builder.get_device_type(&layer), DeviceType::GPU); 261 | } 262 | 263 | #[cfg(target_arch = "aarch64")] 264 | #[test] 265 | fn can_run_on_dla() { 266 | let logger = match LOGGER.lock() { 267 | Ok(guard) => guard, 268 | Err(poisoned) => poisoned.into_inner(), 269 | }; 270 | let (network, builder) = create_network(&logger); 271 | 272 | let layer = network.get_layer(0); 273 | assert_eq!(builder.can_run_on_dla(&layer), true); 274 | } 275 | 276 | #[test] 277 | fn get_max_dla_batch_size() { 278 | let logger = match LOGGER.lock() { 279 | Ok(guard) => guard, 280 | Err(poisoned) => poisoned.into_inner(), 281 | }; 282 | let builder = Builder::new(&logger); 283 | 284 | assert_eq!(builder.get_max_dla_batch_size(), 1); 285 | } 286 | 287 | #[test] 288 | fn allow_gpu_fallback_true() { 289 | let logger = match LOGGER.lock() { 290 | Ok(guard) => guard, 291 | Err(poisoned) => poisoned.into_inner(), 292 | }; 293 | let builder = Builder::new(&logger); 294 | 295 | builder.allow_gpu_fallback(true); 296 | } 297 | 298 | #[test] 299 | fn get_nb_dla_cores() { 300 | let logger = match LOGGER.lock() { 301 | Ok(guard) => guard, 302 | Err(poisoned) => poisoned.into_inner(), 303 | }; 304 | let builder = Builder::new(&logger); 305 | 306 | assert_eq!(builder.get_nb_dla_cores(), 0); 307 | } 308 | 309 | #[test] 310 | fn set_dla_core() { 311 | let logger = match LOGGER.lock() { 312 | Ok(guard) => guard, 313 | Err(poisoned) => poisoned.into_inner(), 314 | }; 315 | let builder = Builder::new(&logger); 316 | 317 | builder.set_dla_core(1); 318 | 319 | assert_eq!(builder.get_dla_core(), 1); 320 | } 321 | 322 | #[test] 323 | fn reset_builder() { 324 | let logger = match LOGGER.lock() { 325 | Ok(guard) => guard, 326 | Err(poisoned) => poisoned.into_inner(), 327 | }; 328 | let builder = Builder::new(&logger); 329 | assert_eq!(builder.get_half2_mode(), false); 330 | builder.set_half2_mode(true); 331 | 332 | let network = builder.create_network_v2(NetworkBuildFlags::EXPLICIT_BATCH); 333 | assert_eq!(builder.get_half2_mode(), true); 334 | 335 | builder.reset(network); 336 | assert_eq!(builder.get_half2_mode(), false); 337 | } 338 | 339 | #[test] 340 | fn set_strict_type_constraints_true() { 341 | let logger = match LOGGER.lock() { 342 | Ok(guard) => guard, 343 | Err(poisoned) => poisoned.into_inner(), 344 | }; 345 | let builder = Builder::new(&logger); 346 | 347 | builder.set_strict_type_constraints(true); 348 | 349 | assert_eq!(builder.get_strict_type_constraints(), true); 350 | } 351 | 352 | #[test] 353 | fn set_strict_type_constraints_false() { 354 | let logger = match LOGGER.lock() { 355 | Ok(guard) => guard, 356 | Err(poisoned) => poisoned.into_inner(), 357 | }; 358 | let builder = Builder::new(&logger); 359 | 360 | builder.set_strict_type_constraints(false); 361 | 362 | assert_eq!(builder.get_strict_type_constraints(), false); 363 | } 364 | 365 | #[test] 366 | fn set_refittable_true() { 367 | let logger = match LOGGER.lock() { 368 | Ok(guard) => guard, 369 | Err(poisoned) => poisoned.into_inner(), 370 | }; 371 | let builder = Builder::new(&logger); 372 | 373 | builder.set_refittable(true); 374 | 375 | assert_eq!(builder.get_refittable(), true); 376 | } 377 | 378 | #[test] 379 | fn set_refittable_false() { 380 | let logger = match LOGGER.lock() { 381 | Ok(guard) => guard, 382 | Err(poisoned) => poisoned.into_inner(), 383 | }; 384 | let builder = Builder::new(&logger); 385 | 386 | builder.set_refittable(false); 387 | 388 | assert_eq!(builder.get_refittable(), false); 389 | } 390 | 391 | #[test] 392 | fn set_engine_capability() { 393 | let logger = match LOGGER.lock() { 394 | Ok(guard) => guard, 395 | Err(poisoned) => poisoned.into_inner(), 396 | }; 397 | let builder = Builder::new(&logger); 398 | 399 | builder.set_engine_capability(EngineCapability::Default); 400 | 401 | assert_eq!(builder.get_engine_capability(), EngineCapability::Default); 402 | } 403 | -------------------------------------------------------------------------------- /tensorrt/src/context.rs: -------------------------------------------------------------------------------- 1 | use crate::check_cuda; 2 | use crate::profiler::{IProfiler, Profiler}; 3 | use anyhow::Error; 4 | use cuda_runtime_sys::{cudaFree, cudaMalloc, cudaMemcpy, cudaMemcpyKind}; 5 | use ndarray; 6 | use ndarray::Dimension; 7 | use num_traits::Num; 8 | use std::ffi::{CStr, CString}; 9 | use std::mem::size_of; 10 | use std::os::raw::c_void; 11 | use std::ptr; 12 | use std::vec::Vec; 13 | use tensorrt_sys::{ 14 | context_get_debug_sync, context_get_name, context_set_debug_sync, context_set_name, 15 | context_set_profiler, destroy_excecution_context, execute, nvinfer1_IExecutionContext, 16 | }; 17 | 18 | pub enum ExecuteInput<'a, D: Dimension> { 19 | Integer(&'a mut ndarray::Array), 20 | Float(&'a mut ndarray::Array), 21 | } 22 | 23 | struct DeviceBuffer { 24 | device_ptr: *mut c_void, 25 | } 26 | 27 | impl DeviceBuffer { 28 | pub fn new(host_data: &ndarray::Array) -> Result { 29 | let mut device_ptr: *mut c_void = ptr::null_mut(); 30 | check_cuda!(cudaMalloc( 31 | &mut device_ptr, 32 | host_data.len() * size_of::() 33 | )); 34 | 35 | check_cuda!(cudaMemcpy( 36 | device_ptr, 37 | host_data.as_ptr() as *const c_void, 38 | host_data.len() * size_of::(), 39 | cudaMemcpyKind::cudaMemcpyHostToDevice, 40 | )); 41 | 42 | Ok(DeviceBuffer { device_ptr }) 43 | } 44 | 45 | pub fn new_uninit(size: usize) -> Result { 46 | let mut device_ptr: *mut c_void = ptr::null_mut(); 47 | check_cuda!(cudaMalloc(&mut device_ptr, size)); 48 | Ok(DeviceBuffer { device_ptr }) 49 | } 50 | 51 | pub fn as_mut_ptr(&self) -> *mut c_void { 52 | self.device_ptr 53 | } 54 | 55 | pub fn copy_to_host( 56 | &self, 57 | host_data: &mut ndarray::Array, 58 | ) -> Result<(), Error> { 59 | check_cuda!(cudaMemcpy( 60 | host_data.as_mut_ptr() as *mut c_void, 61 | self.device_ptr, 62 | host_data.len() * size_of::(), 63 | cudaMemcpyKind::cudaMemcpyDeviceToHost, 64 | )); 65 | Ok(()) 66 | } 67 | } 68 | 69 | impl Drop for DeviceBuffer { 70 | fn drop(&mut self) { 71 | if !self.device_ptr.is_null() { 72 | unsafe { 73 | cudaFree(self.device_ptr); 74 | } 75 | } 76 | } 77 | } 78 | 79 | pub struct Context { 80 | pub(crate) internal_context: *mut nvinfer1_IExecutionContext, 81 | } 82 | 83 | impl Context { 84 | pub fn set_debug_sync(&self, sync: bool) { 85 | unsafe { context_set_debug_sync(self.internal_context, sync) } 86 | } 87 | 88 | pub fn get_debug_sync(&self) -> bool { 89 | unsafe { context_get_debug_sync(self.internal_context) } 90 | } 91 | 92 | pub fn set_name(&mut self, context_name: &str) { 93 | unsafe { 94 | context_set_name( 95 | self.internal_context, 96 | CString::new(context_name).unwrap().as_ptr(), 97 | ) 98 | }; 99 | } 100 | 101 | pub fn get_name(&self) -> String { 102 | let context_name = unsafe { 103 | let raw_context_name = context_get_name(self.internal_context); 104 | CStr::from_ptr(raw_context_name) 105 | }; 106 | context_name.to_str().unwrap().to_string() 107 | } 108 | 109 | pub fn set_profiler(&self, profiler: &Profiler

) { 110 | unsafe { context_set_profiler(self.internal_context, profiler.internal_profiler) } 111 | } 112 | 113 | // pub fn get_profiler(&self) -> &T { 114 | // unsafe { 115 | // let profiler_ptr = 116 | // context_get_profiler(self.internal_context) as *mut ProfilerBinding; 117 | // &(*(*profiler_ptr).context) 118 | // } 119 | // } 120 | 121 | pub fn execute( 122 | &self, 123 | input_data: ExecuteInput, 124 | mut output_data: Vec>, 125 | ) -> Result<(), Error> { 126 | let mut buffers = Vec::::with_capacity(output_data.len() + 1); 127 | let dev_buffer = match input_data { 128 | ExecuteInput::Integer(val) => DeviceBuffer::new(val)?, 129 | ExecuteInput::Float(val) => DeviceBuffer::new(val)?, 130 | }; 131 | buffers.push(dev_buffer); 132 | 133 | for output in &output_data { 134 | let device_buffer = match output { 135 | ExecuteInput::Integer(val) => { 136 | DeviceBuffer::new_uninit(val.len() * size_of::())? 137 | } 138 | ExecuteInput::Float(val) => DeviceBuffer::new_uninit(val.len() * size_of::())?, 139 | }; 140 | buffers.push(device_buffer); 141 | } 142 | 143 | let mut bindings = buffers 144 | .iter() 145 | .map(|elem| elem.as_mut_ptr()) 146 | .collect::>(); 147 | 148 | unsafe { 149 | execute(self.internal_context, bindings.as_mut_ptr(), 1); 150 | } 151 | 152 | for (idx, output) in buffers.iter().skip(1).enumerate() { 153 | let data = &mut output_data[idx]; 154 | match data { 155 | ExecuteInput::Integer(val) => { 156 | output.copy_to_host(val)?; 157 | } 158 | ExecuteInput::Float(val) => { 159 | output.copy_to_host(val)?; 160 | } 161 | } 162 | } 163 | Ok(()) 164 | } 165 | } 166 | 167 | impl Drop for Context { 168 | fn drop(&mut self) { 169 | unsafe { destroy_excecution_context(self.internal_context) }; 170 | } 171 | } 172 | 173 | unsafe impl Send for Context {} 174 | unsafe impl Sync for Context {} 175 | 176 | #[cfg(test)] 177 | mod tests { 178 | use crate::builder::{Builder, NetworkBuildFlags}; 179 | use crate::data_size::GB; 180 | use crate::dims::DimsCHW; 181 | use crate::engine::Engine; 182 | use crate::profiler::RustProfiler; 183 | use crate::runtime::Logger; 184 | use crate::uff::{UffFile, UffInputOrder, UffParser}; 185 | use lazy_static::lazy_static; 186 | use std::path::Path; 187 | use std::sync::Mutex; 188 | 189 | lazy_static! { 190 | static ref LOGGER: Mutex = Mutex::new(Logger::new()); 191 | } 192 | 193 | fn setup_engine_test_uff(logger: &Logger) -> Engine { 194 | let builder = Builder::new(&logger); 195 | builder.set_max_workspace_size(1 * GB); 196 | let network = builder.create_network_v2(NetworkBuildFlags::DEFAULT); 197 | 198 | let uff_parser = UffParser::new(); 199 | let dim = DimsCHW::new(1, 28, 28); 200 | 201 | uff_parser 202 | .register_input("in", dim, UffInputOrder::Nchw) 203 | .unwrap(); 204 | uff_parser.register_output("out").unwrap(); 205 | let uff_file = UffFile::new(Path::new("../assets/lenet5.uff")).unwrap(); 206 | uff_parser.parse(&uff_file, &network).unwrap(); 207 | 208 | builder.build_cuda_engine(&network) 209 | } 210 | #[test] 211 | fn set_debug_sync_true() { 212 | let logger = match LOGGER.lock() { 213 | Ok(guard) => guard, 214 | Err(poisoned) => poisoned.into_inner(), 215 | }; 216 | let engine = setup_engine_test_uff(&logger); 217 | let context = engine.create_execution_context(); 218 | 219 | context.set_debug_sync(true); 220 | assert_eq!(context.get_debug_sync(), true); 221 | } 222 | 223 | // Commenting this out until we can come up with a better solution to the `IProfiler` 224 | // interface binding. 225 | // #[test] 226 | // fn set_profiler() { 227 | // let logger = match LOGGER.lock() { 228 | // Ok(guard) => guard, 229 | // Err(poisoned) => poisoned.into_inner(), 230 | // }; 231 | // let engine = setup_engine_test_uff(&logger); 232 | // let context = engine.create_execution_context(); 233 | // 234 | // let mut profiler = RustProfiler::new(); 235 | // context.set_profiler(&mut profiler); 236 | // 237 | // let other_profiler = context.get_profiler::(); 238 | // assert_eq!( 239 | // &profiler as *const RustProfiler, 240 | // other_profiler as *const RustProfiler 241 | // ); 242 | // } 243 | } 244 | -------------------------------------------------------------------------------- /tensorrt/src/data_size.rs: -------------------------------------------------------------------------------- 1 | pub const MB: usize = 1024 * 1024; 2 | pub const GB: usize = 1024 * 1024 * 1024; 3 | -------------------------------------------------------------------------------- /tensorrt/src/dims.rs: -------------------------------------------------------------------------------- 1 | use std::error; 2 | use std::fmt::Formatter; 3 | use tensorrt_sys::{ 4 | create_dims, create_dims2, create_dims3, create_dims4, create_dimsCHW, create_dimsHW, 5 | create_dimsNCHW, nvinfer1_Dims, nvinfer1_Dims2, nvinfer1_Dims3, nvinfer1_Dims4, 6 | nvinfer1_DimsCHW, nvinfer1_DimsHW, nvinfer1_DimsNCHW, 7 | }; 8 | 9 | mod private { 10 | pub trait DimsPrivate { 11 | fn get_internal_dims(&self) -> super::nvinfer1_Dims; 12 | } 13 | } 14 | 15 | pub trait Dim: private::DimsPrivate { 16 | fn nb_dims(&self) -> i32 { 17 | self.get_internal_dims().nbDims 18 | } 19 | 20 | fn d(&self) -> [i32; 8] { 21 | self.get_internal_dims().d 22 | } 23 | } 24 | 25 | #[repr(C)] 26 | pub enum DimensionType { 27 | Spacial, 28 | Channel, 29 | Index, 30 | Sequence, 31 | } 32 | 33 | #[repr(transparent)] 34 | pub struct Dims(pub nvinfer1_Dims); 35 | 36 | impl Dims { 37 | pub fn new(num_dims: i32, dimension_sizes: [i32; 8], dimension_types: [i32; 8]) -> Dims { 38 | let nv_dims = unsafe { 39 | create_dims( 40 | num_dims, 41 | &dimension_sizes as *const i32, 42 | &dimension_types as *const i32, 43 | ) 44 | }; 45 | Dims(nv_dims) 46 | } 47 | } 48 | 49 | impl private::DimsPrivate for Dims { 50 | fn get_internal_dims(&self) -> nvinfer1_Dims { 51 | self.0 52 | } 53 | } 54 | impl Dim for Dims {} 55 | 56 | #[repr(transparent)] 57 | pub struct Dims2(nvinfer1_Dims2); 58 | 59 | impl Dims2 { 60 | pub fn new(dim1: i32, dim2: i32) -> Dims2 { 61 | let internal_dims = unsafe { create_dims2(dim1, dim2) }; 62 | 63 | Dims2(internal_dims) 64 | } 65 | } 66 | 67 | impl private::DimsPrivate for Dims2 { 68 | fn get_internal_dims(&self) -> nvinfer1_Dims { 69 | self.0._base 70 | } 71 | } 72 | 73 | impl Dim for Dims2 {} 74 | 75 | #[repr(transparent)] 76 | pub struct DimsHW(pub nvinfer1_DimsHW); 77 | 78 | impl DimsHW { 79 | pub fn new(height: i32, width: i32) -> DimsHW { 80 | let internal_dims = unsafe { create_dimsHW(height, width) }; 81 | 82 | DimsHW(internal_dims) 83 | } 84 | } 85 | 86 | impl private::DimsPrivate for DimsHW { 87 | fn get_internal_dims(&self) -> nvinfer1_Dims { 88 | self.0._base._base 89 | } 90 | } 91 | 92 | impl Dim for DimsHW {} 93 | 94 | #[repr(transparent)] 95 | pub struct Dims3(nvinfer1_Dims3); 96 | 97 | impl Dims3 { 98 | pub fn new(dim1: i32, dim2: i32, dim3: i32) -> Dims3 { 99 | let internal_dims = unsafe { create_dims3(dim1, dim2, dim3) }; 100 | Dims3(internal_dims) 101 | } 102 | } 103 | 104 | impl private::DimsPrivate for Dims3 { 105 | fn get_internal_dims(&self) -> nvinfer1_Dims { 106 | self.0._base 107 | } 108 | } 109 | 110 | impl Dim for Dims3 {} 111 | 112 | #[repr(transparent)] 113 | pub struct DimsCHW(nvinfer1_DimsCHW); 114 | 115 | impl DimsCHW { 116 | pub fn new(channels: i32, height: i32, width: i32) -> DimsCHW { 117 | let internal_dims = unsafe { create_dimsCHW(channels, height, width) }; 118 | DimsCHW(internal_dims) 119 | } 120 | } 121 | 122 | impl private::DimsPrivate for DimsCHW { 123 | fn get_internal_dims(&self) -> nvinfer1_Dims { 124 | self.0._base._base 125 | } 126 | } 127 | 128 | impl Dim for DimsCHW {} 129 | 130 | #[repr(transparent)] 131 | pub struct Dims4(nvinfer1_Dims4); 132 | 133 | impl Dims4 { 134 | pub fn new(dim1: i32, dim2: i32, dim3: i32, dim4: i32) -> Dims4 { 135 | let internal_dims = unsafe { create_dims4(dim1, dim2, dim3, dim4) }; 136 | Dims4(internal_dims) 137 | } 138 | } 139 | 140 | impl private::DimsPrivate for Dims4 { 141 | fn get_internal_dims(&self) -> nvinfer1_Dims { 142 | self.0._base 143 | } 144 | } 145 | 146 | impl Dim for Dims4 {} 147 | 148 | #[repr(transparent)] 149 | pub struct DimsNCHW(nvinfer1_DimsNCHW); 150 | 151 | impl DimsNCHW { 152 | pub fn new(index: i32, channels: i32, height: i32, width: i32) -> DimsNCHW { 153 | let internal_dims = unsafe { create_dimsNCHW(index, channels, height, width) }; 154 | DimsNCHW(internal_dims) 155 | } 156 | } 157 | 158 | impl private::DimsPrivate for DimsNCHW { 159 | fn get_internal_dims(&self) -> nvinfer1_Dims { 160 | self.0._base._base 161 | } 162 | } 163 | 164 | impl Dim for DimsNCHW {} 165 | 166 | #[derive(Debug, Clone)] 167 | pub struct DimsShapeError { 168 | message: String, 169 | } 170 | 171 | impl DimsShapeError { 172 | pub fn new(message: &str) -> Self { 173 | DimsShapeError { 174 | message: message.to_string(), 175 | } 176 | } 177 | } 178 | 179 | impl std::fmt::Display for DimsShapeError { 180 | fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { 181 | write!(f, "{}", self.message) 182 | } 183 | } 184 | 185 | impl error::Error for DimsShapeError {} 186 | -------------------------------------------------------------------------------- /tensorrt/src/engine.rs: -------------------------------------------------------------------------------- 1 | use crate::context::Context; 2 | use crate::dims::Dims; 3 | use num_derive::FromPrimitive; 4 | use num_traits::FromPrimitive; 5 | use std::convert::TryInto; 6 | use std::ffi::{CStr, CString}; 7 | use std::slice; 8 | use tensorrt_sys::{ 9 | destroy_host_memory, engine_binding_is_input, engine_create_execution_context, 10 | engine_create_execution_context_without_device_memory, engine_destroy, 11 | engine_get_binding_data_type, engine_get_binding_dimensions, engine_get_binding_index, 12 | engine_get_binding_name, engine_get_device_memory_size, engine_get_location, 13 | engine_get_max_batch_size, engine_get_nb_bindings, engine_get_nb_layers, 14 | engine_get_workspace_size, engine_is_refittable, engine_serialize, host_memory_get_data, 15 | host_memory_get_size, nvinfer1_ICudaEngine, 16 | }; 17 | 18 | #[repr(C)] 19 | #[derive(Debug, FromPrimitive, Eq, PartialEq)] 20 | pub enum DataType { 21 | Float, 22 | Half, 23 | Int8, 24 | Int32, 25 | } 26 | 27 | #[repr(C)] 28 | #[derive(Debug, FromPrimitive, Eq, PartialEq)] 29 | pub enum TensorLocation { 30 | Host, 31 | Device, 32 | } 33 | 34 | #[derive(Debug)] 35 | pub struct Engine { 36 | pub(crate) internal_engine: *mut nvinfer1_ICudaEngine, 37 | } 38 | 39 | impl Engine { 40 | pub fn get_nb_bindings(&self) -> i32 { 41 | let res = if !self.internal_engine.is_null() { 42 | unsafe { engine_get_nb_bindings(self.internal_engine) } 43 | } else { 44 | 0 45 | }; 46 | res 47 | } 48 | 49 | pub fn get_binding_name(&self, binding_index: i32) -> Option { 50 | if binding_index >= self.get_nb_bindings() { 51 | return None; 52 | } 53 | 54 | let binding_name = unsafe { 55 | let raw_binding_name = 56 | engine_get_binding_name(self.internal_engine, binding_index.try_into().unwrap()); 57 | CStr::from_ptr(raw_binding_name) 58 | }; 59 | 60 | Some(binding_name.to_str().unwrap().to_string()) 61 | } 62 | 63 | pub fn get_binding_index(&self, binding_name: &str) -> Option { 64 | let binding_index = unsafe { 65 | engine_get_binding_index( 66 | self.internal_engine, 67 | CString::new(binding_name).unwrap().as_ptr(), 68 | ) 69 | }; 70 | 71 | return if binding_index == -1 { 72 | None 73 | } else { 74 | Some(binding_index) 75 | }; 76 | } 77 | 78 | pub fn binding_is_input(&self, binding_index: i32) -> bool { 79 | unsafe { engine_binding_is_input(self.internal_engine, binding_index) } 80 | } 81 | 82 | pub fn get_binding_dimensions(&self, binding_index: i32) -> Dims { 83 | let raw_dims = 84 | unsafe { engine_get_binding_dimensions(self.internal_engine, binding_index) }; 85 | 86 | Dims(raw_dims) 87 | } 88 | 89 | pub fn get_binding_data_type(&self, binding_index: i32) -> DataType { 90 | let primitive = 91 | unsafe { engine_get_binding_data_type(self.internal_engine, binding_index) }; 92 | FromPrimitive::from_i32(primitive).unwrap() 93 | } 94 | 95 | pub fn get_max_batch_size(&self) -> i32 { 96 | unsafe { engine_get_max_batch_size(self.internal_engine) } 97 | } 98 | 99 | pub fn get_nb_layers(&self) -> i32 { 100 | unsafe { engine_get_nb_layers(self.internal_engine) } 101 | } 102 | 103 | pub fn get_workspace_size(&self) -> usize { 104 | unsafe { engine_get_workspace_size(self.internal_engine) } 105 | } 106 | 107 | pub fn create_execution_context(&self) -> Context { 108 | let execution_context = unsafe { engine_create_execution_context(self.internal_engine) }; 109 | Context { 110 | internal_context: execution_context, 111 | } 112 | } 113 | 114 | pub fn create_execution_context_without_device_memory(&self) -> Context { 115 | let execution_context = 116 | unsafe { engine_create_execution_context_without_device_memory(self.internal_engine) }; 117 | Context { 118 | internal_context: execution_context, 119 | } 120 | } 121 | 122 | pub fn serialize(&self) -> HostMemory { 123 | let memory = unsafe { engine_serialize(self.internal_engine) }; 124 | HostMemory { memory } 125 | } 126 | 127 | pub fn get_location(&self, binding_index: i32) -> TensorLocation { 128 | let primitive = unsafe { engine_get_location(self.internal_engine, binding_index) }; 129 | FromPrimitive::from_i32(primitive).unwrap() 130 | } 131 | 132 | pub fn get_device_memory_size(&self) -> usize { 133 | unsafe { engine_get_device_memory_size(self.internal_engine) } 134 | } 135 | 136 | pub fn is_refittable(&self) -> bool { 137 | unsafe { engine_is_refittable(self.internal_engine) } 138 | } 139 | } 140 | 141 | unsafe impl Send for Engine {} 142 | unsafe impl Sync for Engine {} 143 | 144 | impl Drop for Engine { 145 | fn drop(&mut self) { 146 | if !self.internal_engine.is_null() { 147 | unsafe { engine_destroy(self.internal_engine) }; 148 | } 149 | } 150 | } 151 | 152 | pub struct HostMemory { 153 | pub(crate) memory: *mut tensorrt_sys::nvinfer1_IHostMemory, 154 | } 155 | 156 | impl HostMemory { 157 | pub fn data(&self) -> &[u8] { 158 | let ptr = unsafe { host_memory_get_data(self.memory) }; 159 | let size = unsafe { host_memory_get_size(self.memory) }; 160 | unsafe { slice::from_raw_parts(ptr as *const u8, size) } 161 | } 162 | } 163 | 164 | impl AsRef<[u8]> for HostMemory { 165 | fn as_ref(&self) -> &[u8] { 166 | self.data() 167 | } 168 | } 169 | 170 | impl Drop for HostMemory { 171 | fn drop(&mut self) { 172 | unsafe { 173 | destroy_host_memory(self.memory); 174 | } 175 | } 176 | } 177 | 178 | #[cfg(test)] 179 | mod tests { 180 | use super::*; 181 | use crate::builder::{Builder, NetworkBuildFlags}; 182 | use crate::data_size::GB; 183 | use crate::dims::DimsCHW; 184 | use crate::runtime::{Logger, Runtime}; 185 | use crate::uff::{UffFile, UffInputOrder, UffParser}; 186 | use lazy_static::lazy_static; 187 | use std::fs::{remove_file, write, File}; 188 | use std::io::prelude::*; 189 | use std::path::Path; 190 | use std::sync::Mutex; 191 | 192 | lazy_static! { 193 | static ref LOGGER: Mutex = Mutex::new(Logger::new()); 194 | } 195 | 196 | fn setup_engine_test_uff(logger: &Logger) -> Engine { 197 | let builder = Builder::new(&logger); 198 | builder.set_max_workspace_size(1 * GB); 199 | let network = builder.create_network_v2(NetworkBuildFlags::DEFAULT); 200 | 201 | let uff_parser = UffParser::new(); 202 | let dim = DimsCHW::new(1, 28, 28); 203 | 204 | uff_parser 205 | .register_input("in", dim, UffInputOrder::Nchw) 206 | .unwrap(); 207 | uff_parser.register_output("out").unwrap(); 208 | let uff_file = UffFile::new(Path::new("../assets/lenet5.uff")).unwrap(); 209 | uff_parser.parse(&uff_file, &network).unwrap(); 210 | 211 | builder.build_cuda_engine(&network) 212 | } 213 | 214 | #[test] 215 | fn get_nb_bindings() { 216 | let logger = match LOGGER.lock() { 217 | Ok(guard) => guard, 218 | Err(poisoned) => poisoned.into_inner(), 219 | }; 220 | let engine = setup_engine_test_uff(&logger); 221 | 222 | assert_eq!(2, engine.get_nb_bindings()); 223 | } 224 | 225 | #[test] 226 | fn get_engine_binding_name() { 227 | let logger = match LOGGER.lock() { 228 | Ok(guard) => guard, 229 | Err(poisoned) => poisoned.into_inner(), 230 | }; 231 | let engine = setup_engine_test_uff(&logger); 232 | 233 | assert_eq!("in", engine.get_binding_name(0).unwrap()); 234 | } 235 | 236 | #[test] 237 | fn get_invalid_engine_binding() { 238 | let logger = match LOGGER.lock() { 239 | Ok(guard) => guard, 240 | Err(poisoned) => poisoned.into_inner(), 241 | }; 242 | let engine = setup_engine_test_uff(&logger); 243 | 244 | assert_eq!(None, engine.get_binding_name(3)); 245 | } 246 | 247 | #[test] 248 | fn binding_is_input() { 249 | let logger = match LOGGER.lock() { 250 | Ok(guard) => guard, 251 | Err(poisoned) => poisoned.into_inner(), 252 | }; 253 | let engine = setup_engine_test_uff(&logger); 254 | 255 | assert_eq!(engine.binding_is_input(0), true); 256 | } 257 | 258 | #[test] 259 | fn get_binding_index() { 260 | let logger = match LOGGER.lock() { 261 | Ok(guard) => guard, 262 | Err(poisoned) => poisoned.into_inner(), 263 | }; 264 | let engine = setup_engine_test_uff(&logger); 265 | 266 | assert_eq!(Some(0), engine.get_binding_index("in")); 267 | } 268 | 269 | #[test] 270 | fn get_binding_data_type() { 271 | let logger = match LOGGER.lock() { 272 | Ok(guard) => guard, 273 | Err(poisoned) => poisoned.into_inner(), 274 | }; 275 | let engine = setup_engine_test_uff(&logger); 276 | 277 | assert_eq!(engine.get_binding_data_type(0), DataType::Float); 278 | } 279 | 280 | #[test] 281 | fn get_max_batch_size() { 282 | let logger = match LOGGER.lock() { 283 | Ok(guard) => guard, 284 | Err(poisoned) => poisoned.into_inner(), 285 | }; 286 | let engine = setup_engine_test_uff(&logger); 287 | 288 | assert_eq!(engine.get_max_batch_size(), 1); 289 | } 290 | 291 | #[test] 292 | fn get_nb_layers() { 293 | let logger = match LOGGER.lock() { 294 | Ok(guard) => guard, 295 | Err(poisoned) => poisoned.into_inner(), 296 | }; 297 | let engine = setup_engine_test_uff(&logger); 298 | 299 | assert_eq!(engine.get_nb_layers(), 7); 300 | } 301 | 302 | #[test] 303 | fn get_workspace_size() { 304 | let logger = match LOGGER.lock() { 305 | Ok(guard) => guard, 306 | Err(poisoned) => poisoned.into_inner(), 307 | }; 308 | let engine = setup_engine_test_uff(&logger); 309 | 310 | assert_eq!(engine.get_workspace_size(), 0); 311 | } 312 | 313 | #[test] 314 | fn serialize() { 315 | let logger = match LOGGER.lock() { 316 | Ok(guard) => guard, 317 | Err(poisoned) => poisoned.into_inner(), 318 | }; 319 | let uff_engine = setup_engine_test_uff(&logger); 320 | let seralized_path = Path::new("../lenet5.engine"); 321 | write(seralized_path, uff_engine.serialize()).unwrap(); 322 | 323 | assert!(seralized_path.exists()); 324 | 325 | let logger = Logger::new(); 326 | let runtime = Runtime::new(&logger); 327 | 328 | let mut f = File::open(seralized_path).unwrap(); 329 | let mut buffer = Vec::new(); 330 | f.read_to_end(&mut buffer).unwrap(); 331 | 332 | let seralized_engine = runtime.deserialize_cuda_engine(buffer); 333 | 334 | assert_eq!( 335 | uff_engine.get_nb_bindings(), 336 | seralized_engine.get_nb_bindings() 337 | ); 338 | 339 | for i in 0..uff_engine.get_nb_bindings() { 340 | assert_eq!( 341 | uff_engine.get_binding_name(i), 342 | seralized_engine.get_binding_name(i) 343 | ); 344 | assert_eq!( 345 | uff_engine.get_binding_name(i), 346 | seralized_engine.get_binding_name(i) 347 | ); 348 | } 349 | 350 | remove_file(seralized_path).unwrap(); 351 | } 352 | 353 | #[test] 354 | fn get_location() { 355 | let logger = match LOGGER.lock() { 356 | Ok(guard) => guard, 357 | Err(poisoned) => poisoned.into_inner(), 358 | }; 359 | let engine = setup_engine_test_uff(&logger); 360 | 361 | assert_eq!(engine.get_location(0), TensorLocation::Host); 362 | } 363 | 364 | #[test] 365 | fn get_device_memory_size() { 366 | let logger = match LOGGER.lock() { 367 | Ok(guard) => guard, 368 | Err(poisoned) => poisoned.into_inner(), 369 | }; 370 | let engine = setup_engine_test_uff(&logger); 371 | 372 | assert_eq!(engine.get_device_memory_size(), 57856); 373 | } 374 | 375 | #[test] 376 | fn is_refittable() { 377 | let logger = match LOGGER.lock() { 378 | Ok(guard) => guard, 379 | Err(poisoned) => poisoned.into_inner(), 380 | }; 381 | let engine = setup_engine_test_uff(&logger); 382 | 383 | assert_eq!(engine.is_refittable(), false); 384 | } 385 | } 386 | -------------------------------------------------------------------------------- /tensorrt/src/lib.rs: -------------------------------------------------------------------------------- 1 | #[macro_use] 2 | extern crate bitflags; 3 | 4 | pub use image; 5 | pub use ndarray; 6 | 7 | pub mod builder; 8 | pub mod context; 9 | pub mod data_size; 10 | pub mod dims; 11 | pub mod engine; 12 | pub mod network; 13 | pub mod onnx; 14 | pub mod profiler; 15 | pub mod runtime; 16 | pub mod uff; 17 | 18 | mod utils; 19 | -------------------------------------------------------------------------------- /tensorrt/src/network/layer/activation_layer.rs: -------------------------------------------------------------------------------- 1 | use super::*; 2 | use num_derive::FromPrimitive; 3 | use num_traits::FromPrimitive; 4 | use tensorrt_rs_derive::Layer; 5 | use tensorrt_sys::{ 6 | activation_get_activation_type, activation_get_alpha, activation_get_beta, 7 | activation_set_activation_type, activation_set_alpha, activation_set_beta, 8 | nvinfer1_IActivationLayer, 9 | }; 10 | 11 | #[repr(C)] 12 | #[derive(Debug, Eq, PartialEq, FromPrimitive)] 13 | pub enum ActivationType { 14 | Relu, 15 | Sigmoid, 16 | Tanh, 17 | LeakyRelu, 18 | Elu, 19 | Selu, 20 | SoftSign, 21 | SoftPlus, 22 | Clip, 23 | HardSigmoid, 24 | ScaledTanh, 25 | ThresholdedRelu, 26 | } 27 | 28 | #[derive(Layer)] 29 | pub struct ActivationLayer { 30 | pub(crate) internal_layer: *mut nvinfer1_IActivationLayer, 31 | } 32 | 33 | impl ActivationLayer { 34 | pub fn get_activation_type(&self) -> ActivationType { 35 | let raw = unsafe { activation_get_activation_type(self.internal_layer) }; 36 | FromPrimitive::from_i32(raw).unwrap() 37 | } 38 | 39 | pub fn set_activation_type(&self, activation_type: ActivationType) { 40 | unsafe { activation_set_activation_type(self.internal_layer, activation_type as c_int) } 41 | } 42 | 43 | pub fn get_alpha(&self) -> f32 { 44 | unsafe { activation_get_alpha(self.internal_layer) } 45 | } 46 | 47 | pub fn set_alpha(&self, alpha: f32) { 48 | unsafe { activation_set_alpha(self.internal_layer, alpha) } 49 | } 50 | 51 | pub fn get_beta(&self) -> f32 { 52 | unsafe { activation_get_beta(self.internal_layer) } 53 | } 54 | 55 | pub fn set_beta(&self, beta: f32) { 56 | unsafe { activation_set_beta(self.internal_layer, beta) } 57 | } 58 | } 59 | 60 | #[cfg(test)] 61 | mod tests { 62 | use super::*; 63 | use crate::builder::{Builder, NetworkBuildFlags}; 64 | use crate::dims::DimsHW; 65 | use crate::network::Network; 66 | use crate::runtime::Logger; 67 | use lazy_static::lazy_static; 68 | use std::sync::Mutex; 69 | 70 | lazy_static! { 71 | static ref LOGGER: Mutex = Mutex::new(Logger::new()); 72 | } 73 | 74 | fn create_network(logger: &Logger) -> Network { 75 | let builder = Builder::new(logger); 76 | builder.create_network_v2(NetworkBuildFlags::EXPLICIT_BATCH) 77 | } 78 | 79 | #[test] 80 | fn get_activation_type() { 81 | let logger = match LOGGER.lock() { 82 | Ok(guard) => guard, 83 | Err(poisoned) => poisoned.into_inner(), 84 | }; 85 | let network = create_network(&logger); 86 | 87 | let input1 = network.add_input("new_input1", DataType::Float, DimsHW::new(10, 10)); 88 | let activation_layer = network.add_activation(&input1, ActivationType::Relu); 89 | 90 | assert_eq!(activation_layer.get_activation_type(), ActivationType::Relu); 91 | } 92 | 93 | #[test] 94 | fn set_activation_type() { 95 | let logger = match LOGGER.lock() { 96 | Ok(guard) => guard, 97 | Err(poisoned) => poisoned.into_inner(), 98 | }; 99 | let network = create_network(&logger); 100 | 101 | let input1 = network.add_input("new_input1", DataType::Float, DimsHW::new(10, 10)); 102 | let activation_layer = network.add_activation(&input1, ActivationType::Relu); 103 | 104 | activation_layer.set_activation_type(ActivationType::Sigmoid); 105 | assert_eq!( 106 | activation_layer.get_activation_type(), 107 | ActivationType::Sigmoid 108 | ); 109 | } 110 | 111 | #[test] 112 | fn set_alpha() { 113 | let logger = match LOGGER.lock() { 114 | Ok(guard) => guard, 115 | Err(poisoned) => poisoned.into_inner(), 116 | }; 117 | let network = create_network(&logger); 118 | 119 | let input1 = network.add_input("new_input1", DataType::Float, DimsHW::new(10, 10)); 120 | let activation_layer = network.add_activation(&input1, ActivationType::Relu); 121 | 122 | activation_layer.set_alpha(1.0); 123 | assert_eq!(activation_layer.get_alpha(), 1.0); 124 | } 125 | 126 | #[test] 127 | fn set_beta() { 128 | let logger = match LOGGER.lock() { 129 | Ok(guard) => guard, 130 | Err(poisoned) => poisoned.into_inner(), 131 | }; 132 | let network = create_network(&logger); 133 | 134 | let input1 = network.add_input("new_input1", DataType::Float, DimsHW::new(10, 10)); 135 | let activation_layer = network.add_activation(&input1, ActivationType::Relu); 136 | 137 | activation_layer.set_beta(2.0); 138 | assert_eq!(activation_layer.get_beta(), 2.0); 139 | } 140 | } 141 | -------------------------------------------------------------------------------- /tensorrt/src/network/layer/element_wise_layer.rs: -------------------------------------------------------------------------------- 1 | use super::*; 2 | use num_derive::FromPrimitive; 3 | use num_traits::FromPrimitive; 4 | use tensorrt_rs_derive::Layer; 5 | use tensorrt_sys::{ 6 | elementwise_get_operation, elementwise_set_operation, nvinfer1_IElementWiseLayer, 7 | }; 8 | 9 | #[repr(C)] 10 | #[derive(Debug, FromPrimitive, Eq, PartialEq)] 11 | pub enum ElementWiseOperation { 12 | Sum, 13 | Prod, 14 | Max, 15 | Min, 16 | Sub, 17 | Div, 18 | Pow, 19 | } 20 | 21 | #[derive(Layer)] 22 | pub struct ElementWiseLayer { 23 | pub(crate) internal_layer: *mut nvinfer1_IElementWiseLayer, 24 | } 25 | 26 | impl ElementWiseLayer { 27 | pub fn get_operation(&self) -> ElementWiseOperation { 28 | let raw = unsafe { elementwise_get_operation(self.internal_layer) }; 29 | FromPrimitive::from_i32(raw).unwrap() 30 | } 31 | 32 | pub fn set_operation(&self, op: ElementWiseOperation) { 33 | unsafe { elementwise_set_operation(self.internal_layer, op as c_int) } 34 | } 35 | } 36 | 37 | #[cfg(test)] 38 | mod tests { 39 | use super::*; 40 | use crate::builder::{Builder, NetworkBuildFlags}; 41 | use crate::dims::DimsHW; 42 | use crate::network::Network; 43 | use crate::runtime::Logger; 44 | use lazy_static::lazy_static; 45 | use std::sync::Mutex; 46 | 47 | lazy_static! { 48 | static ref LOGGER: Mutex = Mutex::new(Logger::new()); 49 | } 50 | 51 | fn create_network(logger: &Logger) -> Network { 52 | let builder = Builder::new(logger); 53 | builder.create_network_v2(NetworkBuildFlags::EXPLICIT_BATCH) 54 | } 55 | 56 | #[test] 57 | fn get_operation() { 58 | let logger = match LOGGER.lock() { 59 | Ok(guard) => guard, 60 | Err(poisoned) => poisoned.into_inner(), 61 | }; 62 | let network = create_network(&logger); 63 | 64 | let input1 = network.add_input("new_input1", DataType::Float, DimsHW::new(1, 1)); 65 | let input2 = network.add_input("new_input2", DataType::Float, DimsHW::new(1, 1)); 66 | let element_wise_layer = 67 | network.add_element_wise_layer(&input1, &input2, ElementWiseOperation::Sum); 68 | 69 | assert_eq!( 70 | element_wise_layer.get_operation(), 71 | ElementWiseOperation::Sum 72 | ); 73 | } 74 | 75 | #[test] 76 | fn set_operation() { 77 | let logger = match LOGGER.lock() { 78 | Ok(guard) => guard, 79 | Err(poisoned) => poisoned.into_inner(), 80 | }; 81 | let network = create_network(&logger); 82 | 83 | let input1 = network.add_input("new_input1", DataType::Float, DimsHW::new(1, 1)); 84 | let input2 = network.add_input("new_input2", DataType::Float, DimsHW::new(1, 1)); 85 | let element_wise_layer = 86 | network.add_element_wise_layer(&input1, &input2, ElementWiseOperation::Sum); 87 | element_wise_layer.set_operation(ElementWiseOperation::Prod); 88 | 89 | assert_eq!( 90 | element_wise_layer.get_operation(), 91 | ElementWiseOperation::Prod 92 | ); 93 | } 94 | } 95 | -------------------------------------------------------------------------------- /tensorrt/src/network/layer/gather_layer.rs: -------------------------------------------------------------------------------- 1 | use super::*; 2 | use tensorrt_rs_derive::Layer; 3 | use tensorrt_sys::{ 4 | gather_layer_get_gather_axis, gather_layer_set_gather_axis, nvinfer1_IGatherLayer, 5 | }; 6 | 7 | #[derive(Layer)] 8 | pub struct GatherLayer { 9 | pub(crate) internal_layer: *mut nvinfer1_IGatherLayer, 10 | } 11 | 12 | impl GatherLayer { 13 | pub fn get_gather_axis(&self) -> i32 { 14 | unsafe { gather_layer_get_gather_axis(self.internal_layer) } 15 | } 16 | 17 | pub fn set_gather_axis(&self, axis: i32) { 18 | unsafe { gather_layer_set_gather_axis(self.internal_layer, axis) } 19 | } 20 | } 21 | 22 | #[cfg(test)] 23 | mod tests { 24 | use super::*; 25 | use crate::builder::{Builder, NetworkBuildFlags}; 26 | use crate::dims::DimsHW; 27 | use crate::network::Network; 28 | use crate::runtime::Logger; 29 | use lazy_static::lazy_static; 30 | use std::sync::Mutex; 31 | 32 | lazy_static! { 33 | static ref LOGGER: Mutex = Mutex::new(Logger::new()); 34 | } 35 | 36 | fn create_network(logger: &Logger) -> Network { 37 | let builder = Builder::new(logger); 38 | builder.create_network_v2(NetworkBuildFlags::EXPLICIT_BATCH) 39 | } 40 | 41 | #[test] 42 | fn get_gather_axis() { 43 | let logger = match LOGGER.lock() { 44 | Ok(guard) => guard, 45 | Err(poisoned) => poisoned.into_inner(), 46 | }; 47 | let network = create_network(&logger); 48 | 49 | let input1 = network.add_input("new_input1", DataType::Float, DimsHW::new(10, 10)); 50 | let input2 = network.add_input("new_input2", DataType::Float, DimsHW::new(10, 10)); 51 | let gather_layer = network.add_gather_layer(&input1, &input2, 1); 52 | 53 | assert_eq!(gather_layer.get_gather_axis(), 1); 54 | } 55 | 56 | #[test] 57 | fn set_gather_axis() { 58 | let logger = match LOGGER.lock() { 59 | Ok(guard) => guard, 60 | Err(poisoned) => poisoned.into_inner(), 61 | }; 62 | let network = create_network(&logger); 63 | 64 | let input1 = network.add_input("new_input1", DataType::Float, DimsHW::new(10, 10)); 65 | let input2 = network.add_input("new_input2", DataType::Float, DimsHW::new(10, 10)); 66 | let gather_layer = network.add_gather_layer(&input1, &input2, 1); 67 | 68 | gather_layer.set_gather_axis(0); 69 | assert_eq!(gather_layer.get_gather_axis(), 0); 70 | } 71 | } 72 | -------------------------------------------------------------------------------- /tensorrt/src/network/layer/identity_layer.rs: -------------------------------------------------------------------------------- 1 | use super::*; 2 | use tensorrt_rs_derive::Layer; 3 | use tensorrt_sys::nvinfer1_IIdentityLayer; 4 | 5 | #[derive(Layer)] 6 | pub struct IdentityLayer { 7 | pub(crate) internal_layer: *mut nvinfer1_IIdentityLayer, 8 | } 9 | -------------------------------------------------------------------------------- /tensorrt/src/network/layer/mod.rs: -------------------------------------------------------------------------------- 1 | pub use activation_layer::{ActivationLayer, ActivationType}; 2 | pub use element_wise_layer::{ElementWiseLayer, ElementWiseOperation}; 3 | pub use gather_layer::GatherLayer; 4 | pub use identity_layer::IdentityLayer; 5 | pub use pooling_layer::{PaddingMode, PoolingLayer, PoolingType}; 6 | 7 | mod activation_layer; 8 | mod element_wise_layer; 9 | mod gather_layer; 10 | mod identity_layer; 11 | mod pooling_layer; 12 | 13 | use crate::engine::DataType; 14 | use crate::network::Tensor; 15 | use num_derive::FromPrimitive; 16 | use num_traits::FromPrimitive; 17 | use std::ffi::{CStr, CString}; 18 | use std::os::raw::c_int; 19 | use tensorrt_rs_derive::Layer; 20 | use tensorrt_sys::{ 21 | layer_get_input, layer_get_name, layer_get_nb_inputs, layer_get_nb_outputs, layer_get_output, 22 | layer_get_output_type, layer_get_precision, layer_get_type, layer_output_type_is_set, 23 | layer_precision_is_set, layer_reset_output_type, layer_reset_precision, layer_set_input, 24 | layer_set_name, layer_set_output_type, layer_set_precision, nvinfer1_ILayer, 25 | }; 26 | 27 | #[repr(C)] 28 | #[derive(Debug, FromPrimitive, Eq, PartialEq)] 29 | pub enum LayerType { 30 | Convolution, 31 | FullyConnected, 32 | Activation, 33 | Pooling, 34 | LRN, 35 | Scale, 36 | SoftMax, 37 | DeConvolution, 38 | Concatenation, 39 | ElementWise, 40 | Plugin, 41 | Rnn, 42 | Unary, 43 | Padding, 44 | Shuffle, 45 | Reduce, 46 | TopK, 47 | Gather, 48 | MatrixMultiply, 49 | RaggedSoftMax, 50 | Constant, 51 | RnnV2, 52 | Identity, 53 | PluginV2, 54 | Slice, 55 | } 56 | 57 | pub trait Layer: private::LayerPrivate { 58 | fn get_type(&self) -> LayerType { 59 | let raw = unsafe { layer_get_type(self.get_internal_layer()) }; 60 | FromPrimitive::from_i32(raw).unwrap() 61 | } 62 | 63 | fn set_name(&self, name: &str) { 64 | unsafe { 65 | layer_set_name( 66 | self.get_internal_layer(), 67 | CString::new(name).unwrap().as_ptr(), 68 | ) 69 | } 70 | } 71 | 72 | fn get_name(&self) -> String { 73 | let raw_string = unsafe { 74 | let ptr = layer_get_name(self.get_internal_layer()); 75 | CStr::from_ptr(ptr) 76 | }; 77 | 78 | raw_string.to_str().unwrap().to_string() 79 | } 80 | 81 | fn get_nb_inputs(&self) -> i32 { 82 | unsafe { layer_get_nb_inputs(self.get_internal_layer()) } 83 | } 84 | 85 | fn get_input(&self, index: i32) -> Tensor { 86 | let internal_tensor = unsafe { layer_get_input(self.get_internal_layer(), index) }; 87 | Tensor { internal_tensor } 88 | } 89 | 90 | fn set_input(&self, index: i32, tensor: &Tensor) { 91 | unsafe { layer_set_input(self.get_internal_layer(), index, tensor.internal_tensor) } 92 | } 93 | 94 | fn get_nb_outputs(&self) -> i32 { 95 | unsafe { layer_get_nb_outputs(self.get_internal_layer()) } 96 | } 97 | 98 | fn get_output(&self, index: i32) -> Tensor { 99 | let internal_tensor = unsafe { layer_get_output(self.get_internal_layer(), index) }; 100 | Tensor { internal_tensor } 101 | } 102 | 103 | fn set_precision(&self, precision: DataType) { 104 | unsafe { layer_set_precision(self.get_internal_layer(), precision as c_int) } 105 | } 106 | 107 | fn get_precision(&self) -> DataType { 108 | let raw = unsafe { layer_get_precision(self.get_internal_layer()) }; 109 | FromPrimitive::from_i32(raw).unwrap() 110 | } 111 | 112 | fn precision_is_set(&self) -> bool { 113 | unsafe { layer_precision_is_set(self.get_internal_layer()) } 114 | } 115 | 116 | fn reset_precision(&self) { 117 | unsafe { layer_reset_precision(self.get_internal_layer()) } 118 | } 119 | 120 | fn set_output_type(&self, index: i32, data_type: DataType) { 121 | unsafe { layer_set_output_type(self.get_internal_layer(), index, data_type as c_int) } 122 | } 123 | 124 | fn get_output_type(&self, index: i32) -> DataType { 125 | let raw = unsafe { layer_get_output_type(self.get_internal_layer(), index) }; 126 | FromPrimitive::from_i32(raw).unwrap() 127 | } 128 | 129 | fn output_type_is_set(&self, index: i32) -> bool { 130 | unsafe { layer_output_type_is_set(self.get_internal_layer(), index) } 131 | } 132 | 133 | fn rest_output_type(&self, index: i32) { 134 | unsafe { layer_reset_output_type(self.get_internal_layer(), index) } 135 | } 136 | } 137 | 138 | mod private { 139 | use tensorrt_sys::nvinfer1_ILayer; 140 | 141 | pub trait LayerPrivate { 142 | fn get_internal_layer(&self) -> *mut nvinfer1_ILayer; 143 | } 144 | } 145 | 146 | #[derive(Layer)] 147 | pub struct BaseLayer { 148 | pub(crate) internal_layer: *mut nvinfer1_ILayer, 149 | } 150 | 151 | #[cfg(test)] 152 | mod tests { 153 | use super::*; 154 | use crate::builder::{Builder, NetworkBuildFlags}; 155 | use crate::dims::DimsCHW; 156 | use crate::engine::DataType; 157 | use crate::network::Network; 158 | use crate::runtime::Logger; 159 | use crate::uff::{UffFile, UffInputOrder, UffParser}; 160 | use lazy_static::lazy_static; 161 | use std::env::current_dir; 162 | use std::path::Path; 163 | use std::sync::Mutex; 164 | 165 | lazy_static! { 166 | static ref LOGGER: Mutex = Mutex::new(Logger::new()); 167 | } 168 | 169 | fn create_network(logger: &Logger) -> Network { 170 | let builder = Builder::new(logger); 171 | builder.create_network_v2(NetworkBuildFlags::DEFAULT) 172 | } 173 | 174 | fn create_network_from_uff(logger: &Logger) -> Network { 175 | let builder = Builder::new(&logger); 176 | let network = builder.create_network_v2(NetworkBuildFlags::DEFAULT); 177 | 178 | let uff_parser = UffParser::new(); 179 | let dim = DimsCHW::new(1, 28, 28); 180 | 181 | uff_parser 182 | .register_input("in", dim, UffInputOrder::Nchw) 183 | .unwrap(); 184 | uff_parser.register_output("out").unwrap(); 185 | println!("{}", current_dir().unwrap().display()); 186 | let uff_file = UffFile::new(Path::new("../assets/lenet5.uff")).unwrap(); 187 | uff_parser.parse(&uff_file, &network).unwrap(); 188 | 189 | network 190 | } 191 | 192 | #[test] 193 | fn get_type() { 194 | let logger = match LOGGER.lock() { 195 | Ok(guard) => guard, 196 | Err(poisoned) => poisoned.into_inner(), 197 | }; 198 | let network = create_network_from_uff(&logger); 199 | 200 | let layer = network.get_layer(1); 201 | assert_eq!(layer.get_type(), LayerType::Convolution); 202 | } 203 | 204 | #[test] 205 | fn set_name() { 206 | let logger = match LOGGER.lock() { 207 | Ok(guard) => guard, 208 | Err(poisoned) => poisoned.into_inner(), 209 | }; 210 | let network = create_network_from_uff(&logger); 211 | 212 | let layer = network.get_layer(1); 213 | assert_eq!(layer.get_name(), "conv1"); 214 | layer.set_name("first_conv"); 215 | assert_eq!(layer.get_name(), "first_conv"); 216 | } 217 | 218 | #[test] 219 | fn get_name() { 220 | let logger = match LOGGER.lock() { 221 | Ok(guard) => guard, 222 | Err(poisoned) => poisoned.into_inner(), 223 | }; 224 | let network = create_network_from_uff(&logger); 225 | 226 | let layer = network.get_layer(1); 227 | assert_eq!(layer.get_name(), "conv1"); 228 | } 229 | 230 | #[test] 231 | fn get_nb_inputs() { 232 | let logger = match LOGGER.lock() { 233 | Ok(guard) => guard, 234 | Err(poisoned) => poisoned.into_inner(), 235 | }; 236 | let network = create_network_from_uff(&logger); 237 | 238 | let layer = network.get_layer(1); 239 | 240 | assert_eq!(layer.get_nb_inputs(), 1); 241 | } 242 | 243 | #[test] 244 | fn get_nb_outputs() { 245 | let logger = match LOGGER.lock() { 246 | Ok(guard) => guard, 247 | Err(poisoned) => poisoned.into_inner(), 248 | }; 249 | let network = create_network_from_uff(&logger); 250 | let layer = network.get_layer(1); 251 | 252 | assert_eq!(layer.get_nb_outputs(), 1); 253 | } 254 | 255 | #[test] 256 | fn get_output() { 257 | let logger = match LOGGER.lock() { 258 | Ok(guard) => guard, 259 | Err(poisoned) => poisoned.into_inner(), 260 | }; 261 | let network = create_network_from_uff(&logger); 262 | let layer = network.get_layer(1); 263 | 264 | assert_eq!(layer.get_output(0).get_name(), "conv1"); 265 | } 266 | 267 | #[test] 268 | fn set_input() { 269 | let logger = match LOGGER.lock() { 270 | Ok(guard) => guard, 271 | Err(poisoned) => poisoned.into_inner(), 272 | }; 273 | let uff_network = create_network_from_uff(&logger); 274 | let output_tensor = uff_network.get_layer(21).get_output(0); 275 | 276 | let network = create_network(&logger); 277 | let tensor = network.add_input("new_input", DataType::Float, DimsCHW::new(1, 28, 28)); 278 | let layer = network.add_identity_layer(&tensor); 279 | 280 | assert_eq!(layer.get_input(0).get_name(), "new_input"); 281 | layer.set_input(0, &output_tensor); 282 | assert_eq!(layer.get_input(0).get_name(), "matmul2"); 283 | } 284 | 285 | #[test] 286 | fn get_precision() { 287 | let logger = match LOGGER.lock() { 288 | Ok(guard) => guard, 289 | Err(poisoned) => poisoned.into_inner(), 290 | }; 291 | let network = create_network_from_uff(&logger); 292 | let layer = network.get_layer(1); 293 | 294 | assert_eq!(layer.get_precision(), DataType::Float); 295 | } 296 | 297 | #[test] 298 | fn set_precision() { 299 | let logger = match LOGGER.lock() { 300 | Ok(guard) => guard, 301 | Err(poisoned) => poisoned.into_inner(), 302 | }; 303 | let network = create_network_from_uff(&logger); 304 | let layer = network.get_layer(1); 305 | 306 | layer.set_precision(DataType::Half); 307 | assert_eq!(layer.get_precision(), DataType::Half); 308 | } 309 | 310 | #[test] 311 | fn precision_is_set_true() { 312 | let logger = match LOGGER.lock() { 313 | Ok(guard) => guard, 314 | Err(poisoned) => poisoned.into_inner(), 315 | }; 316 | let network = create_network_from_uff(&logger); 317 | let layer = network.get_layer(1); 318 | layer.set_precision(DataType::Half); 319 | 320 | assert_eq!(layer.precision_is_set(), true); 321 | } 322 | 323 | #[test] 324 | fn precision_is_set_false() { 325 | let logger = match LOGGER.lock() { 326 | Ok(guard) => guard, 327 | Err(poisoned) => poisoned.into_inner(), 328 | }; 329 | let network = create_network_from_uff(&logger); 330 | let layer = network.get_layer(1); 331 | 332 | assert_eq!(layer.precision_is_set(), false); 333 | } 334 | 335 | #[test] 336 | fn reset_precision() { 337 | let logger = match LOGGER.lock() { 338 | Ok(guard) => guard, 339 | Err(poisoned) => poisoned.into_inner(), 340 | }; 341 | let network = create_network_from_uff(&logger); 342 | let layer = network.get_layer(1); 343 | 344 | layer.set_precision(DataType::Half); 345 | assert_eq!(layer.precision_is_set(), true); 346 | layer.reset_precision(); 347 | assert_eq!(layer.precision_is_set(), false); 348 | } 349 | 350 | #[test] 351 | fn set_output_type() { 352 | let logger = match LOGGER.lock() { 353 | Ok(guard) => guard, 354 | Err(poisoned) => poisoned.into_inner(), 355 | }; 356 | let network = create_network_from_uff(&logger); 357 | let layer = network.get_layer(1); 358 | 359 | layer.set_output_type(0, DataType::Half); 360 | assert_eq!(layer.get_output_type(0), DataType::Half); 361 | } 362 | 363 | #[test] 364 | fn output_type_is_set() { 365 | let logger = match LOGGER.lock() { 366 | Ok(guard) => guard, 367 | Err(poisoned) => poisoned.into_inner(), 368 | }; 369 | let network = create_network_from_uff(&logger); 370 | let layer = network.get_layer(1); 371 | 372 | layer.set_output_type(0, DataType::Half); 373 | assert_eq!(layer.output_type_is_set(0), true); 374 | } 375 | 376 | #[test] 377 | fn rest_output_type() { 378 | let logger = match LOGGER.lock() { 379 | Ok(guard) => guard, 380 | Err(poisoned) => poisoned.into_inner(), 381 | }; 382 | let network = create_network_from_uff(&logger); 383 | let layer = network.get_layer(1); 384 | 385 | layer.set_output_type(0, DataType::Half); 386 | assert_eq!(layer.output_type_is_set(0), true); 387 | layer.rest_output_type(0); 388 | assert_eq!(layer.output_type_is_set(0), false); 389 | } 390 | } 391 | -------------------------------------------------------------------------------- /tensorrt/src/network/layer/pooling_layer.rs: -------------------------------------------------------------------------------- 1 | use super::*; 2 | use crate::dims::{Dim, Dims, DimsHW}; 3 | use num_derive::FromPrimitive; 4 | use num_traits::FromPrimitive; 5 | use tensorrt_rs_derive::Layer; 6 | use tensorrt_sys::{ 7 | nvinfer1_IPoolingLayer, pooling_get_average_count_excludes_padding, pooling_get_blend_factor, 8 | pooling_get_padding, pooling_get_padding_mode, pooling_get_pooling_type, 9 | pooling_get_post_padding, pooling_get_pre_padding, pooling_get_stride, pooling_get_window_size, 10 | pooling_set_average_count_excludes_padding, pooling_set_blend_factor, pooling_set_padding, 11 | pooling_set_padding_mode, pooling_set_pooling_type, pooling_set_post_padding, 12 | pooling_set_pre_padding, pooling_set_stride, pooling_set_window_size, 13 | }; 14 | 15 | #[repr(C)] 16 | #[derive(FromPrimitive, Debug, Eq, PartialEq)] 17 | pub enum PoolingType { 18 | Max, 19 | Average, 20 | MaxAverageBlend, 21 | } 22 | 23 | #[repr(C)] 24 | #[derive(FromPrimitive, Debug, Eq, PartialEq)] 25 | pub enum PaddingMode { 26 | ExplicitRoundDown, 27 | ExplicitRoundUp, 28 | SameUpper, 29 | SameLower, 30 | CaffeRoundDown, 31 | CaffeRoundUp, 32 | } 33 | 34 | #[derive(Layer)] 35 | pub struct PoolingLayer { 36 | pub(crate) internal_layer: *mut nvinfer1_IPoolingLayer, 37 | } 38 | 39 | impl PoolingLayer { 40 | pub fn get_pooling_type(&self) -> PoolingType { 41 | let raw = unsafe { pooling_get_pooling_type(self.internal_layer) }; 42 | FromPrimitive::from_i32(raw).unwrap() 43 | } 44 | 45 | pub fn set_pooling_type(&self, pooling_type: PoolingType) { 46 | unsafe { pooling_set_pooling_type(self.internal_layer, pooling_type as i32) } 47 | } 48 | 49 | pub fn get_window_size(&self) -> DimsHW { 50 | let raw = unsafe { pooling_get_window_size(self.internal_layer) }; 51 | DimsHW(raw) 52 | } 53 | 54 | pub fn set_window_size(&self, dims: DimsHW) { 55 | unsafe { pooling_set_window_size(self.internal_layer, dims.0) } 56 | } 57 | 58 | pub fn get_stride(&self) -> DimsHW { 59 | let raw = unsafe { pooling_get_stride(self.internal_layer) }; 60 | DimsHW(raw) 61 | } 62 | 63 | pub fn set_stride(&self, dims: DimsHW) { 64 | unsafe { pooling_set_stride(self.internal_layer, dims.0) } 65 | } 66 | 67 | pub fn get_padding(&self) -> DimsHW { 68 | let raw = unsafe { pooling_get_padding(self.internal_layer) }; 69 | DimsHW(raw) 70 | } 71 | 72 | pub fn set_padding(&self, padding: DimsHW) { 73 | unsafe { pooling_set_padding(self.internal_layer, padding.0) } 74 | } 75 | 76 | pub fn get_blend_factor(&self) -> f32 { 77 | unsafe { pooling_get_blend_factor(self.internal_layer) } 78 | } 79 | 80 | pub fn set_blend_factor(&self, factor: f32) { 81 | unsafe { pooling_set_blend_factor(self.internal_layer, factor) } 82 | } 83 | 84 | pub fn get_average_count_excludes_padding(&self) -> bool { 85 | unsafe { pooling_get_average_count_excludes_padding(self.internal_layer) } 86 | } 87 | 88 | pub fn set_average_count_excludes_padding(&self, exclusive: bool) { 89 | unsafe { pooling_set_average_count_excludes_padding(self.internal_layer, exclusive) } 90 | } 91 | 92 | pub fn get_pre_padding(&self) -> Dims { 93 | let raw = unsafe { pooling_get_pre_padding(self.internal_layer) }; 94 | Dims(raw) 95 | } 96 | 97 | pub fn set_pre_padding(&self, padding: T) { 98 | unsafe { pooling_set_pre_padding(self.internal_layer, padding.get_internal_dims()) } 99 | } 100 | 101 | pub fn get_post_padding(&self) -> Dims { 102 | let raw = unsafe { pooling_get_post_padding(self.internal_layer) }; 103 | Dims(raw) 104 | } 105 | 106 | pub fn set_post_padding(&self, padding: T) { 107 | unsafe { pooling_set_post_padding(self.internal_layer, padding.get_internal_dims()) } 108 | } 109 | 110 | pub fn get_padding_mode(&self) -> PaddingMode { 111 | let raw = unsafe { pooling_get_padding_mode(self.internal_layer) }; 112 | FromPrimitive::from_i32(raw).unwrap() 113 | } 114 | 115 | pub fn set_padding_mode(&self, mode: PaddingMode) { 116 | unsafe { pooling_set_padding_mode(self.internal_layer, mode as i32) } 117 | } 118 | } 119 | 120 | #[cfg(test)] 121 | mod tests { 122 | use super::*; 123 | use crate::builder::{Builder, NetworkBuildFlags}; 124 | use crate::dims::{Dim, DimsCHW, DimsHW}; 125 | use crate::network::Network; 126 | use crate::runtime::Logger; 127 | use lazy_static::lazy_static; 128 | use std::sync::Mutex; 129 | 130 | lazy_static! { 131 | static ref LOGGER: Mutex = Mutex::new(Logger::new()); 132 | } 133 | 134 | fn create_network(logger: &Logger) -> Network { 135 | let builder = Builder::new(logger); 136 | builder.create_network_v2(NetworkBuildFlags::EXPLICIT_BATCH) 137 | } 138 | 139 | #[test] 140 | fn get_pooling_type() { 141 | let logger = match LOGGER.lock() { 142 | Ok(guard) => guard, 143 | Err(poisoned) => poisoned.into_inner(), 144 | }; 145 | let network = create_network(&logger); 146 | let input1 = network.add_input("new_input1", DataType::Float, DimsCHW::new(1, 28, 28)); 147 | let pooling = network.add_pooling(&input1, PoolingType::Max, DimsHW::new(10, 10)); 148 | 149 | assert_eq!(pooling.get_pooling_type(), PoolingType::Max); 150 | } 151 | 152 | #[test] 153 | fn set_pooling_type() { 154 | let logger = match LOGGER.lock() { 155 | Ok(guard) => guard, 156 | Err(poisoned) => poisoned.into_inner(), 157 | }; 158 | let network = create_network(&logger); 159 | let input1 = network.add_input("new_input1", DataType::Float, DimsCHW::new(1, 28, 28)); 160 | let pooling = network.add_pooling(&input1, PoolingType::Max, DimsHW::new(10, 10)); 161 | 162 | pooling.set_pooling_type(PoolingType::Average); 163 | 164 | assert_eq!(pooling.get_pooling_type(), PoolingType::Average); 165 | } 166 | 167 | #[test] 168 | fn get_window_size() { 169 | let logger = match LOGGER.lock() { 170 | Ok(guard) => guard, 171 | Err(poisoned) => poisoned.into_inner(), 172 | }; 173 | let network = create_network(&logger); 174 | let input1 = network.add_input("new_input1", DataType::Float, DimsCHW::new(1, 28, 28)); 175 | let pooling = network.add_pooling(&input1, PoolingType::Max, DimsHW::new(10, 10)); 176 | 177 | let dims = pooling.get_window_size(); 178 | assert_eq!(dims.d()[0], 10); 179 | assert_eq!(dims.d()[1], 10); 180 | } 181 | 182 | #[test] 183 | fn set_window_size() { 184 | let logger = match LOGGER.lock() { 185 | Ok(guard) => guard, 186 | Err(poisoned) => poisoned.into_inner(), 187 | }; 188 | let network = create_network(&logger); 189 | let input1 = network.add_input("new_input1", DataType::Float, DimsCHW::new(1, 28, 28)); 190 | let pooling = network.add_pooling(&input1, PoolingType::Max, DimsHW::new(10, 10)); 191 | 192 | pooling.set_window_size(DimsHW::new(20, 20)); 193 | 194 | let dims = pooling.get_window_size(); 195 | assert_eq!(dims.d()[0], 20); 196 | assert_eq!(dims.d()[1], 20); 197 | } 198 | 199 | #[test] 200 | fn get_stride() { 201 | let logger = match LOGGER.lock() { 202 | Ok(guard) => guard, 203 | Err(poisoned) => poisoned.into_inner(), 204 | }; 205 | let network = create_network(&logger); 206 | let input1 = network.add_input("new_input1", DataType::Float, DimsCHW::new(1, 28, 28)); 207 | let pooling = network.add_pooling(&input1, PoolingType::Max, DimsHW::new(10, 10)); 208 | 209 | let stride = pooling.get_stride(); 210 | assert_eq!(stride.d()[0], 10); 211 | assert_eq!(stride.d()[1], 10); 212 | } 213 | 214 | #[test] 215 | fn set_stride() { 216 | let logger = match LOGGER.lock() { 217 | Ok(guard) => guard, 218 | Err(poisoned) => poisoned.into_inner(), 219 | }; 220 | let network = create_network(&logger); 221 | let input1 = network.add_input("new_input1", DataType::Float, DimsCHW::new(1, 28, 28)); 222 | let pooling = network.add_pooling(&input1, PoolingType::Max, DimsHW::new(10, 10)); 223 | 224 | pooling.set_stride(DimsHW::new(20, 20)); 225 | let stride = pooling.get_stride(); 226 | 227 | assert_eq!(stride.d()[0], 20); 228 | assert_eq!(stride.d()[1], 20); 229 | } 230 | 231 | #[test] 232 | fn get_padding() { 233 | let logger = match LOGGER.lock() { 234 | Ok(guard) => guard, 235 | Err(poisoned) => poisoned.into_inner(), 236 | }; 237 | let network = create_network(&logger); 238 | let input1 = network.add_input("new_input1", DataType::Float, DimsCHW::new(1, 28, 28)); 239 | let pooling = network.add_pooling(&input1, PoolingType::Max, DimsHW::new(10, 10)); 240 | 241 | let padding = pooling.get_padding(); 242 | assert_eq!(padding.d()[0], 0); 243 | assert_eq!(padding.d()[1], 0); 244 | } 245 | 246 | #[test] 247 | fn set_padding() { 248 | let logger = match LOGGER.lock() { 249 | Ok(guard) => guard, 250 | Err(poisoned) => poisoned.into_inner(), 251 | }; 252 | let network = create_network(&logger); 253 | let input1 = network.add_input("new_input1", DataType::Float, DimsCHW::new(1, 28, 28)); 254 | let pooling = network.add_pooling(&input1, PoolingType::Max, DimsHW::new(10, 10)); 255 | 256 | pooling.set_padding(DimsHW::new(0, 10)); 257 | let padding = pooling.get_padding(); 258 | assert_eq!(padding.d()[0], 0); 259 | assert_eq!(padding.d()[1], 10); 260 | } 261 | 262 | #[test] 263 | fn get_blend_factor() { 264 | let logger = match LOGGER.lock() { 265 | Ok(guard) => guard, 266 | Err(poisoned) => poisoned.into_inner(), 267 | }; 268 | let network = create_network(&logger); 269 | let input1 = network.add_input("new_input1", DataType::Float, DimsCHW::new(1, 28, 28)); 270 | let pooling = 271 | network.add_pooling(&input1, PoolingType::MaxAverageBlend, DimsHW::new(10, 10)); 272 | 273 | assert_eq!(pooling.get_blend_factor(), 0.0); 274 | } 275 | 276 | #[test] 277 | fn set_blend_factor() { 278 | let logger = match LOGGER.lock() { 279 | Ok(guard) => guard, 280 | Err(poisoned) => poisoned.into_inner(), 281 | }; 282 | let network = create_network(&logger); 283 | let input1 = network.add_input("new_input1", DataType::Float, DimsCHW::new(1, 28, 28)); 284 | let pooling = 285 | network.add_pooling(&input1, PoolingType::MaxAverageBlend, DimsHW::new(10, 10)); 286 | 287 | pooling.set_blend_factor(0.5); 288 | assert_eq!(pooling.get_blend_factor(), 0.5); 289 | } 290 | 291 | #[test] 292 | fn get_average_count_excludes_padding() { 293 | let logger = match LOGGER.lock() { 294 | Ok(guard) => guard, 295 | Err(poisoned) => poisoned.into_inner(), 296 | }; 297 | let network = create_network(&logger); 298 | let input1 = network.add_input("new_input1", DataType::Float, DimsCHW::new(1, 28, 28)); 299 | let pooling = network.add_pooling(&input1, PoolingType::Average, DimsHW::new(10, 10)); 300 | 301 | assert_eq!(pooling.get_average_count_excludes_padding(), true); 302 | } 303 | 304 | #[test] 305 | fn set_average_count_excludes_padding() { 306 | let logger = match LOGGER.lock() { 307 | Ok(guard) => guard, 308 | Err(poisoned) => poisoned.into_inner(), 309 | }; 310 | let network = create_network(&logger); 311 | let input1 = network.add_input("new_input1", DataType::Float, DimsCHW::new(1, 28, 28)); 312 | let pooling = network.add_pooling(&input1, PoolingType::Average, DimsHW::new(10, 10)); 313 | 314 | pooling.set_average_count_excludes_padding(false); 315 | assert_eq!(pooling.get_average_count_excludes_padding(), false); 316 | } 317 | 318 | #[test] 319 | fn get_pre_padding() { 320 | let logger = match LOGGER.lock() { 321 | Ok(guard) => guard, 322 | Err(poisoned) => poisoned.into_inner(), 323 | }; 324 | let network = create_network(&logger); 325 | let input1 = network.add_input("new_input1", DataType::Float, DimsCHW::new(1, 28, 28)); 326 | let pooling = network.add_pooling(&input1, PoolingType::Average, DimsHW::new(10, 10)); 327 | 328 | let padding = pooling.get_pre_padding(); 329 | assert_eq!(padding.d()[0], 0); 330 | assert_eq!(padding.d()[1], 0); 331 | assert_eq!(padding.nb_dims(), 2); 332 | } 333 | 334 | #[test] 335 | fn set_pre_padding() { 336 | let logger = match LOGGER.lock() { 337 | Ok(guard) => guard, 338 | Err(poisoned) => poisoned.into_inner(), 339 | }; 340 | let network = create_network(&logger); 341 | let input1 = network.add_input("new_input1", DataType::Float, DimsCHW::new(1, 28, 28)); 342 | let pooling = network.add_pooling(&input1, PoolingType::Average, DimsHW::new(10, 10)); 343 | 344 | pooling.set_pre_padding(DimsHW::new(10, 10)); 345 | let padding = pooling.get_pre_padding(); 346 | assert_eq!(padding.d()[0], 10); 347 | assert_eq!(padding.d()[1], 10); 348 | assert_eq!(padding.nb_dims(), 2); 349 | } 350 | 351 | #[test] 352 | fn get_post_padding() { 353 | let logger = match LOGGER.lock() { 354 | Ok(guard) => guard, 355 | Err(poisoned) => poisoned.into_inner(), 356 | }; 357 | let network = create_network(&logger); 358 | let input1 = network.add_input("new_input1", DataType::Float, DimsCHW::new(1, 28, 28)); 359 | let pooling = network.add_pooling(&input1, PoolingType::Average, DimsHW::new(10, 10)); 360 | 361 | let padding = pooling.get_post_padding(); 362 | assert_eq!(padding.d()[0], 0); 363 | assert_eq!(padding.d()[1], 0); 364 | assert_eq!(padding.nb_dims(), 2); 365 | } 366 | 367 | #[test] 368 | fn set_post_padding() { 369 | let logger = match LOGGER.lock() { 370 | Ok(guard) => guard, 371 | Err(poisoned) => poisoned.into_inner(), 372 | }; 373 | let network = create_network(&logger); 374 | let input1 = network.add_input("new_input1", DataType::Float, DimsCHW::new(1, 28, 28)); 375 | let pooling = network.add_pooling(&input1, PoolingType::Average, DimsHW::new(10, 10)); 376 | 377 | pooling.set_post_padding(DimsHW::new(10, 10)); 378 | let padding = pooling.get_post_padding(); 379 | assert_eq!(padding.d()[0], 10); 380 | assert_eq!(padding.d()[1], 10); 381 | assert_eq!(padding.nb_dims(), 2); 382 | } 383 | 384 | #[test] 385 | fn get_padding_mode() { 386 | let logger = match LOGGER.lock() { 387 | Ok(guard) => guard, 388 | Err(poisoned) => poisoned.into_inner(), 389 | }; 390 | let network = create_network(&logger); 391 | let input1 = network.add_input("new_input1", DataType::Float, DimsCHW::new(1, 28, 28)); 392 | let pooling = network.add_pooling(&input1, PoolingType::Average, DimsHW::new(10, 10)); 393 | 394 | assert_eq!(pooling.get_padding_mode(), PaddingMode::ExplicitRoundDown); 395 | } 396 | 397 | #[test] 398 | fn set_padding_mode() { 399 | let logger = match LOGGER.lock() { 400 | Ok(guard) => guard, 401 | Err(poisoned) => poisoned.into_inner(), 402 | }; 403 | let network = create_network(&logger); 404 | let input1 = network.add_input("new_input1", DataType::Float, DimsCHW::new(1, 28, 28)); 405 | let pooling = network.add_pooling(&input1, PoolingType::Average, DimsHW::new(10, 10)); 406 | 407 | pooling.set_padding_mode(PaddingMode::ExplicitRoundUp); 408 | assert_eq!(pooling.get_padding_mode(), PaddingMode::ExplicitRoundUp); 409 | } 410 | } 411 | -------------------------------------------------------------------------------- /tensorrt/src/network/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod layer; 2 | 3 | use crate::dims::{Dim, DimsHW}; 4 | use crate::engine::DataType; 5 | use layer::*; 6 | use std::ffi::{CStr, CString}; 7 | use std::os::raw::c_int; 8 | use tensorrt_sys::{ 9 | destroy_network, network_add_activation, network_add_element_wise, network_add_gather, 10 | network_add_identity_layer, network_add_input, network_add_pooling, network_get_input, 11 | network_get_layer, network_get_nb_inputs, network_get_nb_layers, network_get_nb_outputs, 12 | network_get_output, network_mark_output, network_remove_tensor, network_unmark_output, 13 | nvinfer1_INetworkDefinition, nvinfer1_ITensor, tensor_get_name, tensor_set_dimensions, 14 | }; 15 | 16 | pub struct Network { 17 | pub(crate) internal_network: *mut nvinfer1_INetworkDefinition, 18 | } 19 | 20 | impl Network { 21 | pub fn get_nb_inputs(&self) -> i32 { 22 | unsafe { network_get_nb_inputs(self.internal_network) } 23 | } 24 | 25 | pub fn add_input(&self, name: &str, data_type: DataType, dims: T) -> Tensor { 26 | let internal_tensor = unsafe { 27 | network_add_input( 28 | self.internal_network, 29 | CString::new(name).unwrap().as_ptr(), 30 | data_type as c_int, 31 | dims.get_internal_dims(), 32 | ) 33 | }; 34 | Tensor { internal_tensor } 35 | } 36 | 37 | pub fn get_input(&self, idx: i32) -> Tensor { 38 | let internal_tensor = unsafe { network_get_input(self.internal_network, idx) }; 39 | Tensor { internal_tensor } 40 | } 41 | 42 | pub fn get_nb_layers(&self) -> i32 { 43 | unsafe { network_get_nb_layers(self.internal_network) } 44 | } 45 | 46 | pub fn get_layer(&self, index: i32) -> BaseLayer { 47 | let internal_layer = unsafe { network_get_layer(self.internal_network, index) }; 48 | BaseLayer { internal_layer } 49 | } 50 | 51 | pub fn get_nb_outputs(&self) -> i32 { 52 | unsafe { network_get_nb_outputs(self.internal_network) } 53 | } 54 | 55 | pub fn get_output(&self, index: i32) -> Tensor { 56 | let internal_tensor = unsafe { network_get_output(self.internal_network, index) }; 57 | Tensor { internal_tensor } 58 | } 59 | 60 | pub fn remove_tensor(&self, tensor: &Tensor) { 61 | unsafe { network_remove_tensor(self.internal_network, tensor.internal_tensor) } 62 | } 63 | 64 | pub fn mark_output(&self, output_tensor: &Tensor) { 65 | unsafe { network_mark_output(self.internal_network, output_tensor.internal_tensor) } 66 | } 67 | 68 | pub fn unmark_output(&self, output_tensor: &Tensor) { 69 | unsafe { network_unmark_output(self.internal_network, output_tensor.internal_tensor) } 70 | } 71 | 72 | pub fn add_identity_layer(&self, input_tensor: &Tensor) -> IdentityLayer { 73 | let internal_layer = unsafe { 74 | network_add_identity_layer(self.internal_network, input_tensor.internal_tensor) 75 | }; 76 | IdentityLayer { internal_layer } 77 | } 78 | 79 | pub fn add_element_wise_layer( 80 | &self, 81 | input_tensor1: &Tensor, 82 | input_tensor2: &Tensor, 83 | op: ElementWiseOperation, 84 | ) -> ElementWiseLayer { 85 | let internal_layer = unsafe { 86 | network_add_element_wise( 87 | self.internal_network, 88 | input_tensor1.internal_tensor, 89 | input_tensor2.internal_tensor, 90 | op as c_int, 91 | ) 92 | }; 93 | ElementWiseLayer { internal_layer } 94 | } 95 | 96 | pub fn add_gather_layer(&self, data: &Tensor, indicies: &Tensor, axis: i32) -> GatherLayer { 97 | let internal_layer = unsafe { 98 | network_add_gather( 99 | self.internal_network, 100 | data.internal_tensor, 101 | indicies.internal_tensor, 102 | axis, 103 | ) 104 | }; 105 | GatherLayer { internal_layer } 106 | } 107 | 108 | pub fn add_activation( 109 | &self, 110 | input: &Tensor, 111 | activation_type: ActivationType, 112 | ) -> ActivationLayer { 113 | let internal_layer = unsafe { 114 | network_add_activation( 115 | self.internal_network, 116 | input.internal_tensor, 117 | activation_type as c_int, 118 | ) 119 | }; 120 | ActivationLayer { internal_layer } 121 | } 122 | 123 | pub fn add_pooling( 124 | &self, 125 | input: &Tensor, 126 | pooling_type: PoolingType, 127 | window_size: DimsHW, 128 | ) -> PoolingLayer { 129 | let internal_layer = unsafe { 130 | network_add_pooling( 131 | self.internal_network, 132 | input.internal_tensor, 133 | pooling_type as c_int, 134 | window_size.0, 135 | ) 136 | }; 137 | PoolingLayer { internal_layer } 138 | } 139 | } 140 | 141 | impl Drop for Network { 142 | fn drop(&mut self) { 143 | unsafe { destroy_network(self.internal_network) }; 144 | } 145 | } 146 | 147 | pub struct Tensor { 148 | pub(crate) internal_tensor: *mut nvinfer1_ITensor, 149 | } 150 | 151 | impl Tensor { 152 | pub fn get_name(&self) -> String { 153 | unsafe { 154 | CStr::from_ptr(tensor_get_name(self.internal_tensor)) 155 | .to_str() 156 | .unwrap() 157 | .to_owned() 158 | } 159 | } 160 | 161 | pub fn set_dimensions(&mut self, dims: D) { 162 | unsafe { tensor_set_dimensions(self.internal_tensor, dims.get_internal_dims()) }; 163 | } 164 | } 165 | 166 | #[cfg(test)] 167 | mod tests { 168 | use super::*; 169 | use crate::builder::{Builder, NetworkBuildFlags}; 170 | use crate::dims::{DimsCHW, DimsHW}; 171 | use crate::runtime::Logger; 172 | use crate::uff::{UffFile, UffInputOrder, UffParser}; 173 | use layer::LayerType; 174 | use lazy_static::lazy_static; 175 | use std::path::Path; 176 | use std::sync::Mutex; 177 | 178 | lazy_static! { 179 | static ref LOGGER: Mutex = Mutex::new(Logger::new()); 180 | } 181 | 182 | fn create_network(logger: &Logger) -> Network { 183 | let builder = Builder::new(logger); 184 | builder.create_network_v2(NetworkBuildFlags::DEFAULT) 185 | } 186 | 187 | fn create_network_from_uff(logger: &Logger) -> Network { 188 | let builder = Builder::new(&logger); 189 | let network = builder.create_network_v2(NetworkBuildFlags::DEFAULT); 190 | 191 | let uff_parser = UffParser::new(); 192 | let dim = DimsCHW::new(1, 28, 28); 193 | 194 | uff_parser 195 | .register_input("in", dim, UffInputOrder::Nchw) 196 | .unwrap(); 197 | uff_parser.register_output("out").unwrap(); 198 | let uff_file = UffFile::new(Path::new("../assets/lenet5.uff")).unwrap(); 199 | uff_parser.parse(&uff_file, &network).unwrap(); 200 | 201 | network 202 | } 203 | 204 | #[test] 205 | fn get_nb_layers_uff() { 206 | let logger = match LOGGER.lock() { 207 | Ok(guard) => guard, 208 | Err(poisoned) => poisoned.into_inner(), 209 | }; 210 | let network = create_network_from_uff(&logger); 211 | 212 | assert_eq!(network.get_nb_layers(), 24); 213 | } 214 | 215 | #[test] 216 | fn layer_name() { 217 | let logger = match LOGGER.lock() { 218 | Ok(guard) => guard, 219 | Err(poisoned) => poisoned.into_inner(), 220 | }; 221 | let network = create_network_from_uff(&logger); 222 | 223 | let layer = network.get_layer(0); 224 | assert_eq!(layer.get_name(), "wc1"); 225 | } 226 | 227 | #[test] 228 | fn get_nb_inputs() { 229 | let logger = match LOGGER.lock() { 230 | Ok(guard) => guard, 231 | Err(poisoned) => poisoned.into_inner(), 232 | }; 233 | let network = create_network_from_uff(&logger); 234 | 235 | assert_eq!(network.get_nb_inputs(), 1); 236 | } 237 | 238 | #[test] 239 | fn add_input() { 240 | let logger = match LOGGER.lock() { 241 | Ok(guard) => guard, 242 | Err(poisoned) => poisoned.into_inner(), 243 | }; 244 | let network = create_network(&logger); 245 | 246 | let tensor = network.add_input("new_input", DataType::Float, DimsCHW::new(1, 28, 28)); 247 | assert_eq!(tensor.get_name(), "new_input"); 248 | } 249 | 250 | #[test] 251 | fn get_input() { 252 | let logger = match LOGGER.lock() { 253 | Ok(guard) => guard, 254 | Err(poisoned) => poisoned.into_inner(), 255 | }; 256 | let network = create_network_from_uff(&logger); 257 | 258 | assert_eq!(network.get_input(0).get_name(), "in"); 259 | } 260 | 261 | #[test] 262 | fn get_nb_outputs() { 263 | let logger = match LOGGER.lock() { 264 | Ok(guard) => guard, 265 | Err(poisoned) => poisoned.into_inner(), 266 | }; 267 | let network = create_network_from_uff(&logger); 268 | 269 | assert_eq!(network.get_nb_outputs(), 1); 270 | } 271 | 272 | #[test] 273 | fn get_output() { 274 | let logger = match LOGGER.lock() { 275 | Ok(guard) => guard, 276 | Err(poisoned) => poisoned.into_inner(), 277 | }; 278 | let network = create_network_from_uff(&logger); 279 | 280 | assert_eq!(network.get_output(0).get_name(), "out"); 281 | } 282 | 283 | #[test] 284 | fn remove_tensor() { 285 | let logger = match LOGGER.lock() { 286 | Ok(guard) => guard, 287 | Err(poisoned) => poisoned.into_inner(), 288 | }; 289 | let uff_network = create_network_from_uff(&logger); 290 | let output_tensor = uff_network.get_layer(21).get_output(0); 291 | 292 | let network = create_network(&logger); 293 | let tensor = network.add_input("new_input", DataType::Float, DimsCHW::new(1, 28, 28)); 294 | let layer = network.add_identity_layer(&tensor); 295 | 296 | assert_eq!(network.get_layer(0).get_input(0).get_name(), "new_input"); 297 | layer.set_input(0, &output_tensor); 298 | network.remove_tensor(&tensor); 299 | assert_eq!(network.get_nb_inputs(), 0); 300 | assert_eq!(network.get_layer(0).get_input(0).get_name(), "matmul2"); 301 | } 302 | 303 | #[test] 304 | fn mark_output() { 305 | let logger = match LOGGER.lock() { 306 | Ok(guard) => guard, 307 | Err(poisoned) => poisoned.into_inner(), 308 | }; 309 | let network = create_network_from_uff(&logger); 310 | let new_output_tensor = network.get_layer(21).get_output(0); 311 | 312 | assert_eq!(network.get_nb_outputs(), 1); 313 | network.mark_output(&new_output_tensor); 314 | assert_eq!(network.get_nb_outputs(), 2); 315 | } 316 | 317 | #[test] 318 | fn unmark_output() { 319 | let logger = match LOGGER.lock() { 320 | Ok(guard) => guard, 321 | Err(poisoned) => poisoned.into_inner(), 322 | }; 323 | let network = create_network_from_uff(&logger); 324 | let new_output_tensor = network.get_layer(21).get_output(0); 325 | 326 | assert_eq!(network.get_nb_outputs(), 1); 327 | network.mark_output(&new_output_tensor); 328 | assert_eq!(network.get_nb_outputs(), 2); 329 | network.unmark_output(&new_output_tensor); 330 | assert_eq!(network.get_nb_outputs(), 1); 331 | } 332 | 333 | #[test] 334 | fn add_identity_layer() { 335 | let logger = match LOGGER.lock() { 336 | Ok(guard) => guard, 337 | Err(poisoned) => poisoned.into_inner(), 338 | }; 339 | let network = create_network(&logger); 340 | let tensor = network.add_input("new_input", DataType::Float, DimsCHW::new(1, 28, 28)); 341 | network.add_identity_layer(&tensor); 342 | assert_eq!(network.get_nb_layers(), 1); 343 | } 344 | 345 | #[test] 346 | fn add_element_wise_layer() { 347 | let logger = match LOGGER.lock() { 348 | Ok(guard) => guard, 349 | Err(poisoned) => poisoned.into_inner(), 350 | }; 351 | let network = create_network(&logger); 352 | let input1 = network.add_input("new_input1", DataType::Float, DimsCHW::new(1, 28, 28)); 353 | let input2 = network.add_input("new_input2", DataType::Float, DimsCHW::new(1, 28, 28)); 354 | network.add_element_wise_layer(&input1, &input2, ElementWiseOperation::Sum); 355 | 356 | assert_eq!(network.get_nb_layers(), 1); 357 | assert_eq!(network.get_layer(0).get_type(), LayerType::ElementWise); 358 | } 359 | 360 | #[test] 361 | fn add_gather_layer() { 362 | let logger = match LOGGER.lock() { 363 | Ok(guard) => guard, 364 | Err(poisoned) => poisoned.into_inner(), 365 | }; 366 | let network = create_network(&logger); 367 | let input1 = network.add_input("new_input1", DataType::Float, DimsCHW::new(1, 28, 28)); 368 | let input2 = network.add_input("new_input2", DataType::Float, DimsCHW::new(1, 28, 28)); 369 | network.add_gather_layer(&input1, &input2, 1); 370 | 371 | assert_eq!(network.get_nb_layers(), 1); 372 | assert_eq!(network.get_layer(0).get_type(), LayerType::Gather); 373 | } 374 | 375 | #[test] 376 | fn add_activation() { 377 | let logger = match LOGGER.lock() { 378 | Ok(guard) => guard, 379 | Err(poisoned) => poisoned.into_inner(), 380 | }; 381 | let network = create_network(&logger); 382 | let input1 = network.add_input("new_input1", DataType::Float, DimsCHW::new(1, 28, 28)); 383 | network.add_activation(&input1, ActivationType::Relu); 384 | 385 | assert_eq!(network.get_layer(0).get_type(), LayerType::Activation); 386 | } 387 | 388 | #[test] 389 | fn add_pooling() { 390 | let logger = match LOGGER.lock() { 391 | Ok(guard) => guard, 392 | Err(poisoned) => poisoned.into_inner(), 393 | }; 394 | let network = create_network(&logger); 395 | let input1 = network.add_input("new_input1", DataType::Float, DimsCHW::new(1, 28, 28)); 396 | 397 | network.add_pooling(&input1, PoolingType::Max, DimsHW::new(10, 10)); 398 | assert_eq!(network.get_layer(0).get_type(), LayerType::Pooling); 399 | } 400 | } 401 | -------------------------------------------------------------------------------- /tensorrt/src/onnx.rs: -------------------------------------------------------------------------------- 1 | use crate::network::Network; 2 | use crate::runtime::Logger; 3 | use std::error; 4 | use std::ffi::CString; 5 | use std::fmt::Formatter; 6 | use std::io; 7 | use std::marker::PhantomData; 8 | use std::path::{Path, PathBuf}; 9 | use tensorrt_sys::{ 10 | onnxparser_create_parser, onnxparser_destroy_parser, onnxparser_parse_from_file, 11 | }; 12 | 13 | pub struct OnnxFile(PathBuf); 14 | 15 | impl OnnxFile { 16 | pub fn new(file_name: &Path) -> Result { 17 | if !file_name.exists() { 18 | return Err(io::Error::new( 19 | io::ErrorKind::NotFound, 20 | "ONNX file does not exist", 21 | )); 22 | } 23 | 24 | if file_name.extension().unwrap() != "onnx" { 25 | return Err(io::Error::new( 26 | io::ErrorKind::InvalidInput, 27 | "Invalid ONNX file. ONNX files should have a .onnx ending", 28 | )); 29 | } 30 | 31 | Ok(OnnxFile(file_name.to_path_buf())) 32 | } 33 | 34 | pub fn path(&self) -> CString { 35 | CString::new(self.0.to_str().unwrap()).unwrap() 36 | } 37 | } 38 | 39 | pub struct OnnxParser<'a, 'b> { 40 | internal_onnxparser: *mut tensorrt_sys::OnnxParser_t, 41 | pub(crate) network: PhantomData<&'a Network>, 42 | pub(crate) logger: PhantomData<&'b Logger>, 43 | } 44 | 45 | impl<'a, 'b> OnnxParser<'a, 'b> { 46 | // const Network_t *network, Logger_t *logger 47 | pub fn new(network: &'a Network, logger: &'b Logger) -> Self { 48 | let internal_onnxparser = 49 | unsafe { onnxparser_create_parser(network.internal_network, logger.internal_logger) }; 50 | Self { 51 | internal_onnxparser, 52 | network: PhantomData, 53 | logger: PhantomData, 54 | } 55 | } 56 | 57 | pub fn parse_from_file( 58 | &self, 59 | onnx_file: &OnnxFile, 60 | verbosity: i32, 61 | ) -> Result<(), OnnxParseError> { 62 | let res = unsafe { 63 | onnxparser_parse_from_file( 64 | self.internal_onnxparser, 65 | onnx_file.path().as_ptr(), 66 | verbosity, 67 | ) 68 | }; 69 | 70 | if res { 71 | Ok(()) 72 | } else { 73 | Err(OnnxParseError::new("Error parsing ONNX file")) 74 | } 75 | } 76 | } 77 | 78 | impl<'a, 'b> Drop for OnnxParser<'a, 'b> { 79 | fn drop(&mut self) { 80 | unsafe { onnxparser_destroy_parser(self.internal_onnxparser) }; 81 | } 82 | } 83 | 84 | #[derive(Debug, Clone)] 85 | pub struct OnnxParseError { 86 | message: String, 87 | } 88 | 89 | impl OnnxParseError { 90 | pub fn new(message: &str) -> Self { 91 | Self { 92 | message: message.to_string(), 93 | } 94 | } 95 | } 96 | 97 | impl std::fmt::Display for OnnxParseError { 98 | fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { 99 | write!(f, "{}", self.message) 100 | } 101 | } 102 | 103 | impl error::Error for OnnxParseError {} 104 | -------------------------------------------------------------------------------- /tensorrt/src/profiler.rs: -------------------------------------------------------------------------------- 1 | use crate::binding_func; 2 | use std::ffi::CStr; 3 | use std::os::raw::c_char; 4 | use tensorrt_sys::{create_profiler, destroy_profiler, CppProfiler, Profiler_t}; 5 | 6 | pub trait IProfiler { 7 | fn report_layer_time(&self, layer_name: *const c_char, ms: f32); 8 | } 9 | 10 | pub struct Profiler { 11 | pub(crate) internal_profiler: *mut CppProfiler, 12 | _supplied_profiler: P, 13 | } 14 | 15 | impl Profiler

{ 16 | pub fn new(mut rust_profiler: P) -> Self { 17 | let profiler_ptr = 18 | Box::into_raw(Box::new(ProfilerBinding::new(&mut rust_profiler))) as *mut Profiler_t; 19 | let internal_profiler = unsafe { create_profiler(profiler_ptr) }; 20 | 21 | Profiler { 22 | internal_profiler, 23 | _supplied_profiler: rust_profiler, 24 | } 25 | } 26 | } 27 | 28 | impl Drop for Profiler

{ 29 | fn drop(&mut self) { 30 | unsafe { 31 | destroy_profiler(self.internal_profiler); 32 | } 33 | } 34 | } 35 | 36 | pub struct DefaultProfiler {} 37 | 38 | impl DefaultProfiler { 39 | pub fn new() -> Self { 40 | DefaultProfiler {} 41 | } 42 | } 43 | 44 | impl IProfiler for DefaultProfiler { 45 | fn report_layer_time(&self, layer_name: *const c_char, ms: f32) { 46 | println!( 47 | "{} took {} ms", 48 | unsafe { CStr::from_ptr(layer_name) }.to_str().unwrap(), 49 | ms 50 | ); 51 | } 52 | } 53 | 54 | #[repr(C)] 55 | struct ProfilerBinding 56 | where 57 | T: IProfiler, 58 | { 59 | pub report_layer_time: unsafe extern "C" fn(*mut T, *const c_char, f32), 60 | destroy: unsafe extern "C" fn(*mut Profiler_t, *mut T), 61 | pub context: *mut T, 62 | } 63 | 64 | impl ProfilerBinding 65 | where 66 | T: IProfiler, 67 | { 68 | pub fn new(profiler: &mut T) -> Self { 69 | binding_func!(report_layer_time(layer_name: *const c_char, ms: f32)); 70 | 71 | //This is a little un-orthodox but having this extern function allows us to 72 | //cleanup the memory that we lose when passing the ProfilerBinding as a raw pointer to 73 | //the C++ bindings. The pointer that gets sent to C++ becomes owned by a C++ object and when 74 | //that object is destroyed this will get called so we can destroy the memory allocated for 75 | //the pointer properly. 76 | unsafe extern "C" fn destroy(binding: *mut Profiler_t, _: *mut T) { 77 | Box::from_raw(binding as *mut ProfilerBinding); 78 | } 79 | 80 | let context: *mut T = &mut *profiler; 81 | ProfilerBinding { 82 | report_layer_time, 83 | destroy, 84 | context, 85 | } 86 | } 87 | } 88 | -------------------------------------------------------------------------------- /tensorrt/src/runtime.rs: -------------------------------------------------------------------------------- 1 | use std::marker::PhantomData; 2 | 3 | use crate::engine::Engine; 4 | use bitflags::_core::ffi::c_void; 5 | use tensorrt_sys::{ 6 | create_infer_runtime, create_logger, delete_logger, deserialize_cuda_engine, 7 | destroy_infer_runtime, nvinfer1_IRuntime, runtime_get_dla_core, runtime_get_nb_dla_cores, 8 | runtime_set_dla_core, set_logger_severity, 9 | }; 10 | 11 | #[repr(C)] 12 | pub enum LoggerSeverity { 13 | InternalError, 14 | Error, 15 | Warning, 16 | Info, 17 | Verbose, 18 | } 19 | 20 | pub struct Logger { 21 | pub(crate) internal_logger: *mut tensorrt_sys::Logger_t, 22 | } 23 | 24 | impl Logger { 25 | pub fn new() -> Logger { 26 | let logger = unsafe { create_logger(LoggerSeverity::Warning as i32) }; 27 | Logger { 28 | internal_logger: logger, 29 | } 30 | } 31 | 32 | pub fn severity(self, severity: LoggerSeverity) -> Logger { 33 | unsafe { 34 | set_logger_severity(self.internal_logger, severity as i32); 35 | }; 36 | self 37 | } 38 | } 39 | 40 | unsafe impl Send for Logger {} 41 | unsafe impl Sync for Logger {} 42 | 43 | impl Drop for Logger { 44 | fn drop(&mut self) { 45 | unsafe { delete_logger(self.internal_logger) }; 46 | } 47 | } 48 | 49 | #[derive(Clone)] 50 | pub struct Runtime<'a> { 51 | pub(crate) internal_runtime: *mut nvinfer1_IRuntime, 52 | pub(crate) logger: PhantomData<&'a Logger>, 53 | } 54 | 55 | impl<'a> Runtime<'a> { 56 | pub fn new(logger: &'a Logger) -> Self { 57 | let internal_runtime = unsafe { create_infer_runtime(logger.internal_logger) }; 58 | let logger = PhantomData; 59 | Self { 60 | internal_runtime, 61 | logger, 62 | } 63 | } 64 | 65 | pub fn deserialize_cuda_engine(&self, buffer: Vec) -> Engine { 66 | let internal_engine = unsafe { 67 | deserialize_cuda_engine( 68 | self.internal_runtime, 69 | buffer.as_ptr() as *const c_void, 70 | buffer.len() as u64, 71 | ) 72 | }; 73 | 74 | Engine { internal_engine } 75 | } 76 | 77 | pub fn get_nb_dla_cores(&self) -> i32 { 78 | unsafe { runtime_get_nb_dla_cores(self.internal_runtime) } 79 | } 80 | 81 | pub fn get_dla_core(&self) -> i32 { 82 | unsafe { runtime_get_dla_core(self.internal_runtime) } 83 | } 84 | 85 | pub fn set_dla_core(&self, dla_core: i32) { 86 | unsafe { runtime_set_dla_core(self.internal_runtime, dla_core) } 87 | } 88 | } 89 | 90 | impl<'a> Drop for Runtime<'a> { 91 | fn drop(&mut self) { 92 | unsafe { destroy_infer_runtime(self.internal_runtime) }; 93 | } 94 | } 95 | 96 | #[cfg(test)] 97 | mod tests { 98 | use super::*; 99 | use lazy_static::lazy_static; 100 | use std::sync::Mutex; 101 | 102 | lazy_static! { 103 | static ref LOGGER: Mutex = Mutex::new(Logger::new()); 104 | } 105 | 106 | #[test] 107 | fn get_nb_dla_cores() { 108 | let logger = match LOGGER.lock() { 109 | Ok(guard) => guard, 110 | Err(poisoned) => poisoned.into_inner(), 111 | }; 112 | 113 | let runtime = Runtime::new(&logger); 114 | 115 | assert_eq!(runtime.get_nb_dla_cores(), 0); 116 | } 117 | 118 | #[test] 119 | fn get_dla_core() { 120 | let logger = match LOGGER.lock() { 121 | Ok(guard) => guard, 122 | Err(poisoned) => poisoned.into_inner(), 123 | }; 124 | 125 | let runtime = Runtime::new(&logger); 126 | 127 | assert_eq!(runtime.get_dla_core(), 0); 128 | } 129 | 130 | #[cfg(target_arch = "aarch64")] 131 | #[test] 132 | fn set_dla_core() { 133 | let logger = match LOGGER.lock() { 134 | Ok(guard) => guard, 135 | Err(poisoned) => poisoned.into_inner(), 136 | }; 137 | 138 | let runtime = Runtime::new(&logger); 139 | runtime.set_dla_core(1); 140 | assert_eq!(runtime.get_dla_core(), 1); 141 | } 142 | } 143 | -------------------------------------------------------------------------------- /tensorrt/src/uff.rs: -------------------------------------------------------------------------------- 1 | use crate::dims::Dim; 2 | use crate::network::Network; 3 | use std::error; 4 | use std::ffi::CString; 5 | use std::fmt::Formatter; 6 | use std::io; 7 | use std::path::{Path, PathBuf}; 8 | use tensorrt_sys::{ 9 | uffparser_create_uff_parser, uffparser_destroy_uff_parser, uffparser_parse, 10 | uffparser_register_input, uffparser_register_output, 11 | }; 12 | 13 | #[repr(C)] 14 | pub enum UffInputOrder { 15 | Nchw, 16 | Nhwc, 17 | Nc, 18 | } 19 | 20 | pub struct UffFile(PathBuf); 21 | 22 | impl UffFile { 23 | pub fn new(file_name: &Path) -> Result { 24 | if !file_name.exists() { 25 | return Err(io::Error::new( 26 | io::ErrorKind::NotFound, 27 | "UFF file does not exist", 28 | )); 29 | } 30 | 31 | if file_name.extension().unwrap() != "uff" { 32 | return Err(io::Error::new( 33 | io::ErrorKind::InvalidInput, 34 | "Invalid UFF file. UFF files should have a .uff ending", 35 | )); 36 | } 37 | 38 | Ok(UffFile(file_name.to_path_buf())) 39 | } 40 | 41 | pub fn path(&self) -> CString { 42 | CString::new(self.0.to_str().unwrap()).unwrap() 43 | } 44 | } 45 | 46 | pub struct UffParser { 47 | internal_uffparser: *mut tensorrt_sys::UffParser_t, 48 | } 49 | 50 | impl UffParser { 51 | pub fn new() -> UffParser { 52 | let parser = unsafe { uffparser_create_uff_parser() }; 53 | UffParser { 54 | internal_uffparser: parser, 55 | } 56 | } 57 | 58 | pub fn register_input( 59 | &self, 60 | input_name: &str, 61 | dims: impl Dim, 62 | input_order: UffInputOrder, 63 | ) -> Result<(), UFFRegistrationError> { 64 | let res = unsafe { 65 | uffparser_register_input( 66 | self.internal_uffparser, 67 | CString::new(input_name).unwrap().as_ptr(), 68 | dims.get_internal_dims(), 69 | input_order as i32, 70 | ) 71 | }; 72 | 73 | if res { 74 | Ok(()) 75 | } else { 76 | Err(UFFRegistrationError::new("Input Registration Failed")) 77 | } 78 | } 79 | 80 | pub fn register_output(&self, output_name: &str) -> Result<(), UFFRegistrationError> { 81 | let res = unsafe { 82 | uffparser_register_output( 83 | self.internal_uffparser, 84 | CString::new(output_name).unwrap().as_ptr(), 85 | ) 86 | }; 87 | 88 | if res { 89 | Ok(()) 90 | } else { 91 | Err(UFFRegistrationError::new("Output Registration Failed")) 92 | } 93 | } 94 | 95 | pub fn parse(&self, uff_file: &UffFile, network: &Network) -> Result<(), UFFParseError> { 96 | let res = unsafe { 97 | uffparser_parse( 98 | self.internal_uffparser, 99 | uff_file.path().as_ptr(), 100 | network.internal_network, 101 | ) 102 | }; 103 | 104 | if res { 105 | Ok(()) 106 | } else { 107 | Err(UFFParseError::new("Error parsing UFF file")) 108 | } 109 | } 110 | } 111 | 112 | impl Drop for UffParser { 113 | fn drop(&mut self) { 114 | unsafe { uffparser_destroy_uff_parser(self.internal_uffparser) }; 115 | } 116 | } 117 | 118 | #[derive(Debug, Clone)] 119 | pub struct UFFRegistrationError { 120 | message: String, 121 | } 122 | 123 | impl UFFRegistrationError { 124 | pub fn new(message: &str) -> Self { 125 | UFFRegistrationError { 126 | message: message.to_string(), 127 | } 128 | } 129 | } 130 | 131 | impl std::fmt::Display for UFFRegistrationError { 132 | fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { 133 | write!(f, "{}", self.message) 134 | } 135 | } 136 | 137 | impl error::Error for UFFRegistrationError {} 138 | 139 | #[derive(Debug, Clone)] 140 | pub struct UFFParseError { 141 | message: String, 142 | } 143 | 144 | impl UFFParseError { 145 | pub fn new(message: &str) -> Self { 146 | UFFParseError { 147 | message: message.to_string(), 148 | } 149 | } 150 | } 151 | 152 | impl std::fmt::Display for UFFParseError { 153 | fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { 154 | write!(f, "{}", self.message) 155 | } 156 | } 157 | 158 | impl error::Error for UFFParseError {} 159 | -------------------------------------------------------------------------------- /tensorrt/src/utils.rs: -------------------------------------------------------------------------------- 1 | mod macros { 2 | #[macro_export] 3 | macro_rules! check_cuda { 4 | ($expression:expr) => { 5 | unsafe { 6 | let res = $expression; 7 | if res != cuda_runtime_sys::cudaError_t::cudaSuccess { 8 | let error_message = 9 | std::ffi::CStr::from_ptr(cuda_runtime_sys::cudaGetErrorString(res)); 10 | return Err(anyhow::anyhow!("{}", error_message.to_str().unwrap())); 11 | } 12 | } 13 | }; 14 | } 15 | 16 | #[macro_export] 17 | macro_rules! binding_func { 18 | ($function_name:ident<$Trait:path>( $($arg_name:ident : $arg_ty:ty),*) ) => { 19 | 20 | unsafe extern "C" fn $function_name( 21 | context: *mut T, 22 | $($arg_name: $arg_ty),* 23 | ) { 24 | let profiler_ref: &mut T = &mut *context; 25 | profiler_ref.$function_name($($arg_name),*); 26 | } 27 | }; 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /tensorrt_rs_derive/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "tensorrt_rs_derive" 3 | version = "0.1.0" 4 | authors = ["Mason Stallmo "] 5 | edition = "2018" 6 | 7 | [lib] 8 | proc-macro = true 9 | 10 | [dependencies] 11 | syn = "1.0.44" 12 | quote = "1.0.7" -------------------------------------------------------------------------------- /tensorrt_rs_derive/src/lib.rs: -------------------------------------------------------------------------------- 1 | use proc_macro::TokenStream; 2 | use quote::quote; 3 | use syn; 4 | 5 | #[proc_macro_derive(Layer)] 6 | pub fn layer_derive(input: TokenStream) -> TokenStream { 7 | let ast = syn::parse(input).unwrap(); 8 | 9 | impl_layer_derive(&ast) 10 | } 11 | 12 | fn impl_layer_derive(ast: &syn::DeriveInput) -> TokenStream { 13 | let name = &ast.ident; 14 | let gen = quote! { 15 | impl private::LayerPrivate for #name { 16 | fn get_internal_layer(&self) -> *mut tensorrt_sys::nvinfer1_ILayer { 17 | self.internal_layer as *mut tensorrt_sys::nvinfer1_ILayer 18 | } 19 | } 20 | 21 | impl Layer for #name {} 22 | }; 23 | 24 | gen.into() 25 | } 26 | 27 | #[proc_macro_derive(Dim)] 28 | pub fn dim_derive(input: TokenStream) -> TokenStream { 29 | let ast = syn::parse(input).unwrap(); 30 | 31 | impl_dim_derive(&ast) 32 | } 33 | 34 | fn impl_dim_derive(ast: &syn::DeriveInput) -> TokenStream { 35 | let name = &ast.ident; 36 | let gen = quote! { 37 | impl private::DimsPrivate for #name { 38 | fn get_internal_dims(&self) -> *mut tensorrt_sys::Dims_t { 39 | self.internal_dims 40 | } 41 | } 42 | 43 | impl Dim for #name {} 44 | }; 45 | 46 | gen.into() 47 | } 48 | --------------------------------------------------------------------------------