├── .gitignore
├── Bug.md
├── GTC-China-2017-NVIDIA-INT8.pdf
├── Hackathon2022
├── LayerNormPlugin
│ ├── LayerNormPlugin.cu
│ ├── LayerNormPlugin.h
│ ├── Makefile
│ ├── SkipLayerNormV1Plugin.cu
│ ├── SkipLayerNormV1Plugin.h
│ ├── SkipLayerNormV2Plugin.cu
│ ├── SkipLayerNormV2Plugin.h
│ ├── common.cuh
│ └── testLayerNormPlugin.py
├── build.sh
├── calibrator.py
├── decoder2trt.py
├── encoder2trt.py
└── encoder_small.py
├── README.md
├── Samples
├── sampleGoogleNet.md
├── sampleINT8.md
├── sampleMNIST.md
└── sampleMNISTAPI.md
├── TensorRT8.5.3
├── English.md
├── a-title_Developer-Guide-NVIDIA-Deep-Learning-TensorRT-Documentation
├── assets
│ ├── 1695349016-1542a53eb400f837845f37b2bedb9d05.png
│ ├── 1695349016-15dd6688b76bdc3d5a16526df91cc631.png
│ ├── 1695349016-2c3934e69ddc53dc474139fe65c49c19.png
│ ├── 1695349016-4e01c008d3875b259cc4cd3da884010e.png
│ ├── 1695349016-536836b9f148a211a3109b46588aea3f.png
│ ├── 1695349016-584559c808bb6b459734d88699daabe1.png
│ ├── 1695349016-5b172dabb4f50368376eee4819ddcb87.png
│ ├── 1695349016-63cc642586086b5be42c04375200c8c9.png
│ ├── 1695349016-656ec99160033df259b215cd7e03af2f.png
│ ├── 1695349016-718f4af533bab6c57307cd4131866023.png
│ ├── 1695349016-7324dda2de00b8d4b99431311c1d901d.png
│ ├── 1695349016-7c4a391e39adc9b201561f4384d8575c.png
│ ├── 1695349016-8167eeb1e237bd2c809028a411e1e9cb.png
│ ├── 1695349016-8c33d06b8c5ffd9dc50eb77f1bbe80d0.png
│ ├── 1695349016-90fbabf1bcd97f82bbffa8751a548cdc.png
│ ├── 1695349016-98a76f9452e7b3c5a2979a9a4d8f828f.png
│ ├── 1695349016-9b422126aef86f0a15d7bfcdcdf37ee9.png
│ ├── 1695349016-a782c77d3e0eff2354898ccef63c5de0.png
│ ├── 1695349016-ad186379984e814039de4d58a0e26c53.png
│ ├── 1695349016-ae831a5e3c8c02af4c7ac82636845a70.png
│ ├── 1695349016-cc50888fa52ed8f93e53ca71ce566c63.png
│ ├── 1695349016-d14711f74598da455c69c20ed5a5cbd1.png
│ ├── 1695349016-dffd0a9679aeefdc5176a6aa55feaa7c.png
│ ├── 1695349016-e24efeac58e23de168680d4f48e18f16.png
│ ├── 1695349016-e829de0bc2b85ec285546dcf1456982a.png
│ └── 1695349016-f9c6506c20f52b409ddfc74a8a4317a2.png
├── index.json
└── readme.md
├── TensorRT_2.1.0_User_Guide.md
├── blogs
├── Conformer Encoder GPU 加速策略较全面汇总.md
├── TensorRT Github 开源部分介绍.md
├── TensorRT Plugin使用方式简介-以leaky relu层为例.md
├── TensorRT 可借鉴代码汇总.md
├── TensorRT 转换模型的几种方式比较.md
├── 使用TensorRT实现leaky relu层.md
└── 写于20200829.md
├── cublas&cudnn_int8_demo
├── README.md
└── cublasGemmEx
│ ├── cuBlasGemmEx.pdf
│ └── gemmInt8_rect.cpp
├── img
├── easy_tensorrt_api
│ ├── 0.png
│ ├── 1.png
│ └── 2.png
└── flow_of_leaky_relu.png
├── resource_for_billibilli
├── debug_plugin
│ ├── CMakeLists.txt
│ ├── README.md
│ ├── cmake
│ │ ├── modules
│ │ │ ├── find_library_create_target.cmake
│ │ │ └── set_ifndef.cmake
│ │ └── toolchains
│ │ │ ├── cmake_aarch64-android.toolchain
│ │ │ ├── cmake_aarch64.toolchain
│ │ │ ├── cmake_ppc64le.toolchain
│ │ │ ├── cmake_qnx.toolchain
│ │ │ ├── cmake_x64_win.toolchain
│ │ │ └── cmake_x86_64.toolchain
│ ├── include
│ │ ├── NvCaffeParser.h
│ │ ├── NvInfer.h
│ │ ├── NvInferPlugin.h
│ │ ├── NvInferPluginUtils.h
│ │ ├── NvInferRuntime.h
│ │ ├── NvInferRuntimeCommon.h
│ │ ├── NvInferVersion.h
│ │ ├── NvOnnxConfig.h
│ │ ├── NvOnnxParser.h
│ │ ├── NvUffParser.h
│ │ └── NvUtils.h
│ └── plugin
│ │ ├── CMakeLists.txt
│ │ ├── debug_plugin
│ │ ├── CMakeLists.txt
│ │ ├── debug_dynamic_plugin.cu
│ │ ├── debug_dynamic_plugin.h
│ │ ├── debug_kernel.cu
│ │ ├── debug_kernel.h
│ │ ├── debug_plugin.cu
│ │ └── debug_plugin.h
│ │ ├── half.h
│ │ ├── infer_plugin_api.cc
│ │ ├── infer_plugin_api.h
│ │ ├── logger.cc
│ │ ├── logger.h
│ │ ├── logging.h
│ │ ├── plugin_common.h
│ │ └── serialize.hpp
├── doc
│ └── TensorRT-Developer-Guide注释版本.pdf
└── 总结.pptx
├── shenlan_homework_bert
├── .gitignore
├── CMakeLists.txt
├── README.md
├── builder.py
├── builder.sh
├── calibrator.py
├── calibrator_data.txt
├── model2onnx.py
├── onnx2trt.py
├── trt_helper.py
├── 基础款LayerNormPlugin.zip
└── 基础款LayerNormPlugin
│ └── 基础款LayerNormPlugin
│ ├── LayerNormPlugin.cu
│ ├── LayerNormPlugin.h
│ ├── Makefile
│ └── testLayerNormPlugin.py
└── 视频版资料
└── TensorRT plugin fp16加速.pdf
/.gitignore:
--------------------------------------------------------------------------------
1 | TensorRT_2.1.0_User_Guide.html
--------------------------------------------------------------------------------
/Bug.md:
--------------------------------------------------------------------------------
1 | # Bug
2 | 本页面为TensorRT存在的bug or error。
3 |
4 | - 版本:TensorRT 2.0
5 | - BUG类别:文档错误
6 | - 贡献者:[LitLeo][1]
7 | - BUG描述:
8 | TensorRT中全连接层的Weights的存储方式是col-major的,在文档中却写的是row-major。
9 | 在文档Data Formats章节,原文为:
10 | ``` bash
11 | Fully Connected weights are in contiguous row-major layout
12 | ```
13 | 但是在Samples/samplePlugin的enqueue函数(该函数实现了一个全连接)中,weights的存储方式是col-major的,代码如下:
14 | ``` c++
15 | CHECK(cublasSgemm(mCublas, CUBLAS_OP_T, CUBLAS_OP_N, nbOutputChannels, batchSize, nbInputChannels, &kONE,
16 | reinterpret_cast(mKernelWeights.values), nbInputChannels,
17 | reinterpret_cast(inputs[0]), nbInputChannels, &kZERO,
18 | reinterpret_cast(outputs[0]), nbOutputChannels));
19 | ```
20 |
21 |
22 | [1]: https://github.com/LitLeo
23 |
--------------------------------------------------------------------------------
/GTC-China-2017-NVIDIA-INT8.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LitLeo/TensorRT_Tutorial/1d370d82c78eddccadf4a8490c5c859db35a78c9/GTC-China-2017-NVIDIA-INT8.pdf
--------------------------------------------------------------------------------
/Hackathon2022/LayerNormPlugin/LayerNormPlugin.cu:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (c) 2019-2021, NVIDIA CORPORATION. All rights reserved.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 |
18 | #include
19 | #include
20 | #include
21 |
22 | #include "LayerNormPlugin.h"
23 | #include "common.cuh"
24 |
25 | using namespace nvinfer1;
26 |
27 | PluginFieldCollection LayerNormPluginCreator::fc_{};
28 | std::vector LayerNormPluginCreator::attr_;
29 |
30 | /*__global__ void layerNormKernel(float *pInput, float *pOutput)*/
31 | /*{*/
32 | /*const int tx = threadIdx.x, index = blockIdx.x * 256 + threadIdx.x;*/
33 |
34 | /*__shared__ float temp[128];*/
35 |
36 | /*float value0 = pInput[index];*/
37 | /*float value1 = pInput[index + 128];*/
38 |
39 | /*temp[tx] = value0 + value1;*/
40 | /*__syncthreads();*/
41 |
42 | /*for (int stride = 64; stride >= 1; stride /= 2)*/
43 | /*{*/
44 | /*if (tx < stride)*/
45 | /*{*/
46 | /*temp[tx] += temp[tx + stride];*/
47 | /*}*/
48 | /*__syncthreads();*/
49 | /*}*/
50 | /*float mean = temp[0] / 256;*/
51 | /*__syncthreads();*/
52 |
53 | /*temp[tx] = (value0 - mean) * (value0 - mean) + (value1 - mean) * (value1 - mean);*/
54 | /*__syncthreads();*/
55 |
56 | /*for (int stride = 64; stride >= 1; stride /= 2)*/
57 | /*{*/
58 | /*if (tx < stride)*/
59 | /*{*/
60 | /*temp[tx] += temp[tx + stride];*/
61 | /*}*/
62 | /*__syncthreads();*/
63 | /*}*/
64 | /*float var = temp[0] / 256;*/
65 |
66 | /*pOutput[index] = (value0 - mean) * rsqrtf(var + 6e-6);*/
67 | /*pOutput[index + 128] = (value1 - mean) * rsqrtf(var + 6e-6);*/
68 | /*}*/
69 |
70 |
71 | template
72 | __global__ void layer_norm_kernel_small(
73 | const int ld, const T* input, const T* beta, const T* gamma, T* output)
74 | {
75 |
76 | const T rld = T(1) / T(ld);
77 | const int offset = blockIdx.x * ld;
78 |
79 | cub::Sum pairSum;
80 | // reduce x and x^2
81 | kvp threadData(0, 0);
82 | const int idx = offset + threadIdx.x;
83 | T val = 0;
84 |
85 | if (threadIdx.x < ld)
86 | {
87 |
88 | val = input[idx];
89 |
90 | const T rldval = rld * val;
91 | threadData = pairSum(threadData, kvp(rldval, rldval * val));
92 | }
93 |
94 | layerNormSmall(val, threadData, ld, idx, beta, gamma, output);
95 | }
96 |
97 | template
98 | __global__ void layer_norm_kernel(
99 | const int ld, const T* input, const T* beta, const T* gamma, T* output)
100 | {
101 | const T rld = T(1) / T(ld);
102 | const int offset = blockIdx.x * ld;
103 |
104 | cub::Sum pairSum;
105 | // reduce x and x^2
106 | kvp threadData(0, 0);
107 |
108 | for (int i = threadIdx.x; i < ld; i += TPB)
109 | {
110 | const int idx = offset + i;
111 | T val = T(input[idx]);
112 |
113 | const T rldval = rld * val;
114 | threadData = pairSum(threadData, kvp(rldval, rldval * val));
115 | output[idx] = val;
116 | }
117 |
118 | layerNorm(threadData, ld, offset, beta, gamma, output);
119 | if (blockIdx.x == 0 && threadIdx.x == 0) {
120 | printf("%f %f %f %f\n", __half2float(gamma[0]), __half2float(beta[0]), __half2float(input[0]), __half2float(output[0]));
121 | }
122 | }
123 |
124 | template
125 | int compute_layer_norm_tpl(cudaStream_t stream, const int ld, const int n, const T* input, const T* beta,
126 | const T* gamma, T* output) {
127 |
128 | // this must be true because n is the total size of the tensor
129 | assert(n % ld == 0);
130 | const int gridSize = n / ld;
131 | /*constexpr int VPT = 16 / sizeof(T);*/
132 | if (ld <= 32) {
133 | constexpr int blockSize = 32;
134 | layer_norm_kernel_small
135 | <<>>(ld, input, beta, gamma, output);
136 | } else if (ld <= 128) {
137 | constexpr int blockSize = 128;
138 | layer_norm_kernel_small
139 | <<>>(ld, input, beta, gamma, output);
140 | } else if (ld <= 384) {
141 | constexpr int blockSize = 384;
142 | layer_norm_kernel_small
143 | <<>>(ld, input, beta, gamma, output);
144 | } else {
145 | constexpr int blockSize = 256;
146 | layer_norm_kernel
147 | <<>>(ld, input, beta, gamma, output);
148 | }
149 | (cudaPeekAtLastError());
150 |
151 | return 0;
152 | }
153 |
154 | int compute_layer_norm(cudaStream_t stream, const int ld, const int n, const float* input,
155 | const float* gamma, const float* beta, float* output) {
156 | return compute_layer_norm_tpl(stream, ld, n, input, beta, gamma, output);
157 | }
158 |
159 | int compute_layer_norm(cudaStream_t stream, const int ld, const int n, const half* input,
160 | const half* gamma, const half* beta, half* output) {
161 | return compute_layer_norm_tpl(stream, ld, n, input, beta, gamma, output);
162 | }
163 |
164 | inline int64_t volume(const nvinfer1::Dims& d) {
165 | return std::accumulate(d.d, d.d + d.nbDims, 1, std::multiplies());
166 | }
167 |
168 | int32_t LayerNormPlugin::enqueue(const PluginTensorDesc* inputDesc, const PluginTensorDesc* outputDesc, const void* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept
169 | {
170 | /*const int nBlock = inputDesc[0].dims.d[0] * inputDesc[0].dims.d[1];*/
171 |
172 | /*layerNormKernel <<>>((float *)inputs[0], (float *)outputs[0]);*/
173 |
174 | const int input_volume = volume(inputDesc[0].dims);
175 | const int dim = inputDesc[0].dims.d[inputDesc[0].dims.nbDims - 1];
176 | const int S = input_volume / dim;
177 |
178 | int status = -1;
179 |
180 | /*const size_t word_size = getElementSize(DataType::kFLOAT);*/
181 |
182 | // Our plugin outputs only one tensor
183 | const float* input = static_cast(inputs[0]);
184 | const float* gamma_ptr = static_cast(inputs[1]);
185 | const float* beta_ptr = static_cast(inputs[2]);
186 | float* output = static_cast(outputs[0]);
187 |
188 | status = compute_layer_norm(stream, dim, input_volume, input, gamma_ptr, beta_ptr, output);
189 |
190 | return 0;
191 | }
192 |
193 |
194 | REGISTER_TENSORRT_PLUGIN(LayerNormPluginCreator);
195 |
196 |
--------------------------------------------------------------------------------
/Hackathon2022/LayerNormPlugin/LayerNormPlugin.h:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (c) 2019-2021, NVIDIA CORPORATION. All rights reserved.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | #include
18 | #include
19 | #include
20 |
21 | // +------- Debug wrapper --------------------------------------------------------------------------
22 | #if DEBUG
23 | #define WHERE_AM_I() do {printf("[%s]: this=->%p\n",__func__,this);} while(0);
24 | #else
25 | #define WHERE_AM_I()
26 | #endif // DEBUG
27 |
28 | using namespace std;
29 |
30 | // +------- Plguin ---------------------------------------------------------------------------------
31 | namespace
32 | {
33 | static const char* PLUGIN_NAME{"LayerNorm"};
34 | static const char* PLUGIN_VERSION{"1"};
35 | } // namespace
36 |
37 | namespace nvinfer1
38 | {
39 |
40 | // +------- Plugin body ----------------------------------------------------------------------------
41 | class LayerNormPlugin: public IPluginV2DynamicExt
42 | {
43 | private:
44 | std::string name_;
45 | std::string namespace_;
46 |
47 | public:
48 | LayerNormPlugin(const std::string& name) : name_(name)
49 | {
50 | WHERE_AM_I();
51 | }
52 |
53 | LayerNormPlugin(const std::string& name, const void* data, size_t length) : name_(name)
54 | {
55 | WHERE_AM_I();
56 | }
57 |
58 | LayerNormPlugin() = delete;
59 |
60 | ~LayerNormPlugin()
61 | {
62 | WHERE_AM_I();
63 | }
64 |
65 | size_t getSerializationSize() const noexcept override
66 | {
67 | WHERE_AM_I();
68 | return 0;
69 | }
70 |
71 | void serialize(void *buffer) const noexcept override
72 | {
73 | WHERE_AM_I();
74 | }
75 |
76 | IPluginV2DynamicExt* clone() const noexcept override
77 | {
78 | WHERE_AM_I();
79 | return new LayerNormPlugin(name_);
80 | }
81 |
82 | int getNbOutputs() const noexcept override
83 | {
84 | WHERE_AM_I();
85 | return 1;
86 | }
87 |
88 | DimsExprs getOutputDimensions(int32_t outputIndex, const DimsExprs* inputs, int32_t nbInputs, IExprBuilder& exprBuilder) noexcept override
89 | {
90 | WHERE_AM_I();
91 | return inputs[0];
92 | }
93 |
94 | bool supportsFormatCombination(int32_t pos, const PluginTensorDesc* inOut, int32_t nbInputs, int32_t nbOutputs) noexcept override
95 | {
96 | WHERE_AM_I();
97 | if(inOut[pos].format != TensorFormat::kLINEAR)
98 | {
99 | return false;
100 | }
101 |
102 | cout << "inOut[pos].type " << (int)inOut[pos].type << endl;
103 | bool res = false;
104 | switch(pos)
105 | {
106 | case 0:
107 | res = (inOut[pos].type == DataType::kFLOAT); break;
108 | case 1:
109 | res = inOut[pos].type == inOut[0].type; break;
110 | case 2:
111 | res = inOut[pos].type == inOut[0].type; break;
112 | case 3:
113 | res = inOut[pos].type == inOut[0].type; break;
114 | default:// should NOT be here
115 | res = false;
116 | }
117 | return res;
118 | }
119 |
120 | DataType getOutputDataType(int outputIndex, const DataType* inputTypes, int nbInputs) const noexcept override
121 | {
122 | WHERE_AM_I();
123 | return DataType::kFLOAT;
124 | }
125 |
126 | void configurePlugin(const DynamicPluginTensorDesc* in, int32_t nbInputs,const DynamicPluginTensorDesc* out, int32_t nbOutputs) noexcept override
127 | {
128 | WHERE_AM_I();
129 | }
130 |
131 | size_t getWorkspaceSize(const PluginTensorDesc* inputs, int32_t nbInputs, const PluginTensorDesc* outputs,int32_t nbOutputs) const noexcept override
132 | {
133 | WHERE_AM_I();
134 | return 0;
135 | }
136 |
137 | void setPluginNamespace(const char* szNamespace) noexcept override
138 | {
139 | WHERE_AM_I();
140 | namespace_ = szNamespace;
141 | }
142 | const char* getPluginNamespace() const noexcept override
143 | {
144 | WHERE_AM_I();
145 | return namespace_.c_str();
146 | }
147 | const char* getPluginType() const noexcept override
148 | {
149 | WHERE_AM_I();
150 | return PLUGIN_NAME;
151 | }
152 | const char* getPluginVersion() const noexcept override
153 | {
154 | WHERE_AM_I();
155 | return PLUGIN_VERSION;
156 | }
157 | int initialize() noexcept override
158 | {
159 | WHERE_AM_I();
160 | return 0;
161 | }
162 | void terminate() noexcept override
163 | {
164 | WHERE_AM_I();
165 | return;
166 | }
167 |
168 | void destroy() noexcept override
169 | {
170 | WHERE_AM_I();
171 | }
172 |
173 | int32_t enqueue(const PluginTensorDesc* inputDesc, const PluginTensorDesc* outputDesc, const void* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept override;
174 | }; // class LayerNormPlugin
175 |
176 | class LayerNormPluginCreator : public IPluginCreator
177 | {
178 | private:
179 | static PluginFieldCollection fc_;
180 | static std::vector attr_;
181 | std::string namespace_;
182 |
183 | public:
184 | LayerNormPluginCreator()
185 | {
186 | fc_.nbFields = attr_.size();
187 | fc_.fields = attr_.data();
188 | }
189 |
190 | ~LayerNormPluginCreator() {}
191 |
192 | IPluginV2* createPlugin(const char* name, const PluginFieldCollection* fc) noexcept override
193 | {
194 | WHERE_AM_I();
195 | return new LayerNormPlugin(name);
196 | }
197 |
198 | IPluginV2* deserializePlugin(const char* name, const void* serialData, size_t serialLength) noexcept override
199 | {
200 | return new LayerNormPlugin(name, serialData, serialLength);
201 | }
202 |
203 | void setPluginNamespace(const char* szNamespace) noexcept override
204 | {
205 | namespace_ = szNamespace;
206 | }
207 |
208 | const char* getPluginNamespace() const noexcept override
209 | {
210 | return namespace_.c_str();
211 | }
212 |
213 | const char* getPluginName() const noexcept override
214 | {
215 | return PLUGIN_NAME;
216 | }
217 |
218 | const char* getPluginVersion() const noexcept override
219 | {
220 | return PLUGIN_VERSION;
221 | }
222 |
223 | const PluginFieldCollection* getFieldNames() noexcept override
224 | {
225 | return &fc_;
226 | }
227 | }; // class LayerNormPluginCreator
228 |
229 | } // namespace nvinfer1
230 |
231 |
--------------------------------------------------------------------------------
/Hackathon2022/LayerNormPlugin/Makefile:
--------------------------------------------------------------------------------
1 | CUDA_PATH = /usr/local/cuda
2 | TRT_PATH = /usr/lib/x86_64-linux-gnu
3 | NVCC = $(CUDA_PATH)/bin/nvcc
4 | #SM = 61
5 | # 61 for GTX1070, 75 for T4,80 for A30
6 | GENCODE = -gencode=arch=compute_60,code=sm_60 -gencode=arch=compute_61,code=sm_61 -gencode=arch=compute_70,code=sm_70 -gencode=arch=compute_75,code=sm_75 -gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_86,code=sm_86
7 | CUFLAG = -w -std=c++14 -O3 -UDEBUG -Xcompiler -fPIC $(GENCODE)
8 | CPPFLAG = -w -std=c++14 -O3 -use_fast_math
9 | SOFLAG = $(CUFLAG) -shared
10 | INCLUDE = -I. -I$(CUDA_PATH)/include
11 | LDFLAG = -L$(CUDA_PATH)/lib64 -lcudart -lcublas -lcublasLt -L$(TRT_PATH)/lib -lnvinfer
12 |
13 | SRC_CU = $(shell find ./ -name '*.cu')
14 |
15 | all: LayerNorm.so
16 |
17 | %.o: %.cu
18 | $(NVCC) $(CUFLAG) $(INCLUDE) -o $@ -c $<
19 |
20 | LayerNorm.so: $(SRC_CU:.cu=.o)
21 | $(NVCC) $(SOFLAG) $(LDFLAG) -o $@ $^
22 |
23 | .PHONY: clean
24 | clean:
25 | rm -rf ./*.so ./*.o ./*.d ./*.trt
26 |
27 | .PHONY: test
28 | test:
29 | clear
30 | python testLayerNormPlugin.py
31 |
32 |
--------------------------------------------------------------------------------
/Hackathon2022/LayerNormPlugin/SkipLayerNormV1Plugin.cu:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (c) 2019-2021, NVIDIA CORPORATION. All rights reserved.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 |
18 | #include
19 | #include
20 | #include
21 |
22 | #include "SkipLayerNormV1Plugin.h"
23 | #include "common.cuh"
24 |
25 | using namespace nvinfer1;
26 |
27 | PluginFieldCollection SkipLayerNormV1PluginCreator::fc_{};
28 | std::vector SkipLayerNormV1PluginCreator::attr_;
29 |
30 | template
31 | __global__ void skip_layer_norm_v1_kernel_small(
32 | const int ld, const T* input, const T* skip, const T* beta, const T* gamma, T* output)
33 | {
34 |
35 | const T rld = T(1) / T(ld);
36 | const int offset = blockIdx.x * ld;
37 |
38 | cub::Sum pairSum;
39 | // reduce x and x^2
40 | kvp threadData(0, 0);
41 | const int idx = offset + threadIdx.x;
42 | T val = 0;
43 |
44 | if (threadIdx.x < ld)
45 | {
46 |
47 | val = input[idx] + skip[idx];
48 |
49 | const T rldval = rld * val;
50 | threadData = pairSum(threadData, kvp(rldval, rldval * val));
51 | }
52 |
53 | layerNormSmall(val, threadData, ld, idx, beta, gamma, output);
54 | }
55 |
56 | template
57 | __global__ void skip_layer_norm_v1_kernel(
58 | const int ld, const T* input, const T* skip, const T* beta, const T* gamma, T* output)
59 | {
60 | const T rld = T(1) / T(ld);
61 | const int offset = blockIdx.x * ld;
62 |
63 | cub::Sum pairSum;
64 | // reduce x and x^2
65 | kvp threadData(0, 0);
66 |
67 | for (int i = threadIdx.x; i < ld; i += TPB)
68 | {
69 | const int idx = offset + i;
70 | T val = T(input[idx] + skip[idx]);
71 |
72 | const T rldval = rld * val;
73 | threadData = pairSum(threadData, kvp(rldval, rldval * val));
74 | output[idx] = val;
75 | }
76 |
77 | layerNorm(threadData, ld, offset, beta, gamma, output);
78 | if (blockIdx.x == 0 && threadIdx.x == 0) {
79 | printf("%f %f %f %f\n", __half2float(gamma[0]), __half2float(beta[0]), __half2float(input[0]), __half2float(output[0]));
80 | }
81 | }
82 |
83 | template
84 | int compute_skip_layer_norm_v1_tpl(cudaStream_t stream, const int ld, const int n,
85 | const T* input, const T* skip, const T* beta,
86 | const T* gamma, T* output) {
87 |
88 | // this must be true because n is the total size of the tensor
89 | assert(n % ld == 0);
90 | const int gridSize = n / ld;
91 | /*constexpr int VPT = 16 / sizeof(T);*/
92 | if (ld <= 32) {
93 | constexpr int blockSize = 32;
94 | skip_layer_norm_v1_kernel_small
95 | <<>>(ld, input, skip, beta, gamma, output);
96 | } else if (ld <= 128) {
97 | constexpr int blockSize = 128;
98 | skip_layer_norm_v1_kernel_small
99 | <<>>(ld, input, skip, beta, gamma, output);
100 | } else if (ld <= 384) {
101 | constexpr int blockSize = 384;
102 | skip_layer_norm_v1_kernel_small
103 | <<>>(ld, input, skip, beta, gamma, output);
104 | } else {
105 | constexpr int blockSize = 256;
106 | skip_layer_norm_v1_kernel
107 | <<>>(ld, input, skip, beta, gamma, output);
108 | }
109 | (cudaPeekAtLastError());
110 |
111 | return 0;
112 | }
113 |
114 | int compute_skip_layer_norm_v1(cudaStream_t stream, const int ld, const int n, const float* input,
115 | const float* skip, const float* gamma, const float* beta, float* output) {
116 | return compute_skip_layer_norm_v1_tpl(stream, ld, n, input, skip, beta, gamma, output);
117 | }
118 |
119 | int compute_skip_layer_norm_v1(cudaStream_t stream, const int ld, const int n, const half* input,
120 | const half* skip, const half* gamma, const half* beta, half* output) {
121 | return compute_skip_layer_norm_v1_tpl(stream, ld, n, input, skip, beta, gamma, output);
122 | }
123 |
124 | inline int64_t volume(const nvinfer1::Dims& d) {
125 | return std::accumulate(d.d, d.d + d.nbDims, 1, std::multiplies());
126 | }
127 |
128 | int32_t SkipLayerNormV1Plugin::enqueue(const PluginTensorDesc* inputDesc, const PluginTensorDesc* outputDesc, const void* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept
129 | {
130 | /*const int nBlock = inputDesc[0].dims.d[0] * inputDesc[0].dims.d[1];*/
131 |
132 | /*layerNormKernel <<>>((float *)inputs[0], (float *)outputs[0]);*/
133 |
134 | const int input_volume = volume(inputDesc[0].dims);
135 | const int dim = inputDesc[0].dims.d[inputDesc[0].dims.nbDims - 1];
136 | const int S = input_volume / dim;
137 |
138 | int status = -1;
139 |
140 | /*const size_t word_size = getElementSize(DataType::kFLOAT);*/
141 |
142 | // Our plugin outputs only one tensor
143 | const float* input = static_cast(inputs[0]);
144 | const float* skip = static_cast(inputs[1]);
145 | const float* gamma_ptr = static_cast(inputs[2]);
146 | const float* beta_ptr = static_cast(inputs[3]);
147 | float* output = static_cast(outputs[0]);
148 |
149 | status = compute_skip_layer_norm_v1(stream, dim, input_volume, input, skip, gamma_ptr, beta_ptr, output);
150 |
151 | return 0;
152 | }
153 |
154 |
155 | REGISTER_TENSORRT_PLUGIN(SkipLayerNormV1PluginCreator);
156 |
157 |
--------------------------------------------------------------------------------
/Hackathon2022/LayerNormPlugin/SkipLayerNormV1Plugin.h:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 |
5 | // +------- Debug wrapper --------------------------------------------------------------------------
6 | #if DEBUG
7 | #define WHERE_AM_I() do {printf("[%s]: this=->%p\n",__func__,this);} while(0);
8 | #else
9 | #define WHERE_AM_I()
10 | #endif // DEBUG
11 |
12 | using namespace std;
13 |
14 | // +------- Plguin ---------------------------------------------------------------------------------
15 | namespace
16 | {
17 | static const char* PLUGIN_NAME{"SkipLayerNormV1"};
18 | static const char* PLUGIN_VERSION{"1"};
19 | } // namespace
20 |
21 | namespace nvinfer1
22 | {
23 |
24 | // +------- Plugin body ----------------------------------------------------------------------------
25 | class SkipLayerNormV1Plugin: public IPluginV2DynamicExt
26 | {
27 | private:
28 | std::string name_;
29 | std::string namespace_;
30 |
31 | public:
32 | SkipLayerNormV1Plugin(const std::string& name) : name_(name)
33 | {
34 | WHERE_AM_I();
35 | }
36 |
37 | SkipLayerNormV1Plugin(const std::string& name, const void* data, size_t length) : name_(name)
38 | {
39 | WHERE_AM_I();
40 | }
41 |
42 | SkipLayerNormV1Plugin() = delete;
43 |
44 | ~SkipLayerNormV1Plugin()
45 | {
46 | WHERE_AM_I();
47 | }
48 |
49 | size_t getSerializationSize() const noexcept override
50 | {
51 | WHERE_AM_I();
52 | return 0;
53 | }
54 |
55 | void serialize(void *buffer) const noexcept override
56 | {
57 | WHERE_AM_I();
58 | }
59 |
60 | IPluginV2DynamicExt* clone() const noexcept override
61 | {
62 | WHERE_AM_I();
63 | return new SkipLayerNormV1Plugin(name_);
64 | }
65 |
66 | int getNbOutputs() const noexcept override
67 | {
68 | WHERE_AM_I();
69 | return 1;
70 | }
71 |
72 | DimsExprs getOutputDimensions(int32_t outputIndex, const DimsExprs* inputs, int32_t nbInputs, IExprBuilder& exprBuilder) noexcept override
73 | {
74 | WHERE_AM_I();
75 | return inputs[0];
76 | }
77 |
78 | bool supportsFormatCombination(int32_t pos, const PluginTensorDesc* inOut, int32_t nbInputs, int32_t nbOutputs) noexcept override
79 | {
80 | WHERE_AM_I();
81 | if(inOut[pos].format != TensorFormat::kLINEAR)
82 | {
83 | return false;
84 | }
85 |
86 | cout << "inOut[pos].type " << (int)inOut[pos].type << endl;
87 | bool res = false;
88 | switch(pos)
89 | {
90 | case 0:
91 | res = (inOut[pos].type == DataType::kFLOAT); break;
92 | case 1:
93 | res = inOut[pos].type == inOut[0].type; break;
94 | case 2:
95 | res = inOut[pos].type == inOut[0].type; break;
96 | case 3:
97 | res = inOut[pos].type == inOut[0].type; break;
98 | default:// should NOT be here
99 | res = false;
100 | }
101 | return res;
102 | }
103 |
104 | DataType getOutputDataType(int outputIndex, const DataType* inputTypes, int nbInputs) const noexcept override
105 | {
106 | WHERE_AM_I();
107 | return DataType::kFLOAT;
108 | }
109 |
110 | void configurePlugin(const DynamicPluginTensorDesc* in, int32_t nbInputs,const DynamicPluginTensorDesc* out, int32_t nbOutputs) noexcept override
111 | {
112 | WHERE_AM_I();
113 | }
114 |
115 | size_t getWorkspaceSize(const PluginTensorDesc* inputs, int32_t nbInputs, const PluginTensorDesc* outputs,int32_t nbOutputs) const noexcept override
116 | {
117 | WHERE_AM_I();
118 | return 0;
119 | }
120 |
121 | void setPluginNamespace(const char* szNamespace) noexcept override
122 | {
123 | WHERE_AM_I();
124 | namespace_ = szNamespace;
125 | }
126 | const char* getPluginNamespace() const noexcept override
127 | {
128 | WHERE_AM_I();
129 | return namespace_.c_str();
130 | }
131 | const char* getPluginType() const noexcept override
132 | {
133 | WHERE_AM_I();
134 | return PLUGIN_NAME;
135 | }
136 | const char* getPluginVersion() const noexcept override
137 | {
138 | WHERE_AM_I();
139 | return PLUGIN_VERSION;
140 | }
141 | int initialize() noexcept override
142 | {
143 | WHERE_AM_I();
144 | return 0;
145 | }
146 | void terminate() noexcept override
147 | {
148 | WHERE_AM_I();
149 | return;
150 | }
151 |
152 | void destroy() noexcept override
153 | {
154 | WHERE_AM_I();
155 | }
156 |
157 | int32_t enqueue(const PluginTensorDesc* inputDesc, const PluginTensorDesc* outputDesc, const void* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept override;
158 | }; // class SkipLayerNormV1Plugin
159 |
160 | class SkipLayerNormV1PluginCreator : public IPluginCreator
161 | {
162 | private:
163 | static PluginFieldCollection fc_;
164 | static std::vector attr_;
165 | std::string namespace_;
166 |
167 | public:
168 | SkipLayerNormV1PluginCreator()
169 | {
170 | fc_.nbFields = attr_.size();
171 | fc_.fields = attr_.data();
172 | }
173 |
174 | ~SkipLayerNormV1PluginCreator() {}
175 |
176 | IPluginV2* createPlugin(const char* name, const PluginFieldCollection* fc) noexcept override
177 | {
178 | WHERE_AM_I();
179 | return new SkipLayerNormV1Plugin(name);
180 | }
181 |
182 | IPluginV2* deserializePlugin(const char* name, const void* serialData, size_t serialLength) noexcept override
183 | {
184 | return new SkipLayerNormV1Plugin(name, serialData, serialLength);
185 | }
186 |
187 | void setPluginNamespace(const char* szNamespace) noexcept override
188 | {
189 | namespace_ = szNamespace;
190 | }
191 |
192 | const char* getPluginNamespace() const noexcept override
193 | {
194 | return namespace_.c_str();
195 | }
196 |
197 | const char* getPluginName() const noexcept override
198 | {
199 | return PLUGIN_NAME;
200 | }
201 |
202 | const char* getPluginVersion() const noexcept override
203 | {
204 | return PLUGIN_VERSION;
205 | }
206 |
207 | const PluginFieldCollection* getFieldNames() noexcept override
208 | {
209 | return &fc_;
210 | }
211 | }; // class SkipLayerNormV1PluginCreator
212 |
213 | } // namespace nvinfer1
214 |
215 |
--------------------------------------------------------------------------------
/Hackathon2022/LayerNormPlugin/SkipLayerNormV2Plugin.cu:
--------------------------------------------------------------------------------
1 |
2 | #include
3 | #include
4 | #include
5 |
6 | #include "SkipLayerNormV2Plugin.h"
7 | #include "common.cuh"
8 |
9 | using namespace nvinfer1;
10 |
11 | PluginFieldCollection SkipLayerNormV2PluginCreator::fc_{};
12 | std::vector SkipLayerNormV2PluginCreator::attr_;
13 |
14 | template
15 | __global__ void skip_layer_norm_v2_kernel_small(
16 | const int ld, const T* input, const T* skip, const T* beta, const T* gamma, T* output, T* add_output)
17 | {
18 |
19 | const T rld = T(1) / T(ld);
20 | const int offset = blockIdx.x * ld;
21 |
22 | cub::Sum pairSum;
23 | // reduce x and x^2
24 | kvp threadData(0, 0);
25 | const int idx = offset + threadIdx.x;
26 | T val = 0;
27 |
28 | if (threadIdx.x < ld)
29 | {
30 |
31 | val = input[idx] + skip[idx];
32 | add_output[idx] = val;
33 |
34 | const T rldval = rld * val;
35 | threadData = pairSum(threadData, kvp(rldval, rldval * val));
36 | }
37 |
38 | layerNormSmall(val, threadData, ld, idx, beta, gamma, output);
39 | }
40 |
41 | template
42 | __global__ void skip_layer_norm_v2_kernel(
43 | const int ld, const T* input, const T* skip, const T* beta, const T* gamma, T* output, T* add_output)
44 | {
45 | const T rld = T(1) / T(ld);
46 | const int offset = blockIdx.x * ld;
47 |
48 | cub::Sum pairSum;
49 | // reduce x and x^2
50 | kvp threadData(0, 0);
51 |
52 | for (int i = threadIdx.x; i < ld; i += TPB)
53 | {
54 | const int idx = offset + i;
55 | T val = T(input[idx] + skip[idx]);
56 | add_output[idx] = val;
57 |
58 | const T rldval = rld * val;
59 | threadData = pairSum(threadData, kvp(rldval, rldval * val));
60 | output[idx] = val;
61 | }
62 |
63 | layerNorm(threadData, ld, offset, beta, gamma, output);
64 | if (blockIdx.x == 0 && threadIdx.x == 0) {
65 | printf("%f %f %f %f\n", __half2float(gamma[0]), __half2float(beta[0]), __half2float(input[0]), __half2float(output[0]));
66 | }
67 | }
68 |
69 | template
70 | int compute_skip_layer_norm_v2_tpl(cudaStream_t stream, const int ld, const int n,
71 | const T* input, const T* skip, const T* beta, const T* gamma, T* output, T* add_output) {
72 |
73 | // this must be true because n is the total size of the tensor
74 | assert(n % ld == 0);
75 | const int gridSize = n / ld;
76 | /*constexpr int VPT = 16 / sizeof(T);*/
77 | if (ld <= 32) {
78 | constexpr int blockSize = 32;
79 | skip_layer_norm_v2_kernel_small
80 | <<>>(ld, input, skip, beta, gamma, output, add_output);
81 | } else if (ld <= 128) {
82 | constexpr int blockSize = 128;
83 | skip_layer_norm_v2_kernel_small
84 | <<>>(ld, input, skip, beta, gamma, output, add_output);
85 | } else if (ld <= 384) {
86 | constexpr int blockSize = 384;
87 | skip_layer_norm_v2_kernel_small
88 | <<>>(ld, input, skip, beta, gamma, output, add_output);
89 | } else {
90 | constexpr int blockSize = 256;
91 | skip_layer_norm_v2_kernel
92 | <<>>(ld, input, skip, beta, gamma, output, add_output);
93 | }
94 | (cudaPeekAtLastError());
95 |
96 | return 0;
97 | }
98 |
99 | int compute_skip_layer_norm_v2(cudaStream_t stream, const int ld, const int n, const float* input,
100 | const float* skip, const float* gamma, const float* beta, float* output, float* add_output) {
101 | return compute_skip_layer_norm_v2_tpl(stream, ld, n, input, skip, beta, gamma, output, add_output);
102 | }
103 |
104 | int compute_skip_layer_norm_v2(cudaStream_t stream, const int ld, const int n, const half* input,
105 | const half* skip, const half* gamma, const half* beta, half* output, half* add_output) {
106 | return compute_skip_layer_norm_v2_tpl(stream, ld, n, input, skip, beta, gamma, output, add_output);
107 | }
108 |
109 | inline int64_t volume(const nvinfer1::Dims& d) {
110 | return std::accumulate(d.d, d.d + d.nbDims, 1, std::multiplies());
111 | }
112 |
113 | int32_t SkipLayerNormV2Plugin::enqueue(const PluginTensorDesc* inputDesc, const PluginTensorDesc* outputDesc, const void* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept
114 | {
115 | /*const int nBlock = inputDesc[0].dims.d[0] * inputDesc[0].dims.d[1];*/
116 |
117 | /*layerNormKernel <<>>((float *)inputs[0], (float *)outputs[0]);*/
118 |
119 | const int input_volume = volume(inputDesc[0].dims);
120 | const int dim = inputDesc[0].dims.d[inputDesc[0].dims.nbDims - 1];
121 | const int S = input_volume / dim;
122 |
123 | int status = -1;
124 |
125 | /*const size_t word_size = getElementSize(DataType::kFLOAT);*/
126 |
127 | // Our plugin outputs only one tensor
128 | const float* input = static_cast(inputs[0]);
129 | const float* skip = static_cast(inputs[1]);
130 | const float* gamma_ptr = static_cast(inputs[2]);
131 | const float* beta_ptr = static_cast(inputs[3]);
132 | float* output = static_cast(outputs[0]);
133 | float* add_output = static_cast(outputs[1]);
134 |
135 | status = compute_skip_layer_norm_v2(stream, dim, input_volume, input, skip, gamma_ptr, beta_ptr, output, add_output);
136 |
137 | return 0;
138 | }
139 |
140 |
141 | REGISTER_TENSORRT_PLUGIN(SkipLayerNormV2PluginCreator);
142 |
143 |
--------------------------------------------------------------------------------
/Hackathon2022/LayerNormPlugin/SkipLayerNormV2Plugin.h:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 |
5 | // +------- Debug wrapper --------------------------------------------------------------------------
6 | #if DEBUG
7 | #define WHERE_AM_I() do {printf("[%s]: this=->%p\n",__func__,this);} while(0);
8 | #else
9 | #define WHERE_AM_I()
10 | #endif // DEBUG
11 |
12 | using namespace std;
13 |
14 | // +------- Plguin ---------------------------------------------------------------------------------
15 | namespace
16 | {
17 | static const char* PLUGIN_NAME{"SkipLayerNormV2"};
18 | static const char* PLUGIN_VERSION{"1"};
19 | } // namespace
20 |
21 | namespace nvinfer1
22 | {
23 |
24 | // +------- Plugin body ----------------------------------------------------------------------------
25 | class SkipLayerNormV2Plugin: public IPluginV2DynamicExt
26 | {
27 | private:
28 | std::string name_;
29 | std::string namespace_;
30 |
31 | public:
32 | SkipLayerNormV2Plugin(const std::string& name) : name_(name)
33 | {
34 | WHERE_AM_I();
35 | }
36 |
37 | SkipLayerNormV2Plugin(const std::string& name, const void* data, size_t length) : name_(name)
38 | {
39 | WHERE_AM_I();
40 | }
41 |
42 | SkipLayerNormV2Plugin() = delete;
43 |
44 | ~SkipLayerNormV2Plugin()
45 | {
46 | WHERE_AM_I();
47 | }
48 |
49 | size_t getSerializationSize() const noexcept override
50 | {
51 | WHERE_AM_I();
52 | return 0;
53 | }
54 |
55 | void serialize(void *buffer) const noexcept override
56 | {
57 | WHERE_AM_I();
58 | }
59 |
60 | IPluginV2DynamicExt* clone() const noexcept override
61 | {
62 | WHERE_AM_I();
63 | return new SkipLayerNormV2Plugin(name_);
64 | }
65 |
66 | int getNbOutputs() const noexcept override
67 | {
68 | WHERE_AM_I();
69 | return 2;
70 | }
71 |
72 | DimsExprs getOutputDimensions(int32_t outputIndex, const DimsExprs* inputs, int32_t nbInputs, IExprBuilder& exprBuilder) noexcept override
73 | {
74 | WHERE_AM_I();
75 | return inputs[0];
76 | }
77 |
78 | bool supportsFormatCombination(int32_t pos, const PluginTensorDesc* inOut, int32_t nbInputs, int32_t nbOutputs) noexcept override
79 | {
80 | return true;
81 | WHERE_AM_I();
82 | if(inOut[pos].format != TensorFormat::kLINEAR)
83 | {
84 | return false;
85 | }
86 |
87 | cout << "inOut[pos].type " << (int)inOut[pos].type << endl;
88 | bool res = false;
89 | switch(pos)
90 | {
91 | case 0:
92 | res = (inOut[pos].type == DataType::kFLOAT); break;
93 | case 1:
94 | res = inOut[pos].type == inOut[0].type; break;
95 | case 2:
96 | res = inOut[pos].type == inOut[0].type; break;
97 | case 3:
98 | res = inOut[pos].type == inOut[0].type; break;
99 | default:// should NOT be here
100 | res = false;
101 | }
102 | return res;
103 | }
104 |
105 | DataType getOutputDataType(int outputIndex, const DataType* inputTypes, int nbInputs) const noexcept override
106 | {
107 | WHERE_AM_I();
108 | return DataType::kFLOAT;
109 | }
110 |
111 | void configurePlugin(const DynamicPluginTensorDesc* in, int32_t nbInputs,const DynamicPluginTensorDesc* out, int32_t nbOutputs) noexcept override
112 | {
113 | WHERE_AM_I();
114 | }
115 |
116 | size_t getWorkspaceSize(const PluginTensorDesc* inputs, int32_t nbInputs, const PluginTensorDesc* outputs,int32_t nbOutputs) const noexcept override
117 | {
118 | WHERE_AM_I();
119 | return 0;
120 | }
121 |
122 | void setPluginNamespace(const char* szNamespace) noexcept override
123 | {
124 | WHERE_AM_I();
125 | namespace_ = szNamespace;
126 | }
127 | const char* getPluginNamespace() const noexcept override
128 | {
129 | WHERE_AM_I();
130 | return namespace_.c_str();
131 | }
132 | const char* getPluginType() const noexcept override
133 | {
134 | WHERE_AM_I();
135 | return PLUGIN_NAME;
136 | }
137 | const char* getPluginVersion() const noexcept override
138 | {
139 | WHERE_AM_I();
140 | return PLUGIN_VERSION;
141 | }
142 | int initialize() noexcept override
143 | {
144 | WHERE_AM_I();
145 | return 0;
146 | }
147 | void terminate() noexcept override
148 | {
149 | WHERE_AM_I();
150 | return;
151 | }
152 |
153 | void destroy() noexcept override
154 | {
155 | WHERE_AM_I();
156 | }
157 |
158 | int32_t enqueue(const PluginTensorDesc* inputDesc, const PluginTensorDesc* outputDesc, const void* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept override;
159 | }; // class SkipLayerNormV2Plugin
160 |
161 | class SkipLayerNormV2PluginCreator : public IPluginCreator
162 | {
163 | private:
164 | static PluginFieldCollection fc_;
165 | static std::vector attr_;
166 | std::string namespace_;
167 |
168 | public:
169 | SkipLayerNormV2PluginCreator()
170 | {
171 | fc_.nbFields = attr_.size();
172 | fc_.fields = attr_.data();
173 | }
174 |
175 | ~SkipLayerNormV2PluginCreator() {}
176 |
177 | IPluginV2* createPlugin(const char* name, const PluginFieldCollection* fc) noexcept override
178 | {
179 | WHERE_AM_I();
180 | return new SkipLayerNormV2Plugin(name);
181 | }
182 |
183 | IPluginV2* deserializePlugin(const char* name, const void* serialData, size_t serialLength) noexcept override
184 | {
185 | return new SkipLayerNormV2Plugin(name, serialData, serialLength);
186 | }
187 |
188 | void setPluginNamespace(const char* szNamespace) noexcept override
189 | {
190 | namespace_ = szNamespace;
191 | }
192 |
193 | const char* getPluginNamespace() const noexcept override
194 | {
195 | return namespace_.c_str();
196 | }
197 |
198 | const char* getPluginName() const noexcept override
199 | {
200 | return PLUGIN_NAME;
201 | }
202 |
203 | const char* getPluginVersion() const noexcept override
204 | {
205 | return PLUGIN_VERSION;
206 | }
207 |
208 | const PluginFieldCollection* getFieldNames() noexcept override
209 | {
210 | return &fc_;
211 | }
212 | }; // class SkipLayerNormV2PluginCreator
213 |
214 | } // namespace nvinfer1
215 |
216 |
--------------------------------------------------------------------------------
/Hackathon2022/LayerNormPlugin/testLayerNormPlugin.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 |
17 | import os
18 | import ctypes
19 | import numpy as np
20 | from cuda import cudart # 使用 cuda runtime API
21 | import tensorrt as trt
22 |
23 | soFilePath = './LayerNorm.so'
24 | nBS = 4
25 | nSL = 64
26 | nEmbedding = 256
27 | epsilon = 6e-6
28 |
29 | np.random.seed(97)
30 |
31 | def check(a, b, weak = False):
32 | if weak:
33 | return np.all( np.abs(a - b) < epsilon)
34 | else:
35 | return np.all( a == b )
36 |
37 | def layerNormCPU(bufferH):
38 | _x = bufferH[0]
39 | nEmbed = bufferH[0].shape[2]
40 | _0 = np.mean(_x,2)[:,:,np.newaxis]
41 | _1 = _x - _0
42 | _2 = _1 * _1
43 | _3 = np.mean(_2,2)[:,:,np.newaxis]
44 | _4 = np.array(epsilon,dtype=np.float32)
45 | _5 = _4.reshape(1,1,1)
46 | _6 = _3 + _5
47 | _7 = np.sqrt(_6)
48 | _8 = 1 / _7 # 1/sqrt(...)
49 | _9 = _1 * _8
50 | return _9
51 |
52 | def getLayerNormPlugin():
53 | for c in trt.get_plugin_registry().plugin_creator_list:
54 | #print(c.name)
55 | if c.name == 'LayerNorm':
56 | return c.create_plugin(c.name, trt.PluginFieldCollection([]))
57 | return None
58 |
59 | def run():
60 | logger = trt.Logger(trt.Logger.ERROR)
61 | trt.init_libnvinfer_plugins(logger, '')
62 | ctypes.cdll.LoadLibrary(soFilePath)
63 |
64 | builder = trt.Builder(logger)
65 | network = builder.create_network(1<<0)
66 | config = builder.create_builder_config()
67 | config.max_workspace_size = 6 << 30
68 | config.flags = 0
69 |
70 | inputTensorList = []
71 | inputTensorList.append( network.add_input('inputT', trt.float32, [-1,-1,256]) )
72 |
73 | profile = builder.create_optimization_profile()
74 | profile.set_shape('inputT',[1,4,256],[4,64,256],[16,256,256])
75 | config.add_optimization_profile(profile)
76 |
77 | pluginLayer = network.add_plugin_v2(inputTensorList, getLayerNormPlugin())
78 |
79 | network.mark_output(pluginLayer.get_output(0))
80 |
81 | engineString = builder.build_serialized_network(network, config)
82 | engine = trt.Runtime(logger).deserialize_cuda_engine(engineString)
83 |
84 | context = engine.create_execution_context()
85 | context.set_binding_shape(0,[nBS,nSL,nEmbedding])
86 | print("Binding all? %s"%(["No","Yes"][int(context.all_binding_shapes_specified)]))
87 |
88 | nInput = np.sum([ engine.binding_is_input(i) for i in range(engine.num_bindings) ])
89 | nOutput = engine.num_bindings - nInput
90 | for i in range(engine.num_bindings):
91 | print("input ->" if engine.binding_is_input(i) else "output->",engine.get_binding_dtype(i),engine.get_binding_shape(i),context.get_binding_shape(i))
92 |
93 | bufferH = []
94 | bufferH.append( np.random.rand(nBS,nSL,nEmbedding).astype(np.float32).reshape(nBS,nSL,nEmbedding) * 2 - 1)
95 | bufferH.append(np.empty(context.get_binding_shape(1),dtype=trt.nptype(engine.get_binding_dtype(1))))
96 |
97 | bufferD = []
98 | for i in range(engine.num_bindings):
99 | bufferD.append(cudart.cudaMalloc(bufferH[i].nbytes)[1])
100 |
101 | for i in range(nInput):
102 | cudart.cudaMemcpy(bufferD[i], bufferH[i].ctypes.data, bufferH[i].nbytes, cudart.cudaMemcpyKind.cudaMemcpyHostToDevice)
103 |
104 | context.execute_v2(bufferD)
105 |
106 | for i in range(nInput, nInput + nOutput):
107 | cudart.cudaMemcpy(bufferH[i].ctypes.data, bufferD[i], bufferH[i].nbytes, cudart.cudaMemcpyKind.cudaMemcpyDeviceToHost)
108 |
109 | print("check result:")
110 | temp1 = bufferH[-1]
111 | temp2 = layerNormCPU(bufferH[:1])
112 | print(check(temp1,temp2,True), "max diff=%f"%(np.max(np.abs(temp1 - temp2))) )
113 |
114 | for b in bufferD:
115 | cudart.cudaFree(b)
116 |
117 | if __name__ == '__main__':
118 | os.system("rm -f ./*.trt")
119 | np.set_printoptions(precision = 4, linewidth = 200, suppress = True)
120 | run()
121 |
--------------------------------------------------------------------------------
/Hackathon2022/build.sh:
--------------------------------------------------------------------------------
1 | cd LayerNormPlugin && make && cd ..
2 | python encoder2trt.py && python decoder2trt.py && cp LayerNormPlugin/*.so .
3 |
--------------------------------------------------------------------------------
/Hackathon2022/calibrator.py:
--------------------------------------------------------------------------------
1 | import tensorrt as trt
2 | import os
3 |
4 | import numpy as np
5 | import pycuda.driver as cuda
6 | import pycuda.autoinit
7 | from sys import getsizeof
8 |
9 | # import pycuda.driver as cuda
10 | # import pycuda.autoinit
11 | # import numpy as np
12 | # import helpers.tokenization as tokenization
13 | # import helpers.data_processing as dp
14 |
15 | class EncoderCalibrator(trt.IInt8LegacyCalibrator):
16 | def __init__(self, calibration_data_file, cache_file, batch_size):
17 | # Whenever you specify a custom constructor for a TensorRT class,
18 | # you MUST call the constructor of the parent explicitly.
19 | trt.IInt8LegacyCalibrator.__init__(self)
20 |
21 | self.cache_file = cache_file
22 |
23 | # self.feat_list = feat_list
24 | # self.feat_len_list = feat_len_list
25 | self.batch_size = batch_size
26 | self.current_index = 0
27 |
28 | print("start read " + calibration_data_file)
29 | # feat_name_list = []
30 | self.feat_list = []
31 | self.feat_len_list = []
32 | data = np.load(calibration_data_file)
33 | for i in data.files:
34 | if "speech-" in i:
35 | self.feat_list.append(data[i])
36 | print(i)
37 | print(data[i].shape)
38 | if "speech_lengths" in i:
39 | self.feat_len_list.append(data[i])
40 | print(i)
41 | print(data[i].shape)
42 |
43 | if len(self.feat_list) != len(self.feat_len_list):
44 | print("len(feat_list) != len(feat_len_list)")
45 | assert(0)
46 |
47 | self.num_inputs = len(self.feat_list)
48 | # self.num_inputs = 1
49 |
50 | self.d_feat = None
51 | self.d_feat_len = None
52 |
53 | def free(self):
54 | pass
55 |
56 | def get_batch_size(self):
57 | return self.batch_size
58 |
59 | # TensorRT passes along the names of the engine bindings to the get_batch function.
60 | # You don't necessarily have to use them, but they can be useful to understand the order of
61 | # the inputs. The bindings list is expected to have the same ordering as 'names'.
62 | def get_batch(self, names):
63 | # print("self.num_inputs:" + str(self.num_inputs))
64 | # print("self.current_index:" + str(self.current_index))
65 | if self.current_index >= self.num_inputs:
66 | print("Calibrating index {:} batch size {:} exceed max input limit {:} sentences".format(self.current_index, self.batch_size, self.num_inputs))
67 | return None
68 |
69 |
70 | np_feats = np.concatenate((np_feats, feat), axis=0)
71 |
72 | feat_len = self.feat_len_list[self.current_index + i]
73 | np_feat_lens = np.concatenate((np_feat_lens, feat_len), axis=0)
74 |
75 | np_feats = self.feat_list[self.current_index]
76 | np_feat_lens = self.feat_len_list[self.current_index]
77 | # print(np_feats.shape)
78 | # print(np_feat_lens.shape)
79 | self.d_feat = cuda.mem_alloc(np_feats.size * 4)
80 | self.d_feat_len = cuda.mem_alloc(np_feat_lens.size * 4)
81 |
82 | print(getsizeof(np_feats))
83 | print(self.d_feat_len)
84 |
85 | cuda.memcpy_htod(self.d_feat, np_feats.ravel())
86 | cuda.memcpy_htod(self.d_feat_len, np_feat_lens.ravel())
87 |
88 | self.current_index += 1
89 | return [self.d_feat, self.d_feat_len]
90 |
91 | # t_feats = torch.from_numpy(np_feats).cuda()
92 | # t_feat_lens = torch.from_numpy(np_feat_lens).cuda()
93 |
94 | # return [t_feats.data_ptr(), t_feat_lens.data_ptr()]
95 |
96 | def read_calibration_cache(self):
97 | # If there is a cache, use it instead of calibrating again. Otherwise, implicitly return None.
98 | if os.path.exists(self.cache_file):
99 | with open(self.cache_file, "rb") as f:
100 | return f.read()
101 |
102 | def write_calibration_cache(self, cache):
103 | with open(self.cache_file, "wb") as f:
104 | f.write(cache)
105 | f.flush()
106 | os.fsync(f)
107 |
108 | def get_quantile(self):
109 | return 0.9999
110 |
111 | def get_regression_cutoff(self):
112 | return 1.0
113 |
114 | def read_histogram_cache(self, length):
115 | return None
116 |
117 | def write_histogram_cache(self, ptr, length):
118 | return None
119 |
120 |
121 | def main():
122 | c = EncoderCalibrator("/workspace/data/calibration.npz", "encoder.cache", 100)
123 | c.get_batch("input")
124 | c.get_batch("input")
125 | c.get_batch("input")
126 | c.get_batch("input")
127 |
128 | if __name__ == '__main__':
129 | main()
130 |
--------------------------------------------------------------------------------
/Hackathon2022/decoder2trt.py:
--------------------------------------------------------------------------------
1 | import tensorrt as trt
2 | from glob import glob
3 | import ctypes
4 | import os
5 |
6 | import numpy as np
7 | import onnx
8 |
9 | import onnx_graphsurgeon as gs
10 | import onnx
11 | import numpy as np
12 | from onnx import TensorProto
13 |
14 | def decoder_surgeon(src_onnx, dst_onnx):
15 |
16 | graph = gs.import_onnx(onnx.load(src_onnx))
17 |
18 | # layer_norm
19 | start_node = None
20 | end_node = None
21 |
22 | weight_node = None
23 | bias_node = None
24 |
25 | layer_norm_layer_idx = 0
26 |
27 | for node in graph.nodes:
28 |
29 | if node.op == 'ReduceMean' and node.o(0).op == "Sub":
30 | start_node = node
31 | sub_node = node.o(0)
32 | if sub_node.o(0).op == "Pow":
33 | pow_node = sub_node.o(0)
34 | # print(pow_node)
35 | if pow_node.o(0).op == "ReduceMean":
36 | rm_node = pow_node.o(0)
37 | # print(rm_node)
38 | if rm_node.o(0).op == "Add":
39 | add_node = rm_node.o(0)
40 | # print(add_node)
41 | if add_node.o(0).op == "Sqrt":
42 | sqrt_node = add_node.o(0)
43 | if sqrt_node.o(0).op == "Div":
44 | div_node = sqrt_node.o(0)
45 | if div_node.o(0).op == "Mul":
46 | mul_node = div_node.o(0)
47 | weight_node = mul_node
48 | if mul_node.o(0).op == "Add":
49 | add_node = mul_node.o(0)
50 | bias_node = add_node
51 | end_node = add_node
52 |
53 | layer_norm_plugin = gs.Node("LayerNorm", "LayerNorm-" + str(layer_norm_layer_idx))
54 | layer_norm_layer_idx = layer_norm_layer_idx + 1
55 | graph.nodes.append(layer_norm_plugin)
56 |
57 | print(start_node)
58 | print(end_node)
59 | print(weight_node.inputs)
60 | print(bias_node.inputs)
61 | print("=======================")
62 | layer_norm_plugin.inputs = [start_node.inputs[0], weight_node.inputs[1], bias_node.inputs[1]]
63 | layer_norm_plugin.outputs = end_node.outputs
64 |
65 | start_node.inputs = []
66 | end_node.outputs = []
67 |
68 | # layer_norm
69 | start_node = None
70 | end_node = None
71 |
72 | weight_node = None
73 | bias_node = None
74 |
75 | # graph.outputs.append(Expand_23_node.outputs[0])
76 | # print(graph.outputs )
77 | graph.cleanup()
78 | onnx.save(gs.export_onnx(graph), dst_onnx)
79 |
80 | model = onnx.load(dst_onnx)
81 | # print(graph.outputs )
82 | # assert(0)
83 |
84 | def onnx2trt(onnxFile, plan_name):
85 |
86 | soFileList = glob("LayerNormPlugin/*.so")
87 |
88 | if len(soFileList) > 0:
89 | print("Find Plugin %s!"%soFileList)
90 | else:
91 | print("No Plugin!")
92 | for soFile in soFileList:
93 | ctypes.cdll.LoadLibrary(soFile)
94 |
95 | logger = trt.Logger(trt.Logger.VERBOSE)
96 |
97 | builder = trt.Builder(logger)
98 | config = builder.create_builder_config()
99 | profile = builder.create_optimization_profile()
100 | network = builder.create_network(1< torch.Tensor:
5 | batch_size = lengths.size(0)
6 | max_len = max_len if max_len > 0 else lengths.max().item()
7 | seq_range = torch.arange(0, max_len, dtype=torch.int64, device=lengths.device)
8 | seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
9 | seq_length_expand = lengths.unsqueeze(-1)
10 | mask = seq_range_expand >= seq_length_expand
11 | return mask
12 |
13 | feat = torch.zeros(1, 100, 80)
14 | feat_len = torch.tensor([feat.size(1)]).int()
15 |
16 | class BaseEncoder(torch.nn.Module):
17 | def forward(self, xs, xs_lens):
18 | T = xs.size(1)
19 | masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
20 |
21 | masks = masks[:, :, :-2:2][:, :, :-2:2]
22 |
23 | masks1 = ~masks
24 | masks2 = torch.unsqueeze(masks, 1).int()
25 |
26 | xs = F.log_softmax(xs, dim=-1)
27 |
28 | return xs, masks1, masks2
29 |
30 |
31 | model = BaseEncoder()
32 | # score, mask = model(feat, feat_len)
33 |
34 | input_names, output_names = ["feat", "feat_len"], ["output", "masks1", "masks2"]
35 | dynamic_axes= {'feat':{0:'batch_size', 1: 'seq_len'},
36 | 'feat_len': {0: 'batch_size'},
37 | 'output':{0:'batch_size', 1: 'seq_len'}}
38 |
39 | torch.onnx.export(model, (feat, feat_len), "mask.onnx",
40 | opset_version=13, verbose=True,
41 | input_names=input_names, output_names=output_names,
42 | dynamic_axes = dynamic_axes)
43 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # 中文翻译文档
2 | - 2023-9-27 增加 TensorRT8.5.3的中文翻译文档,使用Chat-GPT翻译+精校,chapter1-2
3 |
4 | # 建议看最新视频版本!列表如下
5 | - [《TensorRT Tutorial(一)如何选择TensorRT版本》][21]
6 | - [《TensorRT Tutorial(二)编译 TensorRT 的开源源码》][22]
7 | - [《TensorRT Tutorial(3.1)讲解 TensorRT 文档-基本使用》][23]
8 | - [《TensorRT Tutorial(3.2)讲解 TensorRT 文档-TRT可借鉴的代码样例》][24]
9 | - [《TensorRT Tutorial(3.3.1)plugin 例子和原理》][25]
10 | - [《TensorRT Tutorial(3.3.2)如何打造自己的plugin库》][26]
11 | - [《TensorRT plugin 16 加速经验》][27]
12 |
13 | - 视频版资料见目录-视频版资料
14 |
15 | ## 进度记录
16 | - 2017-04-27 项目发起,创建GitHub仓库。
17 | - 2017-09-30 TensorRT 3最近发布,整理一下目前的资源。
18 | - 2017-10-18 增加博客-使用TensorRT实现leaky relu层
19 | - 2017-11-11 资源:新增google的INT8开源库
20 | - 2017-11-25 增加博客-TensorRT Plugin使用方式简介-以leaky relu层为例
21 | - 2020-8-31 增加博客《TensorRT Github 开源部分介绍》
22 | - 2020-9-7 增加博客《TensorRT 可借鉴代码汇总》
23 | - 2022-11-2 增加博客《Conformer Encoder GPU 加速策略较全面汇总》
24 | - 2022-11-2 增加博客《TensorRT 转换模型的几种方式比较》
25 |
26 | ----
27 |
28 | ## 资源整理
29 | - [TensorRT 3 RC][1]和[TensorRT 2.1][2] 下载链接
30 | - [TensorRT 2.1 官方在线文档][3]
31 | - NVIDIA 介绍TensorRT的blog-[Deploying Deep Neural Networks with NVIDIA TensorRT][4]
32 | - GTC 2017介绍TensorRT 的[PPT][5]和[视频][6],内含INT8 Quantization和Calibration的实现原理。
33 | - 新增cublas 和 cudnn的INT8 [demo][7]
34 | - 新增本人在GTC China 2017 Community Corner主题NVIDIA INT8的PPT, [GTC-China-2017-NVIDIA-INT8.pdf][8]
35 | - 新增google的INT8开源库[gemmlowp][9],目前支持ARM和CPU优化
36 | - “子棐之GPGPU”公众号所写的《TensorRT系列》博客,NVIDIA的工程师出的,从入门篇到INT8篇再到FP16篇最后收尾于Custom Layer篇,内容逻辑清楚,干货满满,自愧不如。附四篇博客链接:[TensorRT 系列之入门篇][10],[TensorRT系列之INT8篇][11],[TensorRT系列之FP16篇][12],[TensorRT系列之Custom Layer篇][13]。
37 | - [《高性能深度学习支持引擎实战——TensorRT》][14],主要内容:一、TensorRT理论介绍:基础介绍TensorRT是什么;做了哪些优化;为什么在有了框架的基础上还需要TensorRT的优化引擎。二、TensorRT高阶介绍:对于进阶的用户,出现TensorRT不支持的网络层该如何处理;
38 |
39 | ---
40 | ## 博客
41 | - [使用TensorRT实现leaky relu层][15]
42 | - [TensorRT Plugin使用方式简介-以leaky relu层为例][16]
43 |
44 | # TensorRT_Tutorial
45 |
46 | TensorRT作为NVIDIA推出的c++库,能够实现高性能推理(inference)过程。最近,NVIDIA发布了TensorRT 2.0 Early Access版本,重大更改就是支持INT8类型。在当今DL大行其道的时代,INT8在缩小模型大小、加速运行速度方面具有非常大的优势。Google新发布的TPU就采用了8-bit的数据类型。
47 |
48 | 本人目前在使用TensorRT进行INT8的探究。已经被TensorRT不完善的文档坑了一次了。所以想自力更生做一个TensorRT Tutorial,主要包括三部分:
49 | - TensorRT User Guide 翻译;
50 | - TensorRT samples 介绍分析讲解;
51 | - TensorRT 使用经验。
52 |
53 | 感谢每一位为该翻译项目做出贡献的同学.
54 |
55 | 内容来源:
56 | TensorRT 下载页面:
57 | https://developer.nvidia.com/nvidia-tensorrt-20-download
58 |
59 | TensorRT 文档、Samples
60 | 安装后对应目录中
61 |
62 | ## 参与者(按参与时间排序)
63 | TensorRT User Guide 翻译
64 | - [LitLeo][18]
65 |
66 | TensorRT samples 介绍分析讲解
67 | - [LitLeo][20]
68 |
69 | TensorRT 使用经验。
70 |
71 | 欲参与者请加QQ群:483063470
72 |
73 | 支持捐赠项目
74 |
75 |
76 |
77 | ## 招实习生
78 | 【实习】【腾讯北京AILAB】招募AI异构加速实习生
79 | 简历直接给负责人,给简历保证迅速反馈。
80 | 基本条件: 熟悉c++,至少实习6个月
81 | 工作内容:
82 | 1. 使用c++复现框架训练的模型并进行CPU、GPU、ARM加速,达到上线的性能要求。
83 | 2. 调研各种inference框架并投入生产
84 | 加分项:
85 | 1. 写过或者维护过深度学习框架代码;
86 | 2. 会CUDA 开发,能自己写kernel,会用cublas,cudnn等库;
87 | 3. linux cpu c++编程能力,会写avx、会用mkl;
88 | 4. 熟悉深度学习计算过程
89 | 5. 学习能力强,实习时间长
90 | 联系方式: lityangweiguang@163.com
91 |
92 | [1]: https://developer.nvidia.com/nvidia-tensorrt3rc-download
93 | [2]: https://developer.nvidia.com/nvidia-tensorrt-download
94 | [3]: http://docs.nvidia.com/deeplearning/sdk/tensorrt-user-guide/index.html
95 | [4]: https://devblogs.nvidia.com/parallelforall/deploying-deep-learning-nvidia-tensorrt/
96 | [5]: http://on-demand.gputechconf.com/gtc/2017/presentation/s7310-8-bit-inference-with-tensorrt.pdf
97 | [6]: http://on-demand.gputechconf.com/gtc/2017/video/s7310-szymon-migacz-8-bit-inference-with-tensorrt.mp4
98 | [7]: https://github.com/LitLeo/TensorRT_Tutorial/tree/master/cublas&cudnn_int8_demo
99 | [8]: https://github.com/LitLeo/TensorRT_Tutorial/blob/master/GTC-China-2017-NVIDIA-INT8.pdf
100 | [9]: https://github.com/google/gemmlowp
101 | [10]: https://mp.weixin.qq.com/s/E5qbMsuc7UBnNmYBzq__5Q
102 | [11]: https://mp.weixin.qq.com/s/wyqxUlXxgA9Eaxf0AlAVzg
103 | [12]: https://mp.weixin.qq.com/s/nuEVZlS6JfqRQo30S0W-Ww?scene=25#wechat_redirect
104 | [13]: https://mp.weixin.qq.com/s/xabDoauJc16z3-gpyre8zA
105 | [14]: https://mp.weixin.qq.com/s/F_VvLTWfg-COZKrQAtOSwg
106 | [15]: https://github.com/LitLeo/TensorRT_Tutorial/blob/master/blogs/%E4%BD%BF%E7%94%A8TensorRT%E5%AE%9E%E7%8E%B0leaky%20relu%E5%B1%82.md
107 | [16]: https://github.com/LitLeo/TensorRT_Tutorial/blob/master/blogs/TensorRT%20Plugin%E4%BD%BF%E7%94%A8%E6%96%B9%E5%BC%8F%E7%AE%80%E4%BB%8B-%E4%BB%A5leaky%20relu%E5%B1%82%E4%B8%BA%E4%BE%8B.md
108 | [17]: https://github.com/LitLeo/TensorRT_Tutorial/blob/master/Bug.md
109 | [18]: https://github.com/LitLeo
110 | [19]: https://github.com/MoyanZitto
111 | [20]: https://github.com/LitLeo
112 | [21]: https://www.bilibili.com/video/BV1Nf4y1v7sa/
113 | [22]: https://www.bilibili.com/video/BV1x5411n76K/
114 | [23]: https://www.bilibili.com/video/BV19V411t7LV/
115 | [24]: https://www.bilibili.com/video/BV1DT4y1A7Rx/
116 | [25]: https://www.bilibili.com/video/BV1op4y1p7bj/
117 | [26]: https://www.bilibili.com/video/BV1Qi4y1N7YS/
118 | [27]: https://www.bilibili.com/video/BV19Y411g7YY/
119 |
120 |
121 |
--------------------------------------------------------------------------------
/Samples/sampleGoogleNet.md:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LitLeo/TensorRT_Tutorial/1d370d82c78eddccadf4a8490c5c859db35a78c9/Samples/sampleGoogleNet.md
--------------------------------------------------------------------------------
/Samples/sampleINT8.md:
--------------------------------------------------------------------------------
1 | ## SampleINT8:8-bit校准与推断
2 | SampleINT8说明了以8 bit整数(INT8)进行推理的步骤。SampleINT8使用MNIST训练集进行验证,但也可用于校准和评分其他网络。使用以下命令在MNIST上运行示例。
3 | `./sample_int8 mnist`
4 | **注意**:INT8只有在计算能力6.1以上的GPU上使用。
5 |
6 | INT8的引擎仍从32-bit(float)的网络定义中构建,但是要比32-bit 和16-bit的引擎复杂的多。具体而言,TensorRT在构建网络时,必须校准网络以确定如何最好的用8-bit表示权重和激活值。这需要一组该网络的代表性的输入数据-校准集(the calibration set)和两个参数, 回归截断(regression cutoff)和分位数(quantile)。
7 |
8 | 应用程序必须通过实现INT8Calibrator接口来指定校准集和参数。对于ImageNet网络和MNIST,500张图像一个合理的校准集规模。请参考[选择校准集参数](#choose_calibration_parameters)一节查看确定回归截断点与分位数的设置细节。
9 |
10 | ### IInt8Calibrator接口
11 |
12 | `IInt8Calibrator`接口含有为builder指定校准集和校准参数的方法。此外,因为校准是一个需要运行很多次,代价较高的过程,`IInt8Calibrator`还提供了缓存中间值的方法。缓存的细节将在[缓存](#caching)一节讨论。最简单的实现方式是立即从`write()`方法返回,并从`read()`方法中返回`nullptr`。
13 |
14 | #### 校准集
15 |
16 | 一旦校准开始,builder就会调用`getBatchSize()`以获取校准集的Batch Size,校准集的每一个batch数据大小都必须为该值。接着,方法`getBatch()`会被反复调用以获得batch数据,直到它返回false为止:
17 | ```C++
18 | bool getBatch(void* bindings[], const char* names[], int nbBindings) override
19 | {
20 | if (!mStream.next())
21 | return false;
22 |
23 | CHECK(cudaMemcpy(mDeviceInput, mStream.getBatch(), mInputCount * sizeof(float), cudaMemcpyHostToDevice));
24 | assert(!strcmp(names[0], INPUT_BLOB_NAME));
25 | bindings[0] = mDeviceInput;
26 | return true;
27 | }
28 | ```
29 | 对于每个输入张量,指向GPU内存中数据的指针必须被写入`bindings`数组中,而`names`数组包含了输入张量的名字,`names`数组中的名字与`bindings`数组中的指针按位置一一对应。两个数组的大小都是`nbBindings`。
30 |
31 | **注意:**校准集必须能够代表在TensorRT运行时的输入数据。例如,对图像分类任务而言,校准集不能只由来自一部分类别的图片构成。另外,任何在推断前执行的图像处理过程,如缩放、裁剪或去均值,也必须对校准集的样本执行。
32 |
33 | #### 校准集参数
34 | 这些方法是很明确的:
35 | ```c++
36 | double getQuantile() const override { return mQuantile; }
37 | double getRegressionCutoff() const override { return mCutoff; }
38 | ```
39 |
40 | ### 配置Builder
41 | 对于INT8推断,输入模型必须由32-bit的权重确定。
42 | ```c++
43 | const IBlobNameToTensor* blobNameToTensor =
44 | parser->parse(locateFile(deployFile).c_str(),
45 | locateFile(modelFile).c_str(),
46 | *network,
47 | DataType::kFLOAT);
48 | ```
49 | builder有额外的两个方法:
50 | ```c++
51 | builder->setInt8Mode(true);
52 | builder->setInt8Calibrator(calibrator);
53 | ```
54 |
55 | 一旦模型被builder构建完成,它可以与Float32的网络一样使用:输入和输出仍然是32-bit的浮点数。
56 |
57 |
58 |
59 | ###校准集缓存
60 |
61 | 校准过程可能较为缓慢,因此`IInt8Calibrator`提供了用于缓存中间结果的方法。高效使用这些方法需要对校准过程的细节有一定了解。
62 |
63 | 当构建一个INT8的引擎时,builder执行了下面的步骤:
64 |
65 | 1. 构建一个32-bit的引擎,并在其上运行校准集,对校准集的每一个数据,记录其表示其激活值分布的直方图
66 | 2. 由直方图构建一个校准表,并构建截断参数和分位数
67 | 3. 从网络定义和校准表构建INT8引擎
68 |
69 | 直方图与校准表都可以被缓存
70 |
71 | 当重复构建一个模型多次(例如在不同平台)时,对校准表缓存是很有用的。它捕获了由网络推断的参数、校准集、截断参数与分位数。参数被记录在校准表中,当表中的参数与校准器指定的参数不匹配时,校准表将被忽略。当网络或校准集发生变化,应该由应用程序指校准表无效。
72 |
73 | 当基于同样的校准集对同一个网络进行校准参数搜索时,直方图缓存是很有用的,因为它使得直方图的构建只被运行一次。同样,当网络或校准集发生变化,应该由应用程序指校准表无效。
74 |
75 | 缓存按照下面的方式使用:
76 |
77 | - 如果校准表存在,则跳过校准过程,否则:
78 | - 如果直方图缓存存在,则跳过直方图构造过程,否则:
79 | - 构造Float32网络,并在校准集上运行网络,得到直方图
80 | - 依据直方图与参数构造校准表
81 | - 依据校准表和网络定义构造INT8的网络
82 |
83 | 已经缓存的数据通过指针和长度参数传递,例如:
84 | ```c++
85 | const void* readHistogramCache(size_t& length) override
86 | {
87 | length = mHistogramCache.size();
88 | return length ? &mHistogramCache[0] : nullptr;
89 | }
90 |
91 | void writeHistogramCache(const void* cache, size_t length) override
92 | {
93 | mHistogramCache.clear();
94 | std::copy_n(reinterpret_cast(cache), length, std::back_inserter(mHistogramCache));
95 | }
96 | ```
97 |
98 |
99 |
100 | ###选择校准集参数
101 |
102 | 截断参数与分位数都是[0,1]间的数字,其具体含义在附带的white paper中讨论。为了找到最佳的校准集参数,我们可以基于额外图片得到参数组合和其对应的网络分数,并从中寻找校准集参数。`searchCalibrations()`展示了如何这样做。对ImageNet网络而言,5000张图片被用于进行最佳校准。因为校准过程会仅在截断参数与分位数不同的情况下运行多次,我们强烈建议使用直方图缓存。
--------------------------------------------------------------------------------
/Samples/sampleMNIST.md:
--------------------------------------------------------------------------------
1 |
2 | ## SampleMNIST:简单使用方法
3 | SampleMNIST 使用训练好的MNIST caffe模型来演示典型的构建和执行过程。
4 | 构建阶段,直接调用ICaffeParser接口的parse()函数读取caffe model。
5 |
6 | ### 日志
7 | 构建网络之前需要先重载实现log类,可以用来报告error、warning和informational 信息。
8 |
9 | ### Build过程:caffeToGIEModel
10 | ### 引擎反序列化(引擎重构)
11 | ### Execution过程:doInference()
--------------------------------------------------------------------------------
/Samples/sampleMNISTAPI.md:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LitLeo/TensorRT_Tutorial/1d370d82c78eddccadf4a8490c5c859db35a78c9/Samples/sampleMNISTAPI.md
--------------------------------------------------------------------------------
/TensorRT8.5.3/a-title_Developer-Guide-NVIDIA-Deep-Learning-TensorRT-Documentation:
--------------------------------------------------------------------------------
1 | -
--------------------------------------------------------------------------------
/TensorRT8.5.3/assets/1695349016-1542a53eb400f837845f37b2bedb9d05.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LitLeo/TensorRT_Tutorial/1d370d82c78eddccadf4a8490c5c859db35a78c9/TensorRT8.5.3/assets/1695349016-1542a53eb400f837845f37b2bedb9d05.png
--------------------------------------------------------------------------------
/TensorRT8.5.3/assets/1695349016-15dd6688b76bdc3d5a16526df91cc631.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LitLeo/TensorRT_Tutorial/1d370d82c78eddccadf4a8490c5c859db35a78c9/TensorRT8.5.3/assets/1695349016-15dd6688b76bdc3d5a16526df91cc631.png
--------------------------------------------------------------------------------
/TensorRT8.5.3/assets/1695349016-2c3934e69ddc53dc474139fe65c49c19.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LitLeo/TensorRT_Tutorial/1d370d82c78eddccadf4a8490c5c859db35a78c9/TensorRT8.5.3/assets/1695349016-2c3934e69ddc53dc474139fe65c49c19.png
--------------------------------------------------------------------------------
/TensorRT8.5.3/assets/1695349016-4e01c008d3875b259cc4cd3da884010e.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LitLeo/TensorRT_Tutorial/1d370d82c78eddccadf4a8490c5c859db35a78c9/TensorRT8.5.3/assets/1695349016-4e01c008d3875b259cc4cd3da884010e.png
--------------------------------------------------------------------------------
/TensorRT8.5.3/assets/1695349016-536836b9f148a211a3109b46588aea3f.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LitLeo/TensorRT_Tutorial/1d370d82c78eddccadf4a8490c5c859db35a78c9/TensorRT8.5.3/assets/1695349016-536836b9f148a211a3109b46588aea3f.png
--------------------------------------------------------------------------------
/TensorRT8.5.3/assets/1695349016-584559c808bb6b459734d88699daabe1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LitLeo/TensorRT_Tutorial/1d370d82c78eddccadf4a8490c5c859db35a78c9/TensorRT8.5.3/assets/1695349016-584559c808bb6b459734d88699daabe1.png
--------------------------------------------------------------------------------
/TensorRT8.5.3/assets/1695349016-5b172dabb4f50368376eee4819ddcb87.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LitLeo/TensorRT_Tutorial/1d370d82c78eddccadf4a8490c5c859db35a78c9/TensorRT8.5.3/assets/1695349016-5b172dabb4f50368376eee4819ddcb87.png
--------------------------------------------------------------------------------
/TensorRT8.5.3/assets/1695349016-63cc642586086b5be42c04375200c8c9.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LitLeo/TensorRT_Tutorial/1d370d82c78eddccadf4a8490c5c859db35a78c9/TensorRT8.5.3/assets/1695349016-63cc642586086b5be42c04375200c8c9.png
--------------------------------------------------------------------------------
/TensorRT8.5.3/assets/1695349016-656ec99160033df259b215cd7e03af2f.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LitLeo/TensorRT_Tutorial/1d370d82c78eddccadf4a8490c5c859db35a78c9/TensorRT8.5.3/assets/1695349016-656ec99160033df259b215cd7e03af2f.png
--------------------------------------------------------------------------------
/TensorRT8.5.3/assets/1695349016-718f4af533bab6c57307cd4131866023.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LitLeo/TensorRT_Tutorial/1d370d82c78eddccadf4a8490c5c859db35a78c9/TensorRT8.5.3/assets/1695349016-718f4af533bab6c57307cd4131866023.png
--------------------------------------------------------------------------------
/TensorRT8.5.3/assets/1695349016-7324dda2de00b8d4b99431311c1d901d.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LitLeo/TensorRT_Tutorial/1d370d82c78eddccadf4a8490c5c859db35a78c9/TensorRT8.5.3/assets/1695349016-7324dda2de00b8d4b99431311c1d901d.png
--------------------------------------------------------------------------------
/TensorRT8.5.3/assets/1695349016-7c4a391e39adc9b201561f4384d8575c.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LitLeo/TensorRT_Tutorial/1d370d82c78eddccadf4a8490c5c859db35a78c9/TensorRT8.5.3/assets/1695349016-7c4a391e39adc9b201561f4384d8575c.png
--------------------------------------------------------------------------------
/TensorRT8.5.3/assets/1695349016-8167eeb1e237bd2c809028a411e1e9cb.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LitLeo/TensorRT_Tutorial/1d370d82c78eddccadf4a8490c5c859db35a78c9/TensorRT8.5.3/assets/1695349016-8167eeb1e237bd2c809028a411e1e9cb.png
--------------------------------------------------------------------------------
/TensorRT8.5.3/assets/1695349016-8c33d06b8c5ffd9dc50eb77f1bbe80d0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LitLeo/TensorRT_Tutorial/1d370d82c78eddccadf4a8490c5c859db35a78c9/TensorRT8.5.3/assets/1695349016-8c33d06b8c5ffd9dc50eb77f1bbe80d0.png
--------------------------------------------------------------------------------
/TensorRT8.5.3/assets/1695349016-90fbabf1bcd97f82bbffa8751a548cdc.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LitLeo/TensorRT_Tutorial/1d370d82c78eddccadf4a8490c5c859db35a78c9/TensorRT8.5.3/assets/1695349016-90fbabf1bcd97f82bbffa8751a548cdc.png
--------------------------------------------------------------------------------
/TensorRT8.5.3/assets/1695349016-98a76f9452e7b3c5a2979a9a4d8f828f.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LitLeo/TensorRT_Tutorial/1d370d82c78eddccadf4a8490c5c859db35a78c9/TensorRT8.5.3/assets/1695349016-98a76f9452e7b3c5a2979a9a4d8f828f.png
--------------------------------------------------------------------------------
/TensorRT8.5.3/assets/1695349016-9b422126aef86f0a15d7bfcdcdf37ee9.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LitLeo/TensorRT_Tutorial/1d370d82c78eddccadf4a8490c5c859db35a78c9/TensorRT8.5.3/assets/1695349016-9b422126aef86f0a15d7bfcdcdf37ee9.png
--------------------------------------------------------------------------------
/TensorRT8.5.3/assets/1695349016-a782c77d3e0eff2354898ccef63c5de0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LitLeo/TensorRT_Tutorial/1d370d82c78eddccadf4a8490c5c859db35a78c9/TensorRT8.5.3/assets/1695349016-a782c77d3e0eff2354898ccef63c5de0.png
--------------------------------------------------------------------------------
/TensorRT8.5.3/assets/1695349016-ad186379984e814039de4d58a0e26c53.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LitLeo/TensorRT_Tutorial/1d370d82c78eddccadf4a8490c5c859db35a78c9/TensorRT8.5.3/assets/1695349016-ad186379984e814039de4d58a0e26c53.png
--------------------------------------------------------------------------------
/TensorRT8.5.3/assets/1695349016-ae831a5e3c8c02af4c7ac82636845a70.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LitLeo/TensorRT_Tutorial/1d370d82c78eddccadf4a8490c5c859db35a78c9/TensorRT8.5.3/assets/1695349016-ae831a5e3c8c02af4c7ac82636845a70.png
--------------------------------------------------------------------------------
/TensorRT8.5.3/assets/1695349016-cc50888fa52ed8f93e53ca71ce566c63.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LitLeo/TensorRT_Tutorial/1d370d82c78eddccadf4a8490c5c859db35a78c9/TensorRT8.5.3/assets/1695349016-cc50888fa52ed8f93e53ca71ce566c63.png
--------------------------------------------------------------------------------
/TensorRT8.5.3/assets/1695349016-d14711f74598da455c69c20ed5a5cbd1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LitLeo/TensorRT_Tutorial/1d370d82c78eddccadf4a8490c5c859db35a78c9/TensorRT8.5.3/assets/1695349016-d14711f74598da455c69c20ed5a5cbd1.png
--------------------------------------------------------------------------------
/TensorRT8.5.3/assets/1695349016-dffd0a9679aeefdc5176a6aa55feaa7c.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LitLeo/TensorRT_Tutorial/1d370d82c78eddccadf4a8490c5c859db35a78c9/TensorRT8.5.3/assets/1695349016-dffd0a9679aeefdc5176a6aa55feaa7c.png
--------------------------------------------------------------------------------
/TensorRT8.5.3/assets/1695349016-e24efeac58e23de168680d4f48e18f16.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LitLeo/TensorRT_Tutorial/1d370d82c78eddccadf4a8490c5c859db35a78c9/TensorRT8.5.3/assets/1695349016-e24efeac58e23de168680d4f48e18f16.png
--------------------------------------------------------------------------------
/TensorRT8.5.3/assets/1695349016-e829de0bc2b85ec285546dcf1456982a.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LitLeo/TensorRT_Tutorial/1d370d82c78eddccadf4a8490c5c859db35a78c9/TensorRT8.5.3/assets/1695349016-e829de0bc2b85ec285546dcf1456982a.png
--------------------------------------------------------------------------------
/TensorRT8.5.3/assets/1695349016-f9c6506c20f52b409ddfc74a8a4317a2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LitLeo/TensorRT_Tutorial/1d370d82c78eddccadf4a8490c5c859db35a78c9/TensorRT8.5.3/assets/1695349016-f9c6506c20f52b409ddfc74a8a4317a2.png
--------------------------------------------------------------------------------
/TensorRT8.5.3/index.json:
--------------------------------------------------------------------------------
1 | {
2 | "version": "2.0",
3 | "clipId": "1695349016",
4 | "format": "md",
5 | "title": "Developer Guide :: NVIDIA Deep Learning TensorRT Documentation",
6 | "link": "https://docs.nvidia.com/deeplearning/tensorrt/archives/tensorrt-853/developer-guide/index.html",
7 | "category": "default",
8 | "tags": [],
9 | "created_at": "2023-09-22 10:16:56",
10 | "mainPath": "index.md",
11 | "paths": [
12 | "index.json",
13 | "a-title_Developer-Guide-NVIDIA-Deep-Learning-TensorRT-Documentation",
14 | "assets/1695349016-718f4af533bab6c57307cd4131866023.png",
15 | "assets/1695349016-2c3934e69ddc53dc474139fe65c49c19.png",
16 | "assets/1695349016-ae831a5e3c8c02af4c7ac82636845a70.png",
17 | "assets/1695349016-5b172dabb4f50368376eee4819ddcb87.png",
18 | "assets/1695349016-f9c6506c20f52b409ddfc74a8a4317a2.png",
19 | "assets/1695349016-a782c77d3e0eff2354898ccef63c5de0.png",
20 | "assets/1695349016-536836b9f148a211a3109b46588aea3f.png",
21 | "assets/1695349016-cc50888fa52ed8f93e53ca71ce566c63.png",
22 | "assets/1695349016-90fbabf1bcd97f82bbffa8751a548cdc.png",
23 | "assets/1695349016-1542a53eb400f837845f37b2bedb9d05.png",
24 | "assets/1695349016-dffd0a9679aeefdc5176a6aa55feaa7c.png",
25 | "assets/1695349016-8c33d06b8c5ffd9dc50eb77f1bbe80d0.png",
26 | "assets/1695349016-8167eeb1e237bd2c809028a411e1e9cb.png",
27 | "assets/1695349016-9b422126aef86f0a15d7bfcdcdf37ee9.png",
28 | "assets/1695349016-e24efeac58e23de168680d4f48e18f16.png",
29 | "assets/1695349016-e829de0bc2b85ec285546dcf1456982a.png",
30 | "assets/1695349016-15dd6688b76bdc3d5a16526df91cc631.png",
31 | "assets/1695349016-656ec99160033df259b215cd7e03af2f.png",
32 | "assets/1695349016-7324dda2de00b8d4b99431311c1d901d.png",
33 | "assets/1695349016-d14711f74598da455c69c20ed5a5cbd1.png",
34 | "assets/1695349016-98a76f9452e7b3c5a2979a9a4d8f828f.png",
35 | "assets/1695349016-4e01c008d3875b259cc4cd3da884010e.png",
36 | "assets/1695349016-ad186379984e814039de4d58a0e26c53.png",
37 | "assets/1695349016-584559c808bb6b459734d88699daabe1.png",
38 | "assets/1695349016-7c4a391e39adc9b201561f4384d8575c.png",
39 | "assets/1695349016-63cc642586086b5be42c04375200c8c9.png",
40 | "index.md"
41 | ]
42 | }
--------------------------------------------------------------------------------
/TensorRT_2.1.0_User_Guide.md:
--------------------------------------------------------------------------------
1 | # TensorRT 2.0 User Guide
2 |
3 | ---
4 |
5 | [toc]
6 |
7 | ## 介绍
8 | NVIDIA TensorRT是一个C++库,在NVIDIA GPU上能够实现高性能的推理(inference )过程。TensorRT优化网络的方式有:对张量和层进行合并,转换权重,选择高效的中间数据格式,以及依据层的参数和实测性能,从一个丰富的核仓库中进行筛选。
9 |
10 | 编译TensorRT 2.0 要求GCC >= 4.8
11 |
12 | TensorRT 2.0 现在支持以下layer类型:
13 |
14 | - **Convolution**:卷积层,可无bias。目前仅支持2D卷积(即对4D的输入进行卷积并输出4D输出)。**Note:**该卷积层的操作实际计算的是“相关”而不是“卷积”(严格的卷积定义需要卷积核反转),如果你想通过TensorRT的API而不是通过caffe parser library导入权重,这是一个需要注意的地方。
15 | - **Activation**: 激活层,支持ReLU, tanh和sigmoid.
16 | - **Pooling**: 池化层,支持最大值池化和均值池化
17 | - **Scale**: 可以使用常量对每一个张量, 通道或权重进行仿射变换和取幂操作。**BatchNormalization**可由该层实现。
18 | - **ElementWise**: 两个张量按元素求和、求乘积或取最大
19 | - **LRN**: 局部响应归一化层,仅支持通道间归一化
20 | - **Fully-connected**:全连接层,可无bias
21 | - **SoftMax**: Softmax层,仅支持通道间计算softmax
22 | - **Deconvolution**: 反卷积层,可无bias
23 | - **RNN**: 循环网络层,支持GRU和LSTM
24 |
25 | TensorRT是一个独立的深度学习部署框架,对caffe尤其友好。TensorRT提供了一个针对caffe的模型解析器NvCaffeParser,可以通过几行代码解析caffe生成的model并定义网络。NvCaffeParer使用上面定义的层来实现Caffe中的Convolution, ReLU, Sigmoid, TanH, Pooling, Power, BatchNorm, Eltwise, LRN, InnerProduct, SoftMax, Scale, 和Deconvolution。而目前,NvCaffeParse不支持下面的Caffe层:
26 |
27 | - Deconvolution groups
28 | - Dilated convolutions
29 | - PReLU
30 | - Leaky ReLU
31 | - 除通道间scale的其他Scale层
32 | - 含有两个以上输入的ElementWise操作
33 |
34 | **Note:** TensorRT不支持caffe的旧prototxt格式,特别地,prototxt中定义的层类型应当为由双引号分割的字符串。
35 |
36 | ## 快速开始指南
37 | 【注】本部分由[TensorRT下载页面][1]翻译而来。
38 |
39 | TensorRT原名GIE。GIE又名TensorRT 1.0,TensorRT 2.0正式改名。
40 | TensorRT 2.0非常大的改动点是支持INT8类型(TensorRT 1.0支持FP16)。
41 | 使用TensorRT 2.0的硬件要求:Tesla P4, Tesla P40, GeForce TitanX Pascal, GeForce GTX 1080, DRIVE PX 2 dGPU
42 | 软件要求:CUDA 8.0
43 | ### Ubuntu 下安装方式
44 | 安装命令:
45 |
46 | 1. 验证你是否安装了CUDA 8.0 .
47 | 2. 下载TensorRT的deb包
48 | 3. 从TensrRT的deb包安装,命令为:
49 |
50 | ``
51 | sudo dpkg -i nv-tensorrt-repo-ubuntu1404-7-ea-cuda8.0_2.0.1-1_amd64.deb
52 | sudo apt-get update
53 | sudo apt-get install tensorrt-2
54 | ``
55 | 4. 验证安装:
56 | ```bash
57 | dpkg -l | grep tensorrt-2
58 | ```
59 |
60 | 5. 若安装成功,你应看到:
61 | `tensorrt-2 2.0.0-1+cuda amd64 Meta package of TensorRT`
62 | 同样,通过命令:
63 | `dpkg -l | grep nvinfer2`
64 |
65 | 4. 你应看到:
66 | `libnvinfer2 2.0.0-1+cuda amd64 TensorRT runtime libraries`
67 |
68 | 注意:TensorRT 2.0现在只提供了Ubuntu 14.04和16.04两个版本。
69 |
70 | ### Centos 7 下安装方式
71 |
72 | TensorRT对Ubuntu系统友好,如果是企业级系统(比如centos)可以下载下来解压然后手动安装。
73 | 前提条件:建议Centos 7以上,即gcc 版本要大于4.8,因为TensorRT内使用了大量的c++ 11特性。如果你是大神可以在Centos 6上升级gcc 到4.8并把一些依赖问题搞定。
74 | 安装步骤如下:
75 |
76 | 1. 下载deb安装包,然后解压,一路挑着大文件解压,找到两个头文件NvCaffeParser.h。NvInfer.h和对应的so文件,libnvcaffe_parser.so.2.0.0,libnvinfer.so.2.0.0。
77 | 2. 然后安装方式就跟cudnn一样了,*.h上传到CUDA_HOME/include下,lib文件上传到CUDA_HOME/lib64目录下(lib文件记得添加libnvinfer.so和libnvcaffe_parser.so的链接)
78 | 3. 安装完毕,如果要在Centos上跑samples,记得要修改一下Makefile
79 |
80 | ## 快速开始
81 | 使用TensorRT包括两部步骤(1)打开冰箱;(2)把大象装进去:
82 |
83 | - build阶段,TensorRT进行网络定义、执行优化并生成推理引擎
84 | - execution阶段,需要将input和output在GPU上开辟空间并将input传输到GPU上,调用推理接口得到output结果,再将结果拷贝回host端。
85 |
86 | build阶段比较耗时间,特别是在嵌入式平台上。所以典型的使用方式就是将build后的引擎序列化(序列化后可写到硬盘里)供以后使用。
87 |
88 | build阶段对网络进行了以下优化:
89 |
90 | - 去掉没有被使用过的输出层
91 | - 将convolution、bias和ReLU操作融合到一起
92 | - 将相似度比较高的参数和相同的Tensor进行聚合(例如,GoogleNet v5的初始模块中的1*1卷积)
93 | - 通过将层的输出直接导向其最终位置来简化串接的层
94 |
95 | 此外,TensorRT在虚拟数据(Dummy Data)上运行层,以在kernel仓库中筛选出运行最快的,并在适当的时候执行权重预格式化和内存优化。
96 |
97 | ### 网络定义
98 | 网络定义是由Layers和Tensors组成的。
99 |
100 | 每一层都一组输入tensor和一组输出tensor,根据层类型和输入tensor来计算输出tensor。不同类型的层具有不同的参数,比如卷积size和stride,以及卷积滤波器权值。
101 |
102 | tensor是网络的输入或者输出。tensor的数据目前支持16bit和32bit浮点数和三维(通道,宽,高)。输入tensor的数据大小由程序猿指定,输出tensor的大小自动就算出来了。
103 |
104 | 每一层和tensor都有一个名字,在分析或者读构建日志时非常有用。
105 |
106 | 当使用caffe parser时,tensor和层的名字直接从caffe prototxt读取。
107 |
108 | ## SampleMNIST:简单使用方法
109 |
110 | ## SampleGoogleNet:性能分析与16-bit推断
111 | ### 性能分析
112 | ### half2模式
113 |
114 | ## SampleINT8:8-bit校准与推断
115 |
116 | ## giexec:一个命令行包装器
117 | 在示例程序的文件夹中包含有一个TensorRT的命令行包装,它在基于任意数据对网络做benchmark,以及从这些模型生成序列化引擎很有用。命令行参数如下:
118 | ```bash
119 | Mandatory params:
120 | --deploy= Caffe deploy file
121 | --output= Output blob name (can be specified multiple times)
122 |
123 | Optional params:
124 | --model= Caffe model file (default = no model, random weights used)
125 | --batch=N Set batch size (default = 1)
126 | --device=N Set cuda device to N (default = 0)
127 | --iterations=N Run N iterations (default = 10)
128 | --avgRuns=N Set avgRuns to N - perf is measured as an average of avgRuns (default=10)
129 | --workspace=N Set workspace size in megabytes (default = 16)
130 | --half2 Run in paired fp16 mode (default = false)
131 | --int8 Run in int8 mode (default = false)
132 | --verbose Use verbose logging (default = false)
133 | --hostTime Measure host time rather than GPU time (default = false)
134 | --engine= Generate a serialized GIE engine
135 | --calib= Read INT8 calibration cache file
136 | ```
137 | 例如:
138 | ```bash
139 | giexec --deploy=mnist.prototxt --model=mnist.caffemodel --output=prob
140 | ```
141 | 如果没有提供“--model”,则权重将被随机生成
142 |
143 | 该样例没有展示任何前述未曾包含的TensorRT特性
144 |
145 | ## 在多GPU上使用TensorRT
146 | 每个`ICudaEngine`对象在通过builder或反序列化而实例化时均被builder限制于一个指定的GPU内。要进行GPU的选择,需要在进行反序列化或调用builder之前调用`cudaSetDeviec()`。每个`IExecutionContext`都被限制在产生它的引擎所在的GPU内,当调用`execute()`或`enqueue()`时,请在必要时调用`cudaSetDevice()`以保证线程与正确的设备绑定。
147 |
148 | ## 数据格式
149 | TensorRT的输入输出张量均为以NCHW形式存储的32-bit张量。NCHW指张量的维度顺序为batch维(N)-通道维(C)-高度(H)-宽度(W)
150 |
151 | 对权重而言:
152 |
153 | - 卷积核存储为KCRS形式,其中K轴为卷积核数目的维度,即卷积层输出通道维。C轴为是输入张量的通道维。R和S分别是卷积核的高和宽
154 | - 全连接层按照行主序形式存储 这里是错的!!全连接层中weights的存储方式是col-major,详见[Bugs](https://github.com/LitLeo/TensorRT_Tutorial/blob/master/Bug.md)
155 | - 反卷积层按照CKRS形式存储,各维含义同上
156 |
157 | ## FAQ
158 | **Q:如何在TensorRT中使用自定义层?**
159 | A:当前版本的TensorRT不支持自定义层。要想在TensorRT中使用自定义层,可以创建两个TensorRT工作流,自定义层夹在中间执行。比如:
160 |
161 | ``` c++
162 | IExecutionContext *contextA = engineA->createExecutionContext();
163 | IExecutionContext *contextB = engineB->createExecutionContext();
164 |
165 | <...>
166 |
167 | contextA.enqueue(batchSize, buffersA, stream, nullptr);
168 | myLayer(outputFromA, inputToB, stream);
169 | contextB.enqueue(batchSize, buffersB, stream, nullptr);
170 | ```
171 |
172 | **Q:如何构造对若干不同可能的batch size优化了的引擎?**
173 | A:尽管TensorRT允许在给定的一个batch size下优化模型,并在运行时送入任何小于该batch size的数据,但模型在更小size的数据上的性能可能没有被很好的优化。为了面对不同batch大小优化模型,你应该对每种batch size都运行一下builder和序列化。未来的TensorRT可能能基于单一引擎对多种batch size进行优化,并允许在当不同batch size下层使用相同的权重形式时,共享层的权重。
174 |
175 | **Q:如何选择最佳的工作空间大小**:
176 | A: 一些TensorRT算法需要GPU上额外的工作空间。方法`IBuilder::setMaxWorkspaceSize()`控制了能够分配的工作空间的最大值,并阻止builder考察那些需要更多空间的算法。在运行时,当创造一个`IExecutionContext`时,空间将被自动分配。分配的空间将不多于所需求的空间,即使在`IBuilder::setMaxWorspaceSize()`中设置的空间还有很多。应用程序因此应该允许TensorRT builder使用尽可能多的空间。在运行时,TensorRT分配的空间不会超过此数,通常而言要少得多。
177 |
178 | [1]: https://developer.nvidia.com/nvidia-tensorrt-20-download
179 |
--------------------------------------------------------------------------------
/blogs/Conformer Encoder GPU 加速策略较全面汇总.md:
--------------------------------------------------------------------------------
1 | 最近一直在做WeNet conformer encoder模型的GPU TensorRT加速,也有幸参加了NVIDIA Hackathon 2022 加速 Wenet 的比赛和阅读了NVIDIA 内部团队 关于 WeNet TensorRT加速的代码。学习到了很多东西,抛砖引玉进行汇总,欢迎各位大佬拍砖。
2 |
3 | 以下加速策略描述以TensorRT为例进行简单描述。
4 | PS: 阅读前需要非常了解conformer encoder的模型结构,并比较熟悉TensorRT or CUDA。
5 |
6 | # 一、 算子级别优化
7 | 这一大章节主要内容是算子级别的优化,核心优化思路有两个:
8 | 1. 合并所有能合并的算子。
9 | 2. 等价替换部分算子,去除冗余。
10 |
11 | 按照模型结构构建目录结构如下。
12 | ## ConformerEncoder
13 | ### 1.1 make_pad_mask
14 | 对应代码
15 | ```
16 | masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
17 | ```
18 | mask 逻辑的优化对整体模型影响较大,效果也比较明显。
19 |
20 | 输入 feat_len 维度是[B],类型是int。受python语言限制,原mask逻辑是将feat_len 变为多维的bool类型 mask,在后面的subsampling,attn_softmax 和 conv_maskedfill步骤中使用。
21 |
22 | 将这部分逻辑转换成onnx会发现,多出了大量的equal,not等 bool 级别的算子,非常冗余。
23 |
24 | 使用c++ or cuda 上线,可以直接使用feat_len即可,可以省去大量的算子。但这样改动也意味着后面使用mask的算子,都不能使用标准算子了,需要自己实现plugin。具体改动见其他模块相对应的部分。
25 |
26 | ### 2.2 subsampling: Conv2dSubsampling4
27 | #### 2.2.1 mask 逻辑优化
28 | ```
29 | x_mask[:, :, :-2:2][:, :, :-2:2]
30 | ```
31 | 受python语言限制,简单一行 mask 操作 x_mask[:, :, :-2:2][:, :, :-2:2],需要数个基础算子拼接才能实现。
32 | 没有任何疑问,这里要做成一个plugin。
33 | 翻译一下,逻辑为F = ceil((x - 2) / 2), F(F(time))
34 |
35 | PS: 这段python代码是有问题的,会导致单句和batch执行结果不一致的问题,详细见 https://github.com/wenet-e2e/wenet/issues/1513
36 |
37 | #### 2.2.2 out linear 逻辑优化
38 | ```
39 | x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
40 | ```
41 | 原逻辑为:transpose + reshape + linear + bias,转成成trt 算子就是shuffle + matrix_multiply + add 操作,维度变化如下图。
42 |
43 | ```mermaid
44 | stateDiagram-v2
45 | input --> transpose(1,2)+reshape: [batch, dim, time, dim1]
46 | transpose(1,2)+reshape --> gemm+bias: [batch, time, dim*dim1]
47 | gemm+bias --> output: [batch, time, dim]
48 | ```
49 |
50 | 可以等价转换成 conv + transpose操作。
51 | 需要先将linear_weight[dim, dim *dim1]进行reshape,linear_weight.reshape(dim, dim, 1, dim1)
52 | 维度变化如下图。
53 |
54 | ```mermaid
55 | stateDiagram-v2
56 | input --> conv: [batch, dim, time, dim1]
57 | conv--> transpose(1,2)+reshape: [batch, dim, time, 1]
58 | transpose(1,2)+reshape--> output: [batch, time, dim]
59 | ```
60 | 这里的加速有两个点:
61 | 1. 将gemm + bias 两个算子替换成了 conv 一个算子。
62 | 2. 大大减少了 transpose 算子的计算量。
63 |
64 | 详细可参考第 23 分钟 https://www.bilibili.com/video/BV1i3411G7vN/?spm_id_from=333.999.0.0&vd_source=58a8fa4cc926efbd7338631b3957cc73
65 |
66 | ### 1.3. pos_enc: RelPositionalEncoding
67 | ```
68 | self.pe = self.pe.to(x.device)
69 | x = x * self.xscale
70 | pos_emb = self.position_encoding(offset, x.size(1), False)
71 | return self.dropout(x), self.dropout(pos_emb)
72 | ```
73 |
74 | 这里没什么复杂的逻辑,先一个scale,再根据输入维度截取 pos_emb。涉及到维度操作,使用trt 基础算子实现还是麻烦,直接做成一个算子,将scale也包进去,完事。
75 |
76 | ### 1.4. encoders: ConformerEncoderLayer
77 | 这里是重头戏,前面都只有一层,模型有N 层 encoder,这里每一步的加速效果,都会被放大N倍,
78 | #### 1.4.1 attn: RelPositionMultiHeadedAttention
79 | att模块的加速技术也比较成熟了,核心思路还是将qkvp gemm 合并在一起,将四个gemm合并成一个大gemm。
80 | 大的改动如下:
81 | x[batch, time, dim], qkvp weight[dim, dim]
82 | 1. 将qkvp weights 合并成一个 qkvp_weights[dim, 4dim], x.mul(qkvp_weights) => qkvp_out[batch, time, 4dim]。
83 | 2. +qkv bias, +pos_bias_u 和 pos_bias_v和transpose可以合并成一个plugin。
84 | 3. (matrix_ac + matrix_bd) / math.sqrt(self.d_k) 和 softmax 合并成一个plugin。
85 |
86 | #### 1.4.2 conv_module: ConvolutionModule
87 | 这里的加速思路比较新颖。核心代码如下
88 | ```
89 | x = x.transpose(1, 2) # (#batch, channels, time)
90 | x = nn.functional.pad(x, (self.lorder, 0), 'constant', 0.0)
91 |
92 | # GLU mechanism
93 | x = self.pointwise_conv1(x) # (batch, 2*channel, dim)
94 | x = nn.functional.glu(x, dim=1) # (batch, channel, dim)
95 | # 1D Depthwise Conv
96 | x = self.depthwise_conv(x)
97 | x = x.transpose(1, 2)
98 | x = self.activation(self.norm(x))
99 | x = x.transpose(1, 2)
100 | x = self.pointwise_conv2(x)
101 | x.transpose(1, 2)
102 | ```
103 |
104 | 输入x的维度为 (batch, time,channel),整个模块就是最后一维进行多次1D卷积和操作layernorm。但conv无法对最后一维计算,所以一共需要4个transpose为此服务。
105 | ```mermaid
106 | stateDiagram-v2
107 | input --> 0_transpose(1,2): [batch, time, channels]
108 | 0_transpose(1,2) --> pad: [batch, channels, time]
109 | pad -->pointwise_conv1+glu_dim1: [batch, channels, time+pad]
110 | pointwise_conv1+glu_dim1 --> depthwise_conv: [batch, channels, time+pad]
111 | depthwise_conv --> 1_transpose(1,2): [batch, channels, time]
112 | 1_transpose(1,2) --> activation(norm(x)): [batch, time, channels]
113 | activation(norm(x)) --> 2_transpose(1,2): [batch, time, channels]
114 | 2_transpose(1,2) --> pointwise_conv2: [batch, channels, time]
115 | pointwise_conv2 --> 3_transpose(1,2): [batch, channels, time]
116 | 3_transpose(1,2) --> output: [batch, time, channels]
117 | ```
118 |
119 | 这里加速的核心就是如何去除这四个冗余的转置。
120 |
121 | 加速步骤如下:
122 | 1. 将两个 pointwise_conv 转换为 Linear
123 | 2. depthwise_conv 没有办法替换为现成的算子,需要编写对应的plugin。
124 | 3. 将 self.pointwise_conv1.bias 和 glu 进行合并。
125 |
126 | 优化后流程维度变化如下:
127 | ```mermaid
128 | stateDiagram-v2
129 | input --> gemm+glu_dim2: [batch, time, channels]
130 | gemm+glu_dim2 --> depthwise_conv_plugin: [batch, time, channels]
131 | depthwise_conv_plugin --> activation(norm(x)): [batch, time, channels]
132 | activation(norm(x)) --> gemm: [batch, time, channels]
133 | ```
134 |
135 | PS: 在 depthwise_conv_plugin 计算中,通过直接按0算pad,可以直接将前面的pad操作去掉。
136 |
137 | PPS: 这里pad操作的位置有问题,详细见:https://github.com/wenet-e2e/wenet/issues/1527
138 |
139 | #### 1.4.3 layer_norm
140 | layernorm 值得单独拉出来说一下。
141 | layernorm 多个版本计算不同等原因,trt 并没有提供layernorm算子,onnx也没有。torch layernorm算子转 onnx时,会被打散成一堆算子。在trt8+,mylein 会将这一堆算子合并成一个foreign_node,但速度没有优化过的plugin快。
142 |
143 | 详细优化思路见这个视频 第28分钟:https://www.bilibili.com/video/BV1i3411G7vN?spm_id_from=333.999.0.0
144 |
145 | #### 1.4.4 scale_addbias_residual_prelayernorm
146 | 还可以将 scale, bias, residual, layernorm进行合并。在ConformerEncoderLayer中,该合并策略可以应用到三处。
147 | ```python
148 | x = self.feed_forward_macaron(x)
149 | x = residual + self.ff_scale * x
150 | residual = x
151 | if self.normalize_before:
152 | x = self.norm_mha(x)
153 | ```
154 | 代码为例,不考虑 feed_forward_macaron 模块,可以将逻辑梳理为:
155 | ```python
156 | # 输入为x 和 residual,输出为 out_residual 和 out
157 | x = x + self.feed_forward_macaron.bias
158 | out_residual = residual + self.ff_scale * x
159 | out = self.norm_mha(out_residual)
160 | ```
161 | 梳理后的逻辑,所有的算子可以在一个 kernel 中实现。
162 |
163 | # 二、整体优化策略
164 | 这一块做的不多,就瞎扯了。
165 | ## 2.1 half + int8
166 | 1. 建议优先做 fp16,简单速度快。在 T4 上测试发现单fp16的速度甚至能媲美单INT8的速度。
167 | 2. int8 在 conformer 模型上容易溢出,模型效果影响较大。想要把效果调回来需要费不少劲,但收益并不诱人。
168 | 3. plugin half 版本,有些plugin 可以使用 half2 加速,虽然是微弱的加速,聊胜于无。
169 | 4. 想要追求极致的性能,需要fp16+int8。
170 | 5. 追求更极致的性能,可以直接将一些plugin算子实现int8版本,除了能更快,还能免掉一些类型reformat操作。这个nv一直在尝试,但公开的代码比较少。
171 |
172 | ## 2.2 varlen
173 | 没做过
174 |
175 | # 三、加速效果
176 |
177 | 测试模型,语音识别 conformer 12 层 encoder stream模型,一个chunk是16帧。
178 | NVIDIA T4 显卡,下表中是测试一个chunk的时间。
179 |
180 | 在逐步的应用half加速和大部分算子级计算策略后,加速比达到了8倍之多,RTF 为 0.0157。
181 |
182 | | 模型 | 一个chunk的时间 |
183 | | --- | --- |
184 | conformer 12l float | 24.48ms
185 | conformer 12l half | 7.89ms
186 | conformer 12l half V2 | 4.42696 ms
187 | conformer 12 half V3 | 3.9395 ms
188 | conformer 12 half V4 | 3.39684 ms
189 | conformer 12 half V5 | 3.00334 ms
190 |
191 | 一点点做下来,速度快了起来~
192 |
193 | 整句模型我还没做完,NVIDIA 内部给出整句模型测试数据,RTF甚至低到了0.000452.
194 |
--------------------------------------------------------------------------------
/blogs/TensorRT Github 开源部分介绍.md:
--------------------------------------------------------------------------------
1 | 此文档基于TensorRT7.0。
2 |
3 | TensorRT(下面简称“TRT”)目前由两个部分组成,闭源部分和开源部分。闭源部分就是官方提供的库,是TRT的核心部分;开源部分在github上,包含Parser(caffe,onnx)、Sample和一些plugin。
4 |
5 | 其中开源部分中,有一部分要值得提一下,就是bert相关的plugin,位于demo/bert目录下。TRT在5.0版本中,提供了一个bert的demo,包括了一套完整的代码,实现将bert tf模型使用TRT上线,内容包括提取tf模型权值的python脚本,bert模型所需要的plugin,以及读取权值、搭建bert网络并运行和测试的代码,对于tf bert模型的使用者来说,可以说是很方便了。在6.0和7.0中,demo/bert目录被删除,其中的plugin修改为dynamic input shape版本,并放到plugin中,而其余代码则在7.1版本中才再次提供。这部分会在另一篇博客中(如果我能更下去的话==)做更详细的解释,这里就不赘述了。
6 |
7 | ## 简单介绍
8 | TRT作为NV inference 但大杀器,出世以来虽然速度很快(真香),但一直被“闭源”、“难用”所诟病。TRT github 虽然只是TRT开源出来的一小部分非核心代码,但也是有很多干货的。下面对各目录做一个简单的介绍。
9 |
10 | parser目录,主要包括官方维护的两个parser。caffe和onnx parser,能够将caffe格式模型和onnx格式模型直接转成TRT格式。其中caffe parser随着caffe逐渐退出训练框架舞台已经慢慢不维护了。
11 |
12 | plugin目录,包含官方提供的一些plugin,大部分plugin都是跟计算机视觉网络相关的,6.0开始加入了bert网络的plugin。
13 |
14 | sample目录,就是怕你不会用,提供的一些demo。还提供了一个trtexec工具,我没用过,就不多说了。
15 |
16 | ## 无网情况下编译
17 |
18 | NV虽然只开源了部分非核心代码,但还是很有用的。有用就要用上,第一步当然是编译。但是部分情况下,服务器可能是没有网的(懂的人自然懂)。而trt git在手动源码编译的时候,竟然需要联网下载protoubf……这里针对这个问题简单说一下。
19 |
20 | TRT的cmake编译选项中,有parser、plugin、sample等多个选项开关。其中部分sample编译依赖parser(caffe和onnx)模块,而parser模块依赖protobuf,而这个protobuf是根据你指定的版本先联网下载再编译的。这个设计对于无网的服务器,实在是不友好……
21 |
22 | 因为这个小麻烦而大改CMakeLists.txt实在是有点不值当。下面简单介绍一下比较简单的解决方案。
23 |
24 | 1. 不编译依赖parser的那部分sample,直接在TensorRT/samples/opensource/CMakeLists.txt中删掉即可。
25 | 2. 替换protobuf的下载链接,在另一台机器或者本机上搭建一个apache,把相应版本的protobuf放上去即可。具体修改 TensorRT/third_party/protobuf.cmake line22
26 | 3. 直接把protobuf放到编译目录下,然后修改TensorRT/third_party/protobuf.cmake 的部分代码,不从网络下载,而直接选择本地文件并编译即可。(这个cmake大佬可以试试,我没试过,感觉还挺麻烦的zz)
27 |
28 |
--------------------------------------------------------------------------------
/blogs/TensorRT Plugin使用方式简介-以leaky relu层为例.md:
--------------------------------------------------------------------------------
1 | # TensorRT Plugin使用方式简介-以leaky relu层为例
2 |
3 | ------
4 | ## 写在前面
5 | TensorRT plugin用于实现TensorRT不支持的网络层,比如leaky relu。本文以leaky relu为例,简单介绍plugin的使用方式以及plugin层的serialization和de-serialization的原理。
6 |
7 | 之前我已经分享过使用leaky relu曲线救国的解决方案,但是实验结果表明该方案比较慢,leaky relu的plugin实现方式性能更好。
8 |
9 | ## leaky relu的plugin实现
10 | 添加自定义层主要包括两个步骤,
11 | 1. 继承IPlugin接口,创建自定义层类
12 | 2. 将该自定义层添加到网络中
13 |
14 | 首先来简单介绍IPlugin接口类的成员函数,详细见TensorRT-3.0.0\include\NvInfer.h文件中的类定义。
15 |
16 | ``` c++
17 |
18 | // 获得该自定义层的输出个数,比如leaky relu层的输出个数为1
19 | virtual int getNbOutputs() const = 0;
20 |
21 | // 得到输出Tensor的维数
22 | virtual Dims getOutputDimensions(int index, const Dims* inputs, int nbInputDims) = 0;
23 |
24 | // 配置该层的参数。该函数在initialize()函数之前被构造器调用。它为该层提供了一个机会,可以根据其权重、尺寸和最大批量大小来做出算法选择。
25 | virtual void configure(const Dims* inputDims, int nbInputs, const Dims* outputDims, int nbOutputs, int maxBatchSize) = 0;
26 |
27 | // 对该层进行初始化,在engine创建时被调用。
28 | virtual int initialize() = 0;
29 |
30 | // 该函数在engine被摧毁时被调用
31 | virtual void terminate() = 0;
32 |
33 | // 获得该层所需的临时显存大小。
34 | virtual size_t getWorkspaceSize(int maxBatchSize) const = 0;
35 |
36 | // 执行该层
37 | virtual int enqueue(int batchSize, const void*const * inputs, void** outputs, void* workspace, cudaStream_t stream) = 0;
38 |
39 | // 获得该层进行serialization操作所需要的内存大小
40 | virtual size_t getSerializationSize() = 0;
41 |
42 | // 序列化该层,根据序列化大小getSerializationSize(),将该类的参数和额外内存空间全都写入到系列化buffer中。
43 | virtual void serialize(void* buffer) = 0;
44 | ```
45 | 根据类成员函数和leaky relu层的原理,设计LeakyReluPlugin类,可以很容易计算出的成员变量和各个成员函数的返回值。LeakyReluPlugin类实现代码如下。
46 | ``` c++
47 |
48 | __global__ void _leakyReluKer(float const *in, float *out, int size)
49 | {
50 | int index = threadIdx.x + blockIdx.x * blockDim.x;
51 | if (index >= size)
52 | return ;
53 |
54 | if (in[index] < 0)
55 | out[index] = in[index] * 0.1;
56 | else
57 | out[index] = in[index];
58 | }
59 |
60 | class LeakyReluPlugin : public IPlugin
61 | {
62 | public:
63 | LeakyReluPlugin() {}
64 | LeakyReluPlugin(const void* buffer, size_t size)
65 | {
66 | assert(size == sizeof(mSize));
67 | mSize = *reinterpret_cast(buffer);
68 | }
69 |
70 | int getNbOutputs() const override
71 | {
72 | return 1;
73 | }
74 | Dims getOutputDimensions(int index, const Dims* inputs, int nbInputDims) override
75 | {
76 | assert(nbInputDims == 1);
77 | assert(index == 0);
78 | assert(inputs[index].nbDims == 3);
79 | return DimsCHW(inputs[0].d[0], inputs[0].d[1], inputs[0].d[2]);
80 | }
81 |
82 | int initialize() override
83 | {
84 | return 0;
85 | }
86 |
87 | void terminate() override
88 | {
89 | }
90 |
91 | size_t getWorkspaceSize(int) const override
92 | {
93 | return 0;
94 | }
95 |
96 | // currently it is not possible for a plugin to execute "in place". Therefore we memcpy the data from the input to the output buffer
97 | int enqueue(int batchSize, const void*const *inputs, void** outputs, void*, cudaStream_t stream) override
98 | {
99 | int block_size = 256;
100 | int grid_size = (mSize + block_size - 1) / block_size;
101 | _leakyReluKer<<>>(
102 | reinterpret_cast(inputs[0]),
103 | reinterpret_cast(outputs[0]), mSize);
104 | getLastCudaError("_leakyReluKer");
105 | return 0;
106 | }
107 |
108 | size_t getSerializationSize() override
109 | {
110 | return sizeof(mSize);
111 | }
112 |
113 | void serialize(void* buffer) override
114 | {
115 | *reinterpret_cast(buffer) = mSize;
116 | }
117 |
118 | void configure(const Dims*inputs, int nbInputs, const Dims* outputs, int nbOutputs, int) override
119 | {
120 | mSize = inputs[0].d[0] * inputs[0].d[1] * inputs[0].d[2];
121 | }
122 |
123 | protected:
124 | size_t mSize;
125 | };
126 | ```
127 |
128 | 然后插入到网络中即可,代码如下。
129 |
130 | ```
131 | LeakyReluPlugin *lr = new LeakyReluPlugin();
132 | auto plugin_lr = network->addPlugin(&inputs_v[0], 1, *lr);
133 | plugin_lr->setName(PLUGIN_LEAKY_RELU_NAME);
134 | ```
135 |
136 | 然后运行网络即可。
137 |
138 | ## plugin层的serialization和deserialization的详解
139 | plugin的创建和使用的文档比较健全,照着文档来就行了。但序列化和反序列化这一部分文档中说的比较少,故在这里做详解。
140 |
141 | 序列化非常简单,在plugin类中实现getSerializationSize()和serialize()函数,然后一行代码序列化即可。
142 | gieModelStream = engine_->serialize();
143 |
144 | 重点在于反序列化,反序列化的步骤如下。
145 | 1. 根据序列化serialize()函数内的写入buffer的顺序构建IPluginFactory类。
146 | 2. 在反序列化时将IPluginFactory传入,用于将buffer中的数据反序列化为自定义层类。
147 |
148 | IPluginFactory接口类代码解释如下。
149 | 请注意layerName参数。
150 | ```
151 | class IPluginFactory
152 | {
153 | public:
154 | /**
155 | * \brief 根据序列化数据,反序列化为plugin类
156 | *
157 | * \param 网络层的名字,该参数非常重要,是反序列化为哪种plugin类的唯一凭证。
158 | * \param 序列化数据
159 | * \param 该层序列化后的序列化数据的长度
160 | *
161 | * \return the plugin
162 | *
163 | * \see IPlugin::serialize()
164 | */
165 | virtual IPlugin* createPlugin(const char* layerName, const void* serialData, size_t serialLength) = 0;
166 | };
167 |
168 | ```
169 | 以leaky relu为例,PluginFactory类实现如下。
170 |
171 | ```
172 | class PluginFactory : public nvinfer1::IPluginFactory
173 | {
174 | public:
175 | // deserialization plugin implementation
176 | IPlugin* createPlugin(const char* layerName, const void* serialData, size_t serialLength) override
177 | {
178 | IPlugin *plugin = nullptr;
179 | if (strstr(layerName, PLUGIN_LEAKY_RELU_NAME) != NULL)
180 | {
181 | plugin = new LeakyReluPlugin(serialData, serialLength);
182 | }
183 |
184 | return plugin;
185 | }
186 |
187 |
188 | std::unique_ptr mLR{ nullptr };
189 | };
190 | ```
191 | 然后在deserialize的时候,将PluginFactory传入即可,代码如下。
192 |
193 | ```
194 | engine_ = runtime->deserializeCudaEngine(buffer, length, &pluginFactory);
195 | ```
196 | **实验结果表明,leaky relu的plugin实现方式速度明显快于曲线救国的实现方式!**
197 |
198 |
--------------------------------------------------------------------------------
/blogs/TensorRT 可借鉴代码汇总.md:
--------------------------------------------------------------------------------
1 | 本文档主要针对API方式build网络方式。
2 | 语音领域现在仍处于成熟前的发展期,网络结构日新月异,模型大小逐渐变大,性能要求随之提高。因此我不太习惯用parser的形式去解析模型并构造网络。原因一是TRT目前支持功能有限,就算成熟框架(pytorch,tf)也容易遇到op不支持的情况,更何况还有kadi的存在;另一个是有时候性能要求高,全靠parser不好手动做优化。
3 | TensorRT的文档目前可以说是已经很成熟了,但仍架不住是闭源软件,使用API 构造网络的时候,不太耐操,仍需小心翼翼。学会API的正确使用方法,事半功倍。学习API的方法,自然是要借鉴别人的成熟代码~~
4 | 下面列一些比较成熟的代码,供借鉴。
5 | 1. 首先肯定是TenorRT自身的sample和plugin
6 | 2. onnx parser的源码,主要是builtin_op_importers.cpp
7 | 3. pytorch 源码中TensorRT相关部分
8 | 4. TF 源码中TensorRT相关部分
9 | 5. GitHub TensorRT issue和NVIDIA 论坛中的问答。
10 |
11 | 踩过的一些路:
12 | 1. RNNv2 layer中,双向lstm要求seql_len是固定的,无法适用于dynamic shape. 可以用loop layer特性和其他计算接口“拼接”出来一个支持动态seqlen的lstm。详见 TensorRT sample sampleCharRNN.cpp
13 | 2. IFullyConnectedLayer 进行矩阵乘的时候,会把{C, H, W} reshape 成 {1, C*H*W},输出为{K, 1, 1},就比较蛋疼。如何在不同场景下使用不同的接口实现矩阵乘,进行性能和显存的优化,详见以上各种代码。
14 |
--------------------------------------------------------------------------------
/blogs/TensorRT 转换模型的几种方式比较.md:
--------------------------------------------------------------------------------
1 | 最近一直在做 Conformer 模型的 TensorRT 加速,总结了几种常用的 TensorRT 转换模型方法,分别是转换工具,TensorRT API 搭建和转换 和超大Plugin。根据自己的见解,对这三种方法进行分析和比较。
2 |
3 | ## 一、 介绍
4 | ### 1.1 转换工具
5 | 转换工具可以分为三小类:(1)训练框架自带 TRT 工具;(2)第三方转换工具;(3)TVM。
6 | 框架框架自带TRT工具,主要代表是TF-TRT和Torch-TensorRT。顾名思义,这两个工具分别是针对TensorFlow和PyTorch模型,由官方支持,将相对应的模型转成TensorRT格式。
7 | 第三方 转换工具 的典型代表就是 onnx-tensorrt parser,原先应该是个人维护,后来NVIDIA接手了。
8 | TVM个人感觉不太成熟,尤其是在支持动态输入方面,了解的也不多,这里不多评价。
9 |
10 | **开发流程**
11 | 以 onnx-tensorrt parser 为例。对于标准模型非常简单,几行代码就能搞定。
12 | 对于需要使用plugin 情况,就比较麻烦了。比如layernorm算子加速,开发流程为:(1)编写 layernorm plugin 和测试代码;(2)编写onnx 算子替换代码,从模型结构中搜索出layernorm散装算子,进行替换,生成新onnx模型;(3)编写转换代码,转换模型。
13 |
14 | 举例代码:https://github.com/NVIDIA/trt-samples-for-hackathon-cn/tree/master/Hackathon/2022/code/LayerNormPlugin
15 |
16 | ### 1.2 TensorRT API 搭建和转换
17 | 转换工具一般是黑盒,输入模型,直接就输出了转换完 TensorRT 模型,非常方便。
18 | 使用 TensorRT API 包含搭建网络和转换模型两个步骤,其中还包含各种配置参数,步骤麻烦。加上TensorRT API 设计真的一言难尽,使用 API 转换网络属实费时费力。正因为如此,才催生了各种转换工具。
19 | 但 API 也是有优点的,在某些场景下,使用API方式构建更方便,甚至是必须的。比如(1)使用现有TRT转换工具无法成功转换,存在不支持的算子,比如fastmoe结构;(2)想要对模型做深度合并加速,比如TensorRT OOS中的BERT demo。
20 |
21 | **开发流程**
22 | 1. 熟悉整个网络结构,很细节的那种。需要看透模型训练代码或者给的onnx模型。
23 | 2. 如果是只能拿到onnx模型,则需要分析读取模型权值。直接拿到训练代码的可以省去这步。
24 | 3. 调用TensorRT API,对模型中的算子进行一一替换,填充TensorRT Network。
25 | 3. 编写每一个 plugin 和单元测试代码,并嵌入到网络中。
26 | 4. 设置构建参数,调用构建API 构建网络。
27 |
28 | 举例代码:https://github.com/wang-xinyu/tensorrtx/tree/master/alexnet
29 | 这个git里用的是c++ api,构建网络推荐使用python api,更简单方便。
30 |
31 | ### 1.3 超大 Plugin
32 | 超大 Plugin 是指将整个模型或者某个大模块实现在一个Plugin内,在plugin内部非常自由。一套代码基本只适用于一类模型,开发成本较高。目前只见过 NV 的团队针对一类模型做这样的事情,比如 FasterTransformer。
33 |
34 | **开发流程**
35 | 1. 熟悉整个网络结构,很细节的那种。需要看透模型训练代码或者给的onnx模型。
36 | 2. 分析读取模型权值。
37 | 3. 调用库 or 写 kernel 实现所有算子;应该写一套测试代码即可。
38 | 4. 填充网络,设置构建参数,调用构建API 构建网络。
39 | PS:这种方法一般没几个layer,填充网络步骤比较简单,就和构建网络合并在一起了。
40 |
41 | 举例代码就是FasterTransformer了。
42 |
43 | ## 二、 分析和比较
44 | 上述三种转换方式,在实际应用中,从易用性,开发成本,性能,和灵活性四个方面进行分析,数字代表星级,最高5星:
45 |
46 | | 方法 | 易用性 | 开发成本 | 性能 | 灵活性 |
47 | | --- | --- | --- | --- | --- |
48 | | 转换工具 | 4 | 2 | 2 | 2
49 | | API | 1 | 5 | 4.5 | 5
50 | | 超大Plugin | 1 | 3.5 | 5 | 4
51 |
52 |
53 | **易用性**
54 | 易用性这里指使用难度。毫无疑问,转换工具就是为了解决TensorRT 使用难度高痛点而生的。各大转换工具基本都可以零门槛,几行代码实现转换模型。
55 | 但遇见需要使用plugin情况会麻烦一些,评级4星。
56 | API和超大Plugin也是毫无疑问,门槛比较高。前者需要熟练TensorRT的API和构建步骤,后者需要熟练CUDA。评星最低1星。
57 |
58 | **开发成本**
59 | 开发成本和易用性其实比较像,只所以单独拎出来比较一下,是因为API和超大Plugin虽然开发成本都很高,但也有些许差别。
60 | 从开发流程可以看出,API方法多了填充网络(这个是大头)和plugin单元测试代码的工作量,开发成本更高一些。
61 |
62 | **性能**
63 | 转换工具的最大缺点就是性能。转换工具的本质就是使用TensorRT API,将网络结构一对一的转换成TensorRT模型,是强依赖于模型结构的。而受python语言和模型训练限制,模型结构往往离最优还有不少的距离。但TensorRT 8.4开始,着重增强了mylein的融合算子功能,就算模型结构不是最优,trt本身也能在构建模型的过程中尽可能的提升性能,所以评星2星。
64 | API和超大Plugin难度高,开发成本还高的情况下,还一直有不少人使用,就算可以追求极致速度,尤其在目前“降本增效”的大环境下,速度的提升显得更重要一些。两者都可以实现大部分优化策略,存在一些细微的差别。
65 | 1. 在int8加速方面,API方法不受影响。但超大Plugin 实现 int8就麻烦太多了。但目前的显卡fp16的性能就已经很不错了,int8的性能提升也就没有之前那么诱人了。
66 | 2. TensorRT 的加速技术之一就是可以适配硬件和输入大小选择运行最快的实现方式。超大Plugin也基本跟这个无缘了,但对于gemm等算子,也可以通过cublas/lt的方式选择最优的kernel。
67 | 3. 超大Plugin可以实现稀疏矩阵乘、INT4矩阵乘等新技术,API方式实现这些就只能等TensorRT的更新了。
68 |
69 | **灵活性**
70 | 实际产业应用中,经常有魔改模型和加速模型等需求,灵活性指的就是是否能够满足这类需求。
71 | 转换工具满足魔改模型需求还行,加速模型方面实在是比较麻烦,评星2星。
72 | API方法在灵活性方面是最高的,能够轻松满足各种需求,评星5星。
73 | 超大Plugin方式,一套代码基本只适用于一类模型,每次魔改模型可能发生二次开发和加配置参数这种情况,评星4星。
74 |
75 | ## 三、个人的一些看法
76 | 个人认为,在NLP和语音领域,API方法是最适用于工业领域的TensorRT上线方法。尤其是在目前降本增效的大环境下,有能力的话,直接上API。(最适用并不代表一定要用……)
77 | 为什么 NLP 和 语音领域?因为这两个领域的模型结构相对比较复杂,输入大小一般是动态的,模型结构还属于不断更新的状态(主要还是看google……),向大模型发展等原因,使用现成的转换工具,往往速度不够好。
78 |
--------------------------------------------------------------------------------
/blogs/使用TensorRT实现leaky relu层.md:
--------------------------------------------------------------------------------
1 | # 使用TensorRT实现leaky relu层
2 |
3 | ------
4 |
5 | 最近在使用TensorRT加速CNN的时候,发现TensorRT支持relu但不支持leaky relu,这里分享出网上找到的曲线求国的[解决方案][1]和实现。
6 | ## 解决方案
7 | 解决方案比较简单,使用scale层、relu层和ElementWise层实现leaky relu,具体流程如下:
8 | ![flow_of_leaky_relu][2]
9 |
10 | ## 实现方式
11 | 实现方式有两种。
12 |
13 | 1. 在训练时就这样定义,然后直接把训练好的模型丢给TensorRT。比如使用caffe,在prototxt文件中这样定义leaky relu层([详见这里][3]),然后再使用TensorRT中的NvCaffeParser转化就行了。
14 | 2. 自己用API实现,代码如下:
15 | ``` c++
16 | void LeakyRelu(INetworkDefinition *network, ITensor *it)
17 | {
18 | Weights power{DataType::kFLOAT, nullptr, 0};
19 | Weights shift{DataType::kFLOAT, nullptr, 0};
20 |
21 | float *scale_params = new float[2];
22 |
23 | // scale_1 * 0.1
24 | scale_params[0] = 0.1f;
25 |
26 | Weights scale{DataType::kFLOAT, &scale_params[0], 1};
27 | auto scale_1 = network->addScale(*it, ScaleMode::kUNIFORM, shift, scale, power);
28 | assert(scale_1 != nullptr);
29 |
30 | // relu + scale_2 * 0.9;
31 | auto relu = network->addActivation(*it, ActivationType::kRELU);
32 | assert(relu != nullptr);
33 |
34 | scale_params[1] = 0.9f;
35 | Weights scale1{DataType::kFLOAT, &scale_params[1], 1};
36 | auto scale_2 = network->addScale(*relu->getOutput(0), ScaleMode::kUNIFORM, shift, scale1, power);
37 | assert(scale_2 != nullptr);
38 |
39 | // result = scale_1 + scale_2
40 | auto ew = network->addElementWise(*scale_1->getOutput(0), *scale_2->getOutput(0), ElementWiseOperation::kSUM);
41 | assert(ew != nullptr);
42 | }
43 | ```
44 | ## 坑
45 | 其实这个实现是比较简单的,没太有必要写出来。但是TensorRT给的scale层demo里有个坑,分享给大家。
46 | sampleMNISTAPI里给出的scale层代码如下:
47 | ```c++
48 | // Create a scale layer with default power/shift and specified scale parameter.
49 | float scale_param = 0.0125f;
50 | Weights power{DataType::kFLOAT, nullptr, 0};
51 | Weights shift{DataType::kFLOAT, nullptr, 0};
52 | Weights scale{DataType::kFLOAT, &scale_param, 1};
53 | auto scale_1 = network->addScale(*data, ScaleMode::kUNIFORM, shift, scale, power);
54 | assert(scale_1 != nullptr);
55 | ```
56 | 第5行是scale的参数,根据参数个数可以实现对于per-tensor, per-channel, or per-element进行scale操作,所以其中scale的第二个参数是个指针,指向一个数组。
57 |
58 | 这个坑在于第2行和第5行。因为是要对整个Tensor进行scale操作,所以Weights scale内参数个数为1,因此就声明了一个局部变量scale_param。问题就在这,scale_param是局部变量,在栈空间,函数结束就会被释放掉!所以在实现LeakyRelu()函数的时候,第六行的scale_params一定不能用局部变量存储,我选择放到堆空间中。
59 |
60 | 从这个例子中深深反思自己的c++能力,以后在用现成代码的时候一定要仔细检查,看是否符合目前的需求。
61 |
62 |
63 | [1]: https://github.com/TLESORT/YOLO-TensorRT-GIE-
64 | [2]: https://raw.githubusercontent.com/LitLeo/TensorRT_Tutorial/master/img/flow_of_leaky_relu.png
65 | [3]: https://devtalk.nvidia.com/default/topic/990426/jetson-tx1/tensorrt-yolo-inference-error/post/5087820/
--------------------------------------------------------------------------------
/blogs/写于20200829.md:
--------------------------------------------------------------------------------
1 | Hi 大家好,我诈尸更新啦。
2 |
3 | 上一次更新还是三年前第一次在搜狗语音实习接触TensorRT的时候,那会TRT刚出,实在是没有什么资料可以查,就随手写了这个git。
4 |
5 | 那会TRT实在是有点简陋,不太适合应用到语音领域,就放弃了。没想到三年没更新,竟然已经有300多个star了(我知道你们可能是被名字便进来到zz)。目前随着 TRT 开始支持dynamic input shape 和 loop 功能,功能强大了很多,又开始接触TRT,毕竟TRT的速度实在是香。
6 |
7 | TRT的dynamic input 和 loop 特性,使TRT在语音和nlp领域的可用性大大增加。最近几个月基于TRT7.0在语音领域的应用做了一些功能,也踩了不少的坑。
8 |
9 | 为了不辜负这300多个star,就再更新点东西
10 |
--------------------------------------------------------------------------------
/cublas&cudnn_int8_demo/README.md:
--------------------------------------------------------------------------------
1 | cublas使用cublasGemmEx函数的CUDA_R_32I计算模式来实现INT8加速。需要注意的坑是,alpha 和 beta 这两个参数必须为 int类型,cublas文档没有写明白。
2 |
3 | cudnn 的卷积INT8加速为使用cudnnConvolutionForward的四种INT8配置(INT8, INT8_EXT, INT8x4, INT8x4_EXT),按自己需求决定使用哪个函数。[demo在这里][1],他的这个代码有点小错误,cudnn cudnnConvolutionForward INT8输入要求是4的倍数,详细要求见cudnn文档,[问题讨论在这里][2]。
4 |
5 |
6 | [1]: https://github.com/jesryu/cudnn_conv_int8
7 | [2]: https://devtalk.nvidia.com/default/topic/1005119/cudnn-v6-int8-convolution-failing-with-cudnn_status_not_supported/
--------------------------------------------------------------------------------
/cublas&cudnn_int8_demo/cublasGemmEx/cuBlasGemmEx.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LitLeo/TensorRT_Tutorial/1d370d82c78eddccadf4a8490c5c859db35a78c9/cublas&cudnn_int8_demo/cublasGemmEx/cuBlasGemmEx.pdf
--------------------------------------------------------------------------------
/cublas&cudnn_int8_demo/cublasGemmEx/gemmInt8_rect.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | using namespace std;
3 |
4 | #include
5 | #include
6 | #include
7 |
8 | // cuda
9 | #include
10 | #include
11 |
12 | #define N 4
13 | #define Value 2
14 | #define checkCudaAPIErrors(F) if ((F) != cudaSuccess) \
15 | { printf("Error at line %d in file %s: %s\n", __LINE__, __FILE__, cudaGetErrorString(cudaGetLastError())); exit(-1); }
16 |
17 | void initArray(char * a, const int size) {
18 | for (int i = 0; i < size; ++i)
19 | {
20 | a[i] = Value;
21 | }
22 | }
23 | static const char *_cudaGetErrorEnum(cublasStatus_t error)
24 | {
25 | switch (error)
26 | {
27 | case CUBLAS_STATUS_SUCCESS:
28 | return "CUBLAS_STATUS_SUCCESS";
29 |
30 | case CUBLAS_STATUS_NOT_INITIALIZED:
31 | return "CUBLAS_STATUS_NOT_INITIALIZED";
32 |
33 | case CUBLAS_STATUS_ALLOC_FAILED:
34 | return "CUBLAS_STATUS_ALLOC_FAILED";
35 |
36 | case CUBLAS_STATUS_INVALID_VALUE:
37 | return "CUBLAS_STATUS_INVALID_VALUE";
38 |
39 | case CUBLAS_STATUS_ARCH_MISMATCH:
40 | return "CUBLAS_STATUS_ARCH_MISMATCH";
41 |
42 | case CUBLAS_STATUS_MAPPING_ERROR:
43 | return "CUBLAS_STATUS_MAPPING_ERROR";
44 |
45 | case CUBLAS_STATUS_EXECUTION_FAILED:
46 | return "CUBLAS_STATUS_EXECUTION_FAILED";
47 |
48 | case CUBLAS_STATUS_INTERNAL_ERROR:
49 | return "CUBLAS_STATUS_INTERNAL_ERROR";
50 |
51 | case CUBLAS_STATUS_NOT_SUPPORTED:
52 | return "CUBLAS_STATUS_NOT_SUPPORTED";
53 |
54 | case CUBLAS_STATUS_LICENSE_ERROR:
55 | return "CUBLAS_STATUS_LICENSE_ERROR";
56 | }
57 |
58 | return "";
59 | }
60 |
61 | #define checkcuBlasError(F) if ((F) != CUBLAS_STATUS_SUCCESS) \
62 | { printf("Error at line %d in file %s: %s\n", __LINE__, __FILE__, _cudaGetErrorEnum(F)); exit(-1); }
63 |
64 | /** @main function ****************
65 | **********************************/
66 | int main(int argc, char** argv)
67 | {
68 | // test_count
69 | int iters = 1;
70 |
71 | int alpha = 1;
72 | int beta = 0;
73 |
74 | float TFlops;
75 | cublasStatus_t cublasStat;
76 |
77 | int n[N] = {512, 512, 512, 512};
78 | int k[N] = {2048, 2048, 2048, 2048};
79 | int m[N] = {4, 8, 16, 32};
80 |
81 | int devID = 0;
82 | cudaSetDevice(devID);
83 | cudaDeviceProp devProp;
84 | cudaGetDeviceProperties(&devProp, devID);
85 | printf("Device : %s, compute SM %d.%d.\n",devProp.name, devProp.major, devProp.minor);
86 |
87 | cublasHandle_t handle;
88 | checkcuBlasError(cublasCreate(&handle));
89 |
90 | FILE *output = NULL;
91 | char filename[20] = "result.txt";
92 |
93 | cudaEvent_t start, stop;
94 | float time_used = 0.0;
95 | cudaEventCreate(&start);
96 | cudaEventCreate(&stop);
97 |
98 | char *d_A, *d_B;
99 | int *d_C; // note the result is accumulated in int
100 | char *h_A, *h_B;
101 | int *h_C; // note the result is accumulated in int
102 |
103 |
104 | if ((output = fopen(filename, "w")) == NULL)
105 | {
106 | printf("Can not open file : %s\n", filename);
107 | exit(1);
108 | }
109 | fprintf(output, "m \t k \t n \t Time \t TFlops\n");
110 |
111 | for (int i=0; i 11.0
155 | list(APPEND GPU_ARCHS 80)
156 | else()
157 | message(WARNING "Detected CUDA version is < 11.0. SM80 not supported.")
158 | endif()
159 |
160 | message(STATUS "GPU_ARCHS is not defined. Generating CUDA code for default SMs: ${GPU_ARCHS}")
161 | endif()
162 | set(BERT_GENCODES)
163 | # Generate SASS for each architecture
164 | foreach(arch ${GPU_ARCHS})
165 | if (${arch} GREATER_EQUAL 70)
166 | set(BERT_GENCODES "${BERT_GENCODES} -gencode arch=compute_${arch},code=sm_${arch}")
167 | endif()
168 | set(GENCODES "${GENCODES} -gencode arch=compute_${arch},code=sm_${arch}")
169 | endforeach()
170 | # Generate PTX for the last architecture in the list.
171 | list(GET GPU_ARCHS -1 LATEST_SM)
172 | set(GENCODES "${GENCODES} -gencode arch=compute_${LATEST_SM},code=compute_${LATEST_SM}")
173 | if (${LATEST_SM} GREATER_EQUAL 70)
174 | set(BERT_GENCODES "${BERT_GENCODES} -gencode arch=compute_${LATEST_SM},code=compute_${LATEST_SM}")
175 | endif()
176 | set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler -Wno-deprecated-declarations")
177 |
178 | ############################################################################################
179 | # TensorRT
180 |
181 | if(BUILD_PLUGINS)
182 | add_subdirectory(plugin)
183 | else()
184 | find_library_create_target(nvinfer_plugin nvinfer_plugin SHARED ${TRT_OUT_DIR} ${TRT_LIB_DIR})
185 | endif()
186 |
187 | #if(BUILD_PARSERS)
188 | #add_subdirectory(parsers)
189 | #else()
190 | #find_library_create_target(nvcaffeparser nvparsers SHARED ${TRT_OUT_DIR} ${TRT_LIB_DIR})
191 | #find_library_create_target(nvonnxparser nvonnxparser SHARED ${TRT_OUT_DIR} ${TRT_LIB_DIR})
192 | #endif()
193 |
194 | if(BUILD_SAMPLES)
195 | add_subdirectory(samples)
196 | endif()
197 |
198 |
--------------------------------------------------------------------------------
/resource_for_billibilli/debug_plugin/README.md:
--------------------------------------------------------------------------------
1 | 1. debug_plugin demo for debug, to print output of each layer.
2 | 2. support TensorRT 6&7
3 |
4 | # build
5 | ```
6 | mkdir build && cd build
7 | export TRT_RELEASE=/your/trt/path
8 | cmake .. -DTRT_LIB_DIR=$TRT_RELEASE/lib -DTRT_OUT_DIR=`pwd`/out
9 | ```
10 |
--------------------------------------------------------------------------------
/resource_for_billibilli/debug_plugin/cmake/modules/find_library_create_target.cmake:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 |
17 | macro(find_library_create_target target_name lib libtype hints)
18 | message(STATUS "========================= Importing and creating target ${target_name} ==========================")
19 | message(STATUS "Looking for library ${lib}")
20 | if (CMAKE_BUILD_TYPE STREQUAL "Debug")
21 | find_library(${lib}_LIB_PATH ${lib}${TRT_DEBUG_POSTFIX} HINTS ${hints} NO_DEFAULT_PATH)
22 | endif()
23 | find_library(${lib}_LIB_PATH ${lib} HINTS ${hints} NO_DEFAULT_PATH)
24 | find_library(${lib}_LIB_PATH ${lib})
25 | message(STATUS "Library that was found ${${lib}_LIB_PATH}")
26 | add_library(${target_name} ${libtype} IMPORTED)
27 | set_property(TARGET ${target_name} PROPERTY IMPORTED_LOCATION ${${lib}_LIB_PATH})
28 | message(STATUS "==========================================================================================")
29 | endmacro()
30 |
--------------------------------------------------------------------------------
/resource_for_billibilli/debug_plugin/cmake/modules/set_ifndef.cmake:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 |
17 | function (set_ifndef variable value)
18 | if(NOT DEFINED ${variable})
19 | set(${variable} ${value} PARENT_SCOPE)
20 | endif()
21 | endfunction()
22 |
--------------------------------------------------------------------------------
/resource_for_billibilli/debug_plugin/cmake/toolchains/cmake_aarch64-android.toolchain:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 |
17 | set(CMAKE_SYSTEM_NAME Linux)
18 | set(CMAKE_SYSTEM_PROCESSOR aarch64)
19 |
20 | set(CMAKE_C_COMPILER $ENV{AARCH64_ANDROID_CC})
21 | set(CMAKE_CXX_COMPILER $ENV{AARCH64_ANDROID_CC})
22 |
23 | set(CMAKE_C_FLAGS "$ENV{AARCH64_ANDROID_CFLAGS} -pie -fPIE"
24 | CACHE STRING "" FORCE)
25 | set(CMAKE_CXX_FLAGS "${CMAKE_C_FLAGS}" CACHE STRING "" FORCE)
26 |
27 | set(CMAKE_C_COMPILER_TARGET aarch64-none-linux-android)
28 | set(CMAKE_CXX_COMPILER_TARGET aarch64-none-linux-android)
29 |
30 | set(CMAKE_C_COMPILER_FORCED TRUE)
31 | set(CMAKE_CXX_COMPILER_FORCED TRUE)
32 |
33 | set(CUDA_TOOLKIT_ROOT_DIR ${CUDA_ROOT})
34 | set(CUDA_INCLUDE_DIRS ${CUDA_ROOT}/include)
35 |
36 | set(CMAKE_CUDA_HOST_COMPILER ${CMAKE_CXX_COMPILER} CACHE STRING "" FORCE)
37 | set(CMAKE_CUDA_FLAGS "-I${CUDA_INCLUDE_DIRS} -Xcompiler=\"-fPIC ${CMAKE_CXX_FLAGS}\"" CACHE STRING "" FORCE)
38 | set(CMAKE_CUDA_COMPILER_FORCED TRUE)
39 |
40 |
41 | set(CUDA_LIBS -L${CUDA_ROOT}/lib64)
42 |
43 | set(ADDITIONAL_PLATFORM_LIB_FLAGS ${CUDA_LIBS} -lcublas -lcudart -lnvToolsExt -lculibos -lcudadevrt -llog)
44 |
45 |
46 | set(DISABLE_SWIG TRUE)
47 | set(TRT_PLATFORM_ID "aarch64-android")
48 |
--------------------------------------------------------------------------------
/resource_for_billibilli/debug_plugin/cmake/toolchains/cmake_aarch64.toolchain:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 |
17 | set(CMAKE_SYSTEM_NAME Linux)
18 | set(CMAKE_SYSTEM_PROCESSOR aarch64)
19 | set(TRT_PLATFORM_ID "aarch64")
20 | set(CUDA_PLATFORM_ID "aarch64-linux")
21 |
22 | set(CMAKE_C_COMPILER /usr/bin/aarch64-linux-gnu-gcc)
23 | set(CMAKE_CXX_COMPILER /usr/bin/aarch64-linux-gnu-g++)
24 |
25 | set(CMAKE_C_FLAGS "" CACHE STRING "" FORCE)
26 | set(CMAKE_CXX_FLAGS "" CACHE STRING "" FORCE)
27 |
28 | set(CMAKE_C_COMPILER_TARGET aarch64)
29 | set(CMAKE_CXX_COMPILER_TARGET aarch64)
30 |
31 | set(CMAKE_C_COMPILER_FORCED TRUE)
32 | set(CMAKE_CXX_COMPILER_FORCED TRUE)
33 |
34 | set(CUDA_ROOT /usr/local/cuda-${CUDA_VERSION}/targets/${CUDA_PLATFORM_ID} CACHE STRING "CUDA ROOT dir")
35 |
36 | set(CUDA_TOOLKIT_ROOT_DIR ${CUDA_ROOT})
37 | set(CUDA_INCLUDE_DIRS ${CUDA_ROOT}/include)
38 |
39 | set(RT_LIB /usr/aarch64-linux-gnu/lib/librt.so)
40 |
41 | set(CMAKE_CUDA_HOST_COMPILER ${CMAKE_CXX_COMPILER} CACHE STRING "" FORCE)
42 | set(CMAKE_CUDA_FLAGS "-cudart none -I${CUDA_INCLUDE_DIRS} -Xcompiler=\"-fPIC ${CMAKE_CXX_FLAGS}\"" CACHE STRING "" FORCE)
43 | set(CMAKE_CUDA_COMPILER_FORCED TRUE)
44 |
45 | set(CUDA_LIBS -L${CUDA_ROOT}/lib)
46 |
47 | set(ADDITIONAL_PLATFORM_LIB_FLAGS ${CUDA_LIBS} -lcublas -lcudart -lstdc++ -lm)
48 |
49 | set(DISABLE_SWIG TRUE)
50 |
--------------------------------------------------------------------------------
/resource_for_billibilli/debug_plugin/cmake/toolchains/cmake_ppc64le.toolchain:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 |
17 | set(CMAKE_SYSTEM_NAME Linux)
18 | set(CMAKE_SYSTEM_PROCESSOR ppc64le)
19 |
20 | set(CMAKE_C_COMPILER powerpc64le-linux-gnu-gcc)
21 | set(CMAKE_CXX_COMPILER powerpc64le-linux-gnu-g++)
22 |
23 | set(CMAKE_C_COMPILER_TARGET ppc64le)
24 | set(CMAKE_CXX_COMPILER_TARGET ppc64le)
25 |
26 | set(CMAKE_CUDA_HOST_COMPILER ${CMAKE_CXX_COMPILER} CACHE STRING "" FORCE)
27 | set(CMAKE_CUDA_FLAGS "-I${CUDA_ROOT}/include -Xcompiler=\"-fPIC ${CMAKE_CXX_FLAGS}\"" CACHE STRING "" FORCE)
28 | set(CMAKE_CUDA_COMPILER_FORCED TRUE)
29 |
30 | if(DEFINED CUDA_ROOT)
31 | set(CUDA_TOOLKIT_ROOT_DIR ${CUDA_ROOT})
32 | endif()
33 |
34 | set(CUDA_INCLUDE_DIRS ${CUDA_ROOT}/include)
35 |
36 | set(TRT_PLATFORM_ID "ppc64le")
37 |
--------------------------------------------------------------------------------
/resource_for_billibilli/debug_plugin/cmake/toolchains/cmake_qnx.toolchain:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 |
17 | set(CMAKE_SYSTEM_NAME qnx)
18 | set(CMAKE_SYSTEM_PROCESSOR aarch64)
19 |
20 | if(DEFINED ENV{QNX_BASE})
21 | set(QNX_BASE $ENV{QNX_BASE})
22 | message(STATUS "Found QNX_BASE = ${QNX_BASE}")
23 | elseif(DEFINED ENV{TOOLS_BASE})
24 | set(QNX_BASE $ENV{TOOLS_BASE}/embedded/qnx/qnx700-ga4)
25 | message(STATUS "Found QNX_BASE = ${QNX_BASE}")
26 | else()
27 | message(FATAL_ERROR "QNX_BASE was not found")
28 | endif()
29 |
30 | set(ENV{QNX_HOST} ${QNX_BASE}/host/linux/x86_64)
31 | set(ENV{QNX_TARGET} ${QNX_BASE}/target/qnx7)
32 |
33 | set(QNX_HOST $ENV{QNX_HOST})
34 | set(QNX_TARGET $ENV{QNX_TARGET})
35 |
36 | message(STATUS "QNX_HOST = ${QNX_HOST}")
37 | message(STATUS "QNX_TARGET = ${QNX_TARGET}")
38 |
39 | set(CMAKE_C_COMPILER ${QNX_HOST}/usr/bin/aarch64-unknown-nto-qnx7.0.0-gcc)
40 | set(CMAKE_CXX_COMPILER ${QNX_HOST}/usr/bin/aarch64-unknown-nto-qnx7.0.0-g++)
41 |
42 | set(CMAKE_C_COMPILER_TARGET aarch64)
43 | set(CMAKE_CXX_COMPILER_TARGET aarch64)
44 |
45 | set(CMAKE_C_COMPILER_FORCED TRUE)
46 | set(CMAKE_CXX_COMPILER_FORCED TRUE)
47 |
48 | set(CUDA_TOOLKIT_ROOT_DIR ${CUDA_ROOT})
49 | set(CUDA_INCLUDE_DIRS ${CUDA_ROOT}/include)
50 |
51 | set(CMAKE_CUDA_HOST_COMPILER ${CMAKE_CXX_COMPILER} CACHE STRING "" FORCE)
52 | set(CMAKE_CUDA_FLAGS "-I${CUDA_INCLUDE_DIRS} -Xcompiler -fPIC" CACHE STRING "" FORCE)
53 | set(CMAKE_CUDA_COMPILER_FORCED TRUE)
54 |
55 | set(CUDA_LIBS -L${CUDA_ROOT}/lib)
56 |
57 | set(ADDITIONAL_PLATFORM_LIB_FLAGS ${CUDA_LIBS} -lcublas -lcudart)
58 | #...Disable swig
59 | set(DISABLE_SWIG TRUE)
60 |
61 | set(TRT_PLATFORM_ID "aarch64-qnx")
62 |
--------------------------------------------------------------------------------
/resource_for_billibilli/debug_plugin/cmake/toolchains/cmake_x64_win.toolchain:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 |
17 | set(CMAKE_SYSTEM_NAME WindowsStore)
18 | set(CMAKE_SYSTEM_VERSION 10.0)
19 |
20 | set(CMAKE_C_COMPILER ${CC})
21 | set(CMAKE_CXX_COMPILER ${CC})
22 |
23 | if(DEFINED CUDA_TOOLKIT)
24 | set(CUDA_TOOLKIT_ROOT_DIR ${CUDA_TOOLKIT})
25 | endif()
26 |
27 | set(CMAKE_CUDA_COMPILER ${CUDA_TOOLKIT_ROOT_DIR}/bin/nvcc.exe)
28 | set(CMAKE_CUDA_COMPILER_ID "NVIDIA")
29 |
30 | set(CMAKE_C_COMPILER_FORCED TRUE)
31 | set(CMAKE_CXX_COMPILER_FORCED TRUE)
32 | set(CMAKE_CUDA_COMPILER_FORCED TRUE)
33 |
34 | set(NV_TOOLS ${NV_TOOLS})
35 | set(W10_LIBRARY_SUFFIXES .lib .dll)
36 | set(W10_CUDA_ROOT ${CUDA_TOOLKIT_ROOT_DIR})
37 | set(W10_LINKER ${MSVC_COMPILER_DIR}/bin/amd64/link)
38 |
39 |
40 | set(CMAKE_CUDA_HOST_COMPILER ${CMAKE_NVCC_COMPILER} CACHE STRING "" FORCE)
41 |
42 | set(ADDITIONAL_PLATFORM_INCL_FLAGS "-I${MSVC_COMPILER_DIR}/include -I${MSVC_COMPILER_DIR}/../ucrt/include")
43 | set(ADDITIONAL_PLATFORM_LIB_FLAGS ${ADDITIONAL_PLATFORM_LIB_FLAGS} "-LIBPATH:${NV_TOOLS}/ddk/wddmv2/official/17134/Lib/10.0.17134.0/um/x64")
44 | set(ADDITIONAL_PLATFORM_LIB_FLAGS ${ADDITIONAL_PLATFORM_LIB_FLAGS} "-LIBPATH:${MSVC_COMPILER_DIR}/lib/amd64" )
45 | set(ADDITIONAL_PLATFORM_LIB_FLAGS ${ADDITIONAL_PLATFORM_LIB_FLAGS} "-LIBPATH:${MSVC_COMPILER_DIR}/../ucrt/lib/x64")
46 | set(ADDITIONAL_PLATFORM_LIB_FLAGS ${ADDITIONAL_PLATFORM_LIB_FLAGS} "-LIBPATH:${W10_CUDA_ROOT}/lib/x64 cudart.lib cublas.lib")
47 |
48 | set(TRT_PLATFORM_ID "win10")
49 |
--------------------------------------------------------------------------------
/resource_for_billibilli/debug_plugin/cmake/toolchains/cmake_x86_64.toolchain:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 |
17 | set(CMAKE_SYSTEM_NAME Linux)
18 | set(CMAKE_SYSTEM_PROCESSOR x86_64)
19 |
20 | set(CMAKE_C_COMPILER gcc)
21 | set(CMAKE_CXX_COMPILER g++)
22 |
23 | if(DEFINED CUDA_ROOT)
24 | set(CUDA_TOOLKIT_ROOT_DIR ${CUDA_ROOT})
25 | endif()
26 |
27 | set(CUDA_INCLUDE_DIRS ${CUDA_ROOT}/include)
28 |
29 | set(TRT_PLATFORM_ID "x86_64")
30 |
--------------------------------------------------------------------------------
/resource_for_billibilli/debug_plugin/include/NvInferPlugin.h:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | #ifndef NV_INFER_PLUGIN_H
18 | #define NV_INFER_PLUGIN_H
19 |
20 | #include "NvInfer.h"
21 | #include "NvInferPluginUtils.h"
22 | //!
23 | //! \file NvInferPlugin.h
24 | //!
25 | //! This is the API for the Nvidia provided TensorRT plugins.
26 | //!
27 |
28 | extern "C"
29 | {
30 | //!
31 | //! \brief Initialize and register all the existing TensorRT plugins to the Plugin Registry with an optional
32 | //! namespace. The plugin library author should ensure that this function name is unique to the library. This
33 | //! function should be called once before accessing the Plugin Registry. \param logger Logger object to print plugin
34 | //! registration information \param libNamespace Namespace used to register all the plugins in this library
35 | //!
36 | TENSORRTAPI bool initLibNvInferPlugins(void* logger, const char* libNamespace);
37 | } // extern "C"
38 |
39 | #endif // NV_INFER_PLUGIN_H
40 |
--------------------------------------------------------------------------------
/resource_for_billibilli/debug_plugin/include/NvInferVersion.h:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | //!
18 | //! \file NvInferVersion.h
19 | //!
20 | //! Defines the TensorRT version
21 | //!
22 |
23 | #ifndef NV_INFER_VERSION_H
24 | #define NV_INFER_VERSION_H
25 |
26 | #define NV_TENSORRT_MAJOR 7 //!< TensorRT major version.
27 | #define NV_TENSORRT_MINOR 2 //!< TensorRT minor version.
28 | #define NV_TENSORRT_PATCH 1 //!< TensorRT patch version.
29 | #define NV_TENSORRT_BUILD 6 //!< TensorRT build number.
30 |
31 | #define NV_TENSORRT_SONAME_MAJOR 7 //!< Shared object library major version number.
32 | #define NV_TENSORRT_SONAME_MINOR 2 //!< Shared object library minor version number.
33 | #define NV_TENSORRT_SONAME_PATCH 1 //!< Shared object library patch version number.
34 |
35 | #endif // NV_INFER_VERSION_H
36 |
--------------------------------------------------------------------------------
/resource_for_billibilli/debug_plugin/include/NvOnnxConfig.h:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | #ifndef NV_OnnxConfig_H
18 | #define NV_OnnxConfig_H
19 |
20 | #include "NvInfer.h"
21 |
22 | namespace nvonnxparser
23 | {
24 |
25 | //!
26 | //! \mainpage
27 | //!
28 | //! This is the API documentation for the Configuration Manager for Open Neural Network Exchange (ONNX) Parser for Nvidia TensorRT Inference Engine.
29 | //! It provides information on individual functions, classes
30 | //! and methods. Use the index on the left to navigate the documentation.
31 | //!
32 | //! Please see the accompanying user guide and samples for higher-level information and general advice on using ONNX Parser and TensorRT.
33 | //!
34 |
35 | //!
36 | //! \file NvOnnxConfig.h
37 | //!
38 | //! This is the API file for the Configuration Manager for ONNX Parser for Nvidia TensorRT.
39 | //!
40 |
41 | //!
42 | //! \class IOnnxConfig
43 | //! \brief Configuration Manager Class.
44 | //!
45 | class IOnnxConfig
46 | {
47 | protected:
48 | virtual ~IOnnxConfig() {}
49 |
50 | public:
51 | //!
52 | //! \typedef Verbosity
53 | //! \brief Defines Verbosity level.
54 | //!
55 | typedef int Verbosity;
56 |
57 | //!
58 | //! \brief Set the Model Data Type.
59 | //!
60 | //! Sets the Model DataType, one of the following: float -d 32 (default), half precision -d 16, and int8 -d 8 data types.
61 | //!
62 | //! \see getModelDtype()
63 | //!
64 | virtual void setModelDtype(const nvinfer1::DataType) TRTNOEXCEPT = 0;
65 |
66 | //!
67 | //! \brief Get the Model Data Type.
68 | //!
69 | //! \return DataType nvinfer1::DataType
70 | //!
71 | //! \see setModelDtype() and #DataType
72 | //!
73 | virtual nvinfer1::DataType getModelDtype() const TRTNOEXCEPT = 0;
74 |
75 | //!
76 | //! \brief Get the Model FileName.
77 | //!
78 | //! \return Return the Model Filename, as a pointer to a NULL-terminated character sequence.
79 | //!
80 | //! \see setModelFileName()
81 | //!
82 | virtual const char* getModelFileName() const TRTNOEXCEPT = 0;
83 |
84 | //!
85 | //! \brief Set the Model File Name.
86 | //!
87 | //! The Model File name contains the Network Description in ONNX pb format.
88 | //!
89 | //! This method copies the name string.
90 | //!
91 | //! \param onnxFilename The name.
92 | //!
93 | //! \see getModelFileName()
94 | //!
95 | virtual void setModelFileName(const char* onnxFilename) TRTNOEXCEPT = 0;
96 |
97 | //!
98 | //! \brief Get the Verbosity Level.
99 | //!
100 | //! \return The Verbosity Level.
101 | //!
102 | //! \see addVerbosity(), reduceVerbosity()
103 | //!
104 | virtual Verbosity getVerbosityLevel() const TRTNOEXCEPT = 0;
105 |
106 | //!
107 | //! \brief Increase the Verbosity Level.
108 | //!
109 | //! \return The Verbosity Level.
110 | //!
111 | //! \see addVerbosity(), reduceVerbosity(), setVerbosity(Verbosity)
112 | //!
113 | virtual void addVerbosity() TRTNOEXCEPT = 0; //!< Increase verbosity Level.
114 | virtual void reduceVerbosity() TRTNOEXCEPT = 0; //!< Decrease verbosity Level.
115 | virtual void setVerbosityLevel(Verbosity) TRTNOEXCEPT = 0; //!< Set to specific verbosity Level.
116 |
117 | //!
118 | //! \brief Returns the File Name of the Network Description as a Text File.
119 | //!
120 | //! \return Return the name of the file containing the network description converted to a plain text, used for debugging purposes.
121 | //!
122 | //! \see setTextFilename()
123 | //!
124 | virtual const char* getTextFileName() const TRTNOEXCEPT = 0;
125 |
126 | //!
127 | //! \brief Set the File Name of the Network Description as a Text File.
128 | //!
129 | //! This API allows setting a file name for the network description in plain text, equivalent of the ONNX protobuf.
130 | //!
131 | //! This method copies the name string.
132 | //!
133 | //! \param textFileName Name of the file.
134 | //!
135 | //! \see getTextFilename()
136 | //!
137 | virtual void setTextFileName(const char* textFileName) TRTNOEXCEPT = 0;
138 |
139 | //!
140 | //! \brief Get the File Name of the Network Description as a Text File, including the weights.
141 | //!
142 | //! \return Return the name of the file containing the network description converted to a plain text, used for debugging purposes.
143 | //!
144 | //! \see setFullTextFilename()
145 | //!
146 | virtual const char* getFullTextFileName() const TRTNOEXCEPT = 0;
147 |
148 | //!
149 | //! \brief Set the File Name of the Network Description as a Text File, including the weights.
150 | //!
151 | //! This API allows setting a file name for the network description in plain text, equivalent of the ONNX protobuf.
152 | //!
153 | //! This method copies the name string.
154 | //!
155 | //! \param fullTextFileName Name of the file.
156 | //!
157 | //! \see getFullTextFilename()
158 | //!
159 | virtual void setFullTextFileName(const char* fullTextFileName) TRTNOEXCEPT = 0;
160 |
161 | //!
162 | //! \brief Get whether the layer information will be printed.
163 | //!
164 | //! \return Returns whether the layer information will be printed.
165 | //!
166 | //! \see setPrintLayerInfo()
167 | //!
168 | virtual bool getPrintLayerInfo() const TRTNOEXCEPT = 0;
169 |
170 | //!
171 | //! \brief Set whether the layer information will be printed.
172 | //!
173 | //! \see getPrintLayerInfo()
174 | //!
175 | virtual void setPrintLayerInfo(bool) TRTNOEXCEPT = 0;
176 |
177 | //!
178 | //! \brief Destroy IOnnxConfig object.
179 | //!
180 | virtual void destroy() TRTNOEXCEPT = 0;
181 |
182 | }; // class IOnnxConfig
183 |
184 | TENSORRTAPI IOnnxConfig* createONNXConfig();
185 |
186 | } // namespace nvonnxparser
187 |
188 | #endif
189 |
--------------------------------------------------------------------------------
/resource_for_billibilli/debug_plugin/include/NvUtils.h:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | #ifndef NV_UTILS_H
18 | #define NV_UTILS_H
19 |
20 | #include "NvInfer.h"
21 |
22 | //!
23 | //! \file NvUtils.h
24 | //!
25 | //! This file includes various utility functions
26 | //!
27 |
28 | namespace nvinfer1
29 | {
30 | namespace utils
31 | {
32 |
33 | //!
34 | //! \param input The input weights to reshape.
35 | //! \param shape The shape of the weights.
36 | //! \param shapeOrder The order of the dimensions to process for the output.
37 | //! \param data The location where the output data is placed.
38 | //! \param nbDims The number of dimensions to process.
39 | //!
40 | //! \brief Reformat the input weights of the given shape based on the new
41 | //! order of dimensions.
42 | //!
43 | //! Take the weights specified by \p input with the dimensions specified by
44 | //! \p shape and re-order the weights based on the new dimensions specified
45 | //! by \p shapeOrder. The size of each dimension and the input data is not
46 | //! modified. The output volume pointed to by \p data must be the same as
47 | //! he \p input volume.
48 | //!
49 | //! Example usage:
50 | //! float *out = new float[N*C*H*W];
51 | //! Weights input{DataType::kFLOAT, {0 ... N*C*H*W-1}, N*C*H*W size};
52 | //! int32_t order[4]{1, 0, 3, 2};
53 | //! int32_t shape[4]{C, N, W, H};
54 | //! reshapeWeights(input, shape, order, out, 4);
55 | //! Weights reshaped{input.type, out, input.count};
56 | //!
57 | //! Input Matrix{3, 2, 3, 2}:
58 | //! { 0 1}, { 2 3}, { 4 5} <-- {0, 0, *, *}
59 | //! { 6 7}, { 8 9}, {10 11} <-- {0, 1, *, *}
60 | //! {12 13}, {14 15}, {16 17} <-- {1, 0, *, *}
61 | //! {18 19}, {20 21}, {22 23} <-- {1, 1, *, *}
62 | //! {24 25}, {26 27}, {28 29} <-- {2, 0, *, *}
63 | //! {30 31}, {32 33}, {34 35} <-- {2, 1, *, *}
64 | //!
65 | //! Output Matrix{2, 3, 2, 3}:
66 | //! { 0 2 4}, { 1 3 5} <-- {0, 0, *, *}
67 | //! {12 14 16}, {13 15 17} <-- {0, 1, *, *}
68 | //! {24 26 28}, {25 27 29} <-- {0, 2, *, *}
69 | //! { 6 8 10}, { 7 9 11} <-- {1, 0, *, *}
70 | //! {18 20 22}, {19 21 23} <-- {1, 1, *, *}
71 | //! {30 32 34}, {31 33 35} <-- {1, 2, *, *}
72 | //!
73 | //! \return True on success, false on failure.
74 | //!
75 | TENSORRTAPI bool reshapeWeights(
76 | const Weights& input, const int32_t* shape, const int32_t* shapeOrder, void* data, int32_t nbDims);
77 |
78 | //!
79 | //! \param input The input data to re-order.
80 | //! \param order The new order of the data sub-buffers.
81 | //! \param num The number of data sub-buffers to re-order.
82 | //! \param size The size of each data sub-buffer in bytes.
83 | //!
84 | //! \brief Takes an input stream and re-orders \p num chunks of the data
85 | //! given the \p size and \p order.
86 | //!
87 | //! In some frameworks, the ordering of the sub-buffers within a dimension
88 | //! is different than the way that TensorRT expects them.
89 | //! TensorRT expects the gate/bias sub-buffers for LSTM's to be in fico order.
90 | //! TensorFlow however formats the sub-buffers in icfo order.
91 | //! This helper function solves this in a generic fashion.
92 | //!
93 | //! Example usage output of reshapeWeights above:
94 | //! int32_t indir[1]{1, 0}
95 | //! int32_t stride = W*H;
96 | //! for (int32_t x = 0, y = N*C; x < y; ++x)
97 | //! reorderSubBuffers(out + x * stride, indir, H, W);
98 | //!
99 | //! Input Matrix{2, 3, 2, 3}:
100 | //! { 0 2 4}, { 1 3 5} <-- {0, 0, *, *}
101 | //! {12 14 16}, {13 15 17} <-- {0, 1, *, *}
102 | //! {24 26 28}, {25 27 29} <-- {0, 2, *, *}
103 | //! { 6 8 10}, { 7 9 11} <-- {1, 0, *, *}
104 | //! {18 20 22}, {19 21 23} <-- {1, 1, *, *}
105 | //! {30 32 34}, {31 33 35} <-- {1, 2, *, *}
106 | //!
107 | //! Output Matrix{2, 3, 2, 3}:
108 | //! { 1 3 5}, { 0 2 4} <-- {0, 0, *, *}
109 | //! {13 15 17}, {12 14 16} <-- {0, 1, *, *}
110 | //! {25 27 29}, {24 26 28} <-- {0, 2, *, *}
111 | //! { 7 9 11}, { 6 8 10} <-- {1, 0, *, *}
112 | //! {19 21 23}, {18 20 22} <-- {1, 1, *, *}
113 | //! {31 33 35}, {30 32 34} <-- {1, 2, *, *}
114 | //!
115 | //! \return True on success, false on failure.
116 | //!
117 | //! \see reshapeWeights()
118 | //!
119 | TENSORRTAPI bool reorderSubBuffers(void* input, const int32_t* order, int32_t num, int32_t size);
120 |
121 | //!
122 | //! \param input The input data to transpose.
123 | //! \param type The type of the data to transpose.
124 | //! \param num The number of data sub-buffers to transpose.
125 | //! \param height The size of the height dimension to transpose.
126 | //! \param width The size of the width dimension to transpose.
127 | //!
128 | //! \brief Transpose \p num sub-buffers of \p height * \p width.
129 | //!
130 | //! \return True on success, false on failure.
131 | //!
132 | TENSORRTAPI bool transposeSubBuffers(void* input, DataType type, int32_t num, int32_t height, int32_t width);
133 |
134 | } // namespace utils
135 | } // namespace nvinfer1
136 | #endif // NV_UTILS_H
137 |
--------------------------------------------------------------------------------
/resource_for_billibilli/debug_plugin/plugin/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 | add_custom_target(plugin)
17 |
18 | set(TARGET_NAME nvinfer_plugin)
19 | set(SHARED_TARGET ${TARGET_NAME})
20 | set(STATIC_TARGET ${TARGET_NAME}_static)
21 |
22 | set(TARGET_DIR ${CMAKE_CURRENT_SOURCE_DIR})
23 | #set(PLUGIN_EXPORT_MAP ${TARGET_DIR}/exports.map)
24 |
25 | if(${CMAKE_BUILD_TYPE} MATCHES "Debug")
26 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g")
27 | endif()
28 |
29 | set(PLUGIN_SOURCES)
30 | set(PLUGIN_CU_SOURCES)
31 |
32 | set(PLUGIN_LISTS
33 | debug_plugin
34 | )
35 |
36 | include_directories(common common/kernels ../samples/common)
37 |
38 | foreach(PLUGIN_ITER ${PLUGIN_LISTS})
39 | include_directories(${PLUGIN_ITER})
40 | add_subdirectory(${PLUGIN_ITER})
41 | endforeach(PLUGIN_ITER)
42 |
43 | # Set gencodes
44 | set_source_files_properties(${PLUGIN_CU_SOURCES} PROPERTIES COMPILE_FLAGS ${GENCODES})
45 | list(APPEND PLUGIN_SOURCES "${PLUGIN_CU_SOURCES}")
46 |
47 | list(APPEND PLUGIN_SOURCES "${CMAKE_CURRENT_SOURCE_DIR}/infer_plugin_api.cc")
48 | list(APPEND PLUGIN_SOURCES "${CMAKE_CURRENT_SOURCE_DIR}/logger.cc")
49 |
50 | ################################## SHARED LIBRARY #######################################
51 |
52 | add_library(${SHARED_TARGET} SHARED
53 | ${PLUGIN_SOURCES}
54 | )
55 |
56 | target_include_directories(${SHARED_TARGET}
57 | PUBLIC ${PROJECT_SOURCE_DIR}/include
58 | PUBLIC ${CUB_ROOT_DIR}
59 | PRIVATE ${PROJECT_SOURCE_DIR}/common
60 | PUBLIC ${CUDA_INSTALL_DIR}/include
61 | PRIVATE ${TARGET_DIR}
62 | )
63 |
64 | set_target_properties(${SHARED_TARGET} PROPERTIES
65 | CXX_STANDARD "11"
66 | CXX_STANDARD_REQUIRED "YES"
67 | CXX_EXTENSIONS "NO"
68 | ARCHIVE_OUTPUT_DIRECTORY "${TRT_OUT_DIR}"
69 | LIBRARY_OUTPUT_DIRECTORY "${TRT_OUT_DIR}"
70 | RUNTIME_OUTPUT_DIRECTORY "${TRT_OUT_DIR}"
71 | )
72 |
73 | #set_target_properties(${SHARED_TARGET} PROPERTIES LINK_FLAGS "-Wl,--exclude-libs,ALL -Wl,--version-script=${PLUGIN_EXPORT_MAP} -Wl,--no-undefined")
74 |
75 | #set_target_properties(${SHARED_TARGET} PROPERTIES DEBUG_POSTFIX ${TRT_DEBUG_POSTFIX})
76 |
77 | set_target_properties(${SHARED_TARGET} PROPERTIES VERSION ${TRT_VERSION} SOVERSION ${TRT_SOVERSION} )
78 |
79 | set_property(TARGET ${SHARED_TARGET} PROPERTY CUDA_STANDARD 11)
80 |
81 | target_link_libraries(${SHARED_TARGET}
82 | ${CUBLAS_LIB}
83 | ${CUBLASLT_LIB}
84 | ${CUDART_LIB}
85 | ${CUDNN_LIB}
86 | nvinfer
87 | ${CMAKE_DL_LIBS}
88 | )
89 |
90 | ################################## STATIC LIBRARY #######################################
91 |
92 | add_library(${STATIC_TARGET} STATIC
93 | ${PLUGIN_SOURCES}
94 | )
95 |
96 | target_include_directories(${STATIC_TARGET}
97 | PUBLIC ${PROJECT_SOURCE_DIR}/include
98 | PUBLIC ${CUB_ROOT_DIR}
99 | PRIVATE ${PROJECT_SOURCE_DIR}/common
100 | PUBLIC ${CUDA_INSTALL_DIR}/include
101 | PRIVATE ${TARGET_DIR}
102 | )
103 |
104 | set_target_properties(${STATIC_TARGET} PROPERTIES
105 | CXX_STANDARD "11"
106 | CXX_STANDARD_REQUIRED "YES"
107 | CXX_EXTENSIONS "NO"
108 | ARCHIVE_OUTPUT_DIRECTORY "${TRT_OUT_DIR}"
109 | LIBRARY_OUTPUT_DIRECTORY "${TRT_OUT_DIR}"
110 | RUNTIME_OUTPUT_DIRECTORY "${TRT_OUT_DIR}"
111 | )
112 |
113 | set_target_properties(${STATIC_TARGET} PROPERTIES LINK_FLAGS "-Wl,--exclude-libs,ALL")
114 |
115 | set_target_properties(${STATIC_TARGET} PROPERTIES DEBUG_POSTFIX ${TRT_DEBUG_POSTFIX})
116 |
117 | set_target_properties(${STATIC_TARGET} PROPERTIES VERSION ${TRT_VERSION} SOVERSION ${TRT_SOVERSION} )
118 |
119 | set_property(TARGET ${STATIC_TARGET} PROPERTY CUDA_STANDARD 11)
120 |
121 | #########################################################################################
122 |
123 | add_dependencies(plugin ${SHARED_TARGET} ${STATIC_TARGET})
124 |
125 | ################################### INSTALLATION ########################################
126 |
127 | install(TARGETS ${TARGET_NAME}
128 | RUNTIME DESTINATION bin
129 | LIBRARY DESTINATION lib
130 | ARCHIVE DESTINATION lib
131 | )
132 |
--------------------------------------------------------------------------------
/resource_for_billibilli/debug_plugin/plugin/debug_plugin/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 | file(GLOB CU_SRCS *.cu)
17 | set(PLUGIN_CU_SOURCES ${PLUGIN_CU_SOURCES} ${CU_SRCS})
18 | set(PLUGIN_CU_SOURCES ${PLUGIN_CU_SOURCES} PARENT_SCOPE)
19 |
--------------------------------------------------------------------------------
/resource_for_billibilli/debug_plugin/plugin/debug_plugin/debug_dynamic_plugin.h:
--------------------------------------------------------------------------------
1 | #ifndef PLUGIN_DEBUG_PLUGIN_H
2 | #define PLUGIN_DEBUG_PLUGIN_H
3 |
4 | #include
5 | #include
6 |
7 | #include "NvInfer.h"
8 |
9 | #include "plugin_common.h"
10 |
11 | namespace debug_plugin {
12 |
13 | class DebugPluginDynamic : public nvinfer1::IPluginV2DynamicExt {
14 | public:
15 | DebugPluginDynamic(const std::string name, const nvinfer1::DataType data_type, int input_num);
16 |
17 | DebugPluginDynamic(const std::string name, const void* data, size_t length);
18 |
19 | // It doesn't make sense to make DebugPluginDynamic without arguments, so we delete
20 | // default constructor.
21 | DebugPluginDynamic() = delete;
22 |
23 | // IPluginV2DynamicExt Methods
24 | nvinfer1::IPluginV2DynamicExt* clone() const override;
25 | nvinfer1::DimsExprs getOutputDimensions(
26 | int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) override;
27 | bool supportsFormatCombination(
28 | int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs, int nbOutputs) override;
29 | void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs,
30 | const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) override;
31 | size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, int nbInputs,
32 | const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const override;
33 | int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvinfer1::PluginTensorDesc* outputDesc,
34 | const void* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) override;
35 |
36 | // IPluginV2Ext Methods
37 | nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType* inputTypes, int nbInputs) const override;
38 |
39 | // IPluginV2 Methods
40 | const char* getPluginType() const override;
41 | const char* getPluginVersion() const override;
42 | int getNbOutputs() const override;
43 | int initialize() override;
44 | void terminate() override;
45 | size_t getSerializationSize() const override;
46 | void serialize(void* buffer) const override;
47 | void destroy() override;
48 | void setPluginNamespace(const char* pluginNamespace) override;
49 | const char* getPluginNamespace() const override;
50 |
51 | private:
52 | std::string layer_name_;
53 | std::string namespace_;
54 |
55 | nvinfer1::DataType data_type_;
56 | size_t num_inputs_;
57 |
58 | protected:
59 | // To prevent compiler warnings.
60 | using nvinfer1::IPluginV2DynamicExt::getOutputDimensions;
61 | using nvinfer1::IPluginV2DynamicExt::isOutputBroadcastAcrossBatch;
62 | using nvinfer1::IPluginV2DynamicExt::canBroadcastInputAcrossBatch;
63 | using nvinfer1::IPluginV2DynamicExt::supportsFormat;
64 | using nvinfer1::IPluginV2DynamicExt::configurePlugin;
65 | using nvinfer1::IPluginV2DynamicExt::getWorkspaceSize;
66 | using nvinfer1::IPluginV2DynamicExt::enqueue;
67 | };
68 |
69 | class DebugPluginDynamicCreator : public nvinfer1::IPluginCreator {
70 | public:
71 | DebugPluginDynamicCreator();
72 | const char* getPluginName() const override;
73 | const char* getPluginVersion() const override;
74 | const nvinfer1::PluginFieldCollection* getFieldNames() override;
75 | nvinfer1::IPluginV2* createPlugin(const char* name, const nvinfer1::PluginFieldCollection* fc) override;
76 | nvinfer1::IPluginV2* deserializePlugin(const char* name, const void* serialData, size_t serialLength) override;
77 | void setPluginNamespace(const char* pluginNamespace) override;
78 | const char* getPluginNamespace() const override;
79 |
80 | private:
81 | static nvinfer1::PluginFieldCollection mFC;
82 | static std::vector mPluginAttributes;
83 | std::string namespace_;
84 | };
85 |
86 | } // debug_plugin
87 |
88 | #endif // PLUGIN_DEBUG_PLUGIN_H
89 |
--------------------------------------------------------------------------------
/resource_for_billibilli/debug_plugin/plugin/debug_plugin/debug_kernel.h:
--------------------------------------------------------------------------------
1 | #ifndef PLUGIN_DEBUG_KERNEL_H_
2 | #define PLUGIN_DEBUG_KERNEL_H_
3 |
4 | #include
5 | #include
6 | #include
7 | #include
8 | #include
9 |
10 | #include "plugin_common.h"
11 |
12 | namespace debug_plugin {
13 |
14 | void p(const float *data, std::vector& dims);
15 |
16 | void p_sum(const float *data, std::vector& dims, std::string message);
17 |
18 | } // debug_plugin
19 |
20 | #endif // PLUGIN_DEBUG_KERNEL_H_
21 |
--------------------------------------------------------------------------------
/resource_for_billibilli/debug_plugin/plugin/debug_plugin/debug_plugin.cu:
--------------------------------------------------------------------------------
1 | #include "debug_plugin.h"
2 |
3 | #include
4 | #include
5 | #include
6 |
7 | #include "NvInfer.h"
8 |
9 | #include "debug_kernel.h"
10 | #include "serialize.hpp"
11 |
12 | using namespace nvinfer1;
13 | using namespace std;
14 |
15 | namespace debug_plugin {
16 |
17 | // Clip plugin specific constants
18 | namespace
19 | {
20 | static const char* DEBUG_VERSION{"1"};
21 | static const char* DEBUG_NAME{"DebugPlugin"};
22 | } // namespace
23 |
24 | /*REGISTER_TENSORRT_PLUGIN(DebugPluginCreator);*/
25 |
26 | DebugPlugin::DebugPlugin(const std::string &name, const DataType data_type, int input_num,
27 | std::vector outputs_dims)
28 | : layer_name_(name)
29 | , data_type_(data_type)
30 | , num_inputs_(input_num)
31 | , outputs_dims_(outputs_dims)
32 | { }
33 |
34 | DebugPlugin::DebugPlugin(const std::string &name, const void* data, size_t length)
35 | : layer_name_(name) {
36 | deserialize_value(&data, &length, &data_type_);
37 | deserialize_value(&data, &length, &num_inputs_);
38 | size_t name_len = 0;
39 | deserialize_value(&data, &length, &name_len);
40 |
41 | // deserialize dims
42 | size_t outputs_dims_size = 0;
43 | deserialize_value(&data, &length, &outputs_dims_size);
44 |
45 | outputs_dims_.resize(outputs_dims_size);
46 | const char *d = static_cast(data);
47 |
48 | for (int i = 0; i < outputs_dims_size; i++) {
49 | deserNvDimsToHost(d, outputs_dims_[i]);
50 | }
51 |
52 | char tmp[name_len];
53 | deserToHost(d, tmp, name_len);
54 | layer_name_.resize(name_len);
55 | layer_name_ = std::string(tmp);
56 | gLogVerbose << "Starting to deserialize DEBUG plugin: " << layer_name_ << std::endl;
57 | }
58 |
59 | IPluginV2Ext* DebugPlugin::clone() const {
60 | auto p = new DebugPlugin(layer_name_, data_type_, num_inputs_, outputs_dims_);
61 | return p;
62 | }
63 |
64 | Dims DebugPlugin::getOutputDimensions(int index, const Dims* inputs, int nbInputDims) {
65 | outputs_dims_.push_back(inputs[index]);
66 | return inputs[index];
67 | }
68 |
69 | bool DebugPlugin::supportsFormatCombination(int pos, const PluginTensorDesc* inOut,
70 | int nbInputs, int nbOutputs) const {
71 | return true;
72 | }
73 |
74 | void DebugPlugin::configurePlugin(const PluginTensorDesc* in, int nbInput,
75 | const PluginTensorDesc* out, int nbOutput)
76 | { }
77 |
78 | size_t DebugPlugin::getWorkspaceSize(int maxBatchSize) const {
79 | return 0;
80 | }
81 |
82 | int DebugPlugin::enqueue(int batchSize, const void* const* inputs, void** outputs,
83 | void* workspace, cudaStream_t stream) {
84 |
85 | for (size_t n = 0; n < num_inputs_; n++) {
86 | auto dims = outputs_dims_[n];
87 | const int inputVolume = volume(dims) * batchSize;
88 | // remove dim = 1 or 0
89 | vector v_dims;
90 | v_dims.push_back(batchSize);
91 | for (int i = 0; i < dims.nbDims; i++) {
92 | int d = dims.d[i];
93 | if (d > 1) v_dims.push_back(d);
94 | }
95 |
96 | if (data_type_ == DataType::kFLOAT) {
97 | const float* input = static_cast(inputs[n]);
98 | float *arr = new float[inputVolume];
99 | memset(arr, 0, inputVolume*sizeof(float));
100 |
101 | cudaMemcpy(arr, input, inputVolume*sizeof(float), cudaMemcpyDeviceToHost);
102 | printf("layer_name=%s, dims=%s\n",
103 | layer_name_.c_str(), dims2String(dims).c_str());
104 |
105 | p(arr, v_dims);
106 | delete [] arr;
107 |
108 | float* output = static_cast(outputs[n]);
109 | cudaMemcpy(output, input, inputVolume*sizeof(float), cudaMemcpyDeviceToDevice);
110 |
111 | } else if (data_type_ == DataType::kHALF) {
112 | #ifdef __SCORE_HALF__
113 | const half* input = static_cast(inputs[0]);
114 | #endif
115 | } else {
116 | assert(false);
117 | }
118 | }
119 |
120 | return 0;
121 | }
122 |
123 | // IPluginV2Ext Methods
124 | DataType DebugPlugin::getOutputDataType(int index, const DataType* inputTypes, int nbInputs) const {
125 | assert(inputTypes[0] == DataType::kFLOAT || inputTypes[0] == DataType::kHALF);
126 | return inputTypes[0];
127 | }
128 |
129 | const char* DebugPlugin::getPluginType() const {
130 | return DEBUG_NAME;
131 | }
132 |
133 | const char* DebugPlugin::getPluginVersion() const {
134 | return DEBUG_VERSION;
135 | }
136 |
137 | int DebugPlugin::getNbOutputs() const {
138 | return num_inputs_;
139 | }
140 |
141 | int DebugPlugin::initialize() {
142 | return 0;
143 | }
144 |
145 | void DebugPlugin::terminate()
146 | { }
147 |
148 | size_t DebugPlugin::getSerializationSize() const
149 | {
150 | return sizeof(data_type_) + sizeof(num_inputs_) +
151 | sizeof(int) * outputs_dims_.size() * (nvinfer1::Dims::MAX_DIMS+ 1) +
152 | sizeof(layer_name_.size()) + layer_name_.size() + 10;
153 | }
154 |
155 | void DebugPlugin::serialize(void* buffer) const {
156 | serialize_value(&buffer, data_type_);
157 | serialize_value(&buffer, num_inputs_);
158 | serialize_value(&buffer, layer_name_.size());
159 |
160 | serialize_value(&buffer, outputs_dims_.size());
161 | char *d = static_cast(buffer);
162 | for (size_t i = 0; i < outputs_dims_.size(); i++) {
163 | serNvDimsFromHost(d, outputs_dims_[i]);
164 | }
165 |
166 | serFromHost(d, layer_name_, (size_t)layer_name_.size());
167 | }
168 |
169 | void DebugPlugin::destroy() {
170 | delete this;
171 | }
172 |
173 | void DebugPlugin::setPluginNamespace(const char* libNamespace) {
174 | namespace_ = libNamespace;
175 | }
176 |
177 | const char* DebugPlugin::getPluginNamespace() const {
178 | return namespace_.c_str();
179 | }
180 |
181 | bool DebugPlugin::isOutputBroadcastAcrossBatch(int outputIndex, const bool* inputIsBroadcasted, int nbInputs) const {
182 | return false;
183 | }
184 |
185 | bool DebugPlugin::canBroadcastInputAcrossBatch(int inputIndex) const {
186 | return false;
187 | }
188 |
189 | const char* DebugPluginCreator::getPluginName() const {
190 | return DEBUG_NAME;
191 | }
192 |
193 | const char* DebugPluginCreator::getPluginVersion() const {
194 | return DEBUG_VERSION;
195 | }
196 |
197 | const PluginFieldCollection* DebugPluginCreator::getFieldNames() {
198 | return &field_collection_;
199 | }
200 |
201 | IPluginV2* DebugPluginCreator::createPlugin(const char* name, const PluginFieldCollection* fc) {
202 | gLogVerbose << "Creating DebugPlugin...\n";
203 |
204 | int typeId = -1;
205 | int input_num = 0;
206 | for (int i = 0; i < fc->nbFields; i++) {
207 | std::string field_name(fc->fields[i].name);
208 |
209 | if (field_name.compare("type_id") == 0) {
210 | typeId = *static_cast(fc->fields[i].data);
211 | gLogVerbose << "Building typeId: " << typeId << std::endl;
212 | }
213 | if (field_name.compare("input_num") == 0) {
214 | input_num = *static_cast(fc->fields[i].data);
215 | gLogVerbose << "Building input_num: " << input_num << std::endl;
216 | }
217 | }
218 |
219 | if (typeId < 0 || typeId > 2) {
220 | gLogError << "DEBUG: invalid typeId " << typeId << std::endl;
221 | return nullptr;
222 | }
223 | DataType type = static_cast(typeId);
224 | gLogVerbose << "Creating DebugPlugin...\n";
225 | return new DebugPlugin(name, type, input_num);
226 | }
227 |
228 | IPluginV2* DebugPluginCreator::deserializePlugin(const char* name, const void* serialData, size_t serialLength) {
229 | return new DebugPlugin(name, serialData, serialLength);
230 | }
231 |
232 | void DebugPluginCreator::setPluginNamespace(const char* libNamespace) {
233 | namespace_ = libNamespace;
234 | }
235 |
236 | const char* DebugPluginCreator::getPluginNamespace() const {
237 | return namespace_.c_str();
238 | }
239 |
240 | } // debug_plugin
241 |
242 |
--------------------------------------------------------------------------------
/resource_for_billibilli/debug_plugin/plugin/debug_plugin/debug_plugin.h:
--------------------------------------------------------------------------------
1 | #ifndef PLUGIN_DEBUG_PLUGIN_H_
2 | #define PLUGIN_DEBUG_PLUGIN_H_
3 |
4 | #include "NvInferPlugin.h"
5 | #include "NvInferRuntime.h"
6 |
7 | #include
8 | #include
9 |
10 | #include "plugin_common.h"
11 |
12 | namespace debug_plugin {
13 |
14 | class DebugPlugin : public nvinfer1::IPluginV2IOExt {
15 | public:
16 | DebugPlugin(const std::string &name, const nvinfer1::DataType type, int input_num,
17 | std::vector outputs_dims = std::vector());
18 |
19 | DebugPlugin(const std::string &name, const void* data, size_t length);
20 |
21 | // It makes no sense to construct DebugPlugin without arguments.
22 | DebugPlugin() = delete;
23 |
24 | virtual ~DebugPlugin() {}
25 |
26 | public:
27 | int getNbOutputs() const override;
28 | nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims* inputs, int nbInputDims) override;
29 | int initialize() override;
30 | void terminate() override;
31 | size_t getWorkspaceSize(int maxBatchSize) const override;
32 | int enqueue(int batchSize, const void* const* inputs, void** outputs,
33 | void* workspace, cudaStream_t stream) override;
34 | size_t getSerializationSize() const override;
35 | void serialize(void* buffer) const override;
36 | void configurePlugin(const nvinfer1::PluginTensorDesc* in, int nbInput,
37 | const nvinfer1::PluginTensorDesc* out, int nbOutput) override;
38 | bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc* inOut,
39 | int nbInputs, int nbOutputs) const override;
40 | nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType* inputTypes,
41 | int nbInputs) const override;
42 | const char* getPluginType() const override;
43 | const char* getPluginVersion() const override;
44 | void destroy() override;
45 | nvinfer1::IPluginV2Ext* clone() const override;
46 | void setPluginNamespace(const char* libNamespace) override;
47 | const char* getPluginNamespace() const override;
48 | bool isOutputBroadcastAcrossBatch(int outputIndex, const bool* inputIsBroadcasted,
49 | int nbInputs) const override;
50 | bool canBroadcastInputAcrossBatch(int inputIndex) const override;
51 |
52 | private:
53 | std::string layer_name_;
54 | std::string namespace_;
55 | nvinfer1::DataType data_type_;
56 |
57 | size_t num_inputs_;
58 |
59 | std::vector outputs_dims_;
60 |
61 | protected:
62 | // To prevent compiler warnings.
63 | using nvinfer1::IPluginV2IOExt::configurePlugin;
64 | };
65 |
66 | class DebugPluginCreator: public nvinfer1::IPluginCreator {
67 | public:
68 | const char* getPluginName() const override;
69 | const char* getPluginVersion() const override;
70 | const nvinfer1::PluginFieldCollection* getFieldNames() override;
71 | nvinfer1::IPluginV2* createPlugin(const char* name, const nvinfer1::PluginFieldCollection* fc) override;
72 | nvinfer1::IPluginV2* deserializePlugin(const char* name, const void* serialData, size_t serialLength) override;
73 | void setPluginNamespace(const char* libNamespace) override;
74 | const char* getPluginNamespace() const override;
75 |
76 | private:
77 | std::string namespace_;
78 | std::string plugin_name_;
79 | nvinfer1::PluginFieldCollection field_collection_{0, nullptr};
80 | };
81 |
82 | } // debug_plugin
83 |
84 | #endif // PLUGIN_DEBUG_PLUGIN_H_
85 |
--------------------------------------------------------------------------------
/resource_for_billibilli/debug_plugin/plugin/infer_plugin_api.cc:
--------------------------------------------------------------------------------
1 | #include "infer_plugin_api.h"
2 |
3 | #include
4 | #include
5 | #include