├── .gitignore ├── README.md ├── backend ├── VERSION.txt ├── setup.cfg ├── setup.py └── trtis │ ├── README.md │ ├── __init__.py │ ├── onnx_backend │ ├── __init__.py │ ├── onnxsim.py │ └── torch2onnx.py │ ├── set_config.py │ ├── tf_backend │ ├── __init__.py │ └── tf2graphdef.py │ └── trt_backend │ ├── __init__.py │ ├── calibrator.py │ ├── tf2trt.py │ └── torch2trt.py ├── client_py ├── VERSION.txt ├── setup.cfg ├── setup.py └── trt_client │ ├── __init__.py │ ├── client.py │ └── client_grpc.py ├── example ├── detection │ ├── calibrator_files │ │ ├── 1.jpg │ │ ├── 10.jpg │ │ ├── 2.jpg │ │ ├── 3.jpg │ │ ├── 4.jpg │ │ ├── 5.jpg │ │ ├── 6.jpg │ │ ├── 7.jpg │ │ ├── 8.jpg │ │ └── 9.jpg │ ├── client.py │ ├── client.sh │ ├── config.pbtxt │ ├── convert.sh │ ├── network.py │ ├── network │ │ ├── __init__.py │ │ ├── dla34.py │ │ └── resnet.py │ ├── post_process.py │ └── pre_process.py └── test-data │ └── widerface.jpg ├── install.sh └── start.sh /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | # TensorRT Inference Server 菜鸟教程 4 | 5 | 通过一个简单易懂,方便快捷的教程,部署一套完整的深度学习模型,一定程度可以满足部分工业界需求。对于不需要自己重写服务接口的团队来说,使用 tesorrt inference sever 作为服务,也足够了。 6 | 7 | 这里采取的案例是 centernet 检测,SSD,YOLO 系列都比较古老了,虽然教程也比较多,但是都不够简洁而且相对思想比较老,稍微用点新的。 8 | 9 | 本教程使用的检测模型暂时不提供 model zoo,主要原因是官方 release 的 model 都带 DCN 模块,这个模块有 c++ 层面的库,作为初学者来说,部署起来非常不方便,大家可以根据 centernet 官方提供的代码,自行训练不带 DCN 的模型。 10 | 11 | 本教程使用的是 DLA34 网络作为例子,模型文件位置: 12 | 13 | 链接: https://pan.baidu.com/s/1gcC7qcBi68W0hzJO8IeB3w 提取码: rsut 14 | 15 | 然后放置到 ./example/detection/network 下面 16 | 17 | 18 | #### 效果评估 19 | 20 | 如果在 p40 GPU 上部署,消耗时间最多的,是服务网络层面的通信,和把请求通过轮训方式发送到 GPU 上,本身模型计算是非常快的。 21 | 22 | 1. 一张卡上启动 16 个实例,占用显存为 2G 左右,单个客户端做异步请求,能够到 100 左右 QPS 23 | 2. 4 张卡,每张卡启动 16 个实例,占用显存为 2G 左右,单个客户端做异步请求,能够到 400-500 左右 QPS 24 | 25 | #### 文件结构与说明 26 | 27 | ```sh 28 | ./ 29 | ├── README.md 30 | ├── backend # 转换库 31 | │   ├── VERSION.txt 32 | │   ├── setup.cfg 33 | │   ├── setup.py 34 | │   └── trtis 35 | ├── client_py # python 客户端工具 36 | │   ├── VERSION.txt 37 | │   ├── setup.cfg 38 | │   ├── setup.py 39 | │   └── trt_client 40 | ├── example 41 | │   ├── detection # 检测前后预处理,网络,客户端等 42 | │   └── test-data # 数据 43 | ├── install.sh 44 | ├── model_repository 45 | └── start.sh 46 | ``` 47 | 48 | 49 | 50 | 51 | 52 | ## 前言 53 | 54 | 对于绝大多数深度学习部署问题,总是包含如下的基本操作:前处理,神经网络计算,后处理 55 | 56 | 值得注意的是,每个前处理不仅需要完成数据解析,标准化等常见操作,还可能需要保存输入数据的一些整体信息,比如原始图像大小,字符串标注信息等,这些 meta 信息需要交给后处理用来做各种针对性的问题,对于 centernet 来说,这个 meta 信息就是仿射变换。 57 | 58 | ```python 59 | nn_inputs, meta = preprocess(raw_image) 60 | nn_outputs = model(nn_inputs) 61 | result = postprocess(nn_outputs, meta) 62 | ``` 63 | 64 | 本教程的实现路径如下: 65 | 66 | 1. 前处理采取 tensorflow 编写,包括图像解析,resize,计算仿射变换矩阵,标准化等,保存成 tensorflow pd 文件 67 | 2. 神经网络部分是 torch,首先把 torch 的模型转换成 onnx,然后通过 onnx-simplifier 做进一步的简化,接着交由 tensorRT 进行进一步优化,以及做 int8 量化。 68 | - onnx-simplifier 的目的是为了更好地避免 onnx 到 tensorRT 的转换失败,但是,其并不能够百分百保证所有网络都能够被成功转换成 tensorRT,比如 torch 里面的 unsquezze 等 shape 层面的操作会有潜在问题,需要 model.py 里面改改。 69 | - onnx 有一定概率会掉性能点,这个原因暂时不明,onnx 解析 torch 的计算图时候,并不是一个算子对应一个 onnx 算子,这里面存在一些超参不一致等非常隐藏的问题。 70 | 3. 后处理是 torch 编写,然后转成 onnx,靠 onnx runtime 调度 71 | 4. tensorRT Inference Server 提供 ensemble 模式,可以联合调度 tensorflow 的 pd 文件,tensorRT plan 文件,onnx 格式文件,这样一来,可以把前处理,NN 计算,后处理都服务化,免除工程师搞复杂的编译工作和写 c++ 的工作,整个部署只需要写 python,特别通用高效,且没有竞争力 72 | 73 | 74 | 75 | ## 服务端搭建 76 | 77 | ```sh 78 | docker pull nvcr.io/nvidia/tensorrtserver:19.12-py3 79 | ``` 80 | 81 | 注意,这里面需要 nvidia 驱动版本大于 418 才行,cuda 版本要求是 10.1,详细配置参考: 82 | 83 | https://docs.nvidia.com/deeplearning/sdk/inference-release-notes/rel_19-12.html#rel_19-12 84 | 85 | 86 | 87 | ## 客户端搭建 88 | 89 | ```sh 90 | docker pull nvcr.io/nvidia/tensorrtserver:19.12-py3-clientsdk 91 | ``` 92 | 93 | 理论上来说,grpc 接口不依赖系统环境,没必要靠 docker 启动客户端,docker run 上述镜像以后,把 /workspace/install/python/tensorrtserver-1.9.0-py2.py3-none-linux_x86_64.whl 的安装文件取出来,直接在任意一台机器 pip install 便可 94 | 95 | ```sh 96 | # docker run --rm nvcr.io/nvidia/tensorrtserver:19.12-py3-clientsdk /bin/bash 97 | # copy `/workspace/install/python/tensorrtserver-1.9.0-py2.py3-none-linux_x86_64.whl` file to any linux machine 98 | # run the following commad 99 | pip install tensorrtserver-1.9.0-py2.py3-none-linux_x86_64.whl 100 | ``` 101 | 102 | 对于 c++ 来说,把 client 端的 SDK 抠下来找个地方编译自己的文件即可,这里比较烦,暂时不做例子。 103 | 104 | 105 | 106 | ## Inference Server Backend 安装 107 | 108 | 安装各种 backend,用于生成如下转换格式: 109 | 110 | - onnx=1.6.0 111 | - tensorRT=6.0.1.5 112 | - tensorflow=1.15.0 113 | - pytorch=1.3.0 114 | 115 | 安装 TensorRT-6.0.1.5,请参考 https://docs.nvidia.com/deeplearning/sdk/tensorrt-install-guide/index.html 116 | 117 | 安装其它 backend 库,目前只需要 python 端的即可: 118 | 119 | ```sh 120 | pip install onnx==1.6.0 onnxruntime==1.1.0 onnx-simplifier==0.2.2 121 | pip install tensorflow-gpu==1.5.0 122 | pip install torch==1.3.0 torchvision==0.4.1 123 | pip install opencv-python pillow pycuda 124 | ``` 125 | 126 | 127 | 128 | ## 开始教程 129 | 130 | 安装教程内的转换脚本和客户端接口,这个接口不仅能够完成转换,还能生成 tensorRT Inference Server 要求的 config 文件,所以,也适用于其它模型的转换,唯一问题在于 onnx 到 tensorRT 仍然没办法做百分百无缝转换 131 | 132 | ```sh 133 | cd backend 134 | python setup.py install 135 | cd client_py 136 | python setup.py install 137 | ``` 138 | 139 | 执行教程的 example,这个 example 会生成完整的 model_repository,剩下交给 tensorRT inference server 140 | 141 | ```sh 142 | cd example/detection 143 | ./convert.sh 144 | ``` 145 | 146 | model_repository 的文件结构如下: 147 | 148 | ```sh 149 | ./model_repository/ 150 | ├── detection 151 | │   ├── 1 152 | │   └── config.pbtxt 153 | ├── detection-network 154 | │   ├── 1 155 | │   │   └── model.plan 156 | │   └── config.pbtxt 157 | ├── detection-postprocess 158 | │   ├── 1 159 | │   │   └── model.onnx 160 | │   └── config.pbtxt 161 | └── detection-preprocess 162 | ├── 1 163 | │   └── model.graphdef 164 | └── config.pbtxt 165 | ``` 166 | 167 | 启动服务: 168 | 169 | ```sh 170 | #!/bin/bash 171 | HTTP_PORT=7000 172 | GRPC_PORT=7001 173 | METRIC_PORT=7002 174 | DOCKER_IMAGE=nvcr.io/nvidia/tensorrtserver:19.12-py3 175 | MODEL_REPOSITORY=./model_repository 176 | 177 | docker run --rm \ 178 | --runtime nvidia \ 179 | --name trt_server \ 180 | --shm-size=4g \ 181 | --ulimit memlock=-1 \ 182 | --ulimit stack=67108864 \ 183 | -p${HTTP_PORT}:8000 \ 184 | -p${GRPC_PORT}:8001 \ 185 | -p${METRIC_PORT}:8002 \ 186 | -v${MODEL_REPOSITORY}/:/models \ 187 | ${DOCKER_IMAGE} \ 188 | trtserver --model-repository=/models 189 | ``` 190 | 191 | 使用 client: 192 | 193 | ```sh 194 | cd example/detection 195 | ./client.sh 196 | ``` 197 | 198 | 199 | 200 | #### python 客户端使用 201 | 202 | 单步调度举例: 203 | 204 | ```python 205 | from trt_client import client 206 | import numpy as np 207 | 208 | raw_image = open("./xxx.jpg", "rb").read() 209 | raw_image = np.array([raw_image], dtype=bytes) 210 | 211 | runner = client.Inference( 212 | url="xx.xxx.xxx.xxx:7001", # grpc 213 | model_name="detection", 214 | model_version="1" 215 | ) 216 | results = runner.run(input={"raw_image": raw_image}) 217 | print(results) 218 | ``` 219 | 220 | 异步非阻塞调度举例: 221 | 222 | ```python 223 | from trt_client import client 224 | import numpy as np 225 | 226 | runner = client.Inference( 227 | url="xx.xxx.xxx.xxx:7001", # grpc 228 | model_name="detection", 229 | model_version="1" 230 | ) 231 | 232 | for i in range(10): 233 | raw_image = open("./{}.jpg".format(i), "rb").read() 234 | raw_image = np.array([raw_image], dtype=bytes) 235 | results = runner.async_run( 236 | input={"raw_image": raw_image}, 237 | input_id="image_{}".format(i) 238 | ) 239 | for i in range(10): 240 | input_id, results = runner.get(block=True) 241 | ``` 242 | -------------------------------------------------------------------------------- /backend/VERSION.txt: -------------------------------------------------------------------------------- 1 | 0.1.0 2 | -------------------------------------------------------------------------------- /backend/setup.cfg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/layerism/TensorRT-Inference-Server-Tutorial/c40d485f69c1349c01b40a4eb7c225b01bf8dbe0/backend/setup.cfg -------------------------------------------------------------------------------- /backend/setup.py: -------------------------------------------------------------------------------- 1 | from os.path import dirname, join 2 | 3 | from setuptools import find_packages, setup 4 | 5 | # from pip.req import parse_requirements 6 | 7 | 8 | def parse_requirements(filename): 9 | """ load requirements from a pip requirements file """ 10 | lineiter = (line.strip() for line in open(filename)) 11 | return [line for line in lineiter if line and not line.startswith("#")] 12 | 13 | 14 | with open(join(dirname(__file__), './VERSION.txt'), 'rb') as f: 15 | version = f.read().decode('ascii').strip() 16 | 17 | 18 | setup( 19 | name='trtis', 20 | version='0.1.0', 21 | keywords='363246', 22 | description='a library for DS CAA Developer', 23 | license='MIT License', 24 | url='', 25 | author='layersim', 26 | author_email='', 27 | packages=find_packages(), 28 | include_package_data=True, 29 | platforms='any', 30 | install_requires=[ 31 | "numpy>=1.16.0", 32 | "protobuf>=3.8.0", 33 | "onnx>=1.6.0", 34 | "pycuda>=2018.1.2", 35 | "tensorflow-gpu>=1.15.0" 36 | ], 37 | python_requires='>=3.6' 38 | ) 39 | -------------------------------------------------------------------------------- /backend/trtis/README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/layerism/TensorRT-Inference-Server-Tutorial/c40d485f69c1349c01b40a4eb7c225b01bf8dbe0/backend/trtis/README.md -------------------------------------------------------------------------------- /backend/trtis/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/layerism/TensorRT-Inference-Server-Tutorial/c40d485f69c1349c01b40a4eb7c225b01bf8dbe0/backend/trtis/__init__.py -------------------------------------------------------------------------------- /backend/trtis/onnx_backend/__init__.py: -------------------------------------------------------------------------------- 1 | from .torch2onnx import * 2 | from .mxnet2onnx import * 3 | -------------------------------------------------------------------------------- /backend/trtis/onnx_backend/onnxsim.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from typing import Dict, List, Optional, Union 3 | 4 | import numpy as np # type: ignore 5 | 6 | import onnx # type: ignore 7 | import onnx.helper # type: ignore 8 | import onnx.optimizer # type: ignore 9 | import onnx.shape_inference # type: ignore 10 | import onnxruntime as rt # type: ignore 11 | 12 | TensorShape = List[int] 13 | TensorShapes = Dict[Optional[str], TensorShape] 14 | 15 | 16 | def add_features_to_output(m: onnx.ModelProto) -> None: 17 | """ 18 | Add features to output in pb, so that ONNX Runtime will output them. 19 | :param m: the model that will be run in ONNX Runtime 20 | """ 21 | for node in m.graph.node: 22 | for output in node.output: 23 | m.graph.output.extend([onnx.ValueInfoProto(name=output)]) 24 | 25 | 26 | def get_shape_from_value_info_proto(v: onnx.ValueInfoProto) -> List[int]: 27 | return [dim.dim_value for dim in v.type.tensor_type.shape.dim] 28 | 29 | 30 | def get_value_info_all(m: onnx.ModelProto, name: str) -> Optional[onnx.ValueInfoProto]: 31 | for v in m.graph.value_info: 32 | if v.name == name: 33 | return v 34 | 35 | for v in m.graph.input: 36 | if v.name == name: 37 | return v 38 | 39 | for v in m.graph.output: 40 | if v.name == name: 41 | return v 42 | 43 | return None 44 | 45 | 46 | def get_shape(m: onnx.ModelProto, name: str) -> TensorShape: 47 | """ 48 | Note: This method relies on onnx shape inference, 49 | which is not reliable. So only use it on input or output tensors 50 | """ 51 | v = get_value_info_all(m, name) 52 | if v is not None: 53 | return get_shape_from_value_info_proto(v) 54 | raise RuntimeError('Cannot get shape of "{}"'.format(name)) 55 | 56 | 57 | def get_elem_type(m: onnx.ModelProto, name: str) -> int: 58 | v = get_value_info_all(m, name) 59 | if v is not None: 60 | return v.type.tensor_type.elem_type 61 | raise RuntimeError('Cannot get type of "{}"'.format(name)) 62 | 63 | 64 | def get_np_type_from_elem_type(elem_type: int) -> int: 65 | sizes = ( 66 | None, 67 | np.float32, 68 | np.uint8, 69 | np.int8, 70 | np.uint16, 71 | np.int16, 72 | np.int32, 73 | np.int64, 74 | None, 75 | None, 76 | np.float16, 77 | np.double, 78 | np.uint32, 79 | np.uint64, 80 | None, 81 | None, 82 | np.float16 83 | ) 84 | assert len(sizes) == 17 85 | size = sizes[elem_type] 86 | assert size is not None 87 | return size 88 | 89 | 90 | def get_input_names(model: onnx.ModelProto) -> List[str]: 91 | input_names = list( 92 | set([ipt.name for ipt in model.graph.input]) - 93 | set([x.name for x in model.graph.initializer]) 94 | ) 95 | return input_names 96 | 97 | 98 | def add_initializers_into_inputs(model: onnx.ModelProto) -> onnx.ModelProto: 99 | # Due to a onnx bug, https://github.com/onnx/onnx/issues/2417, 100 | # we need to add missing initializers into inputs 101 | for x in model.graph.initializer: 102 | input_names = [x.name for x in model.graph.input] 103 | if x.name not in input_names: 104 | shape = onnx.TensorShapeProto() 105 | for dim in x.dims: 106 | shape.dim.extend([onnx.TensorShapeProto.Dimension(dim_value=dim)]) 107 | model.graph.input.extend( 108 | [ 109 | onnx.ValueInfoProto( 110 | name=x.name, 111 | type=onnx.TypeProto( 112 | tensor_type=onnx.TypeProto. 113 | Tensor(elem_type=x.data_type, shape=shape) 114 | ) 115 | ) 116 | ] 117 | ) 118 | return model 119 | 120 | 121 | def generate_rand_input(model, input_shapes: TensorShapes = {}): 122 | input_names = get_input_names(model) 123 | full_input_shapes = {ipt: get_shape(model, ipt) for ipt in input_names} 124 | assert None not in input_shapes 125 | full_input_shapes.update(input_shapes) # type: ignore 126 | for key in full_input_shapes: 127 | if np.prod(full_input_shapes[key]) <= 0: 128 | raise RuntimeError( 129 | 'The shape of input "{}" has dynamic size, ' 130 | 'please determine the input size manually by --input-shape xxx'. 131 | format(key) 132 | ) 133 | 134 | inputs = { 135 | ipt: np.array( 136 | np.random.rand(*full_input_shapes[ipt]), 137 | dtype=get_np_type_from_elem_type(get_elem_type(model, ipt)) 138 | ) for ipt in input_names 139 | } 140 | return inputs 141 | 142 | 143 | def get_constant_nodes(m: onnx.ModelProto) -> List[onnx.NodeProto]: 144 | const_nodes = [] 145 | const_tensors = [x.name for x in m.graph.initializer] 146 | const_tensors.extend( 147 | [node.output[0] for node in m.graph.node if node.op_type == 'Constant'] 148 | ) 149 | 150 | for node in m.graph.node: 151 | if node.op_type == 'Shape': 152 | const_nodes.append(node) 153 | const_tensors.extend(node.output) 154 | elif all([x in const_tensors for x in node.input]): 155 | const_nodes.append(node) 156 | const_tensors.extend(node.output) 157 | return const_nodes 158 | 159 | 160 | def forward(model, inputs=None, input_shapes: TensorShapes = {}) -> Dict[str, np.ndarray]: 161 | sess = rt.InferenceSession(model.SerializeToString()) 162 | if inputs is None: 163 | inputs = generate_rand_input(model, input_shapes=input_shapes) 164 | outputs = [x.name for x in sess.get_outputs()] 165 | res = OrderedDict(zip(outputs, sess.run(outputs, inputs))) 166 | return res 167 | 168 | 169 | def forward_all(model: onnx.ModelProto, 170 | input_shapes: TensorShapes = {}) -> Dict[str, np.ndarray]: 171 | import copy 172 | model = copy.deepcopy(model) 173 | add_features_to_output(model) 174 | res = forward(model, input_shapes=input_shapes) 175 | return res 176 | 177 | 178 | def eliminate_const_nodes( 179 | model: onnx.ModelProto, const_nodes: List[onnx.NodeProto], res: Dict[str, np.ndarray] 180 | ) -> onnx.ModelProto: 181 | """ 182 | :param model: the original onnx model 183 | :param const_nodes: const nodes detected by `get_constant_nodes` 184 | :param res: The dict containing all tensors, got by `forward_all` 185 | :return: the simplified onnx model. Redundant ops are all removed. 186 | """ 187 | for node in model.graph.node[:]: 188 | if node in const_nodes: 189 | assert len(node.output) == 1 190 | node.op_type = 'Constant' 191 | elem_type = get_elem_type(model, node.output[0]) 192 | shape = res[node.output[0]].shape 193 | new_attr = onnx.helper.make_attribute( 194 | 'value', 195 | onnx.helper.make_tensor( 196 | name=node.output[0], 197 | data_type=elem_type, 198 | dims=shape, 199 | vals=np.array(res[node.output[0]]).flatten().astype( 200 | get_np_type_from_elem_type(elem_type) 201 | ) 202 | ) 203 | ) 204 | del node.input[:] 205 | del node.attribute[:] 206 | node.attribute.extend([new_attr]) 207 | return model 208 | 209 | 210 | def optimize(model: onnx.ModelProto) -> onnx.ModelProto: 211 | """ 212 | :param model: The onnx model. 213 | :return: The optimized onnx model. 214 | Before simplifying, use this method to generate value_info, which is used in `forward_all` 215 | After simplifying, use this method to fold constants generated in previous step into initializer, 216 | and eliminate unused constants. 217 | """ 218 | onnx.helper.strip_doc_string(model) 219 | model = onnx.optimizer.optimize( 220 | model, 221 | [ 222 | 'eliminate_deadend', 223 | 'eliminate_identity', 224 | 'eliminate_nop_dropout', 225 | 'eliminate_nop_monotone_argmax', 226 | 'eliminate_nop_pad', 227 | 'extract_constant_to_initializer', 228 | 'eliminate_unused_initializer', 229 | 'eliminate_nop_transpose', 230 | 'fuse_add_bias_into_conv', 231 | 'fuse_bn_into_conv', 232 | # https://github.com/daquexian/onnx-simplifier/issues/31 233 | # 'fuse_consecutive_concats', 234 | 'fuse_consecutive_log_softmax', 235 | 'fuse_consecutive_reduce_unsqueeze', 236 | 'fuse_consecutive_squeezes', 237 | 'fuse_consecutive_transposes', 238 | 'fuse_matmul_add_bias_into_gemm', 239 | 'fuse_pad_into_conv', 240 | 'fuse_transpose_into_gemm' 241 | ], 242 | fixed_point=True 243 | ) 244 | return model 245 | 246 | 247 | def check( 248 | model_opt: onnx.ModelProto, 249 | model_ori: onnx.ModelProto, 250 | n_times: int = 5, 251 | input_shapes: TensorShapes = {} 252 | ) -> None: 253 | """ 254 | Warning: Some models (e.g., MobileNet) may fail this check by a small magnitude. 255 | Just ignore if it happens. 256 | :param model_opt: The simplified ONNX model 257 | :param model_ori: The original ONNX model 258 | :param n_times: Generate n random inputs 259 | """ 260 | onnx.checker.check_model(model_opt) 261 | for i in range(n_times): 262 | print("Checking {}/{}...".format(i, n_times)) 263 | rand_input = generate_rand_input(model_opt, input_shapes=input_shapes) 264 | res_opt = forward(model_opt, inputs=rand_input) 265 | res_ori = forward(model_ori, inputs=rand_input) 266 | 267 | for name in res_opt.keys(): 268 | if not np.allclose(res_opt[name], res_ori[name], rtol=1e-4, atol=1e-5): 269 | print( 270 | "Tensor {} changes after simplifying. The max diff is {}.".format( 271 | name, np.max(np.abs(res_opt[name] - res_ori[name])) 272 | ) 273 | ) 274 | print("Note that the checking is not always correct.") 275 | 276 | 277 | def clean_constant_nodes(const_nodes: List[onnx.NodeProto], res: Dict[str, np.ndarray]): 278 | """ 279 | It seems not needed since commit 6f2a72, but maybe it still prevents some unknown bug 280 | :param const_nodes: const nodes detected by `get_constant_nodes` 281 | :param res: The dict containing all tensors, got by `forward_all` 282 | :return: The constant nodes which have an output in res 283 | """ 284 | return [node for node in const_nodes if node.output[0] in res] 285 | 286 | 287 | def check_and_update_input_shapes( 288 | model: onnx.ModelProto, input_shapes: TensorShapes 289 | ) -> TensorShapes: 290 | input_names = get_input_names(model) 291 | if None in input_shapes: 292 | if len(input_names) == 1: 293 | input_shapes[input_names[0]] = input_shapes[None] 294 | del input_shapes[None] 295 | else: 296 | raise RuntimeError( 297 | 'The model has more than 1 inputs, please use the format "input_name:dim0,dim1,...,dimN" in --input-shape' 298 | ) 299 | for x in input_shapes: 300 | if x not in input_names: 301 | raise RuntimeError('The model doesn\'t have input named "{}"'.format(x)) 302 | return input_shapes 303 | 304 | 305 | def simplify( 306 | model_ori: Union[str, onnx.ModelProto], 307 | check_n: int = 0, 308 | perform_optimization: bool = True, 309 | input_shapes: TensorShapes = {} 310 | ) -> onnx.ModelProto: 311 | if type(model_ori) == str: 312 | model_ori = onnx.load(model_ori) 313 | 314 | onnx.checker.check_model(model_ori) 315 | model_ori = add_initializers_into_inputs(model_ori) 316 | 317 | input_shapes = check_and_update_input_shapes(model_ori, input_shapes) 318 | 319 | model_opt = onnx.shape_inference.infer_shapes(model_ori) 320 | if perform_optimization: 321 | model_opt = optimize(model_opt) 322 | 323 | const_nodes = get_constant_nodes(model_opt) 324 | res = forward_all(model_opt, input_shapes=input_shapes) 325 | const_nodes = clean_constant_nodes(const_nodes, res) 326 | model_opt = eliminate_const_nodes(model_opt, const_nodes, res) 327 | 328 | if perform_optimization: 329 | model_opt = optimize(model_opt) 330 | 331 | check(model_opt, model_ori, check_n, input_shapes=input_shapes) 332 | 333 | return model_opt 334 | -------------------------------------------------------------------------------- /backend/trtis/onnx_backend/torch2onnx.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | from functools import partial 5 | 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | 10 | import onnx as onnx 11 | from onnx import shape_inference, optimizer 12 | 13 | from trtis.set_config import TORCH_DTYPE, generate_trtis_config 14 | from trtis.onnx_backend import onnxsim 15 | 16 | 17 | class MergeModel(nn.Module): 18 | 19 | def __init__(self, model, preprocess, postprocess): 20 | super(MergeModel, self).__init__() 21 | self.model = model 22 | self.preprocess = preprocess 23 | self.postprocess = postprocess 24 | 25 | def forward(self, x): 26 | x = self.preprocess(x) 27 | x = self.model(x) 28 | x = self.postprocess(x) 29 | return x 30 | 31 | 32 | class WrapperFunc(nn.Module): 33 | 34 | def __init__(self, func): 35 | super(WrapperFunc, self).__init__() 36 | self.func = func 37 | 38 | def forward(self, x): 39 | return self.func(x) 40 | 41 | 42 | class WrapperModel(nn.Module): 43 | 44 | def __init__(self, model): 45 | super(WrapperModel, self).__init__() 46 | self.model = model 47 | 48 | def forward(self, x): 49 | x = self.model(x) 50 | return x 51 | 52 | 53 | def modify_to_dynamic_shape(model, input_names, output_names): 54 | for input in model.graph.input: 55 | if input.name not in input_names: 56 | continue 57 | input.type.tensor_type.shape.dim[0].dim_param = '?' 58 | #input.type.tensor_type.shape.dim[1].dim_param = '?' 59 | #input.type.tensor_type.shape.dim[2].dim_param = '?' 60 | #input.type.tensor_type.shape.dim[3].dim_param = '?' 61 | for output in model.graph.output: 62 | if output.name not in output_names: 63 | continue 64 | #for i in range(len(output.type.tensor_type.shape.dim)): 65 | # output.type.tensor_type.shape.dim[i].dim_param = '?' 66 | 67 | return model 68 | 69 | 70 | def optim_onnx(onnx_path, verbose=True): 71 | model = onnx.load(onnx_path) 72 | print("Begin Simplify ONNX Model ...") 73 | passes = [ 74 | "eliminate_deadend", 75 | "eliminate_identity", 76 | "extract_constant_to_initializer", 77 | "eliminate_unused_initializer", 78 | "fuse_add_bias_into_conv", 79 | "fuse_bn_into_conv", 80 | "fuse_matmul_add_bias_into_gemm" 81 | ] 82 | model = optimizer.optimize(model, passes) 83 | #model = shape_inference.infer_shapes(model) 84 | #model = onnxsim.simplify(model) 85 | 86 | if verbose: 87 | for m in onnx.helper.printable_graph(model.graph).split("\n"): 88 | print(m) 89 | 90 | return model 91 | 92 | 93 | def torch2onnx( 94 | computation_graph, 95 | graph_name="model", 96 | model_file=None, 97 | inputs_def=[{ 98 | "name": None, "shape": [] 99 | }], 100 | outputs_def=[{ 101 | "name": None, "shape": [] 102 | }], 103 | instances=1, 104 | gpus=[0], 105 | version=1, 106 | export_path="./", 107 | opset_version=10, 108 | max_batch_size=0, 109 | device="cuda", 110 | gen_trtis_config=True, 111 | verbose=True 112 | ): 113 | 114 | if not isinstance(computation_graph, nn.Module): 115 | model = WrapperFunc(computation_graph) 116 | else: 117 | model = WrapperModel(computation_graph) 118 | 119 | if device == "cuda": 120 | model = model.cuda() 121 | else: 122 | model = model.cpu() 123 | 124 | # model key : A.B.x.x.x.x 125 | # pth-file key : C.A.E.x.x.x.x 126 | # load the model by non-exactly match pth-file 127 | if model_file is not None: 128 | checkpoint = torch.load(model_file)['state_dict'] 129 | print("loading pt-file finishing .... ") 130 | state_dict = {} 131 | required_keys = list(model.state_dict().keys()) 132 | # for required_key in required_keys: 133 | # value = checkpoint.get(required_key.split(".", 1)[1], None) 134 | # if value is None: 135 | # print("missing key: {}".format(required_key)) 136 | # continue 137 | # state_dict[required_key] = value 138 | 139 | for required_key, (key, value) in zip(required_keys, checkpoint.items()): 140 | if required_key.endswith(key.split(".", 1)[1]): 141 | print("pth-key: [{:60s}] ---> model-key: [{}]".format(key, required_key)) 142 | else: 143 | print("pth-key: [{:60s}] -\\-> model-key: [{}]".format(key, required_key)) 144 | continue 145 | 146 | state_dict[required_key] = value 147 | model.load_state_dict(state_dict) 148 | print("loading model finishing .... ") 149 | 150 | dummy_inputs = [] 151 | input_names = [] 152 | for i, input in enumerate(inputs_def): 153 | name = input.get("name", None) 154 | shape = input.get("dims", None) 155 | dtype = TORCH_DTYPE[input.get("data_type", None)] 156 | dummy_input = torch.ones(shape).to(dtype) 157 | #dummy_input = (torch.rand(shape) * 255).to(torch.uint8) 158 | if device == "cuda": 159 | dummy_input = dummy_input.cuda() 160 | dummy_inputs.append(dummy_input) 161 | input_names.append(name) 162 | 163 | dummy_inputs = dummy_inputs[0] if len(dummy_inputs) == 1 else dummy_inputs 164 | 165 | output_names = [] 166 | for i, output in enumerate(outputs_def): 167 | shape = output.get("dims", None) 168 | name = output.get("name", None) 169 | output_names.append(name) 170 | 171 | export_path = os.path.join(export_path, graph_name) 172 | os.system("mkdir -p {}".format(export_path)) 173 | os.system("mkdir -p {}/{}".format(export_path, version)) 174 | 175 | onnx_path = "{}/{}/{}.onnx".format(export_path, version, "model") 176 | 177 | torch.onnx.export( 178 | model, 179 | dummy_inputs, 180 | onnx_path, 181 | verbose=False, 182 | input_names=input_names, 183 | output_names=output_names, 184 | opset_version=opset_version, 185 | keep_initializers_as_inputs=True, 186 | #dynamic_axes=dict(dynamic_batches) 187 | ) 188 | 189 | model = optim_onnx(onnx_path) 190 | onnx.save(model, onnx_path) 191 | os.system("python -m onnxsim {} {}".format(onnx_path, onnx_path)) 192 | 193 | if gen_trtis_config: 194 | generate_trtis_config( 195 | graph_name=graph_name, 196 | platform="onnxruntime_onnx", 197 | inputs_def=inputs_def, 198 | outputs_def=outputs_def, 199 | max_batch_size=max_batch_size, 200 | instances=instances, 201 | gpus=gpus, 202 | export_path=export_path, 203 | verbose=verbose 204 | ) 205 | 206 | return model, onnx_path 207 | -------------------------------------------------------------------------------- /backend/trtis/set_config.py: -------------------------------------------------------------------------------- 1 | #!/python 2 | import argparse 3 | import json 4 | import os 5 | import torch 6 | import tensorflow as tf 7 | 8 | 9 | TORCH_DTYPE = { 10 | "TYPE_FP32": torch.float32, 11 | "TYPE_INT32": torch.int32, 12 | "TYPE_UINT8": torch.uint8, 13 | "TYPE_FP16": torch.float16 14 | } 15 | 16 | 17 | TF_DTYPE = { 18 | "TYPE_FP32": tf.float32, 19 | "TYPE_INT32": tf.int32, 20 | "TYPE_UINT8": tf.uint8, 21 | "TYPE_FP16": tf.float16, 22 | "TYPE_STRING": tf.string 23 | } 24 | 25 | 26 | TENSOR = """\ 27 | {{ 28 | name: "{name}", 29 | dims: {dims}, 30 | data_type: {data_type} 31 | }}\ 32 | """ 33 | 34 | 35 | TENSOR_RESHAPE = """\ 36 | {{ 37 | name: "{name}", 38 | dims: {dims}, 39 | data_type: {data_type}, 40 | reshape: {{ shape: {reshape} }} 41 | }}\ 42 | """ 43 | 44 | 45 | CONFIG_PBTXT = """\ 46 | name: "{name}" 47 | platform: "{platform}" 48 | version_policy: {{ all {{ }} }} 49 | max_batch_size: {max_batch_size} 50 | input {input} 51 | output {output} 52 | instance_group [ 53 | {{ 54 | count: {instances} 55 | kind: KIND_GPU 56 | gpus: {gpus} 57 | }} 58 | ]\ 59 | """ 60 | 61 | 62 | def data_def_dumps(node_def, remove_batch_dim=False): 63 | data = [] 64 | for node in node_def: 65 | name = node.get("name", None) 66 | dims = node.get("dims", None) 67 | data_type = node.get("data_type", "TYPE_FP32") 68 | #format = node.get("format", "FORMAT_NONE") 69 | reshape = node.get("reshape", dims) 70 | if remove_batch_dim is False: 71 | format_output = TENSOR.format( 72 | name=name, 73 | dims=dims, 74 | data_type=data_type 75 | ) 76 | else: 77 | format_output = TENSOR_RESHAPE.format( 78 | name=name, 79 | dims=dims, 80 | data_type=data_type, 81 | reshape=dims[1:] 82 | ) 83 | data.append(format_output) 84 | 85 | data = "[\n" + ",\n".join(data) + "\n]" 86 | return data 87 | 88 | 89 | def generate_trtis_config( 90 | graph_name="dd", 91 | platform="tensorrt_plan", 92 | inputs_def={}, 93 | outputs_def={}, 94 | max_batch_size=1, 95 | instances=1, 96 | gpus=[0], 97 | export_path="./", 98 | verbose=True 99 | ): 100 | 101 | remove_batch_dim = True if platform is "tensorrt_plan" else False 102 | 103 | config_path = "{}/config.pbtxt".format(export_path) 104 | config_content = CONFIG_PBTXT.format( 105 | name=str(graph_name), 106 | platform=str(platform), 107 | max_batch_size=max_batch_size, 108 | input=data_def_dumps(inputs_def, remove_batch_dim), 109 | output=data_def_dumps(outputs_def, remove_batch_dim), 110 | instances=instances, 111 | gpus=gpus 112 | ) 113 | 114 | with open(config_path, "w") as cfg: 115 | cfg.write(config_content) 116 | 117 | if verbose: 118 | print(config_content) 119 | -------------------------------------------------------------------------------- /backend/trtis/tf_backend/__init__.py: -------------------------------------------------------------------------------- 1 | from .tf2graphdef import * 2 | -------------------------------------------------------------------------------- /backend/trtis/tf_backend/tf2graphdef.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import tensorflow as tf 5 | from tensorflow.python.framework import graph_util 6 | 7 | from trtis.set_config import TF_DTYPE, generate_trtis_config 8 | 9 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 10 | 11 | 12 | def tf2graphdef( 13 | computation_graph, 14 | graph_name="model", 15 | model_file=None, 16 | inputs_def=[{ 17 | "name": None, "dims": [] 18 | }], 19 | outputs_def=[{ 20 | "name": None, "dims": [] 21 | }], 22 | instances=1, 23 | gpus=[0], 24 | version=1, 25 | export_path="./", 26 | max_batch_size=0, 27 | device="cuda", 28 | gen_trtis_config=True, 29 | verbose=True 30 | ): 31 | 32 | export_path = os.path.join(export_path, graph_name) 33 | model_path = os.path.join(export_path, str(version), "model.graphdef") 34 | os.system("rm -rf {}".format(export_path)) 35 | os.system("mkdir -p {}".format(export_path)) 36 | os.system("mkdir -p {}/{}".format(export_path, version)) 37 | 38 | graph = tf.Graph() 39 | with graph.as_default() as g: 40 | dummy_inputs = [] 41 | input_names = [] 42 | for i, input in enumerate(inputs_def): 43 | name = input.get("name", None) 44 | shape = input.get("dims", None) 45 | dtype = TF_DTYPE[input.get("data_type", None)] 46 | tf_shape = list(map(lambda dim: None if dim is -1 else dim, shape)) 47 | dummy_input = tf.placeholder(dtype, tf_shape, name=name) 48 | #dummy_input = (torch.rand(shape) * 255).to(torch.uint8) 49 | #dummy_inputs.append({name: dummy_input}) 50 | dummy_inputs.append(dummy_input) 51 | input_names.append(name) 52 | 53 | dummy_outputs = computation_graph(*dummy_inputs) 54 | if type(dummy_outputs) not in [list, tuple]: 55 | dummy_outputs = [dummy_outputs] 56 | 57 | output_names = [] 58 | for i, output in enumerate(outputs_def): 59 | name = output.get("name", None) 60 | shape = output.get("dims", None) 61 | dtype = TF_DTYPE[output.get("data_type", None)] 62 | output = tf.identity(dummy_outputs[i], name=name) 63 | output_names.append(name) 64 | 65 | config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False) 66 | config.gpu_options.allow_growth = True 67 | config.gpu_options.visible_device_list = ",".join(map(str, gpus)) 68 | with tf.Session(graph=graph, config=config) as sess: 69 | frozen_graph_def = graph_util.convert_variables_to_constants( 70 | sess, sess.graph_def, output_names 71 | ) 72 | 73 | with open(model_path, 'wb') as f: 74 | f.write(frozen_graph_def.SerializeToString()) 75 | 76 | if gen_trtis_config: 77 | generate_trtis_config( 78 | graph_name=graph_name, 79 | platform="tensorflow_graphdef", 80 | inputs_def=inputs_def, 81 | outputs_def=outputs_def, 82 | max_batch_size=max_batch_size, 83 | instances=instances, 84 | gpus=gpus, 85 | export_path=export_path, 86 | verbose=verbose 87 | ) 88 | -------------------------------------------------------------------------------- /backend/trtis/trt_backend/__init__.py: -------------------------------------------------------------------------------- 1 | from .torch2trt import * 2 | from .calibrator import * 3 | -------------------------------------------------------------------------------- /backend/trtis/trt_backend/calibrator.py: -------------------------------------------------------------------------------- 1 | import ctypes 2 | import glob 3 | import os 4 | import random 5 | 6 | import numpy as np 7 | 8 | import pycuda.autoinit 9 | import pycuda.driver as cuda 10 | import tensorrt as trt 11 | 12 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 13 | 14 | 15 | def create_calibration_dataset(datasets_path, n=100): 16 | # Create list of calibration images (filename) 17 | # This sample code picks 100 images at random from training set 18 | datasets_path = os.path.join(datasets_path, "*") 19 | calibration_files = glob.glob(datasets_path) 20 | random.shuffle(calibration_files) 21 | return calibration_files[:n] 22 | 23 | 24 | class ImageBatchStream(object): 25 | 26 | def __init__(self, calibration_path, batch_size, preprocessor): 27 | self.batch_size = batch_size 28 | self.preprocessor = preprocessor 29 | 30 | self.batch = 0 31 | calibration_files = create_calibration_dataset(calibration_path) 32 | self.max_batches = len(calibration_files) // batch_size 33 | self.max_batches += 1 if (len(calibration_files) % batch_size) else 0 34 | self.files = calibration_files 35 | 36 | def reset(self): 37 | self.batch = 0 38 | 39 | def next_batch(self, verbose=True): 40 | if self.batch < self.max_batches: 41 | start = self.batch_size * self.batch 42 | end = self.batch_size * (self.batch + 1) 43 | 44 | calibration_data = [] 45 | for image_path in self.files[start:end]: 46 | if verbose: 47 | print("[ImageBatchStream] Processing ", image_path) 48 | raw_image = open(image_path, "rb").read() 49 | img = self.preprocessor(raw_image) 50 | calibration_data.append(img) 51 | 52 | self.batch += 1 53 | return np.concatenate(calibration_data, 0) 54 | else: 55 | return np.array([]) 56 | 57 | 58 | class IInt8EntropyCalibrator(trt.IInt8EntropyCalibrator): 59 | 60 | def __init__(self, input_def, stream, cache_file="./calibration_cache.bin"): 61 | trt.IInt8EntropyCalibrator.__init__(self) 62 | 63 | self.input_layers = [] 64 | for input in input_def: 65 | name = input.get("name", None) 66 | self.input_layers.append(name) 67 | 68 | self.stream = stream 69 | calibration_data = self.stream.next_batch(verbose=False) 70 | self.d_input = cuda.mem_alloc(calibration_data.nbytes) 71 | self.stream.reset() 72 | 73 | self.cache_file = cache_file 74 | os.system("rm -f {}".format(self.cache_file)) 75 | 76 | def get_batch_size(self): 77 | return self.stream.batch_size 78 | 79 | def get_batch(self, names): 80 | data = self.stream.next_batch() 81 | cuda.memcpy_htod(self.d_input, data) 82 | 83 | if data.size == 0: 84 | return None 85 | else: 86 | return [int(self.d_input)] 87 | 88 | def read_calibration_cache(self): 89 | # If there is a cache, use it instead of calibrating again. 90 | # Otherwise, implicitly return None. 91 | if os.path.exists(self.cache_file): 92 | with open(self.cache_file, "rb") as f: 93 | return f.read() 94 | 95 | def write_calibration_cache(self, cache): 96 | with open(self.cache_file, "wb") as f: 97 | f.write(cache) 98 | 99 | 100 | class IInt8EntropyCalibrator2(trt.IInt8EntropyCalibrator2): 101 | 102 | def __init__(self, input_def, stream, cache_file="./calibration_cache.bin"): 103 | trt.IInt8EntropyCalibrator2.__init__(self) 104 | 105 | self.input_layers = [] 106 | for input in input_def: 107 | name = input.get("name", None) 108 | self.input_layers.append(name) 109 | 110 | self.stream = stream 111 | calibration_data = self.stream.next_batch(verbose=False) 112 | self.d_input = cuda.mem_alloc(calibration_data.nbytes) 113 | self.stream.reset() 114 | 115 | self.cache_file = cache_file 116 | os.system("rm -f {}".format(self.cache_file)) 117 | 118 | def get_batch_size(self): 119 | return self.stream.batch_size 120 | 121 | def get_batch(self, names): 122 | data = self.stream.next_batch() 123 | cuda.memcpy_htod(self.d_input, data) 124 | 125 | if data.size == 0: 126 | return None 127 | else: 128 | return [int(self.d_input)] 129 | 130 | def read_calibration_cache(self): 131 | # If there is a cache, use it instead of calibrating again. 132 | # Otherwise, implicitly return None. 133 | if os.path.exists(self.cache_file): 134 | with open(self.cache_file, "rb") as f: 135 | return f.read() 136 | 137 | def write_calibration_cache(self, cache): 138 | with open(self.cache_file, "wb") as f: 139 | f.write(cache) 140 | 141 | 142 | class IInt8MinMaxCalibrator(trt.IInt8MinMaxCalibrator): 143 | 144 | def __init__(self, input_def, stream, cache_file="./calibration_cache.bin"): 145 | trt.IInt8MinMaxCalibrator.__init__(self) 146 | 147 | self.input_layers = [] 148 | for input in input_def: 149 | name = input.get("name", None) 150 | self.input_layers.append(name) 151 | 152 | self.stream = stream 153 | calibration_data = self.stream.next_batch(verbose=False) 154 | self.d_input = cuda.mem_alloc(calibration_data.nbytes) 155 | self.stream.reset() 156 | 157 | self.cache_file = cache_file 158 | os.system("rm -f {}".format(self.cache_file)) 159 | 160 | def get_batch_size(self): 161 | return self.stream.batch_size 162 | 163 | def get_batch(self, names): 164 | data = self.stream.next_batch() 165 | cuda.memcpy_htod(self.d_input, data) 166 | 167 | if data.size == 0: 168 | return None 169 | else: 170 | return [int(self.d_input)] 171 | 172 | def read_calibration_cache(self): 173 | # If there is a cache, use it instead of calibrating again. 174 | # Otherwise, implicitly return None. 175 | if os.path.exists(self.cache_file): 176 | with open(self.cache_file, "rb") as f: 177 | return f.read() 178 | 179 | def write_calibration_cache(self, cache): 180 | with open(self.cache_file, "wb") as f: 181 | f.write(cache) 182 | 183 | 184 | class IInt8LegacyCalibrator(trt.IInt8LegacyCalibrator): 185 | 186 | def __init__(self, input_def, stream, cache_file="./calibration_cache.bin"): 187 | trt.IInt8LegacyCalibrator.__init__(self) 188 | 189 | self.input_layers = [] 190 | for input in input_def: 191 | name = input.get("name", None) 192 | self.input_layers.append(name) 193 | 194 | self.stream = stream 195 | calibration_data = self.stream.next_batch(verbose=False) 196 | self.d_input = cuda.mem_alloc(calibration_data.nbytes) 197 | self.stream.reset() 198 | 199 | self.cache_file = cache_file 200 | os.system("rm -f {}".format(self.cache_file)) 201 | 202 | self.quantile = 0.0 203 | self.regression_cutoff = 0.0 204 | 205 | def get_batch_size(self): 206 | return self.stream.batch_size 207 | 208 | def get_batch(self, names): 209 | data = self.stream.next_batch() 210 | cuda.memcpy_htod(self.d_input, data) 211 | 212 | if data.size == 0: 213 | return None 214 | else: 215 | return [int(self.d_input)] 216 | 217 | def get_quantile(self): 218 | return self.quantile 219 | 220 | def get_regression_cutoff(self): 221 | return self.regression_cutoff 222 | 223 | def read_calibration_cache(self): 224 | # If there is a cache, use it instead of calibrating again. 225 | # Otherwise, implicitly return None. 226 | if os.path.exists(self.cache_file): 227 | with open(self.cache_file, "rb") as f: 228 | return f.read() 229 | 230 | def write_calibration_cache(self, cache): 231 | with open(self.cache_file, "wb") as f: 232 | f.write(cache) 233 | -------------------------------------------------------------------------------- /backend/trtis/trt_backend/tf2trt.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/layerism/TensorRT-Inference-Server-Tutorial/c40d485f69c1349c01b40a4eb7c225b01bf8dbe0/backend/trtis/trt_backend/tf2trt.py -------------------------------------------------------------------------------- /backend/trtis/trt_backend/torch2trt.py: -------------------------------------------------------------------------------- 1 | #!/python 2 | import argparse 3 | import json 4 | import os 5 | from functools import partial 6 | 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | 11 | import onnx as onnx 12 | import tensorrt as trt 13 | 14 | from trtis.set_config import generate_trtis_config 15 | from trtis import onnx_backend 16 | 17 | 18 | class MergeModel(nn.Module): 19 | 20 | def __init__(self, model, preprocess, postprocess): 21 | super(MergeModel, self).__init__() 22 | self.model = model 23 | self.preprocess = preprocess 24 | self.postprocess = postprocess 25 | 26 | def forward(self, x): 27 | x = self.preprocess(x) 28 | x = self.model(x) 29 | x = self.postprocess(x) 30 | return x 31 | 32 | 33 | class WrapperFunc(nn.Module): 34 | 35 | def __init__(self, func): 36 | super(WrapperFunc, self).__init__() 37 | self.func = func 38 | 39 | def forward(self, x): 40 | return self.func(x) 41 | 42 | 43 | class WrapperModel(nn.Module): 44 | 45 | def __init__(self, model): 46 | super(WrapperModel, self).__init__() 47 | self.model = model 48 | 49 | def forward(self, x): 50 | x = self.model(x) 51 | return x 52 | 53 | 54 | def GiB(val): 55 | return val * 1 << 30 56 | 57 | 58 | def build_engine(onnx_path, export_path, int8_calibrator=None): 59 | """Takes an ONNX file and creates a TensorRT engine to run inference with""" 60 | trt_logger = trt.Logger() 61 | with trt.Builder(trt_logger) as builder: 62 | with builder.create_network() as network: 63 | with trt.OnnxParser(network, trt_logger) as parser: 64 | builder.max_workspace_size = GiB(1) # 1GB 65 | builder.max_batch_size = 1 66 | if int8_calibrator is not None: 67 | builder.int8_mode = True 68 | builder.int8_calibrator = int8_calibrator 69 | 70 | # Parse model file 71 | if not os.path.exists(onnx_path): 72 | print('ONNX file {} not found'.format(onnx_path)) 73 | exit(0) 74 | 75 | print('Loading ONNX file from path {}...'.format(onnx_path)) 76 | with open(onnx_path, 'rb') as model: 77 | print('Beginning ONNX file parsing') 78 | parser.parse(model.read()) 79 | print('Completed parsing of ONNX file') 80 | 81 | print('Building an engine from file {} ...'.format(onnx_path)) 82 | engine = builder.build_cuda_engine(network) 83 | 84 | print("Completed creating Engine") 85 | with open(export_path, "wb") as f: 86 | f.write(engine.serialize()) 87 | return engine 88 | 89 | 90 | def torch2trt( 91 | computation_graph, 92 | graph_name="model", 93 | model_file=None, 94 | inputs_def=[{ 95 | "name": None, "shape": [] 96 | }], 97 | outputs_def=[{ 98 | "name": None, "shape": [] 99 | }], 100 | instances=1, 101 | gpus=[0], 102 | version=1, 103 | max_batch_size=1, 104 | int8_calibrator=None, 105 | export_path="./", 106 | onnx_opset_version=10, 107 | device="cuda", 108 | gen_trtis_config=True, 109 | verbose=True 110 | ): 111 | 112 | onnx_model, onnx_path = onnx_backend.torch2onnx( 113 | computation_graph=computation_graph, 114 | graph_name=graph_name, 115 | model_file=model_file, 116 | inputs_def=inputs_def, 117 | outputs_def=outputs_def, 118 | instances=instances, 119 | gpus=gpus, 120 | version=version, 121 | export_path=export_path, 122 | opset_version=onnx_opset_version, 123 | max_batch_size=max_batch_size, 124 | device=device, 125 | gen_trtis_config=False, 126 | verbose=verbose 127 | ) 128 | 129 | export_path = os.path.join(export_path, graph_name) 130 | os.system("mkdir -p {}".format(export_path)) 131 | os.system("mkdir -p {}/{}".format(export_path, version)) 132 | trt_engine_path = "{}/{}/{}.plan".format(export_path, version, "model") 133 | build_engine( 134 | onnx_path=onnx_path, 135 | export_path=trt_engine_path, 136 | int8_calibrator=int8_calibrator 137 | ) 138 | os.system("rm -rf {}".format(onnx_path)) 139 | 140 | if gen_trtis_config: 141 | generate_trtis_config( 142 | graph_name=graph_name, 143 | platform="tensorrt_plan", 144 | inputs_def=inputs_def, 145 | outputs_def=outputs_def, 146 | max_batch_size=max_batch_size, 147 | instances=instances, 148 | gpus=gpus, 149 | export_path=export_path, 150 | verbose=verbose 151 | ) 152 | -------------------------------------------------------------------------------- /client_py/VERSION.txt: -------------------------------------------------------------------------------- 1 | 0.1.0 2 | -------------------------------------------------------------------------------- /client_py/setup.cfg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/layerism/TensorRT-Inference-Server-Tutorial/c40d485f69c1349c01b40a4eb7c225b01bf8dbe0/client_py/setup.cfg -------------------------------------------------------------------------------- /client_py/setup.py: -------------------------------------------------------------------------------- 1 | from os.path import dirname, join 2 | 3 | from setuptools import find_packages, setup 4 | 5 | # from pip.req import parse_requirements 6 | 7 | 8 | def parse_requirements(filename): 9 | """ load requirements from a pip requirements file """ 10 | lineiter = (line.strip() for line in open(filename)) 11 | return [line for line in lineiter if line and not line.startswith("#")] 12 | 13 | 14 | with open(join(dirname(__file__), './VERSION.txt'), 'rb') as f: 15 | version = f.read().decode('ascii').strip() 16 | 17 | 18 | setup( 19 | name='trt_client', 20 | version='0.1.0', 21 | keywords='23456', 22 | description='a library for DS CAA Developer', 23 | license='MIT License', 24 | url='', 25 | author='layersim', 26 | author_email='', 27 | packages=find_packages(), 28 | include_package_data=True, 29 | platforms='any', 30 | install_requires=[ 31 | "numpy>=1.16.0", 32 | "opencv-python>=3.4.1", 33 | "Pillow>=6.0.0", 34 | "tensorrtserver>=1.9.0", 35 | "protobuf>=3.8.0", 36 | "onnx>=1.6.0" 37 | ], 38 | #scripts=['scripts/s3'], 39 | ) 40 | -------------------------------------------------------------------------------- /client_py/trt_client/__init__.py: -------------------------------------------------------------------------------- 1 | from .client import * 2 | -------------------------------------------------------------------------------- /client_py/trt_client/client.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import multiprocessing as mp 3 | import os 4 | import time 5 | from builtins import range 6 | from functools import partial 7 | 8 | import tensorrtserver.api.model_config_pb2 as model_config 9 | import tensorrtserver.cuda_shared_memory as cudashm 10 | from tensorrtserver.api import * 11 | 12 | 13 | DTYPE = {2: np.uint8, 11: np.float32, 8: np.int32} 14 | 15 | 16 | def create_cuda_shm(data, name, url, protocol, is_input=True): 17 | #c, h, w = shape 18 | shared_memory_ctx = SharedMemoryControlContext(url, protocol) 19 | byte_size = data.size * data.itemsize 20 | shm_handle = cudashm.create_shared_memory_region(name, byte_size, 3) 21 | 22 | if is_input: 23 | cudashm.set_shared_memory_region(shm_handle, [data]) 24 | shared_memory_ctx.cuda_register(shm_handle) 25 | else: 26 | shared_memory_ctx.cuda_register(shm_handle) 27 | 28 | 29 | def get_server_status(url, protocol, model_name, verbose=False): 30 | protocol = ProtocolType.from_str(protocol) 31 | ctx = ServerStatusContext(url, protocol, model_name, verbose) 32 | server_status = ctx.get_server_status() 33 | 34 | return server_status 35 | 36 | 37 | def parse_model(url, protocol, model_name, verbose=False): 38 | protocol = ProtocolType.from_str(protocol) 39 | ctx = ServerStatusContext(url, protocol, model_name, verbose) 40 | server_status = ctx.get_server_status() 41 | 42 | if model_name not in server_status.model_status: 43 | raise Exception("unable to get status for '" + model_name + "'") 44 | 45 | status = server_status.model_status[model_name] 46 | config = status.config 47 | 48 | return config 49 | 50 | 51 | class Inference(object): 52 | 53 | def __init__(self, url, model_name, model_version, protocol='gRPC'): 54 | model = parse_model(url, protocol, model_name) 55 | protocol = ProtocolType.from_str(protocol) 56 | self.model = model 57 | self.ctx = InferContext( 58 | url=url, 59 | protocol=protocol, 60 | model_name=model_name, 61 | model_version=model_version, 62 | verbose=False, 63 | streaming=False 64 | ) 65 | self.result_queue = mp.Queue() 66 | 67 | self.outputs = {} 68 | for output in self.model.output: 69 | self.outputs[output.name] = InferContext.ResultFormat.RAW 70 | 71 | def callback(self, input_id, result_queue, infer_ctx, request_id): 72 | result_queue.put((request_id, input_id)) 73 | 74 | def async_run(self, input, input_id): 75 | for key in input.keys(): 76 | if type(input[key]) not in [list, tuple]: 77 | input[key] = [input[key]] 78 | 79 | callback_fn = partial(self.callback, input_id, self.result_queue) 80 | self.ctx.async_run(callback_fn, input, self.outputs) 81 | 82 | def run(self, input): 83 | for key in input.keys(): 84 | if type(input[key]) not in [list, tuple]: 85 | input[key] = [input[key]] 86 | 87 | results = self.ctx.run(input, self.outputs, batch_size=1) 88 | return results 89 | 90 | def get_time(self): 91 | stat = self.ctx.get_stat() 92 | 93 | count = stat["completed_request_count"] 94 | request_dt = stat["cumulative_total_request_time_ns"] / 1.0e6 95 | send_dt = stat["cumulative_send_time_ns"] / 1.0e6 96 | receive_dt = stat["cumulative_receive_time_ns"] / 1.0e6 97 | inference_dt = (request_dt - send_dt - receive_dt) 98 | stat = { 99 | "completed_request_count": stat["completed_request_count"], 100 | "send_time_ms": float("{:5.3f}".format(send_dt / count)), 101 | "inference_time_ms": float("{:5.3f}".format(inference_dt / count)), 102 | "receive_time_ms": float("{:5.3f}".format(receive_dt / count)) 103 | } 104 | return stat 105 | 106 | def get_result(self, block=True): 107 | (request_id, input_id) = self.result_queue.get(block=block) 108 | results = self.ctx.get_async_run_results(request_id) 109 | 110 | return input_id, results 111 | 112 | 113 | class ManagerWatchdog(object): 114 | 115 | def __init__(self): 116 | self.manager_pid = os.getppid() 117 | self.manager_dead = False 118 | 119 | def is_alive(self): 120 | if not self.manager_dead: 121 | self.manager_dead = os.getppid() != self.manager_pid 122 | return not self.manager_dead 123 | 124 | 125 | # class Parallel(object): 126 | # 127 | # def __init__(self, worker=8): 128 | # self.inference = Inference() 129 | # 130 | # def run(self, ) 131 | # def start(self, ) 132 | -------------------------------------------------------------------------------- /client_py/trt_client/client_grpc.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from builtins import range 4 | from functools import partial 5 | import struct 6 | 7 | import grpc 8 | import numpy as np 9 | import sys 10 | 11 | from tensorrtserver.api import ProtocolType, ServerStatusContext 12 | from tensorrtserver.api import api_pb2 13 | from tensorrtserver.api import grpc_service_pb2 14 | from tensorrtserver.api import grpc_service_pb2_grpc 15 | import tensorrtserver.api.model_config_pb2 as model_config 16 | 17 | 18 | DTYPE = { 19 | model_config.TYPE_BOOL: np.bool, 20 | model_config.TYPE_FP16: np.float16, 21 | model_config.TYPE_FP32: np.float32, 22 | model_config.TYPE_FP64: np.float64, 23 | model_config.TYPE_INT8: np.int8, 24 | model_config.TYPE_INT32: np.int32, 25 | model_config.TYPE_INT64: np.int64, 26 | model_config.TYPE_UINT8: np.uint8, 27 | model_config.TYPE_UINT16: np.uint16, 28 | model_config.TYPE_UINT32: np.uint32, 29 | model_config.TYPE_STRING: np.str 30 | } 31 | 32 | 33 | def _parse_model(url, model_name, verbose=False): 34 | protocol = ProtocolType.from_str('gRPC') 35 | ctx = ServerStatusContext(url, protocol, model_name, verbose) 36 | server_status = ctx.get_server_status() 37 | 38 | if model_name not in server_status.model_status: 39 | raise Exception("unable to get status for '" + model_name + "'") 40 | 41 | status = server_status.model_status[model_name] 42 | config = status.config 43 | 44 | return config 45 | 46 | 47 | class Inference(object): 48 | 49 | def __init__(self, url, model_name, model_version): 50 | model = _parse_model(url, model_name) 51 | self.url = url 52 | self.model = model 53 | self.model_name = model_name 54 | self.model_version = model_version 55 | 56 | def _to_bytes(self, input_value): 57 | input_value = np.array([input_value], dtype=bytes) 58 | flattened = bytes() 59 | for obj in np.nditer(input_value, flags=["refs_ok"], order='C'): 60 | # If directly passing bytes to STRING type, 61 | # don't convert it to str as Python will encode the 62 | # bytes which may distort the meaning 63 | if obj.dtype.type == np.bytes_: 64 | s = bytes(obj) 65 | else: 66 | s = str(obj).encode('utf-8') 67 | flattened += struct.pack("= 0.4: 69 | cv2.rectangle(img, (bbox[0], bbox[1]), (bbox[2], bbox[3]), (0, 255, 0), 2) 70 | cv2.imwrite('./xxx.jpg', img) 71 | -------------------------------------------------------------------------------- /example/detection/client.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | IMAGE="../test-data/widerface.jpg" 4 | python client.py \ 5 | --model-name detection \ 6 | --model-version 1 \ 7 | --url "x.x.x.x:7001" \ 8 | --image ${IMAGE} 9 | -------------------------------------------------------------------------------- /example/detection/config.pbtxt: -------------------------------------------------------------------------------- 1 | name: "detection" 2 | platform: "ensemble" 3 | max_batch_size: 0 4 | input [ 5 | { 6 | name: "raw_image" 7 | data_type: TYPE_STRING 8 | dims: [ 1 ] 9 | } 10 | ] 11 | output [ 12 | { 13 | name: "score" 14 | data_type: TYPE_FP32 15 | dims: [ 1, 100 ] 16 | }, 17 | { 18 | name: "category" 19 | data_type: TYPE_INT32 20 | dims: [ 1, 100 ] 21 | }, 22 | { 23 | name: "bbox" 24 | data_type: TYPE_FP32 25 | dims: [ 1, 100, 4 ] 26 | } 27 | ] 28 | ensemble_scheduling { 29 | step [ 30 | { 31 | model_name: "detection-preprocess" 32 | model_version: 1 33 | input_map { 34 | key: "raw_image" 35 | value: "raw_image" 36 | } 37 | output_map { 38 | key: "process_img" 39 | value: "process_img" 40 | } 41 | output_map { 42 | key: "affine_trans_mat" 43 | value: "affine_trans_mat" 44 | } 45 | }, 46 | { 47 | model_name: "detection-network" 48 | model_version: 1 49 | input_map { 50 | key: "process_img" 51 | value: "process_img" 52 | } 53 | output_map { 54 | key: "heatmap" 55 | value: "heatmap" 56 | } 57 | output_map { 58 | key: "bbox_wh" 59 | value: "bbox_wh" 60 | } 61 | output_map { 62 | key: "center_shift" 63 | value: "center_shift" 64 | } 65 | }, 66 | { 67 | model_name: "detection-postprocess" 68 | model_version: 1 69 | input_map { 70 | key: "heatmap" 71 | value: "heatmap" 72 | } 73 | input_map { 74 | key: "bbox_wh" 75 | value: "bbox_wh" 76 | } 77 | input_map { 78 | key: "center_shift" 79 | value: "center_shift" 80 | } 81 | input_map { 82 | key: "affine_trans_mat" 83 | value: "affine_trans_mat" 84 | } 85 | output_map { 86 | key: "score" 87 | value: "score" 88 | } 89 | output_map { 90 | key: "category" 91 | value: "category" 92 | } 93 | output_map { 94 | key: "bbox" 95 | value: "bbox" 96 | } 97 | } 98 | ] 99 | } 100 | -------------------------------------------------------------------------------- /example/detection/convert.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | REPO=../../model_repository 4 | MODEL_NAME=detection 5 | 6 | rm -rf ${REPO}/${MODEL_NAME}* 7 | 8 | mkdir -p ${REPO}/${MODEL_NAME}/1 9 | cp -r config.pbtxt ${REPO}/${MODEL_NAME} 10 | 11 | python pre_process.py 12 | python network.py 13 | python post_process.py 14 | -------------------------------------------------------------------------------- /example/detection/network.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import time 4 | from builtins import range 5 | from functools import partial 6 | 7 | import numpy as np 8 | 9 | 10 | if __name__ == "__main__": 11 | from trtis import onnx_backend 12 | from trtis import trt_backend 13 | 14 | from network import dla34, resnet 15 | model = dla34.get_pose_net(34, {'hm': 1, 'wh': 2, 'reg': 2}, 256) 16 | #model = resnet.get_pose_net(18, {'hm': 1, 'wh': 2, 'reg': 2}, 64) 17 | 18 | inputs_def = [ 19 | { 20 | "name": "process_img", 21 | "dims": [1, 3, 512, 512], 22 | "data_type": "TYPE_FP32" 23 | } 24 | ] 25 | 26 | outputs_def = [ 27 | { 28 | "name": "heatmap", 29 | "dims": [1, 1, 128, 128], 30 | "data_type": "TYPE_FP32" 31 | }, 32 | { 33 | "name": "bbox_wh", 34 | "dims": [1, 2, 128, 128], 35 | "data_type": "TYPE_FP32" 36 | }, 37 | { 38 | "name": "center_shift", 39 | "dims": [1, 2, 128, 128], 40 | "data_type": "TYPE_FP32" 41 | } 42 | ] 43 | 44 | from pre_process import preprocess 45 | import tensorflow as tf 46 | sess = tf.Session() 47 | preprocess_fn = lambda x: sess.run(preprocess([x]))[0] 48 | 49 | stream = trt_backend.ImageBatchStream("./calibrator_files", 5, preprocess_fn) 50 | int8_calibrator = trt_backend.IInt8EntropyCalibrator2(inputs_def, stream) 51 | 52 | trt_backend.torch2trt( 53 | computation_graph=model, 54 | graph_name="detection-network", 55 | model_file="./network/dla34.pth", 56 | inputs_def=inputs_def, 57 | outputs_def=outputs_def, 58 | instances=16, 59 | gpus=[0, 1, 2, 3], 60 | version=1, 61 | export_path="../../model_repository", 62 | int8_calibrator=int8_calibrator 63 | ) 64 | 65 | # onnx_backend.torch2onnx( 66 | # computation_graph=model, 67 | # graph_name="face-det-network", 68 | # model_file="./network/dla34.pth", 69 | # inputs_def=INPUT_DEF, 70 | # outputs_def=OUTPUT_DEF, 71 | # instances=16, 72 | # gpus=[2], 73 | # version=1, 74 | # export_path="../../model_repository" 75 | # ) 76 | -------------------------------------------------------------------------------- /example/detection/network/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/layerism/TensorRT-Inference-Server-Tutorial/c40d485f69c1349c01b40a4eb7c225b01bf8dbe0/example/detection/network/__init__.py -------------------------------------------------------------------------------- /example/detection/network/dla34.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | from __future__ import absolute_import, division, print_function 4 | 5 | import math 6 | from os.path import join 7 | from functools import partial 8 | 9 | import numpy as np 10 | import torch 11 | import torch.utils.model_zoo as model_zoo 12 | from torch import nn 13 | 14 | BatchNorm = nn.BatchNorm2d 15 | 16 | 17 | def get_model_url(data='imagenet', name='dla34', hash='ba72cf86'): 18 | return join('http://dl.yf.io/dla/models', data, '{}-{}.pth'.format(name, hash)) 19 | 20 | 21 | def conv3x3(in_planes, out_planes, stride=1): 22 | "3x3 convolution with padding" 23 | return nn.Conv2d( 24 | in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False 25 | ) 26 | 27 | 28 | class BasicBlock(nn.Module): 29 | 30 | def __init__(self, inplanes, planes, stride=1, dilation=1): 31 | super(BasicBlock, self).__init__() 32 | self.conv1 = nn.Conv2d( 33 | inplanes, 34 | planes, 35 | kernel_size=3, 36 | stride=stride, 37 | padding=dilation, 38 | bias=False, 39 | dilation=dilation 40 | ) 41 | self.bn1 = BatchNorm(planes) 42 | self.relu = nn.ReLU(inplace=True) 43 | self.conv2 = nn.Conv2d( 44 | planes, 45 | planes, 46 | kernel_size=3, 47 | stride=1, 48 | padding=dilation, 49 | bias=False, 50 | dilation=dilation 51 | ) 52 | self.bn2 = BatchNorm(planes) 53 | self.stride = stride 54 | 55 | def forward(self, x, residual=None): 56 | if residual is None: 57 | residual = x 58 | 59 | out = self.conv1(x) 60 | out = self.bn1(out) 61 | out = self.relu(out) 62 | 63 | out = self.conv2(out) 64 | out = self.bn2(out) 65 | 66 | out += residual 67 | out = self.relu(out) 68 | 69 | return out 70 | 71 | 72 | class Bottleneck(nn.Module): 73 | expansion = 2 74 | 75 | def __init__(self, inplanes, planes, stride=1, dilation=1): 76 | super(Bottleneck, self).__init__() 77 | expansion = Bottleneck.expansion 78 | bottle_planes = planes // expansion 79 | self.conv1 = nn.Conv2d(inplanes, bottle_planes, kernel_size=1, bias=False) 80 | self.bn1 = BatchNorm(bottle_planes) 81 | self.conv2 = nn.Conv2d( 82 | bottle_planes, 83 | bottle_planes, 84 | kernel_size=3, 85 | stride=stride, 86 | padding=dilation, 87 | bias=False, 88 | dilation=dilation 89 | ) 90 | self.bn2 = BatchNorm(bottle_planes) 91 | self.conv3 = nn.Conv2d(bottle_planes, planes, kernel_size=1, bias=False) 92 | self.bn3 = BatchNorm(planes) 93 | self.relu = nn.ReLU(inplace=True) 94 | self.stride = stride 95 | 96 | def forward(self, x, residual=None): 97 | if residual is None: 98 | residual = x 99 | 100 | out = self.conv1(x) 101 | out = self.bn1(out) 102 | out = self.relu(out) 103 | 104 | out = self.conv2(out) 105 | out = self.bn2(out) 106 | out = self.relu(out) 107 | 108 | out = self.conv3(out) 109 | out = self.bn3(out) 110 | 111 | out += residual 112 | out = self.relu(out) 113 | 114 | return out 115 | 116 | 117 | class BottleneckX(nn.Module): 118 | expansion = 2 119 | cardinality = 32 120 | 121 | def __init__(self, inplanes, planes, stride=1, dilation=1): 122 | super(BottleneckX, self).__init__() 123 | cardinality = BottleneckX.cardinality 124 | # dim = int(math.floor(planes * (BottleneckV5.expansion / 64.0))) 125 | # bottle_planes = dim * cardinality 126 | bottle_planes = planes * cardinality // 32 127 | self.conv1 = nn.Conv2d(inplanes, bottle_planes, kernel_size=1, bias=False) 128 | self.bn1 = BatchNorm(bottle_planes) 129 | self.conv2 = nn.Conv2d( 130 | bottle_planes, 131 | bottle_planes, 132 | kernel_size=3, 133 | stride=stride, 134 | padding=dilation, 135 | bias=False, 136 | dilation=dilation, 137 | groups=cardinality 138 | ) 139 | self.bn2 = BatchNorm(bottle_planes) 140 | self.conv3 = nn.Conv2d(bottle_planes, planes, kernel_size=1, bias=False) 141 | self.bn3 = BatchNorm(planes) 142 | self.relu = nn.ReLU(inplace=True) 143 | self.stride = stride 144 | 145 | def forward(self, x, residual=None): 146 | if residual is None: 147 | residual = x 148 | 149 | out = self.conv1(x) 150 | out = self.bn1(out) 151 | out = self.relu(out) 152 | 153 | out = self.conv2(out) 154 | out = self.bn2(out) 155 | out = self.relu(out) 156 | 157 | out = self.conv3(out) 158 | out = self.bn3(out) 159 | 160 | out += residual 161 | out = self.relu(out) 162 | 163 | return out 164 | 165 | 166 | class Root(nn.Module): 167 | 168 | def __init__(self, in_channels, out_channels, kernel_size, residual): 169 | super(Root, self).__init__() 170 | self.conv = nn.Conv2d( 171 | in_channels, 172 | out_channels, 173 | 1, 174 | stride=1, 175 | bias=False, 176 | padding=(kernel_size - 1) // 2 177 | ) 178 | self.bn = BatchNorm(out_channels) 179 | self.relu = nn.ReLU(inplace=True) 180 | self.residual = residual 181 | 182 | def forward(self, *x): 183 | children = x 184 | x = self.conv(torch.cat(x, 1)) 185 | x = self.bn(x) 186 | if self.residual: 187 | x += children[0] 188 | x = self.relu(x) 189 | 190 | return x 191 | 192 | 193 | class Tree(nn.Module): 194 | 195 | def __init__( 196 | self, 197 | levels, 198 | block, 199 | in_channels, 200 | out_channels, 201 | stride=1, 202 | level_root=False, 203 | root_dim=0, 204 | root_kernel_size=1, 205 | dilation=1, 206 | root_residual=False 207 | ): 208 | super(Tree, self).__init__() 209 | if root_dim == 0: 210 | root_dim = 2 * out_channels 211 | if level_root: 212 | root_dim += in_channels 213 | if levels == 1: 214 | self.tree1 = block(in_channels, out_channels, stride, dilation=dilation) 215 | self.tree2 = block(out_channels, out_channels, 1, dilation=dilation) 216 | else: 217 | self.tree1 = Tree( 218 | levels - 1, 219 | block, 220 | in_channels, 221 | out_channels, 222 | stride, 223 | root_dim=0, 224 | root_kernel_size=root_kernel_size, 225 | dilation=dilation, 226 | root_residual=root_residual 227 | ) 228 | self.tree2 = Tree( 229 | levels - 1, 230 | block, 231 | out_channels, 232 | out_channels, 233 | root_dim=root_dim + out_channels, 234 | root_kernel_size=root_kernel_size, 235 | dilation=dilation, 236 | root_residual=root_residual 237 | ) 238 | if levels == 1: 239 | self.root = Root(root_dim, out_channels, root_kernel_size, root_residual) 240 | self.level_root = level_root 241 | self.root_dim = root_dim 242 | self.downsample = None 243 | self.project = None 244 | self.levels = levels 245 | if stride > 1: 246 | self.downsample = nn.MaxPool2d(stride, stride=stride) 247 | if in_channels != out_channels: 248 | self.project = nn.Sequential( 249 | nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False), 250 | BatchNorm(out_channels) 251 | ) 252 | 253 | def forward(self, x, residual=None, children=None): 254 | children = [] if children is None else children 255 | bottom = self.downsample(x) if self.downsample else x 256 | residual = self.project(bottom) if self.project else bottom 257 | if self.level_root: 258 | children.append(bottom) 259 | x1 = self.tree1(x, residual) 260 | if self.levels == 1: 261 | x2 = self.tree2(x1) 262 | x = self.root(x2, x1, *children) 263 | else: 264 | children.append(x1) 265 | x = self.tree2(x1, children=children) 266 | return x 267 | 268 | 269 | class DLA(nn.Module): 270 | 271 | def __init__( 272 | self, 273 | levels, 274 | channels, 275 | num_classes=1000, 276 | block=BasicBlock, 277 | residual_root=False, 278 | return_levels=False, 279 | pool_size=7, 280 | linear_root=False 281 | ): 282 | super(DLA, self).__init__() 283 | self.channels = channels 284 | self.return_levels = return_levels 285 | self.num_classes = num_classes 286 | self.base_layer = nn.Sequential( 287 | nn.Conv2d(3, channels[0], kernel_size=7, stride=1, padding=3, bias=False), 288 | BatchNorm(channels[0]), 289 | nn.ReLU(inplace=True) 290 | ) 291 | self.level0 = self._make_conv_level(channels[0], channels[0], levels[0]) 292 | self.level1 = self._make_conv_level(channels[0], channels[1], levels[1], stride=2) 293 | self.level2 = Tree( 294 | levels[2], 295 | block, 296 | channels[1], 297 | channels[2], 298 | 2, 299 | level_root=False, 300 | root_residual=residual_root 301 | ) 302 | self.level3 = Tree( 303 | levels[3], 304 | block, 305 | channels[2], 306 | channels[3], 307 | 2, 308 | level_root=True, 309 | root_residual=residual_root 310 | ) 311 | self.level4 = Tree( 312 | levels[4], 313 | block, 314 | channels[3], 315 | channels[4], 316 | 2, 317 | level_root=True, 318 | root_residual=residual_root 319 | ) 320 | self.level5 = Tree( 321 | levels[5], 322 | block, 323 | channels[4], 324 | channels[5], 325 | 2, 326 | level_root=True, 327 | root_residual=residual_root 328 | ) 329 | 330 | self.avgpool = nn.AvgPool2d(pool_size) 331 | self.fc = nn.Conv2d( 332 | channels[-1], num_classes, kernel_size=1, stride=1, padding=0, bias=True 333 | ) 334 | 335 | for m in self.modules(): 336 | if isinstance(m, nn.Conv2d): 337 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 338 | m.weight.data.normal_(0, math.sqrt(2. / n)) 339 | elif isinstance(m, BatchNorm): 340 | m.weight.data.fill_(1) 341 | m.bias.data.zero_() 342 | 343 | def _make_level(self, block, inplanes, planes, blocks, stride=1): 344 | downsample = None 345 | if stride != 1 or inplanes != planes: 346 | downsample = nn.Sequential( 347 | nn.MaxPool2d(stride, stride=stride), 348 | nn.Conv2d(inplanes, planes, kernel_size=1, stride=1, bias=False), 349 | BatchNorm(planes), 350 | ) 351 | 352 | layers = [] 353 | layers.append(block(inplanes, planes, stride, downsample=downsample)) 354 | for i in range(1, blocks): 355 | layers.append(block(inplanes, planes)) 356 | 357 | return nn.Sequential(*layers) 358 | 359 | def _make_conv_level(self, inplanes, planes, convs, stride=1, dilation=1): 360 | modules = [] 361 | for i in range(convs): 362 | modules.extend( 363 | [ 364 | nn.Conv2d( 365 | inplanes, 366 | planes, 367 | kernel_size=3, 368 | stride=stride if i == 0 else 1, 369 | padding=dilation, 370 | bias=False, 371 | dilation=dilation 372 | ), 373 | BatchNorm(planes), 374 | nn.ReLU(inplace=True) 375 | ] 376 | ) 377 | inplanes = planes 378 | return nn.Sequential(*modules) 379 | 380 | def forward(self, x): 381 | y = [] 382 | x = self.base_layer(x) 383 | for i in range(6): 384 | x = getattr(self, 'level{}'.format(i))(x) 385 | y.append(x) 386 | if self.return_levels: 387 | return y 388 | else: 389 | x = self.avgpool(x) 390 | x = self.fc(x) 391 | x = x.view(x.size(0), -1) 392 | 393 | return x 394 | 395 | def load_pretrained_model(self, data='imagenet', name='dla34', hash='ba72cf86'): 396 | fc = self.fc 397 | if name.endswith('.pth'): 398 | model_weights = torch.load(data + name) 399 | else: 400 | model_url = get_model_url(data, name, hash) 401 | model_weights = model_zoo.load_url(model_url) 402 | num_classes = len(model_weights[list(model_weights.keys())[-1]]) 403 | self.fc = nn.Conv2d( 404 | self.channels[-1], num_classes, kernel_size=1, stride=1, padding=0, bias=True 405 | ) 406 | self.load_state_dict(model_weights) 407 | self.fc = fc 408 | 409 | 410 | def dla34(pretrained, **kwargs): # DLA-34 411 | model = DLA( 412 | [1, 1, 1, 2, 2, 1], [16, 32, 64, 128, 256, 512], block=BasicBlock, **kwargs 413 | ) 414 | if pretrained: 415 | model.load_pretrained_model(data='imagenet', name='dla34', hash='ba72cf86') 416 | return model 417 | 418 | 419 | def dla46_c(pretrained=None, **kwargs): # DLA-46-C 420 | Bottleneck.expansion = 2 421 | model = DLA( 422 | [1, 1, 1, 2, 2, 1], [16, 32, 64, 64, 128, 256], block=Bottleneck, **kwargs 423 | ) 424 | if pretrained is not None: 425 | model.load_pretrained_model(pretrained, 'dla46_c') 426 | return model 427 | 428 | 429 | def dla46x_c(pretrained=None, **kwargs): # DLA-X-46-C 430 | BottleneckX.expansion = 2 431 | model = DLA( 432 | [1, 1, 1, 2, 2, 1], [16, 32, 64, 64, 128, 256], block=BottleneckX, **kwargs 433 | ) 434 | if pretrained is not None: 435 | model.load_pretrained_model(pretrained, 'dla46x_c') 436 | return model 437 | 438 | 439 | def dla60x_c(pretrained, **kwargs): # DLA-X-60-C 440 | BottleneckX.expansion = 2 441 | model = DLA( 442 | [1, 1, 1, 2, 3, 1], [16, 32, 64, 64, 128, 256], block=BottleneckX, **kwargs 443 | ) 444 | if pretrained: 445 | model.load_pretrained_model(data='imagenet', name='dla60x_c', hash='b870c45c') 446 | return model 447 | 448 | 449 | def dla60(pretrained=None, **kwargs): # DLA-60 450 | Bottleneck.expansion = 2 451 | model = DLA( 452 | [1, 1, 1, 2, 3, 1], [16, 32, 128, 256, 512, 1024], block=Bottleneck, **kwargs 453 | ) 454 | if pretrained is not None: 455 | model.load_pretrained_model(pretrained, 'dla60') 456 | return model 457 | 458 | 459 | def dla60x(pretrained=None, **kwargs): # DLA-X-60 460 | BottleneckX.expansion = 2 461 | model = DLA( 462 | [1, 1, 1, 2, 3, 1], [16, 32, 128, 256, 512, 1024], block=BottleneckX, **kwargs 463 | ) 464 | if pretrained is not None: 465 | model.load_pretrained_model(pretrained, 'dla60x') 466 | return model 467 | 468 | 469 | def dla102(pretrained=None, **kwargs): # DLA-102 470 | Bottleneck.expansion = 2 471 | model = DLA( 472 | [1, 1, 1, 3, 4, 1], [16, 32, 128, 256, 512, 1024], 473 | block=Bottleneck, 474 | residual_root=True, 475 | **kwargs 476 | ) 477 | if pretrained is not None: 478 | model.load_pretrained_model(pretrained, 'dla102') 479 | return model 480 | 481 | 482 | def dla102x(pretrained=None, **kwargs): # DLA-X-102 483 | BottleneckX.expansion = 2 484 | model = DLA( 485 | [1, 1, 1, 3, 4, 1], [16, 32, 128, 256, 512, 1024], 486 | block=BottleneckX, 487 | residual_root=True, 488 | **kwargs 489 | ) 490 | if pretrained is not None: 491 | model.load_pretrained_model(pretrained, 'dla102x') 492 | return model 493 | 494 | 495 | def dla102x2(pretrained=None, **kwargs): # DLA-X-102 64 496 | BottleneckX.cardinality = 64 497 | model = DLA( 498 | [1, 1, 1, 3, 4, 1], [16, 32, 128, 256, 512, 1024], 499 | block=BottleneckX, 500 | residual_root=True, 501 | **kwargs 502 | ) 503 | if pretrained is not None: 504 | model.load_pretrained_model(pretrained, 'dla102x2') 505 | return model 506 | 507 | 508 | def dla169(pretrained=None, **kwargs): # DLA-169 509 | Bottleneck.expansion = 2 510 | model = DLA( 511 | [1, 1, 2, 3, 5, 1], [16, 32, 128, 256, 512, 1024], 512 | block=Bottleneck, 513 | residual_root=True, 514 | **kwargs 515 | ) 516 | if pretrained is not None: 517 | model.load_pretrained_model(pretrained, 'dla169') 518 | return model 519 | 520 | 521 | def set_bn(bn): 522 | global BatchNorm 523 | BatchNorm = bn 524 | dla.BatchNorm = bn 525 | 526 | 527 | class Identity(nn.Module): 528 | 529 | def __init__(self): 530 | super(Identity, self).__init__() 531 | 532 | def forward(self, x): 533 | return x 534 | 535 | 536 | def fill_up_weights(up): 537 | w = up.weight.data 538 | f = math.ceil(w.size(2) / 2) 539 | c = (2 * f - 1 - f % 2) / (2. * f) 540 | for i in range(w.size(2)): 541 | for j in range(w.size(3)): 542 | w[0, 0, i, j] = \ 543 | (1 - math.fabs(i / f - c)) * (1 - math.fabs(j / f - c)) 544 | for c in range(1, w.size(0)): 545 | w[c, 0, :, :] = w[0, 0, :, :] 546 | 547 | 548 | class IDAUp(nn.Module): 549 | 550 | def __init__(self, node_kernel, out_dim, channels, up_factors): 551 | super(IDAUp, self).__init__() 552 | self.channels = channels 553 | self.out_dim = out_dim 554 | for i, c in enumerate(channels): 555 | if c == out_dim: 556 | proj = Identity() 557 | else: 558 | proj = nn.Sequential( 559 | nn.Conv2d(c, out_dim, kernel_size=1, stride=1, bias=False), 560 | BatchNorm(out_dim), 561 | nn.ReLU(inplace=True) 562 | ) 563 | f = int(up_factors[i]) 564 | if f == 1: 565 | up = Identity() 566 | else: 567 | up = nn.ConvTranspose2d( 568 | out_dim, 569 | out_dim, 570 | f * 2, 571 | stride=f, 572 | padding=f // 2, 573 | output_padding=0, 574 | groups=out_dim, 575 | bias=False 576 | ) 577 | fill_up_weights(up) 578 | setattr(self, 'proj_' + str(i), proj) 579 | setattr(self, 'up_' + str(i), up) 580 | 581 | for i in range(1, len(channels)): 582 | node = nn.Sequential( 583 | nn.Conv2d( 584 | out_dim * 2, 585 | out_dim, 586 | kernel_size=node_kernel, 587 | stride=1, 588 | padding=node_kernel // 2, 589 | bias=False 590 | ), 591 | BatchNorm(out_dim), 592 | nn.ReLU(inplace=True) 593 | ) 594 | setattr(self, 'node_' + str(i), node) 595 | 596 | for m in self.modules(): 597 | if isinstance(m, nn.Conv2d): 598 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 599 | m.weight.data.normal_(0, math.sqrt(2. / n)) 600 | elif isinstance(m, BatchNorm): 601 | m.weight.data.fill_(1) 602 | m.bias.data.zero_() 603 | 604 | def forward(self, layers): 605 | assert len(self.channels) == len(layers), \ 606 | '{} vs {} layers'.format(len(self.channels), len(layers)) 607 | layers = list(layers) 608 | for i, l in enumerate(layers): 609 | upsample = getattr(self, 'up_' + str(i)) 610 | project = getattr(self, 'proj_' + str(i)) 611 | layers[i] = upsample(project(l)) 612 | x = layers[0] 613 | y = [] 614 | for i in range(1, len(layers)): 615 | node = getattr(self, 'node_' + str(i)) 616 | x = node(torch.cat([x, layers[i]], 1)) 617 | y.append(x) 618 | return x, y 619 | 620 | 621 | class DLAUp(nn.Module): 622 | 623 | def __init__(self, channels, scales=(1, 2, 4, 8, 16), in_channels=None): 624 | super(DLAUp, self).__init__() 625 | if in_channels is None: 626 | in_channels = channels 627 | self.channels = channels 628 | channels = list(channels) 629 | scales = np.array(scales, dtype=int) 630 | for i in range(len(channels) - 1): 631 | j = -i - 2 632 | setattr( 633 | self, 634 | 'ida_{}'.format(i), 635 | IDAUp(3, channels[j], in_channels[j:], scales[j:] // scales[j]) 636 | ) 637 | scales[j + 1:] = scales[j] 638 | in_channels[j + 1:] = [channels[j] for _ in channels[j + 1:]] 639 | 640 | def forward(self, layers): 641 | layers = list(layers) 642 | assert len(layers) > 1 643 | for i in range(len(layers) - 1): 644 | ida = getattr(self, 'ida_{}'.format(i)) 645 | x, y = ida(layers[-i - 2:]) 646 | layers[-i - 1:] = y 647 | return x 648 | 649 | 650 | def fill_fc_weights(layers): 651 | for m in layers.modules(): 652 | if isinstance(m, nn.Conv2d): 653 | nn.init.normal_(m.weight, std=0.001) 654 | # torch.nn.init.kaiming_normal_(m.weight.data, nonlinearity='relu') 655 | # torch.nn.init.xavier_normal_(m.weight.data) 656 | if m.bias is not None: 657 | nn.init.constant_(m.bias, 0) 658 | 659 | 660 | class DLASeg(nn.Module): 661 | 662 | def __init__(self, base_name, heads, pretrained=True, down_ratio=4, head_conv=256): 663 | super(DLASeg, self).__init__() 664 | assert down_ratio in [2, 4, 8, 16] 665 | self.heads = heads 666 | self.first_level = int(np.log2(down_ratio)) 667 | self.base = globals()[base_name](pretrained=pretrained, return_levels=True) 668 | channels = self.base.channels 669 | scales = [2**i for i in range(len(channels[self.first_level:]))] 670 | self.dla_up = DLAUp(channels[self.first_level:], scales=scales) 671 | ''' 672 | self.fc = nn.Sequential( 673 | nn.Conv2d(channels[self.first_level], classes, kernel_size=1, 674 | stride=1, padding=0, bias=True) 675 | ) 676 | ''' 677 | 678 | for head in self.heads: 679 | classes = self.heads[head] 680 | if head_conv > 0: 681 | fc = nn.Sequential( 682 | nn.Conv2d( 683 | channels[self.first_level], 684 | head_conv, 685 | kernel_size=3, 686 | padding=1, 687 | bias=True 688 | ), 689 | nn.ReLU(inplace=True), 690 | nn.Conv2d( 691 | head_conv, classes, kernel_size=1, stride=1, padding=0, bias=True 692 | ) 693 | ) 694 | if 'hm' in head: 695 | fc[-1].bias.data.fill_(-2.19) 696 | else: 697 | fill_fc_weights(fc) 698 | else: 699 | fc = nn.Conv2d( 700 | channels[self.first_level], 701 | classes, 702 | kernel_size=1, 703 | stride=1, 704 | padding=0, 705 | bias=True 706 | ) 707 | if 'hm' in head: 708 | fc.bias.data.fill_(-2.19) 709 | else: 710 | fill_fc_weights(fc) 711 | self.__setattr__(head, fc) 712 | ''' 713 | up_factor = 2 ** self.first_level 714 | if up_factor > 1: 715 | up = nn.ConvTranspose2d(classes, classes, up_factor * 2, 716 | stride=up_factor, padding=up_factor // 2, 717 | output_padding=0, groups=classes, 718 | bias=False) 719 | fill_up_weights(up) 720 | up.weight.requires_grad = False 721 | else: 722 | up = Identity() 723 | self.up = up 724 | self.softmax = nn.LogSoftmax(dim=1) 725 | 726 | for m in self.fc.modules(): 727 | if isinstance(m, nn.Conv2d): 728 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 729 | m.weight.data.normal_(0, math.sqrt(2. / n)) 730 | elif isinstance(m, BatchNorm): 731 | m.weight.data.fill_(1) 732 | m.bias.data.zero_() 733 | ''' 734 | 735 | def forward(self, x): 736 | x = self.base(x) 737 | x = self.dla_up(x[self.first_level:]) 738 | # x = self.fc(x) 739 | # y = self.softmax(self.up(x)) 740 | ret = [] 741 | for head in self.heads: 742 | ret.append(self.__getattr__(head)(x)) 743 | return ret 744 | 745 | ''' 746 | def optim_parameters(self, memo=None): 747 | for param in self.base.parameters(): 748 | yield param 749 | for param in self.dla_up.parameters(): 750 | yield param 751 | for param in self.fc.parameters(): 752 | yield param 753 | ''' 754 | 755 | 756 | ''' 757 | def dla34up(classes, pretrained_base=None, **kwargs): 758 | model = DLASeg('dla34', classes, pretrained_base=pretrained_base, **kwargs) 759 | return model 760 | 761 | 762 | def dla60up(classes, pretrained_base=None, **kwargs): 763 | model = DLASeg('dla60', classes, pretrained_base=pretrained_base, **kwargs) 764 | return model 765 | 766 | 767 | def dla102up(classes, pretrained_base=None, **kwargs): 768 | model = DLASeg('dla102', classes, 769 | pretrained_base=pretrained_base, **kwargs) 770 | return model 771 | 772 | 773 | def dla169up(classes, pretrained_base=None, **kwargs): 774 | model = DLASeg('dla169', classes, 775 | pretrained_base=pretrained_base, **kwargs) 776 | return model 777 | ''' 778 | 779 | 780 | def get_pose_net(num_layers, heads, head_conv=256, down_ratio=4): 781 | model = DLASeg( 782 | 'dla{}'.format(num_layers), 783 | heads, 784 | pretrained=False, 785 | down_ratio=down_ratio, 786 | head_conv=head_conv 787 | ) 788 | return model 789 | -------------------------------------------------------------------------------- /example/detection/network/resnet.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft 3 | # Licensed under the MIT License. 4 | # Written by Bin Xiao (Bin.Xiao@microsoft.com) 5 | # Modified by Xingyi Zhou 6 | # ------------------------------------------------------------------------------ 7 | 8 | from __future__ import absolute_import, division, print_function 9 | 10 | import os 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torch.utils.model_zoo as model_zoo 15 | 16 | BN_MOMENTUM = 0.1 17 | 18 | model_urls = { 19 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 20 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 21 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 22 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 23 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 24 | } 25 | 26 | 27 | def conv3x3(in_planes, out_planes, stride=1): 28 | """3x3 convolution with padding""" 29 | return nn.Conv2d( 30 | in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False 31 | ) 32 | 33 | 34 | class BasicBlock(nn.Module): 35 | expansion = 1 36 | 37 | def __init__(self, inplanes, planes, stride=1, downsample=None): 38 | super(BasicBlock, self).__init__() 39 | self.conv1 = conv3x3(inplanes, planes, stride) 40 | self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) 41 | self.relu = nn.ReLU(inplace=True) 42 | self.conv2 = conv3x3(planes, planes) 43 | self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) 44 | self.downsample = downsample 45 | self.stride = stride 46 | 47 | def forward(self, x): 48 | residual = x 49 | 50 | out = self.conv1(x) 51 | out = self.bn1(out) 52 | out = self.relu(out) 53 | 54 | out = self.conv2(out) 55 | out = self.bn2(out) 56 | 57 | if self.downsample is not None: 58 | residual = self.downsample(x) 59 | 60 | out += residual 61 | out = self.relu(out) 62 | 63 | return out 64 | 65 | 66 | class Bottleneck(nn.Module): 67 | expansion = 4 68 | 69 | def __init__(self, inplanes, planes, stride=1, downsample=None): 70 | super(Bottleneck, self).__init__() 71 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 72 | self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) 73 | self.conv2 = nn.Conv2d( 74 | planes, planes, kernel_size=3, stride=stride, padding=1, bias=False 75 | ) 76 | self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) 77 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) 78 | self.bn3 = nn.BatchNorm2d(planes * self.expansion, momentum=BN_MOMENTUM) 79 | self.relu = nn.ReLU(inplace=True) 80 | self.downsample = downsample 81 | self.stride = stride 82 | 83 | def forward(self, x): 84 | residual = x 85 | 86 | out = self.conv1(x) 87 | out = self.bn1(out) 88 | out = self.relu(out) 89 | 90 | out = self.conv2(out) 91 | out = self.bn2(out) 92 | out = self.relu(out) 93 | 94 | out = self.conv3(out) 95 | out = self.bn3(out) 96 | 97 | if self.downsample is not None: 98 | residual = self.downsample(x) 99 | 100 | out += residual 101 | out = self.relu(out) 102 | 103 | return out 104 | 105 | 106 | class PoseResNet(nn.Module): 107 | 108 | def __init__(self, block, layers, heads, head_conv, **kwargs): 109 | self.inplanes = 64 110 | self.deconv_with_bias = False 111 | self.heads = heads 112 | 113 | super(PoseResNet, self).__init__() 114 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 115 | self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM) 116 | self.relu = nn.ReLU(inplace=True) 117 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 118 | self.layer1 = self._make_layer(block, 64, layers[0]) 119 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 120 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 121 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 122 | 123 | # used for deconv layers 124 | self.deconv_layers = self._make_deconv_layer( 125 | 3, 126 | [256, 256, 256], 127 | [4, 4, 4], 128 | ) 129 | # self.final_layer = [] 130 | 131 | for head in sorted(self.heads): 132 | num_output = self.heads[head] 133 | if head_conv > 0: 134 | fc = nn.Sequential( 135 | nn.Conv2d(256, head_conv, kernel_size=3, padding=1, bias=True), 136 | nn.ReLU(inplace=True), 137 | nn.Conv2d(head_conv, num_output, kernel_size=1, stride=1, padding=0) 138 | ) 139 | else: 140 | fc = nn.Conv2d( 141 | in_channels=256, 142 | out_channels=num_output, 143 | kernel_size=1, 144 | stride=1, 145 | padding=0 146 | ) 147 | self.__setattr__(head, fc) 148 | 149 | # self.final_layer = nn.ModuleList(self.final_layer) 150 | 151 | def _make_layer(self, block, planes, blocks, stride=1): 152 | downsample = None 153 | if stride != 1 or self.inplanes != planes * block.expansion: 154 | downsample = nn.Sequential( 155 | nn.Conv2d( 156 | self.inplanes, 157 | planes * block.expansion, 158 | kernel_size=1, 159 | stride=stride, 160 | bias=False 161 | ), 162 | nn.BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM), 163 | ) 164 | 165 | layers = [] 166 | layers.append(block(self.inplanes, planes, stride, downsample)) 167 | self.inplanes = planes * block.expansion 168 | for i in range(1, blocks): 169 | layers.append(block(self.inplanes, planes)) 170 | 171 | return nn.Sequential(*layers) 172 | 173 | def _get_deconv_cfg(self, deconv_kernel, index): 174 | if deconv_kernel == 4: 175 | padding = 1 176 | output_padding = 0 177 | elif deconv_kernel == 3: 178 | padding = 1 179 | output_padding = 1 180 | elif deconv_kernel == 2: 181 | padding = 0 182 | output_padding = 0 183 | 184 | return deconv_kernel, padding, output_padding 185 | 186 | def _make_deconv_layer(self, num_layers, num_filters, num_kernels): 187 | assert num_layers == len(num_filters), \ 188 | 'ERROR: num_deconv_layers is different len(num_deconv_filters)' 189 | assert num_layers == len(num_kernels), \ 190 | 'ERROR: num_deconv_layers is different len(num_deconv_filters)' 191 | 192 | layers = [] 193 | for i in range(num_layers): 194 | kernel, padding, output_padding = \ 195 | self._get_deconv_cfg(num_kernels[i], i) 196 | 197 | planes = num_filters[i] 198 | layers.append( 199 | nn.ConvTranspose2d( 200 | in_channels=self.inplanes, 201 | out_channels=planes, 202 | kernel_size=kernel, 203 | stride=2, 204 | padding=padding, 205 | output_padding=output_padding, 206 | bias=self.deconv_with_bias 207 | ) 208 | ) 209 | layers.append(nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)) 210 | layers.append(nn.ReLU(inplace=True)) 211 | self.inplanes = planes 212 | 213 | return nn.Sequential(*layers) 214 | 215 | def forward(self, x): 216 | x = self.conv1(x) 217 | x = self.bn1(x) 218 | x = self.relu(x) 219 | x = self.maxpool(x) 220 | 221 | x = self.layer1(x) 222 | x = self.layer2(x) 223 | x = self.layer3(x) 224 | x = self.layer4(x) 225 | 226 | x = self.deconv_layers(x) 227 | ret = [] 228 | for head in self.heads: 229 | ret.append(self.__getattr__(head)(x)) 230 | return ret 231 | 232 | def init_weights(self, num_layers, pretrained=True): 233 | if pretrained: 234 | # print('=> init resnet deconv weights from normal distribution') 235 | for _, m in self.deconv_layers.named_modules(): 236 | if isinstance(m, nn.ConvTranspose2d): 237 | # print('=> init {}.weight as normal(0, 0.001)'.format(name)) 238 | # print('=> init {}.bias as 0'.format(name)) 239 | nn.init.normal_(m.weight, std=0.001) 240 | if self.deconv_with_bias: 241 | nn.init.constant_(m.bias, 0) 242 | elif isinstance(m, nn.BatchNorm2d): 243 | # print('=> init {}.weight as 1'.format(name)) 244 | # print('=> init {}.bias as 0'.format(name)) 245 | nn.init.constant_(m.weight, 1) 246 | nn.init.constant_(m.bias, 0) 247 | # print('=> init final conv weights from normal distribution') 248 | for head in self.heads: 249 | final_layer = self.__getattr__(head) 250 | for i, m in enumerate(final_layer.modules()): 251 | if isinstance(m, nn.Conv2d): 252 | # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 253 | # print('=> init {}.weight as normal(0, 0.001)'.format(name)) 254 | # print('=> init {}.bias as 0'.format(name)) 255 | if m.weight.shape[0] == self.heads[head]: 256 | if 'hm' in head: 257 | nn.init.constant_(m.bias, -2.19) 258 | else: 259 | nn.init.normal_(m.weight, std=0.001) 260 | nn.init.constant_(m.bias, 0) 261 | #pretrained_state_dict = torch.load(pretrained) 262 | # url = model_urls['resnet{}'.format(num_layers)] 263 | # pretrained_state_dict = model_zoo.load_url(url) 264 | # print('=> loading pretrained model {}'.format(url)) 265 | # self.load_state_dict(pretrained_state_dict, strict=False) 266 | # else: 267 | # print('=> imagenet pretrained model dose not exist') 268 | # print('=> please download it first') 269 | # raise ValueError('imagenet pretrained model does not exist') 270 | 271 | 272 | resnet_spec = { 273 | 18: (BasicBlock, [2, 2, 2, 2]), 274 | 34: (BasicBlock, [3, 4, 6, 3]), 275 | 50: (Bottleneck, [3, 4, 6, 3]), 276 | 101: (Bottleneck, [3, 4, 23, 3]), 277 | 152: (Bottleneck, [3, 8, 36, 3]) 278 | } 279 | 280 | 281 | def get_pose_net(num_layers, heads, head_conv): 282 | block_class, layers = resnet_spec[num_layers] 283 | 284 | model = PoseResNet(block_class, layers, heads, head_conv=head_conv) 285 | model.init_weights(num_layers, pretrained=True) 286 | return model 287 | -------------------------------------------------------------------------------- /example/detection/post_process.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn import functional 6 | 7 | from torchvision import transforms 8 | 9 | TOP_K = 100 10 | HEATMAP_SIZE = 128 11 | N_CATEGORY = 1 12 | 13 | 14 | def pixel_nms(heatmap, ksize=3, cutoff=1e-7): 15 | pool_heatmap = functional.max_pool2d( 16 | heatmap, kernel_size=ksize, stride=1, padding=ksize // 2 17 | ) 18 | zeros = torch.zeros_like(heatmap) 19 | score = torch.sigmoid(heatmap) 20 | pool_score = torch.sigmoid(pool_heatmap) 21 | diff = torch.abs(score - pool_score) 22 | score = torch.where(diff < cutoff, score, zeros) 23 | return score 24 | 25 | 26 | def get_coord(score, topk=TOP_K): 27 | # score.shape = [N, C, 128, 128] 28 | # score, category, indices = [N, topk] 29 | score = score.flatten(2, 3) # [N, C, 128 * 128] 30 | score, category = score.max(1) # [N, 128 * 128] 31 | score, indices = score.topk(topk, dim=1) 32 | indices = indices.flatten() 33 | category = category.flatten().index_select(0, indices).to(torch.int32) 34 | return score, category, indices 35 | 36 | 37 | def get_bbox_wh(bbox_wh, indices): 38 | # bbox_wh.shape = [N, 2, 128, 128] 39 | bbox_wh = bbox_wh.flatten(2, 3) 40 | bbox_w, bbox_h = bbox_wh[:, 0, :], bbox_wh[:, 1, :] 41 | bbox_w = bbox_w.flatten().index_select(0, indices) 42 | bbox_h = bbox_h.flatten().index_select(0, indices) 43 | return (bbox_w, bbox_h) 44 | 45 | 46 | def get_center_shift(center_shift, indices): 47 | bsize = center_shift.shape[0] 48 | center_shift = center_shift.flatten(2, 3) 49 | shift_w, shift_h = center_shift[:, 0, :], center_shift[:, 1, :] 50 | shift_w = shift_w.flatten().index_select(0, indices) 51 | shift_h = shift_h.flatten().index_select(0, indices) 52 | return (shift_w, shift_h) 53 | 54 | 55 | def remap_bbox_to_raw(trans_mat, x, y): 56 | # trans_mat.shape = [2, 3] 57 | # x.shape = [100] 58 | # y.shape = [100] 59 | x_ = x * trans_mat[0, 0] + y * trans_mat[0, 1] + trans_mat[0, 2] 60 | y_ = x * trans_mat[1, 0] + y * trans_mat[1, 1] + trans_mat[1, 2] 61 | return (x_, y_) 62 | 63 | 64 | def cvt_to_coord(indices, bbox_wh, shift_wh, trans_mat, shape): 65 | # indices, bbox_wh, shift_wh = [N, topk] 66 | # indices = indices.to(torch.float32) 67 | cy_coord = (indices // shape[2]).to(torch.float32) 68 | cx_coord = (indices - cy_coord * shape[2]).to(torch.float32) 69 | cx_coord += shift_wh[0] 70 | cy_coord += shift_wh[1] 71 | 72 | x0 = cx_coord - bbox_wh[0] / 2 73 | y0 = cy_coord - bbox_wh[1] / 2 74 | x1 = cx_coord + bbox_wh[0] / 2 75 | y1 = cy_coord + bbox_wh[1] / 2 76 | 77 | x0, y0 = remap_bbox_to_raw(trans_mat, x0, y0) 78 | x1, y1 = remap_bbox_to_raw(trans_mat, x1, y1) 79 | 80 | bbox = torch.stack([x0, y0, x1, y1], 1) 81 | 82 | return bbox 83 | 84 | 85 | def postprocess(inputs): 86 | heatmap, bbox_wh, center_shift, trans_mat = inputs 87 | score = pixel_nms(heatmap) 88 | score, category, indices = get_coord(score) 89 | bbox_wh = get_bbox_wh(bbox_wh, indices) 90 | shift_wh = get_center_shift(center_shift, indices) 91 | 92 | shape = torch.tensor(heatmap.shape, dtype=torch.long) 93 | bbox = cvt_to_coord(indices, bbox_wh, shift_wh, trans_mat, shape) 94 | 95 | score = score.reshape(1, TOP_K) 96 | category = category.reshape(1, TOP_K) 97 | bbox = bbox.reshape(1, TOP_K, 4) 98 | 99 | return score, category, bbox 100 | 101 | 102 | if __name__ == "__main__": 103 | from trtis import onnx_backend 104 | 105 | inputs_def = [ 106 | { 107 | "name": "heatmap", 108 | "dims": [1, N_CATEGORY, HEATMAP_SIZE, HEATMAP_SIZE], 109 | "data_type": "TYPE_FP32" 110 | }, 111 | { 112 | "name": "bbox_wh", 113 | "dims": [1, 2, HEATMAP_SIZE, HEATMAP_SIZE], 114 | "data_type": "TYPE_FP32" 115 | }, 116 | { 117 | "name": "center_shift", 118 | "dims": [1, 2, HEATMAP_SIZE, HEATMAP_SIZE], 119 | "data_type": "TYPE_FP32" 120 | }, 121 | { 122 | "name": "affine_trans_mat", 123 | "dims": [2, 3], 124 | "data_type": "TYPE_FP32" 125 | } 126 | ] 127 | 128 | outputs_def = [ 129 | { 130 | "name": "score", 131 | "dims": [1, TOP_K], 132 | "data_type": "TYPE_FP32", 133 | }, 134 | { 135 | "name": "category", 136 | "dims": [1, TOP_K], 137 | "data_type": "TYPE_INT32", 138 | }, 139 | { 140 | "name": "bbox", 141 | "dims": [1, TOP_K, 4], 142 | "data_type": "TYPE_FP32", 143 | } 144 | ] 145 | 146 | onnx_backend.torch2onnx( 147 | computation_graph=postprocess, 148 | graph_name="detection-postprocess", 149 | model_file=None, 150 | inputs_def=inputs_def, 151 | outputs_def=outputs_def, 152 | instances=16, 153 | gpus=[0, 1, 2, 3], 154 | version=1, 155 | export_path="../../model_repository" 156 | ) 157 | -------------------------------------------------------------------------------- /example/detection/pre_process.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import tensorflow as tf 5 | 6 | HEATMAP_SIZE = 128 7 | MEAN = (0.408, 0.447, 0.470) 8 | STD = (0.274, 0.274, 0.274) 9 | 10 | 11 | def backward_affine_transform(raw_shape, dst_shape=(HEATMAP_SIZE, HEATMAP_SIZE)): 12 | """ 13 | shape: 14 | dst(3, 3) x T(3, 2) = src(3, 2) 15 | solve T matrix: 16 | |d11, d12, 1| |T11, T12| |s11, s12| 17 | |d21, d22, 1| x |T21, T22| = |s21, s22| 18 | |d31, d32, 1| |T31, T32| |s31, s32| 19 | output: 20 | T(3, 2).transpose = T(2, 3) 21 | """ 22 | raw_shape = tf.cast(raw_shape, tf.float32) 23 | src_cy, src_cx = raw_shape[0] / 2.0, raw_shape[1] / 2.0 24 | src_dc = tf.math.maximum(src_cx, src_cy) 25 | src_points = tf.stack([ 26 | [src_cx, src_cy], 27 | [src_cx, src_cy - src_dc], 28 | [src_cx - src_dc, src_cy] 29 | ]) 30 | 31 | dst_shape = tf.cast(dst_shape, tf.float32) 32 | dst_cy, dst_cx = dst_shape[0] / 2.0, dst_shape[1] / 2.0 33 | dst_dc = tf.math.maximum(dst_cx, dst_cy) 34 | dst_points = tf.stack([ 35 | [dst_cx, dst_cy, 1], 36 | [dst_cx, dst_cy - dst_dc, 1], 37 | [dst_cx - dst_dc, dst_cy, 1] 38 | ]) 39 | 40 | trans_mat = tf.linalg.solve(dst_points, src_points, name="affine") 41 | trans_mat = tf.transpose(trans_mat, perm=(1, 0)) 42 | 43 | return trans_mat 44 | 45 | 46 | def image_decoder(image_bytes): 47 | with tf.name_scope("image_decoder") as scope: 48 | image = tf.io.decode_image( 49 | image_bytes[0], 50 | channels=None, 51 | dtype=tf.dtypes.uint8, 52 | name="decode_image", 53 | expand_animations=False 54 | ) 55 | image = tf.expand_dims(image, 0) # [1, H, W, 3] 56 | raw_shape = tf.shape(image)[1:3] 57 | affine_trans_mat = backward_affine_transform(raw_shape) 58 | return image, affine_trans_mat 59 | 60 | 61 | def _resize_and_pad(image, size=512): 62 | with tf.name_scope("image_resize") as scope: 63 | image = tf.image.resize_image_with_pad( 64 | image, target_height=size, target_width=size 65 | ) 66 | image = tf.transpose(image, perm=[0, 3, 1, 2]) 67 | return image 68 | 69 | 70 | def _normalize(tensor, mean=MEAN, std=STD): 71 | with tf.name_scope("normalize") as scope: 72 | tensor = tf.cast(tensor, tf.float32) 73 | tensor = tensor / 255.0 74 | mean = tf.constant(mean, dtype=tf.float32, name="mean") 75 | mean = tf.reshape(mean, [1, 3, 1, 1]) 76 | std = tf.constant(std, dtype=tf.float32, name="std") 77 | std = tf.reshape(std, [1, 3, 1, 1]) 78 | tensor = (tensor - mean) / std 79 | return tensor 80 | 81 | 82 | def preprocess(raw_image): 83 | image, affine_trans_mat = image_decoder(raw_image) 84 | image = _resize_and_pad(image) 85 | tensor = _normalize(image) 86 | return tensor, affine_trans_mat 87 | 88 | 89 | if __name__ == "__main__": 90 | from trtis import tf_backend 91 | 92 | inputs_def = [ 93 | { 94 | "name": "raw_image", 95 | "dims": [1], 96 | "data_type": "TYPE_STRING", 97 | } 98 | ] 99 | 100 | outputs_def = [ 101 | { 102 | "name": "process_img", 103 | "dims": [1, 3, 512, 512], 104 | "data_type": "TYPE_FP32" 105 | }, 106 | { 107 | "name": "affine_trans_mat", 108 | "dims": [2, 3], 109 | "data_type": "TYPE_FP32" 110 | } 111 | ] 112 | 113 | tf_backend.tf2graphdef( 114 | computation_graph=preprocess, 115 | graph_name="detection-preprocess", 116 | model_file=None, 117 | inputs_def=inputs_def, 118 | outputs_def=outputs_def, 119 | instances=16, 120 | gpus=[0, 1, 2, 3], 121 | version=1, 122 | export_path="../../model_repository" 123 | ) 124 | -------------------------------------------------------------------------------- /example/test-data/widerface.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/layerism/TensorRT-Inference-Server-Tutorial/c40d485f69c1349c01b40a4eb7c225b01bf8dbe0/example/test-data/widerface.jpg -------------------------------------------------------------------------------- /install.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd backend 4 | python setup.py install 5 | cd - 6 | 7 | cd client_py 8 | python setup.py install 9 | cd - 10 | -------------------------------------------------------------------------------- /start.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | HTTP_PORT=7000 4 | GRPC_PORT=7001 5 | METRIC_PORT=7002 6 | DOCKER_IMAGE=nvcr.io/nvidia/tensorrtserver:19.12-py3 7 | 8 | docker run --rm \ 9 | --runtime nvidia \ 10 | --name trt_server \ 11 | --shm-size=4g \ 12 | --ulimit memlock=-1 \ 13 | --ulimit stack=67108864 \ 14 | -p${HTTP_PORT}:8000 \ 15 | -p${GRPC_PORT}:8001 \ 16 | -p${METRIC_PORT}:8002 \ 17 | -v`pwd`/model_repository/:/models \ 18 | ${DOCKER_IMAGE} \ 19 | trtserver --model-repository=/models 20 | --------------------------------------------------------------------------------