├── triton ├── model_repository │ ├── simple_yolov5 │ │ ├── 1 │ │ │ └── .keep │ │ └── config.pbtxt │ ├── simple_yolov5_ensemble │ │ ├── 1 │ │ │ └── .keep │ │ └── config.pbtxt │ ├── yolov5_dynamic_batched_nms │ │ ├── 1 │ │ │ └── .keep │ │ ├── labels.txt │ │ └── config.pbtxt │ ├── nms │ │ ├── 1 │ │ │ ├── utils.py │ │ │ ├── model.py │ │ │ └── triton_python_backend_utils.py │ │ └── config.pbtxt │ └── simple_yolov5_bls │ │ ├── 1 │ │ ├── utils.py │ │ ├── model.py │ │ └── triton_python_backend_utils.py │ │ └── config.pbtxt └── generate_input.py ├── assets ├── bls.png ├── after.png ├── before.png ├── bls_arc.png ├── ensemble.png ├── bls_ensemble.png ├── radar_plot.png ├── python_backend.png ├── simple_output.png ├── bls_arc_official.png ├── thoughput_latency.png └── python_backend_official.png ├── docker ├── requirements_nvidia.txt ├── build.sh ├── Dockerfile ├── sources.list └── requirements.txt ├── LICENSE ├── README_CN.md ├── docs ├── pipelines.md ├── pipelines_EN.md ├── bls_vs_ensemble.md ├── bls_vs_ensemble_EN.md ├── batchedNMS.md ├── batchedNMS_EN.md ├── custom_yolov5_detect_layer.md └── custom_yolov5_detect_layer_EN.md ├── README.md ├── common.py ├── trt_infer.py └── export.py /triton/model_repository/simple_yolov5/1/.keep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /triton/model_repository/simple_yolov5_ensemble/1/.keep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /triton/model_repository/yolov5_dynamic_batched_nms/1/.keep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /assets/bls.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bug-developer021/YOLOV5_optimization_on_triton/HEAD/assets/bls.png -------------------------------------------------------------------------------- /assets/after.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bug-developer021/YOLOV5_optimization_on_triton/HEAD/assets/after.png -------------------------------------------------------------------------------- /assets/before.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bug-developer021/YOLOV5_optimization_on_triton/HEAD/assets/before.png -------------------------------------------------------------------------------- /assets/bls_arc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bug-developer021/YOLOV5_optimization_on_triton/HEAD/assets/bls_arc.png -------------------------------------------------------------------------------- /assets/ensemble.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bug-developer021/YOLOV5_optimization_on_triton/HEAD/assets/ensemble.png -------------------------------------------------------------------------------- /assets/bls_ensemble.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bug-developer021/YOLOV5_optimization_on_triton/HEAD/assets/bls_ensemble.png -------------------------------------------------------------------------------- /assets/radar_plot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bug-developer021/YOLOV5_optimization_on_triton/HEAD/assets/radar_plot.png -------------------------------------------------------------------------------- /assets/python_backend.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bug-developer021/YOLOV5_optimization_on_triton/HEAD/assets/python_backend.png -------------------------------------------------------------------------------- /assets/simple_output.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bug-developer021/YOLOV5_optimization_on_triton/HEAD/assets/simple_output.png -------------------------------------------------------------------------------- /assets/bls_arc_official.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bug-developer021/YOLOV5_optimization_on_triton/HEAD/assets/bls_arc_official.png -------------------------------------------------------------------------------- /assets/thoughput_latency.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bug-developer021/YOLOV5_optimization_on_triton/HEAD/assets/thoughput_latency.png -------------------------------------------------------------------------------- /assets/python_backend_official.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bug-developer021/YOLOV5_optimization_on_triton/HEAD/assets/python_backend_official.png -------------------------------------------------------------------------------- /docker/requirements_nvidia.txt: -------------------------------------------------------------------------------- 1 | # nvidia 2 | nvidia-pyindex 3 | onnx-graphsurgeon 4 | # match the docker's trt version 5 | nvidia-tensorrt==8.2.3 6 | pycuda -------------------------------------------------------------------------------- /docker/build.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | 4 | HTTP_PROXY= 5 | HTTPS_PROXY= 6 | NO_PROXY=localhost,127.0.0.1 7 | UBUNTU_VERSION=2004 8 | 9 | DOCKERFILE=Dockerfile 10 | TAG=nvcr.io/nvidia/tritonserver:22.03-py3-custom 11 | 12 | docker build -f $DOCKERFILE --network host --build-arg UBUNTU_VERSION=$UBUNTU_VERSION --build-arg HTTP_PROXY=$HTTP_PROXY --build-arg HTTPS_PROXY=$HTTPS_PROXY --build-arg NO_PROXY=$NO_PROXY -t $TAG . -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvcr.io/nvidia/tritonserver:22.03-py3 2 | USER root 3 | 4 | 5 | ARG HTTP_PROXY 6 | ARG HTTPS_PROXY 7 | ARG NO_PROXY 8 | 9 | 10 | WORKDIR /opt/tritonserver/myapps 11 | 12 | 13 | COPY sources.list sources.list 14 | RUN mv sources.list /etc/apt/sources.list 15 | 16 | RUN apt-get update && apt-get install -yq --no-install-recommends \ 17 | python3-scipy && \ 18 | rm -rf /var/lib/apt/lists/* 19 | 20 | 21 | COPY requirements* ./ 22 | RUN pip3 install --no-cache --upgrade pip setuptools && \ 23 | pip3 install --no-cache --upgrade -r requirements.txt && \ 24 | pip3 install --no-cache --upgrade -r requirements_nvidia.txt --extra-index-url https://pypi.ngc.nvidia.com 25 | 26 | # USER triton-server -------------------------------------------------------------------------------- /triton/model_repository/nms/config.pbtxt: -------------------------------------------------------------------------------- 1 | name: "nms" 2 | backend: "python" 3 | max_batch_size: 8 4 | input [ 5 | { 6 | name: "candidate_boxes" 7 | data_type: TYPE_FP32 8 | dims: [ 1000, 6 ] 9 | } 10 | ] 11 | 12 | output [ 13 | { 14 | name: "BBOXES" 15 | data_type: TYPE_FP32 16 | # padding number of bboxes to 300 17 | dims: [ 300, 6 ] 18 | } 19 | ] 20 | 21 | 22 | 23 | model_warmup [ 24 | { 25 | batch_size: 8 26 | name: "warmup_requests" 27 | inputs: { 28 | key: "candidate_boxes" 29 | value: { 30 | random_data: true 31 | dims: [1000, 6] 32 | data_type: TYPE_FP32 33 | } 34 | } 35 | } 36 | ] 37 | 38 | 39 | parameters: { 40 | key: "FORCE_CPU_ONLY_INPUT_TENSORS" 41 | value: {string_value:"no"} 42 | } 43 | 44 | -------------------------------------------------------------------------------- /docker/sources.list: -------------------------------------------------------------------------------- 1 | deb https://mirrors.ustc.edu.cn/ubuntu/ focal main restricted universe multiverse 2 | deb-src https://mirrors.ustc.edu.cn/ubuntu/ focal main restricted universe multiverse 3 | deb https://mirrors.ustc.edu.cn/ubuntu/ focal-updates main restricted universe multiverse 4 | deb-src https://mirrors.ustc.edu.cn/ubuntu/ focal-updates main restricted universe multiverse 5 | deb https://mirrors.ustc.edu.cn/ubuntu/ focal-backports main restricted universe multiverse 6 | deb-src https://mirrors.ustc.edu.cn/ubuntu/ focal-backports main restricted universe multiverse 7 | deb https://mirrors.ustc.edu.cn/ubuntu/ focal-security main restricted universe multiverse 8 | deb-src https://mirrors.ustc.edu.cn/ubuntu/ focal-security main restricted universe multiverse 9 | deb https://mirrors.ustc.edu.cn/ubuntu/ focal-proposed main restricted universe multiverse 10 | deb-src https://mirrors.ustc.edu.cn/ubuntu/ focal-proposed main restricted universe multi -------------------------------------------------------------------------------- /triton/model_repository/simple_yolov5_bls/config.pbtxt: -------------------------------------------------------------------------------- 1 | name: "simple_yolov5_bls" 2 | backend: "python" 3 | max_batch_size: 8 4 | input [ 5 | { 6 | name: "images" 7 | data_type: TYPE_FP32 8 | format: FORMAT_NCHW 9 | dims: [ 3, 640, 640 ] 10 | } 11 | ] 12 | 13 | output [ 14 | { 15 | name: "BBOXES" 16 | data_type: TYPE_FP32 17 | # padding number of bboxes to 300 18 | dims: [ 300, 6 ] 19 | } 20 | ] 21 | 22 | # instance_group [ 23 | # { 24 | # name: "simple_yolov5_bls" 25 | # count: 1 26 | # kind: KIND_GPU 27 | # gpus: [0] 28 | # } 29 | # ] 30 | 31 | model_warmup [ 32 | { 33 | batch_size: 8 34 | name: "warmup_requests" 35 | inputs: { 36 | key: "images" 37 | value: { 38 | random_data: true 39 | dims: [3 ,640, 640] 40 | data_type: TYPE_FP32 41 | } 42 | } 43 | } 44 | ] 45 | 46 | 47 | parameters: { 48 | key: "FORCE_CPU_ONLY_INPUT_TENSORS" 49 | value: {string_value:"no"} 50 | } 51 | 52 | -------------------------------------------------------------------------------- /triton/model_repository/yolov5_dynamic_batched_nms/labels.txt: -------------------------------------------------------------------------------- 1 | person 2 | bicycle 3 | car 4 | motorbike 5 | aeroplane 6 | bus 7 | train 8 | truck 9 | boat 10 | traffic light 11 | fire hydrant 12 | stop sign 13 | parking meter 14 | bench 15 | bird 16 | cat 17 | dog 18 | horse 19 | sheep 20 | cow 21 | elephant 22 | bear 23 | zebra 24 | giraffe 25 | backpack 26 | umbrella 27 | handbag 28 | tie 29 | suitcase 30 | frisbee 31 | skis 32 | snowboard 33 | sports ball 34 | kite 35 | baseball bat 36 | baseball glove 37 | skateboard 38 | surfboard 39 | tennis racket 40 | bottle 41 | wine glass 42 | cup 43 | fork 44 | knife 45 | spoon 46 | bowl 47 | banana 48 | apple 49 | sandwich 50 | orange 51 | broccoli 52 | carrot 53 | hot dog 54 | pizza 55 | donut 56 | cake 57 | chair 58 | sofa 59 | pottedplant 60 | bed 61 | diningtable 62 | toilet 63 | tvmonitor 64 | laptop 65 | mouse 66 | remote 67 | keyboard 68 | cell phone 69 | microwave 70 | oven 71 | toaster 72 | sink 73 | refrigerator 74 | book 75 | clock 76 | vase 77 | scissors 78 | teddy bear 79 | hair drier 80 | toothbrush -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 bug-developer021 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /docker/requirements.txt: -------------------------------------------------------------------------------- 1 | # pip install -r requirements.txt 2 | 3 | # Base ---------------------------------------- 4 | matplotlib>=3.2.2 5 | numpy>=1.18.5 6 | #opencv-python>=4.1.2 7 | Pillow>=7.1.2 8 | PyYAML>=5.3.1 9 | requests>=2.23.0 10 | scipy>=1.4.1 11 | #torch>=1.7.0 12 | #torchvision>=0.8.1 13 | tqdm>=4.41.0 14 | 15 | 16 | # Logging ------------------------------------- 17 | tensorboard>=2.4.1 18 | # wandb 19 | 20 | # Plotting ------------------------------------ 21 | pandas>=1.1.4 22 | seaborn>=0.11.0 23 | 24 | 25 | 26 | # Export -------------------------------------- 27 | # coremltools>=4.1 # CoreML export 28 | onnx>=1.9.0 # ONNX export 29 | onnx-simplifier>=0.3.6 # ONNX simplifier 30 | scikit-learn==0.19.2 # CoreML quantization 31 | # tensorflow>=2.4.1 # TFLite export 32 | # tensorflowjs>=3.9.0 # TF.js export 33 | # openvino-dev # OpenVINO export 34 | 35 | 36 | # Extras -------------------------------------- 37 | # albumentations>=1.0.3 38 | # Cython # for pycocotools https://github.com/cocodataset/cocoapi/issues/172 39 | pycocotools>=2.0 # COCO mAP 40 | # roboflow 41 | # thop # FLOPs computation 42 | 43 | # notebook 44 | notebook 45 | jupyterlab 46 | -------------------------------------------------------------------------------- /triton/model_repository/simple_yolov5/config.pbtxt: -------------------------------------------------------------------------------- 1 | name: "simple_yolov5" 2 | platform: "tensorrt_plan" 3 | backend: "tensorrt" 4 | default_model_filename: "model.plan" 5 | max_batch_size: 32 6 | 7 | input: [ 8 | { 9 | name: "images" 10 | data_type: TYPE_FP32 11 | format: FORMAT_NONE 12 | dims: [3, 640, 640] 13 | } 14 | ] 15 | 16 | output: [ 17 | { 18 | name: "output" 19 | data_type: TYPE_FP32 20 | dims: [1000, 6] 21 | label_filename: "" 22 | } 23 | ] 24 | 25 | 26 | instance_group: [ 27 | { 28 | name: "simple_yolov5" 29 | kind: KIND_GPU 30 | count: 1 31 | gpus: [0] 32 | } 33 | ] 34 | 35 | version_policy { 36 | latest: { 37 | num_versions: 1 38 | } 39 | } 40 | 41 | # dynamic_batching { 42 | # max_queue_delay_microseconds: 100 43 | # } 44 | 45 | model_warmup [ 46 | { 47 | batch_size: 8 48 | name: "warmup_requests" 49 | inputs: { 50 | key: "images" 51 | value: { 52 | random_data: true 53 | dims: [3 ,640, 640] 54 | data_type: TYPE_FP32 55 | } 56 | } 57 | } 58 | ] 59 | -------------------------------------------------------------------------------- /triton/model_repository/simple_yolov5_ensemble/config.pbtxt: -------------------------------------------------------------------------------- 1 | name: "simple_yolov5_ensemble" 2 | platform: "ensemble" 3 | max_batch_size: 8 4 | input [ 5 | { 6 | name: "ENSEMBLE_INPUT_0" 7 | data_type: TYPE_FP32 8 | dims: [3, 640, 640] 9 | } 10 | ] 11 | 12 | output [ 13 | { 14 | name: "ENSEMBLE_OUTPUT_0" 15 | data_type: TYPE_FP32 16 | dims: [ 300, 6 ] 17 | } 18 | ] 19 | 20 | ensemble_scheduling { 21 | step [ 22 | { 23 | model_name: "simple_yolov5" 24 | model_version: 1 25 | input_map: { 26 | key: "images" 27 | value: "ENSEMBLE_INPUT_0" 28 | } 29 | output_map: { 30 | key: "output" 31 | value: "FILTER_BBOXES" 32 | } 33 | }, 34 | { 35 | model_name: "nms" 36 | model_version: 1 37 | input_map: { 38 | key: "candidate_boxes" 39 | value: "FILTER_BBOXES" 40 | } 41 | output_map: { 42 | key: "BBOXES" 43 | value: "ENSEMBLE_OUTPUT_0" 44 | } 45 | } 46 | ] 47 | } 48 | 49 | 50 | version_policy { 51 | latest: { 52 | num_versions: 1 53 | } 54 | } 55 | 56 | parameters: { 57 | key: "FORCE_CPU_ONLY_INPUT_TENSORS" 58 | value: {string_value:"no"} 59 | } 60 | 61 | -------------------------------------------------------------------------------- /triton/model_repository/yolov5_dynamic_batched_nms/config.pbtxt: -------------------------------------------------------------------------------- 1 | 2 | name: "yolov5_batched_nms_dynamic" 3 | platform: "tensorrt_plan" 4 | backend: "tensorrt" 5 | default_model_filename: "model.plan" 6 | max_batch_size: 8 7 | 8 | input: [ 9 | { 10 | name: "images" 11 | data_type: TYPE_FP32 12 | format: FORMAT_NONE 13 | dims: [3, 640, 640] 14 | } 15 | ] 16 | 17 | output: [ 18 | { 19 | name: "BatchedNMS" 20 | data_type: TYPE_INT32 21 | dims: [1] 22 | label_filename: "" 23 | }, 24 | { 25 | name: "BatchedNMS_1" 26 | data_type: TYPE_FP32 27 | dims: [300, 4] 28 | label_filename: "" 29 | }, 30 | { 31 | name: "BatchedNMS_2" 32 | data_type: TYPE_FP32 33 | dims: [300] 34 | label_filename: "" 35 | }, 36 | { 37 | name: "BatchedNMS_3" 38 | data_type: TYPE_FP32 39 | dims: [300] 40 | label_filename: "" 41 | } 42 | ] 43 | 44 | # batch_input: [] 45 | # batch_output: [] 46 | 47 | instance_group: [ 48 | { 49 | name: "yolov5_batched_nms_dynamic" 50 | kind: KIND_GPU 51 | count: 1 52 | gpus: [0] 53 | } 54 | ] 55 | 56 | version_policy { 57 | latest: { 58 | num_versions: 1 59 | } 60 | } 61 | 62 | # dynamic_batching { 63 | # preferred_batch_size: [8] 64 | # max_queue_delay_microseconds: 1000 65 | # } 66 | 67 | model_warmup [ 68 | { 69 | batch_size: 1 70 | name: "warmup_requests" 71 | inputs: { 72 | key: "images" 73 | value: { 74 | random_data: true 75 | dims: [3 ,640, 640] 76 | data_type: TYPE_FP32 77 | } 78 | } 79 | } 80 | ] -------------------------------------------------------------------------------- /triton/generate_input.py: -------------------------------------------------------------------------------- 1 | # generate real input data for triton perf_analyzer 2 | 3 | import sys 4 | 5 | sys.path.append('../') 6 | from trt_infer import load_images_cv 7 | import argparse 8 | import numpy as np 9 | import json 10 | import os 11 | from tqdm import tqdm 12 | 13 | 14 | 15 | 16 | 17 | if __name__ == '__main__': 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument('--input_images', type=str, help='input images', required=True) 20 | parser.add_argument('--output_file', type=str, help='output file', required=True) 21 | parser.add_argument('--input_size', type=int, default=640, help="Input Size") 22 | parser.add_argument('--input_tensor_name', type=str, default='images', help="Input Tensor Name") 23 | 24 | 25 | args = parser.parse_args() 26 | 27 | img_root = args.input_images 28 | output_file = args.output_file 29 | input_size=args.input_size 30 | input_tensor_name = args.input_tensor_name 31 | new_shape = (input_size, input_size) 32 | 33 | triton_input_list = [] 34 | for img_name in tqdm(sorted(os.listdir(img_root))): 35 | if os.path.splitext(img_name)[-1] not in ['.jpg', '.png', '.jpeg']: 36 | continue 37 | img_path = os.path.join(img_root, img_name) 38 | 39 | input_image, _ = load_images_cv(img_path, new_shape) 40 | input_image = np.squeeze(input_image) 41 | triton_input_shape = input_image.shape 42 | 43 | # flatten_img = input_image.flatten().astype(np.float16) 44 | flatten_img = input_image.flatten() 45 | 46 | triton_input = { 47 | input_tensor_name: 48 | { 49 | "content": flatten_img.tolist(), 50 | "shape": list(triton_input_shape) 51 | } 52 | } 53 | triton_input_list.append(triton_input) 54 | 55 | with open(output_file, "w") as f: 56 | json.dump( 57 | {"data": triton_input_list}, f 58 | ) 59 | 60 | 61 | -------------------------------------------------------------------------------- /triton/model_repository/nms/1/utils.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import torch 4 | import torchvision 5 | from torch.nn import functional as F 6 | import time 7 | 8 | def xywh2xyxy(x): 9 | # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right 10 | y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) 11 | y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x 12 | y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y 13 | y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x 14 | y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y 15 | return y 16 | 17 | def postprocess(output, conf_th=0.25, nms_threshold=0.45, max_det=300): 18 | """Postprocess TensorRT outputs. 19 | # Args 20 | output: list of detections with schema 21 | [batch_size, num_detections, xywh + conf + cls_id] 22 | 23 | conf_th: confidence threshold 24 | nms_threshold: nms threshold 25 | 26 | # Returns 27 | list of bounding boxes with all detections above threshold and after nms, see class BoundingBox 28 | [num_detections , xyxy + conf + cls_id] * batch_size 29 | """ 30 | 31 | # Get the num of boxes detected 32 | output_candidates = output[..., 4] > conf_th 33 | output_bboxes = [torch.zeros((0, 6), device=output.device)] * output.shape[0] 34 | for xi, x in enumerate(output): 35 | # Apply confidence constraints 36 | x = x[output_candidates[xi]] 37 | 38 | if not x.shape[0]: 39 | continue 40 | 41 | boxes = xywh2xyxy(x[:, :4]) 42 | scores = x[:, 4] 43 | i = torchvision.ops.nms(boxes, scores, nms_threshold) 44 | 45 | # padding boxes to 300 46 | if i.shape[0] > max_det: # limit detections 47 | i = i[:max_det] 48 | 49 | bbox_pad_nums = max_det - i.shape[0] 50 | 51 | output_bboxes[xi] = F.pad(x[i], (0,0,0, bbox_pad_nums), value=0) 52 | # output_bboxes[xi] = x[i] 53 | 54 | return torch.stack(output_bboxes, dim=0) 55 | 56 | def time_sync(): 57 | # pytorch-accurate time 58 | if torch.cuda.is_available(): 59 | torch.cuda.synchronize() 60 | return time.time() -------------------------------------------------------------------------------- /triton/model_repository/simple_yolov5_bls/1/utils.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import torch 4 | import torchvision 5 | from torch.nn import functional as F 6 | import time 7 | 8 | def xywh2xyxy(x): 9 | # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right 10 | y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) 11 | y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x 12 | y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y 13 | y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x 14 | y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y 15 | return y 16 | 17 | def postprocess(output, conf_th=0.25, nms_threshold=0.45, max_det=300): 18 | """Postprocess TensorRT outputs. 19 | # Args 20 | output: list of detections with schema 21 | [batch_size, num_detections, xywh + conf + cls_id] 22 | 23 | conf_th: confidence threshold 24 | nms_threshold: nms threshold 25 | 26 | # Returns 27 | list of bounding boxes with all detections above threshold and after nms, see class BoundingBox 28 | [num_detections , xyxy + conf + cls_id] * batch_size 29 | """ 30 | 31 | # Get the num of boxes detected 32 | output_candidates = output[..., 4] > conf_th 33 | output_bboxes = [torch.zeros((0, 6), device=output.device)] * output.shape[0] 34 | for xi, x in enumerate(output): 35 | # Apply confidence constraints 36 | x = x[output_candidates[xi]] 37 | 38 | if not x.shape[0]: 39 | continue 40 | 41 | boxes = xywh2xyxy(x[:, :4]) 42 | scores = x[:, 4] 43 | i = torchvision.ops.nms(boxes, scores, nms_threshold) 44 | 45 | # padding boxes to 300 46 | if i.shape[0] > max_det: # limit detections 47 | i = i[:max_det] 48 | 49 | bbox_pad_nums = max_det - i.shape[0] 50 | 51 | output_bboxes[xi] = F.pad(x[i], (0,0,0, bbox_pad_nums), value=0) 52 | # output_bboxes[xi] = x[i] 53 | 54 | return torch.stack(output_bboxes, dim=0) 55 | 56 | def time_sync(): 57 | # pytorch-accurate time 58 | if torch.cuda.is_available(): 59 | torch.cuda.synchronize() 60 | return time.time() -------------------------------------------------------------------------------- /README_CN.md: -------------------------------------------------------------------------------- 1 | Zh-CN| [English](README.md) 2 | 3 | # YOLOV5 optimization on Triton Inference Server 4 | 5 | 6 | 在Triton中部署yolov5目标检测服务, 并分别进行了如下优化: 7 | 1. [轻量化Detect层的Output](./docs/custom_yolov5_detect_layer.md) 8 | 2. [集成TensorRT的BatchedNMSPlugin到engine中](./docs/batchedNMS.md) 9 | 3. [通过Triton Pipelines部署](./docs/pipelines.md) 10 | 11 | 其中Pipelines分别通过`Ensemble`和`BLS`两种方式来实现,Pipelines的infer模块是基于上述1中精简后的TensorRT Engine部署, Postprocess模块则通过Python Backend实现, 工作流参考[如何部署Triton Pipelines](./docs/pipelines.md#3-如何部署triton-pipelines) 12 | 13 | --- 14 | ## Environment 15 | - CPU: 4cores 16GB 16 | - GPU: Nvidia Tesla T4 17 | - Cuda: 11.6 18 | - TritonServer: 2.20.0 19 | - TensorRT: 8.2.3 20 | - Yolov5: v6.1 21 | 22 | 23 | 24 | 25 | --- 26 | 27 | ## Benchmark 28 | 一台机器部署Triton Inference Server, 在另外一台机器上通过Perf_analyzer通过gRPC调用接口, 对比测试`BLS Pipelines`、`Ensemble Pipelines`、`BatchedNMS`这三种部署方式在并发数逐渐增加条件下的性能表现。 29 | 30 | - [生成真实数据](https://github.com/triton-inference-server/server/blob/main/docs/user_guide/perf_analyzer.md#real-input-data) 31 | 32 | ```shell 33 | python generate_input.py --input_images ----output_file .json 34 | ``` 35 | 36 | 37 | - 利用真实数据进行测试 38 | ```shell 39 | perf_analyzer -m -b 8 --input-data .json --concurrency-range 1:10 --measurement-interval 10000 -u -i gRPC -f .csv 40 | ``` 41 | 42 | 43 | 数据显示`BatchedNMS`这一方式整体性相对更好,更快在并发数较大的情况下收敛到最优性能,在低时延下达到较高的吞吐; 而`Ensemble Pipelines`和`BLS Pipelines`则在并发数较小时性能更好,但是随着并发数的增加,性能下降的幅度更大。 44 | 45 | ![](./assets/thoughput_latency.png) 46 | 47 | 48 | 49 | 选取了六个指标进行对比,每个指标均通过[处理](./triton/plot.ipynb#metrics-process),并归一化到0-1区间,数值越大表示性能越好。每个指标的原始释义如下: 50 | 51 | - Server Queue: 数据在Triton队列中的等待时间 52 | - Server Compute Input: Triton处理Input Tensor的时间 53 | - Server Compute Infer: Triton执行推理的时间 54 | - Server Compute Output: Triton处理Output Tensor的时间 55 | - latency: 端到端延迟的90分位数 56 | - throughput: 吞吐 57 | 58 | ![](./assets/radar_plot.png) 59 | 60 | 结果分析[参考](./docs/bls_vs_ensemble.md#4-性能分析) 61 | 62 | --- 63 | 64 | ## REFERENCES 65 | 66 | 67 | - [Ultralytics Yolov5](https://github.com/ultralytics/yolov5.git) 68 | - [Yolov5 GPU Optimization](https://github.com/NVIDIA-AI-IOT/yolov5_gpu_optimization.git) 69 | - [TensorRT BatchedNMSPlugin ](https://github.com/NVIDIA/TensorRT/tree/main/plugin/batchedNMSPlugin) 70 | - [Perf Analyzer](https://github.com/triton-inference-server/server/blob/main/docs/user_guide/perf_analyzer.md) 71 | - [Ensemble models](https://github.com/triton-inference-server/server/blob/main/docs/user_guide/architecture.md#ensemble-models) 72 | - [Business Logic Scripting](https://github.com/triton-inference-server/python_backend#business-logic-scripting) 73 | 74 | 75 | 76 | 77 | -------------------------------------------------------------------------------- /triton/model_repository/nms/1/model.py: -------------------------------------------------------------------------------- 1 | 2 | import json 3 | 4 | import triton_python_backend_utils as pb_utils 5 | import numpy as np 6 | from torch.utils.dlpack import from_dlpack, to_dlpack 7 | import utils 8 | 9 | 10 | 11 | 12 | class TritonPythonModel: 13 | 14 | def initialize(self, args): 15 | 16 | """`initialize` is called only once when the model is being loaded. 17 | Implementing `initialize` function is optional. This function allows 18 | the model to intialize any state associated with this model. 19 | Parameters 20 | ---------- 21 | args : dict 22 | Both keys and values are strings. The dictionary keys and values are: 23 | * model_config: A JSON string containing the model configuration 24 | * model_instance_kind: A string containing model instance kind 25 | * model_instance_device_id: A string containing model instance device ID 26 | * model_repository: Model repository path 27 | * model_version: Model version 28 | * model_name: Model name 29 | """ 30 | print('Initializing...') 31 | self.model_config = model_config = json.loads(args['model_config']) 32 | output_config = pb_utils.get_output_config_by_name(model_config, "BBOXES") 33 | self.output_dtype = pb_utils.triton_string_to_numpy(output_config['data_type']) 34 | self.max_det = output_config['dims'][0] 35 | 36 | # print(f'output_dims {self.output_dims} type is {type(self.output_dims)}', flush=True) 37 | 38 | def execute(self, requests): 39 | # output_dtype = self.output_dtype 40 | max_det = self.max_det 41 | responses = [] 42 | for request in requests: 43 | 44 | before_nms = pb_utils.get_input_tensor_by_name( 45 | request, 'candidate_boxes') 46 | 47 | # before_nms_torch_tensor = self.pb_tensor_transform(before_nms) 48 | print (f'nms pb_tensor is from cpu {before_nms.is_cpu()}', flush=True) 49 | before_nms_torch_tensor = from_dlpack(before_nms.to_dlpack()) 50 | 51 | 52 | bboxes = utils.postprocess(before_nms_torch_tensor, max_det=max_det) 53 | 54 | # print(f'bls bboxes shape is {bboxes.shape}', flush=True) 55 | 56 | # encoding pytorch tensor boxes to pb_tensor 57 | # out_tensor = pb_utils.Tensor('BBOXES', bboxes.astype(output_dtype)) 58 | out_tensor = pb_utils.Tensor.from_dlpack('BBOXES', to_dlpack(bboxes)) 59 | 60 | inference_response = pb_utils.InferenceResponse( 61 | output_tensors=[out_tensor]) 62 | responses.append(inference_response) 63 | 64 | return responses 65 | 66 | 67 | 68 | def finalize(self): 69 | print('Cleaning up...') 70 | 71 | -------------------------------------------------------------------------------- /docs/pipelines.md: -------------------------------------------------------------------------------- 1 | # 部署yolov5 Triton Pipelines 2 | 3 | ## 1. 为什么使用Triton pipelines 4 | 5 | 众所周知,模型服务不仅包含 GPU based Inference,还包括preprocess和postprocess。Triton Pipelines是一种workflow, 它可以组合不同的模型服务组合成一个完整的应用, 同一个模型服务还可以被不同的workflow使用。 6 | 因此可以单独将preprocess或postprocess单独部署,然后通过Pipeline将它们和infer模块串联起来。这样做的好处是: 7 | - 每个子模块都可以分别申请不同种类和大小的资源、配置不同的参数,以达到最大化模型服务效率的同时,充分利用计算资源。 8 | 9 | - 可以避免传输中间张量的开销,减小通过网络传输的数据大小,并最大限度地减少必须发送到 Triton 的请求数量。 10 | 11 | --- 12 | 13 | ## 2. Triton Pipelines的实现方式 14 | Nvidia Triton提供了两种Pipleline的部署方式:分别为Business Logic Scripting(BLS)和Ensemble。下面简单介绍一下这两种方式。 15 | 16 | - [Ensemble](https://github.com/triton-inference-server/server/blob/main/docs/user_guide/architecture.md#ensemble-models) 17 | 通过组合model repository里的各类模型成为一个workflow。是一种pipeline调度策略,而不是具体的model。 ensemble效率更高,但无法加入条件逻辑判断,数据只能按照设定的pipeline流动,适合pipeline结构固定的场景 18 | ![](..//assets/ensemble.png) 19 | 20 | - [BLS](https://github.com/triton-inference-server/python_backend#business-logic-scripting) 21 | 一种特殊的python backend,通过python code调用其他model instance。BLS灵活性更高,可以加入一些逻辑和循环来动态组合不同的模型,从而控制数据的流动方向。 22 | ![](../assets/bls.png) 23 | 24 | --- 25 | 26 | ## 3. 如何部署Triton Pipelines 27 | 28 | 通过Pipelines部署process模块的一个出发点是,减小通过网络传输的数据大小。在目标检测模型服务中,输入端的raw_image和nms之前的candidate bboxes的数据量都是相对较大,因此一个合适的方案就是将nms这一postprecess模块单独通过python backend部署,通过pipelines连接infer和nms模块,client则需要对raw_data进行必要的resize等preprocess操作。 29 | 30 | 31 | ### 3.1 工作流 32 | Pipleine配置及python backend参考Model Repository的[ensemble](../triton/model_repository/simple_yolov5_ensemble/)和[bls](../triton/model_repository/simple_yolov5_bls/) 33 | 34 | 两种部署方式的工作流如下: 35 | 36 | ![](../assets/bls_ensemble.png) 37 | 38 | 39 | 40 | ### 3.2 BLS 41 | 42 | - 数据流向 43 | 1. 通过http/gRPC发送resize后的image到BLS模型服务 44 | 2. BLS服务通过C API调用yolov5 tensorrt模型服务 45 | 3. Triton Server将candidate bboxes返回给BLS服务 46 | 4. BLS服务对candidate bboxes进行nms操作,将最终的bboxes通过http/gRPC返回给client 47 | 48 | 49 | 50 | ### 3.3 Ensemble 51 | 52 | - 数据流向 53 | 1. 通过http/gRPC发送resize后的image到ensemble模型服务 54 | 2. ensemble模型服务通过memory copy将yolov5 tensorrt的输出的candidate bboxes传递给nms模型服务 55 | 3. ensemble模型服务将nms输出的bboxes通过http/gRPC返回给client 56 | 57 | ### 3.3 Notice 58 | 59 | NMS输出的bboxes数量不固定,一般有三种处理方式: 60 | 61 | 1. 对bboxes做padding, 例如规定输出是 `[batch_size, padding_count, xywh or xyxy]`, 其中 pandding_count 根据实际场景来确定 62 | 2. 将模型的输出结果放到一个 json, 以 `json string ([N, 1])` 的形式返回 63 | 3. 采用将 response [解耦的方式](https://github.com/triton-inference-server/python_backend#decoupled-mode) 64 | 65 | 本文采用padding的方式来解决该问题 66 | ```python 67 | from torch.nn import functional as F 68 | i = torchvision.ops.nms(boxes, scores, nms_threshold) 69 | # padding boxes to 300 70 | if i.shape[0] > max_det: # limit detections 71 | i = i[:max_det] 72 | bbox_pad_nums = max_det - i.shape[0] 73 | output_bboxes[xi] = F.pad(x[i], (0,0,0, bbox_pad_nums), value=0) 74 | ``` 75 | 76 | --- 77 | ## REFERENCES 78 | 79 | 80 | - [Ultralytics Yolov5](https://github.com/ultralytics/yolov5.git) 81 | - [Ensemble models](https://github.com/triton-inference-server/server/blob/main/docs/user_guide/architecture.md#ensemble-models) 82 | - [Business Logic Scripting](https://github.com/triton-inference-server/python_backend#business-logic-scripting) 83 | - [Triton 从入门到精通](https://space.bilibili.com/1320140761/channel/collectiondetail?sid=493256) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | English | [Zh-CN](README_CN.md) 2 | 3 | # YOLOV5 optimization on Triton Inference Server 4 | 5 | Deploy Yolov5 object detection service in Triton. The following optimizations have been made: 6 | 7 | 1. [Lightweight Detect layer Output](./docs/custom_yolov5_detect_layer_EN.md) 8 | 9 | 2. [Integrate TensorRT BatchedNMSPlugin into engine](./docs/batchedNMS_EN.md) 10 | 11 | 3. [Deploy via Triton Pipelines](./docs/pipelines_EN.md) 12 | 13 | Pipelines are implemented through `Ensemble` and `BLS` respectively. The infer module in Pipelines is based on the optimized TensorRT Engine in item 1 above, and the Postprocess module is implemented through Python Backend. The workflow refers to [How to deploy Triton Pipelines](./docs/pipelines_EN.md#3-How_to_deploy_Triton_Pipelines). 14 | 15 | --- 16 | ## Environment 17 | - CPU: 4cores 16GB 18 | - GPU: Nvidia Tesla T4 19 | - Cuda: 11.6 20 | - TritonServer: 2.20.0 21 | - TensorRT: 8.2.3 22 | - Yolov5: v6.1 23 | 24 | --- 25 | 26 | ## Benchmark 27 | Triton Inference Server is deployed on one machine. Perf_analyzer is used to call the gRPC interface on another machine to compare the performance of `BLS Pipelines`, `Ensemble Pipelines`, and `BatchedNMS` under gradually increasing concurrency. 28 | 29 | - [Generate real data](https://github.com/triton-inference-server/server/blob/main/docs/user_guide/perf_analyzer.md#real-input-data) 30 | 31 | ```shell 32 | python generate_input.py --input_images ----output_file .json 33 | ``` 34 | 35 | - Test with real data 36 | ```shell 37 | perf_analyzer -m -b 8 --input-data .json --concurrency-range 1:10 --measurement-interval 10000 -u -i gRPC -f .csv 38 | ``` 39 | 40 | The data shows that `BatchedNMS` performs better overall, converging to optimal performance faster under high concurrency, and achieving higher throughput at lower latency. `Ensemble Pipelines` and `BLS Pipelines` perform better at lower concurrency, but performance degrades more as concurrency increases. 41 | 42 | ![](./assets/thoughput_latency.png) 43 | 44 | 45 | Six metrics are selected for comparison. Each metric is [processed](./triton/plot.ipynb#metrics-process) and normalized to the 0-1 interval. The original meaning of each metric is as follows: 46 | 47 | - Server Queue: Data waiting time in Triton queue 48 | - Server Compute Input: Triton input tensor processing time 49 | - Server Compute Infer: Triton inference execution time 50 | - Server Compute Output: Triton output tensor processing time 51 | - latency: 90th percentile end-to-end latency 52 | - throughput: throughput 53 | 54 | ![](./assets/radar_plot.png) 55 | 56 | See [here](./docs/bls_vs_ensemble_EN.md#4-performance-analysis) for results analysis. 57 | --- 58 | 59 | ## REFERENCES 60 | 61 | 62 | - [Ultralytics Yolov5](https://github.com/ultralytics/yolov5.git) 63 | - [Yolov5 GPU Optimization](https://github.com/NVIDIA-AI-IOT/yolov5_gpu_optimization.git) 64 | - [TensorRT BatchedNMSPlugin](https://github.com/NVIDIA/TensorRT/tree/main/plugin/batchedNMSPlugin) 65 | - [Perf Analyzer](https://github.com/triton-inference-server/server/blob/main/docs/user_guide/perf_analyzer.md) 66 | - [Ensemble models](https://github.com/triton-inference-server/server/blob/main/docs/user_guide/architecture.md#ensemble-models) 67 | - [Business Logic Scripting](https://github.com/triton-inference-server/python_backend#business-logic-scripting) -------------------------------------------------------------------------------- /common.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import tensorrt as trt 4 | import pycuda.autoinit 5 | import pycuda.driver as cuda 6 | import time 7 | tensorrt_loggers = [] 8 | 9 | 10 | def create_tensorrt_logger(verbose=False): 11 | """Create a TensorRT logger. 12 | 13 | Args: 14 | verbose (bool): whether to make the logger verbose. 15 | """ 16 | if verbose: 17 | # trt_verbosity = trt.Logger.Severity.INFO 18 | trt_verbosity = trt.Logger.Severity.VERBOSE 19 | else: 20 | trt_verbosity = trt.Logger.Severity.WARNING 21 | tensorrt_logger = trt.Logger(trt_verbosity) 22 | tensorrt_loggers.append(tensorrt_logger) 23 | return tensorrt_logger 24 | 25 | 26 | 27 | class HostDeviceMem(object): 28 | def __init__(self, host_mem, device_mem, binding_name, shape=None): 29 | self.host = host_mem 30 | self.device = device_mem 31 | self.binding_name = binding_name 32 | self.shape = shape 33 | 34 | def __str__(self): 35 | return "Host:\n" + str(self.host) + "\nDevice\n" + str(self.device) + "Shape: " + str(self.shape) 36 | 37 | def __repr__(self): 38 | return self.__str__() 39 | 40 | 41 | def allocate_buffers(engine, context): 42 | 43 | inputs = [] 44 | outputs = [] 45 | bindings = [] 46 | stream = cuda.Stream() 47 | for binding in engine: 48 | binding_id = engine.get_binding_index(str(binding)) 49 | size = trt.volume(context.get_binding_shape(binding_id)) * engine.max_batch_size 50 | print("{}:{}".format(binding, size)) 51 | dtype = trt.nptype(engine.get_binding_dtype(binding)) 52 | host_mem = cuda.pagelocked_empty(size, dtype) 53 | device_mem = cuda.mem_alloc(host_mem.nbytes) 54 | bindings.append(int(device_mem)) 55 | if engine.binding_is_input(binding): 56 | inputs.append(HostDeviceMem(host_mem, device_mem, binding)) 57 | else: 58 | output_shape = engine.get_binding_shape(binding) 59 | if len(output_shape) == 3: 60 | dims = trt.Dims3(engine.get_binding_shape(binding)) 61 | output_shape = (engine.max_batch_size, dims[0], dims[1], dims[2]) 62 | elif len(output_shape) == 2: 63 | dims = trt.Dims2(output_shape) 64 | output_shape = (engine.max_batch_size, dims[0], dims[1]) 65 | outputs.append(HostDeviceMem(host_mem, device_mem, binding, output_shape)) 66 | 67 | return inputs, outputs, bindings, stream 68 | # return inputs, outputs, bindings 69 | 70 | def do_inference(batch, context, bindings, inputs, outputs, stream): 71 | batch_size = batch.shape[0] 72 | assert len(inputs) == 1 73 | 74 | inputs[0].host = np.ascontiguousarray(batch, dtype=np.float32) 75 | [cuda.memcpy_htod_async(inp.device, inp.host, stream) for inp in inputs] 76 | 77 | 78 | # time calculation 79 | #------------------# 80 | stream.synchronize() 81 | t1 = time.time() 82 | context.execute_async(batch_size=batch_size, bindings=bindings, stream_handle=stream.handle) 83 | stream.synchronize() 84 | t2 = time.time() 85 | cost = t2-t1 86 | #------------------# 87 | 88 | [cuda.memcpy_dtoh_async(out.host, out.device, stream) for out in outputs] 89 | stream.synchronize() 90 | 91 | outputs_dict = {} 92 | outputs_shape = {} 93 | for out in outputs: 94 | outputs_dict[out.binding_name] = np.reshape(out.host, out.shape) 95 | outputs_shape[out.binding_name] = out.shape 96 | 97 | return outputs_shape, outputs_dict, cost 98 | -------------------------------------------------------------------------------- /docs/pipelines_EN.md: -------------------------------------------------------------------------------- 1 | # Deploy yolov5 Triton Pipelines 2 | 3 | ## 1. Why use Triton pipelines 4 | 5 | It is well known that model services include not only GPU based Inference, but also preprocess and postprocess. Triton Pipelines are workflows that can combine different model services into a complete application. The same model service can also be used by different workflows. 6 | 7 | Therefore, preprocess or postprocess can be deployed separately, and then connected with the infer module through Pipeline. The benefits of doing this are: 8 | 9 | - Each submodule can apply for different types and sizes of resources separately, and be configured with different parameters, in order to maximize model serving efficiency while making full use of computing resources. 10 | 11 | - Overhead of transferring intermediate tensors can be avoided, reducing the amount of data transferred over the network, and minimizing the number of requests that need to be sent to Triton. 12 | 13 | --- 14 | 15 | ## 2. Triton Pipeline implementation methods 16 | 17 | Nvidia Triton provides two ways to deploy Piplelines: Business Logic Scripting(BLS) and Ensemble. Below is a brief introduction of the two methods. 18 | 19 | - [Ensemble](https://github.com/triton-inference-server/server/blob/main/docs/user_guide/architecture.md#ensemble-models) 20 | A workflow formed by combining various models in the model repository. It is a pipeline scheduling strategy rather than a specific model. Ensemble is more efficient, but cannot incorporate conditional logic judgments. Data can only flow according to the fixed pipeline, suitable for scenarios with fixed pipeline structure. 21 | 22 | ![](../assets/ensemble.png) 23 | 24 | - [BLS](https://github.com/triton-inference-server/python_backend#business-logic-scripting) 25 | A special python backend that calls other model instances through python code. BLS is more flexible, can incorporate some logic and loops to dynamically combine different models and thus control the data flow direction. 26 | ![](../assets/bls.png) 27 | 28 | --- 29 | 30 | ## 3. How to deploy Triton Pipelines 31 | 32 | One motivation for deploying process modules through Pipelines is to reduce the amount of data transmitted over the network. In object detection model services, both the raw_image on the input side and the candidate bboxes before nms have relatively large data volume. Therefore, a reasonable approach is to deploy the nms postprocess module separately through python backend, and connect the infer and nms modules through pipelines. The client only needs to do necessary resize and other preprocess operations on the raw_data. 33 | 34 | 35 | ### 3.1 Workflow 36 | 37 | Pipeline configuration and python backend refer to [ensemble](../triton/model_repository/simple_yolov5_ensemble/) and [bls](../triton/model_repository/simple_yolov5_bls/) in Model Repository. 38 | 39 | Workflow of the two deployment methods is as follows: 40 | 41 | ![](../assets/bls_ensemble.png) 42 | 43 | 44 | ### 3.2 BLS 45 | 46 | - Data flow 47 | 1. Send resized image to BLS model service through http/gRPC 48 | 2. BLS service calls yolov5 tensorrt model service through C API 49 | 3. Triton Server returns candidate bboxes to BLS service 50 | 4. BLS service performs nms operation on candidate bboxes and returns final bboxes to client through http/gRPC 51 | 52 | 53 | ### 3.3 Ensemble 54 | 55 | - Data flow 56 | 1. Send resized image to ensemble model service through http/gRPC 57 | 2. Ensemble model service copies yolov5 tensorrt output candidate bboxes to nms model service through memory 58 | 3. Ensemble model service returns bboxes after nms to client through http/gRPC 59 | 60 | ### 3.3 Notice 61 | 62 | The output bboxes number is not fixed. There are usually three ways to handle this: 63 | 64 | 1. Pad bboxes, for example, specify output as `[batch_size, padding_count, xywh or xyxy]`, where pandding_count is determined according to actual scenario 65 | 2. Put model output results into a json, returned as `json string ([N, 1])` 66 | 3. Use [decoupled response](https://github.com/triton-inference-server/python_backend#decoupled-mode) 67 | 68 | Padding is used here to solve this problem 69 | ```python 70 | from torch.nn import functional as F 71 | i = torchvision.ops.nms(boxes, scores, nms_threshold) 72 | # padding boxes to 300 73 | if i.shape[0] > max_det: # limit detections 74 | i = i[:max_det] 75 | bbox_pad_nums = max_det - i.shape[0] 76 | output_bboxes[xi] = F.pad(x[i], (0,0,0, bbox_pad_nums), value=0) 77 | ``` 78 | 79 | --- 80 | ## REFERENCES 81 | 82 | - [Ultralytics Yolov5](https://github.com/ultralytics/yolov5.git) 83 | - [Ensemble models](https://github.com/triton-inference-server/server/blob/main/docs/user_guide/architecture.md#ensemble-models) 84 | - [Business Logic Scripting](https://github.com/triton-inference-server/python_backend#business-logic-scripting) 85 | - [Triton Tutorials](https://space.bilibili.com/1320140761/channel/collectiondetail?sid=493256) -------------------------------------------------------------------------------- /triton/model_repository/simple_yolov5_bls/1/model.py: -------------------------------------------------------------------------------- 1 | 2 | import json 3 | 4 | import triton_python_backend_utils as pb_utils 5 | import numpy as np 6 | from torch.utils.dlpack import from_dlpack, to_dlpack 7 | import utils 8 | 9 | 10 | 11 | 12 | class TritonPythonModel: 13 | 14 | def initialize(self, args): 15 | 16 | """`initialize` is called only once when the model is being loaded. 17 | Implementing `initialize` function is optional. This function allows 18 | the model to intialize any state associated with this model. 19 | Parameters 20 | ---------- 21 | args : dict 22 | Both keys and values are strings. The dictionary keys and values are: 23 | * model_config: A JSON string containing the model configuration 24 | * model_instance_kind: A string containing model instance kind 25 | * model_instance_device_id: A string containing model instance device ID 26 | * model_repository: Model repository path 27 | * model_version: Model version 28 | * model_name: Model name 29 | """ 30 | print('Initializing...') 31 | self.model_config = model_config = json.loads(args['model_config']) 32 | output_config = pb_utils.get_output_config_by_name(model_config, "BBOXES") 33 | self.output_dtype = pb_utils.triton_string_to_numpy(output_config['data_type']) 34 | self.max_det = output_config['dims'][0] 35 | 36 | # print(f'output_dims {self.output_dims} type is {type(self.output_dims)}', flush=True) 37 | 38 | def execute(self, requests): 39 | output_dtype = self.output_dtype 40 | max_det = self.max_det 41 | responses = [] 42 | for request in requests: 43 | 44 | # Get Model Name 45 | # hard code temporarily 46 | model_name_string = 'simple_yolov5' 47 | 48 | # model_name = pb_utils.get_input_tensor_by_name( 49 | # request, 'MODEL_NAME') 50 | 51 | # Model Name string 52 | # batch_size = 1 53 | # model_name_string = model_name.as_numpy()[0].item() 54 | 55 | input = pb_utils.get_input_tensor_by_name( 56 | request, 'images') 57 | 58 | 59 | # for fix TritonModelException: 60 | t1 = utils.time_sync() 61 | before_nms = self.request_real_engine( 62 | input, model_name_string) 63 | t2 = utils.time_sync() 64 | 65 | 66 | 67 | print(f'bls request_real_engine time: {(t2-t1)*1000} ms', flush=True) 68 | before_nms_torch_tensor = self.pb_tensor_transform(before_nms) 69 | 70 | t3 = utils.time_sync() 71 | print(f'bls pb_tensor_transform time: {(t3-t2)*1000} ms', flush=True) 72 | bboxes = utils.postprocess(before_nms_torch_tensor, max_det=max_det) 73 | t4 = utils.time_sync() 74 | # 3~10ms 75 | print(f'bls postprocess time: {(t4 - t3)*1000} ms', flush=True) 76 | # print(f'bls bboxes shape is {bboxes.shape}', flush=True) 77 | 78 | # encoding pytorch tensor boxes to pb_tensor 79 | # out_tensor = pb_utils.Tensor('BBOXES', bboxes.astype(output_dtype)) 80 | out_tensor = pb_utils.Tensor.from_dlpack('BBOXES', to_dlpack(bboxes)) 81 | # false 82 | # print(f'out_tensor is on cpu: {out_tensor.is_cpu()}', flush=True) 83 | 84 | inference_response = pb_utils.InferenceResponse( 85 | output_tensors=[out_tensor]) 86 | responses.append(inference_response) 87 | # t3 = utils.time_sync() 88 | # print(f'output time: {(t3 - t2)*1000} ms', flush=True) 89 | return responses 90 | 91 | # BLS 92 | def request_real_engine(self, frames_tensor, model_name_string): 93 | # frames_tensor: tensor 94 | 95 | inference_request = pb_utils.InferenceRequest( 96 | model_name=model_name_string, 97 | requested_output_names=['output'], 98 | inputs=[frames_tensor] 99 | ) 100 | inference_response = inference_request.exec() 101 | if inference_response.has_error(): 102 | raise pb_utils.TritonModelException(inference_response.error().message()) 103 | 104 | # tensor 105 | before_nms = pb_utils.get_output_tensor_by_name(inference_response, 'output') 106 | 107 | print (f'bls pb_tensor is from cpu {before_nms.is_cpu()}', flush=True) 108 | 109 | return before_nms 110 | 111 | 112 | def finalize(self): 113 | print('Cleaning up...') 114 | 115 | 116 | def pb_tensor_transform(self, pb_tensor): 117 | if pb_tensor.is_cpu(): 118 | # print(f'bls pb_tensor is from cpu', flush=True) 119 | return pb_tensor.as_numpy() 120 | else: 121 | pytorch_tensor = from_dlpack(pb_tensor.to_dlpack()) 122 | # print(f'bls pb_tensor is from {pytorch_tensor.device}', flush=True) 123 | return pytorch_tensor 124 | # return pytorch_tensor.cpu().numpy() -------------------------------------------------------------------------------- /docs/bls_vs_ensemble.md: -------------------------------------------------------------------------------- 1 | # Triton Pipeines的实现方式及对比 2 | 3 | 在[部署yolov5 Triton Pipelines](pipelines.md#2-triton-pipelines的实现方式)中,简单介绍了BLS和Ensemble这两种实现Triton Pipelines的方式,同时在[Benchmark](../README_CN.md#benchmark)中,对两种Pipelines和[All in TensorRT Engine](./batchedNMS.md)的部署方式进行了性能测试,本文将对比介绍一下BLS和Ensemble, 同时对性能测试的结果进行解读 4 | 5 | ## 1 Python Backend 6 | 7 | ### 1.1 实现方式及结构 8 | BLS是一种特殊的python backend,通过在python backend里调用其他模型服务来完成Pipelines。python backend的结构如下 9 | 10 | ![](../assets/python_backend.png) 11 | 12 | 13 | 14 | - 进程间通信IPC 15 | 16 | 由于GIL的限制,python backend通过对每个model instance起一个单独的进程(`python stub process(C++)`)来支持多实例部署。既然是多进程,那么就需要通过`shared memory`来完成python model instance和Triton主进程之间的通信,具体为给每个python stub process在`shared memory里分配一个shm block`, shm block连接`python backend agent(C++)`来进行通信。 17 | 18 | 19 | - 数据流向 20 | 21 | `shm block`通过`Request MessageQ` 和 `Response MessageQ`调度和中转Input和Output, 上述两个队列均通过生产者-消费者模型的逻辑实现 22 | 1. 发送到Triton server的request被`python backend agent(C++)`放到`Request MessageQ` 23 | 2. python stub process从`Request MessageQ`取出Input, 给到python model instance执行完推理后,将Output放到`Response MessageQ` 24 | 3. `python backend agent(C++)`再从`Response MessageQ`中取出Output,打包成response返回给Triton server主进程 25 | 26 | 27 | 示例如下: 28 | ```python 29 | responses = [] 30 | for request in requests: 31 | input_tensor = pb_utils.get_input_tensor_by_name( 32 | request, 'input') 33 | 34 | # INFER_FUNC is python backend core logic 35 | output_tensor = INFER_FUNC(input_tensor) 36 | 37 | inference_response = pb_utils.InferenceResponse( 38 | output_tensors=[out_tensor]) 39 | responses.append(inference_response) 40 | ``` 41 | 42 | ### 1.2 Notice 43 | 44 | - 需要手动管理Tensor在CPU还是GPU上,config中的`instance_group {kind: KIND_GPU}`不起作用 45 | - 输入不会自动打batch, 需要手动将request列表转化为batch, 这点和所有backend一样 46 | - 默认情况下,python backend主动将input tensor移动到CPU, 再提供给模型推理,将`FORCE_CPU_ONLY_INPUT_TENSORS`设置为`no`可以尽可能的避免host-device之间的内存拷贝 47 | - python backend model instance与Triton server交换数据都是通过shared memory完成的,因此每个instance需要较大的shared memory, 至少64MB 48 | - 如果性能成为瓶颈,特别是包含许多循环时,需要换成C++ backend 49 | 50 | 51 | --- 52 | 53 | ## 2 BLS 54 | 一种特殊的python backend,通过python code调用其他model service。使用场景:通过一些逻辑判断来动态组合已部署的模型服务 55 | 56 | ### 2.1 BLS流程 57 | 58 | ![](../assets/bls_arc.png) 59 | 60 | 虚线上方表示调用python backend的一般方式, 虚线下方表示在python backend里调用其他model service。整体流程可以总结为: 61 | 62 | 63 | 1. python model instance处理接受到的Input tensor 64 | 2. python model instance通过BLS call发起request, 65 | 3. request经过python stub process放到shm block 66 | 4. python backend agent将shm block里的BLS input拿出来, 并通过Triton C API将BLS input将input送到指定model上去执行推理 67 | 5. Triton python backend angent将推理得到的输出送到shm block 68 | 6. BLS Output 经过python stub process从shm block中取出,封装成BLS response并返回给python model instance 69 | 70 | ### 2.2 Notice 71 | 72 | - Input tensor的位置 73 | 默认情况下,python backend主动将input tensor移动到CPU, 再提供给模型推理,将`FORCE_CPU_ONLY_INPUT_TENSORS`设置为`no`可以避免这一行为,input tensor的位置取决于它最后是如何被处理的,因此开启此设置后,需要python backend能够同时处理CPU和GPU tensor 74 | 75 | - 模块执行顺序 76 | BLS不支持step并行,step必须是顺序执行,前一个step执行完之后才执行后一个step 77 | 78 | - 数据传输 79 | 通过`DLPack`来编解码tensor,完成tensor在不同framework与python backend之间的数据传输,这一步是零拷贝,速度非常快 80 | 81 | 82 | 83 | --- 84 | ## 3 Ensemble 85 | 86 | ### 3.1 Ensemble概述 87 | 使用Ensemble来实现Pipelines可以避免传输中间张量的开销,并最大限度地减少必须发送到 Triton server的请求数量, 相对于BLS,Ensemble的优势在于可以将多个模型(step)的执行过程并行化(即每个step异步执行,真正意义的Pipelines),从而提高整体性能。 88 | 89 | 一个典型的Ensemble Pipelines如下: 90 | ``` 91 | name: "simple_yolov5_ensemble" 92 | platform: "ensemble" 93 | max_batch_size: 8 94 | input [ 95 | { 96 | name: "ENSEMBLE_INPUT_0" 97 | data_type: TYPE_FP32 98 | dims: [3, 640, 640] 99 | } 100 | ] 101 | 102 | output [ 103 | { 104 | name: "ENSEMBLE_OUTPUT_0" 105 | data_type: TYPE_FP32 106 | dims: [ 300, 6 ] 107 | } 108 | ] 109 | 110 | ensemble_scheduling { 111 | step [ 112 | { 113 | model_name: "simple_yolov5" 114 | model_version: 1 115 | input_map: { 116 | key: "images" 117 | value: "ENSEMBLE_INPUT_0" 118 | } 119 | output_map: { 120 | key: "output" 121 | value: "FILTER_BBOXES" 122 | } 123 | }, 124 | { 125 | model_name: "nms" 126 | model_version: 1 127 | input_map: { 128 | key: "candidate_boxes" 129 | value: "FILTER_BBOXES" 130 | } 131 | output_map: { 132 | key: "BBOXES" 133 | value: "ENSEMBLE_OUTPUT_0" 134 | } 135 | } 136 | ] 137 | } 138 | ``` 139 | 以上Pipelines包含[simple_yolov5](../triton/model_repository/simple_yolov5/config.pbtxt)和[nms](../triton/model_repository/nms/config.pbtxt)两个独立部署的model service,通过Ensemble将两个model service[串联起来](./pipelines.md#31-工作流),simple_yolov5的输出作为nms的输入,nms的输出作为整个Pipelines的输出。每个input_map和output_map都是一个key-value对,key是每个model service的input/output name,value是Ensemble的input/output name。 140 | 141 | ### 3.2 Ensemble数据传输 142 | 143 | - 如果Ensemble的所有子模型都是基于Triton内置framework backend部署的,子模型之间的数据可以通过CUDA API来进行点对点传输,不需要经过CPU内存拷贝 144 | 145 | - 如果Ensemble的子模型使用了custom backend或python backend,则子模型之间的张量通信都是通过系统(CPU)的内存拷贝完成的, 即使python backend将`FORCE_CPU_ONLY_INPUT_TENSORS`设置为`no`,也无法避免这种内存拷贝。如下step,上一个step是通过tensorrt backend输出的output, 位于GPU上,在python backend中打印出来的input始终位于cpu,即这里发生了一步Device to Host的内存拷贝 146 | ```python 147 | for request in requests: 148 | 149 | before_nms = pb_utils.get_input_tensor_by_name( 150 | request, 'candidate_boxes') 151 | 152 | # always true 153 | print (f'nms pb_tensor is from cpu {before_nms.is_cpu()}', flush=True) 154 | ``` 155 | --- 156 | ## 4 性能分析 157 | 数据来源: [Benchmark](../README_CN.md#benchmark) 158 | 159 | 吞吐和时延是主要考虑的两个性能指标,时延三者差别不大,而在吞吐量上,`batched_nms_dynamic > Ensemble > BLS`, 原因为: 160 | - batched_nms_dynamic的inference和nms全都包含在trt engine中了,layer之间通过CUDA API来传输,效率最高 161 | - Ensemble和BLS的inference和nms都是两个独立的model instance,其中BLS中python backend的Input tensor位于GPU上,而ensemble中的Input tensor被强制转换到CPU上,内存拷贝带来的开销比step并行执行的收益要大。因此在包含python backend的情况下,BLS的性能优于Ensemble -------------------------------------------------------------------------------- /docs/bls_vs_ensemble_EN.md: -------------------------------------------------------------------------------- 1 | # Comparison of Triton Pipeline implementation methods 2 | 3 | In [Deploying yolov5 Triton Pipelines](pipelines_EN.md#2-triton-pipeline-implementation-methods), BLS and Ensemble, two ways of implementing Triton Pipelines, are briefly introduced. In [Benchmark](../README.md#benchmark), the three deployment methods of `BLS Pipelines`, `Ensemble Pipelines`, and [All in TensorRT Engine](./batchedNMS_EN.md) are performance tested under gradually increasing concurrency. This article will compare and introduce BLS and Ensemble, and interpret the performance test results. 4 | 5 | ## 1 Python Backend 6 | 7 | ### 1.1 Implementation and structure 8 | BLS is a special python backend that completes Pipelines by calling other model services in python backend. The structure of python backend is as follows: 9 | 10 | ![](../assets/python_backend.png) 11 | 12 | - Inter-process communication IPC 13 | 14 | Due to GIL limitations, python backend supports multi-instance deployment by starting a separate process (`python stub process(C++)`) for each model instance. Since it is multi-process, `shared memory` is used to complete the communication between the python model instance and the Triton main process. Specifically, a shm block is allocated in the `shared memory` for each python stub process, and the shm block connects the `python backend agent(C++)` for communication. 15 | 16 | - Data flow 17 | 18 | `shm block` schedules and forwards Input and Output through `Request MessageQ` and `Response MessageQ`. Both queues are implemented using producer-consumer model logic. 19 | 1. The request sent to Triton server is put into `Request MessageQ` by `python backend agent(C++)` 20 | 2. The python stub process takes the Input from the `Request MessageQ`, passes it to the python model instance for inference, and then puts the Output into the `Response MessageQ` 21 | 3. `python backend agent(C++)` takes the Output from the `Response MessageQ` and packages it into a response returned to Triton server main process 22 | 23 | For example: 24 | ```python 25 | responses = [] 26 | for request in requests: 27 | input_tensor = pb_utils.get_input_tensor_by_name( 28 | request, 'input') 29 | 30 | # INFER_FUNC is python backend core logic 31 | output_tensor = INFER_FUNC(input_tensor) 32 | 33 | inference_response = pb_utils.InferenceResponse( 34 | output_tensors=[out_tensor]) 35 | responses.append(inference_response) 36 | ``` 37 | 38 | ### 1.2 Notice 39 | 40 | - Need to manually manage whether Tensors are on CPU or GPU, `instance_group {kind: KIND_GPU}` in config does not work 41 | - Input is not automatically batched, requests list needs to be manually converted to batch, same for all backends 42 | - By default, python backend actively moves input tensor to CPU before inference, set `FORCE_CPU_ONLY_INPUT_TENSORS` to `no` to avoid host-device memory copies as much as possible 43 | - Python backend model instance exchanges data with Triton server through shared memory, so each instance requires a large shared memory, at least 64MB 44 | - If performance becomes a bottleneck, especially with many loops, switch to C++ backend 45 | 46 | --- 47 | 48 | ## 2 BLS 49 | A special python backend that calls other model services through python code. Use cases: dynamically combine deployed model services through some logic judgments. 50 | 51 | ### 2.1 BLS workflow 52 | 53 | ![](../assets/bls_arc.png) 54 | 55 | The part above the dotted line is the general way to call the python backend. The part below the dotted line is to call other model services in the python backend. The overall workflow can be summarized as: 56 | 57 | 1. The python model instance processes the received Input tensor 58 | 2. The python model instance initiates a request through BLS call 59 | 3. The request goes through the python stub process into the shm block 60 | 4. The python backend agent takes the BLS input from the shm block and sends it to the specified model for inference through Triton C API 61 | 5. The Triton python backend agent sends the inferred output to the shm block 62 | 6. The BLS Output goes through the python stub process, taken from the shm block, packaged into a BLS response, and returned to the python model instance 63 | 64 | ### 2.2 Notice 65 | 66 | - Location of Input tensor 67 | By default, python backend actively moves input tensor to CPU before providing it for inference. Set `FORCE_CPU_ONLY_INPUT_TENSORS` to `no` to avoid this behavior. The location of the input tensor depends on how it is finally processed. After enabling this setting, the python backend needs to be able to handle both CPU and GPU tensors at the same time. 68 | 69 | - Execution order of modules 70 | BLS does not support step parallelism, steps must be executed sequentially, the next step is executed only after the previous step is completed. 71 | 72 | - Data transfer 73 | Use `DLPack` for tensor encoding/decoding between different frameworks and python backend. This step has zero copy and is very fast. 74 | 75 | --- 76 | 77 | ## 3 Ensemble 78 | 79 | ### 3.1 Overview of Ensemble 80 | Using Ensemble to implement Pipelines can avoid the overhead of intermediate tensor transfer and minimize the number of requests that must be sent to Triton server. Compared to BLS, the advantage of Ensemble is that it can parallelize the execution of multiple models (steps), thereby improving overall performance. 81 | 82 | A typical Ensemble Pipeline is as follows: 83 | ``` 84 | name: "simple_yolov5_ensemble" 85 | platform: "ensemble" 86 | max_batch_size: 8 87 | input [ 88 | { 89 | name: "ENSEMBLE_INPUT_0" 90 | data_type: TYPE_FP32 91 | dims: [3, 640, 640] 92 | } 93 | ] 94 | 95 | output [ 96 | { 97 | name: "ENSEMBLE_OUTPUT_0" 98 | data_type: TYPE_FP32 99 | dims: [ 300, 6 ] 100 | } 101 | ] 102 | 103 | ensemble_scheduling { 104 | step [ 105 | { 106 | model_name: "simple_yolov5" 107 | model_version: 1 108 | input_map: { 109 | key: "images" 110 | value: "ENSEMBLE_INPUT_0" 111 | } 112 | output_map: { 113 | key: "output" 114 | value: "FILTER_BBOXES" 115 | } 116 | }, 117 | { 118 | model_name: "nms" 119 | model_version: 1 120 | input_map: { 121 | key: "candidate_boxes" 122 | value: "FILTER_BBOXES" 123 | } 124 | output_map: { 125 | key: "BBOXES" 126 | value: "ENSEMBLE_OUTPUT_0" 127 | } 128 | } 129 | ] 130 | } 131 | ``` 132 | The above Pipeline contains two independently deployed model services [simple_yolov5](../triton/model_repository/simple_yolov5/config.pbtxt) and [nms](../triton/model_repository/nms/config.pbtxt) connected by Ensemble. The output of simple_yolov5 is the input of nms, and the output of nms is the output of the entire Pipeline. Each input_map and output_map is a key-value pair, where key is the input/output name of each model service, and value is the input/output name of Ensemble. 133 | 134 | ### 3.2 Ensemble data transfer 135 | 136 | - If all child models of Ensemble are deployed based on Triton built-in framework backends, data between child models can be transferred point-to-point via CUDA API without CPU memory copy. 137 | 138 | - If child models of Ensemble use custom backends or python backends, tensor communication between child models is completed by system (CPU) memory copy, even if `FORCE_CPU_ONLY_INPUT_TENSORS` is set to `no` in python backend. As in the following step, the output of the previous step is from tensorrt backend on GPU, but the input printed in python backend is always on CPU, meaning a Device to Host memory copy happened here. 139 | 140 | ```python 141 | for request in requests: 142 | 143 | before_nms = pb_utils.get_input_tensor_by_name( 144 | request, 'candidate_boxes') 145 | 146 | # always true 147 | print (f'nms pb_tensor is from cpu {before_nms.is_cpu()}', flush=True) 148 | ``` 149 | --- 150 | 151 | ## 4 Performance analysis 152 | 153 | Data source: [Benchmark](../README.md#benchmark) 154 | 155 | Throughput and latency are the two main performance metrics considered. Latency difference between the three is not big, but in terms of throughput, `batched_nms_dynamic > Ensemble > BLS`. The reasons are: 156 | 157 | - inference and nms are all included in the trt engine for batched_nms_dynamic, communication between layers is via CUDA API, which is most efficient 158 | - For Ensemble and BLS, inference and nms are two separate model instances. For BLS, the input tensor is on GPU in python backend, while for Ensemble the input tensor is forced to CPU, the overhead of memory copy outweighs the benefits of step parallelism. Therefore, when python backend is involved, BLS performs better than Ensemble -------------------------------------------------------------------------------- /docs/batchedNMS.md: -------------------------------------------------------------------------------- 1 | # YOLOV5 TensorRT BatchedNMS 2 | 在[修改yolov5的detect层](custom_yolov5_detect_layer.md)一文中,介绍了对detect层的轻量化改造,以提高模型服务的效率。在[部署yolov5 Triton Pipelines](pipelines.md)一文中利用改造的模型文件分别通过BLS和Ensemble两种方式部署了Triton Pipelines。但是Pipelines中的infer engine和nms始终是两个相对独立的step,由于nms是通过python backend来完成的,无论是BLS还是Ensemble都在数据传输方面存在一些限制。 3 | 4 | 本文利用onnx_graphsurgeon改造原生detect层的输出张量,对接通过cuda实现的TensorRT batchedNMSPlugin,将yolov5的nms集成到tensorrt engine中,避免部分场景下device to host的数据拷贝,提高整体计算性能。 5 | 6 | ## 0. 前置条件 7 | ```shell 8 | # clone ultralytics repo 9 | git clone -b v6.1 https://github.com/ultralytics/yolov5.git 10 | # clone this repo 11 | git clone 12 | cp -r /* yolov5/ 13 | ``` 14 | 15 | 16 | --- 17 | ## 1. 具体步骤 18 | 19 | 同[修改yolov5的detect层](custom_yolov5_detect_layer.md#3-具体步骤)大致类似,都是遵循: 20 | 21 | - 修改detect层的forward函数 22 | - 导出.onnx文件 23 | - 转换为trt engine 24 | 25 | 的步骤。只不过这里需要对导出onnx文件的函数进行一些修改,新增一个`BatchedNMSDynamic_TRT` node并追加到原始graph的末尾, 并按照[TensorRT batchedNMSPlugin的输入格式](https://github.com/NVIDIA/TensorRT/tree/main/plugin/batchedNMSPlugin#structure)调整node的属性 26 | 27 | 28 | ### 1.1 修改前后 29 | - infer模式下forward函数原始输出格式 30 | - squeezed boxes and classes: 31 | ``` 32 | [batch_size, number_boxes, box_xywh + c + number_classes] = [batch_size, 25200, 85] 33 | ``` 34 | 35 | ![](../assets/before.png) 36 | 37 | 38 | 39 | 40 | 41 | - 修改后的输出格式 42 | - boxes 43 | ``` 44 | [batch_size, number_boxes, 1, x1y1x2y2] 45 | ``` 46 | - cls_conf 47 | ``` 48 | [batch_size, number_boxes, number_classes] 49 | ``` 50 | 51 | ![](../assets/after.png) 52 | 53 | 54 | 根据[batchedNMSPlugin.cpp](https://github.com/NVIDIA/TensorRT/blob/main/plugin/batchedNMSPlugin/batchedNMSPlugin.cpp#L193)源码中的注释,boxes的输入形状为`[batch_size, num_boxes, num_classes, 4] or [batch_size, num_boxes, 1, 4]`,但`batchedNMSPlugin`的文档没有详细说明这二者的差别,在 55 | [efficientNMSPlugin](https://github.com/NVIDIA/TensorRT/tree/main/plugin/efficientNMSPlugin#boxes-input)的文档里可以找到相关的解释: 56 | 57 | > The boxes input can have 3 dimensions in case a single box prediction is produced for all classes (such as in EfficientDet or SSD), or 4 dimensions when separate box predictions are generated for each class (such as in FasterRCNN), in which case number_classes >= 1 and must match the number of classes in the scores input. The final dimension represents the four coordinates that define the bounding box prediction. 58 | 59 | 由于使用的是yolov5, 所以不会对每个类别去生成bouding box, 所以boxes的输入形状应该为`[batch_size, num_boxes, 1, 4]` 60 | 61 | ### 1.2 改造detect层 62 | 63 | yolov5 Detect层forward函数的输出改成[TensorRT batchedNMSPlugin的输入格式](https://github.com/NVIDIA/TensorRT/tree/main/plugin/batchedNMSPlugin#structure) 64 | 65 | ```python 66 | def forward(self, x): 67 | z = [] # inference output 68 | for i in range(self.nl): 69 | x[i] = self.m[i](x[i]) # conv 70 | bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85) 71 | x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous() 72 | 73 | if not self.training: # inference 74 | if self.onnx_dynamic or self.grid[i].shape[2:4] != x[i].shape[2:4]: 75 | self.grid[i], self.anchor_grid[i] = self._make_grid(nx, ny, i) 76 | 77 | y = x[i].sigmoid() 78 | if self.inplace: 79 | y[..., 0:2] = (y[..., 0:2] * 2 - 0.5 + self.grid[i]) * self.stride[i] # xy 80 | y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh 81 | else: # for YOLOv5 on AWS Inferentia https://github.com/ultralytics/yolov5/pull/2953 82 | xy = (y[..., 0:2] * 2 - 0.5 + self.grid[i]) * self.stride[i] # xy 83 | wh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh 84 | # custom output >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> 85 | conf = y[..., 4:] 86 | xmin = xy[..., 0:1] - wh[..., 0:1] / 2 87 | ymin = xy[..., 1:2] - wh[..., 1:2] / 2 88 | xmax = xy[..., 0:1] + wh[..., 0:1] / 2 89 | ymax = xy[..., 1:2] + wh[..., 1:2] / 2 90 | obj_conf = conf[..., 0:1] 91 | cls_conf = conf[..., 1:] 92 | cls_conf *= obj_conf 93 | # y = torch.cat((xy, wh, y[..., 4:]), -1) 94 | y = torch.cat((xmin, ymin, xmax, ymax, cls_conf), 4) 95 | # z.append(y.view(bs, -1, self.no)) 96 | z.append(y.view(bs, -1, self.no - 1)) 97 | 98 | z = torch.cat(z, 1) 99 | bbox = z[..., 0:4].view(bs, -1, 1, 4) 100 | cls_conf = z[..., 4:] 101 | 102 | return bbox, cls_conf 103 | # return x if self.training else (torch.cat(z, 1), x) 104 | # custom output >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> 105 | ``` 106 | 107 | ### 1.3 修改export_onnx 108 | export onnx时,修改output,满足[TensorRT batchedNMSPlugin的输入格式](https://github.com/NVIDIA/TensorRT/tree/main/plugin/batchedNMSPlugin#structure) 109 | 110 | 这里介绍一下关键点,详细代码见[export.py中的export_onnx函数](../export.py) 111 | 112 | 113 | - onnx simplify的时候避免导出成static shape 114 | ```python 115 | model_onnx, check = onnxsim.simplify( 116 | model_onnx, 117 | dynamic_input_shape=dynamic 118 | # 必须注释 119 | #input_shapes={'images': list(im.shape)} if dynamic else None 120 | ) 121 | ``` 122 | 123 | 124 | - 利用onnx-graphsurgeon创建一个`BatchedNMSDynamic_TRT` node,并添加到原有计算图的末尾 125 | 126 | ```python 127 | # add batch NMS: 128 | yolo_graph = onnx_gs.import_onnx(model_onnx) 129 | box_data = yolo_graph.outputs[0] 130 | cls_data = yolo_graph.outputs[1] 131 | nms_out_0 = onnx_gs.Variable( 132 | "BatchedNMS", 133 | dtype=np.int32 134 | ) 135 | nms_out_1 = onnx_gs.Variable( 136 | "BatchedNMS_1", 137 | dtype=np.float32 138 | ) 139 | nms_out_2 = onnx_gs.Variable( 140 | "BatchedNMS_2", 141 | dtype=np.float32 142 | ) 143 | nms_out_3 = onnx_gs.Variable( 144 | "BatchedNMS_3", 145 | dtype=np.float32 146 | ) 147 | nms_attrs = dict() 148 | # ........ 149 | 150 | nms_plugin = onnx_gs.Node( 151 | op="BatchedNMSDynamic_TRT", 152 | name="BatchedNMS_N", 153 | inputs=[box_data, cls_data], 154 | outputs=[nms_out_0, nms_out_1, nms_out_2, nms_out_3], 155 | attrs=nms_attrs 156 | ) 157 | yolo_graph.nodes.append(nms_plugin) 158 | yolo_graph.outputs = nms_plugin.outputs 159 | yolo_graph.cleanup().toposort() 160 | model_onnx = onnx_gs.export_onnx(yolo_graph) 161 | ``` 162 | 163 | 164 | 165 | 166 | 167 | - 依次导出onnx和tensorrt engine 168 | ```shell 169 | # export onxx 170 | python export.py --weights yolov5s.pt --include onnx --simplify --dynamic 171 | 172 | # export trt engine 173 | /usr/src/tensorrt/bin/trtexec \ 174 | --onnx=yolov5s.onnx \ 175 | --minShapes=images:1x3x640x640 \ 176 | --optShapes=images:1x3x640x640 \ 177 | --maxShapes=images:1x3x640x640 \ 178 | --workspace=4096 \ 179 | --saveEngine= yolov5s_opt1_max1_fp16.engine \ 180 | --shapes=images:1x3x640x640 \ 181 | --verbose \ 182 | --fp16 \ 183 | > result-FP16-BatchedNMS.txt 184 | ``` 185 | --- 186 | 187 | ## 2. 性能测试 188 | 189 | ### 2.1 COCO17 validation数据集测试 190 | 191 | 对比测试infer + nms的耗时 192 | 193 | - original yolo 194 | 195 | ```shell 196 | python detect.py --weight original-yolov5s-fp16.engine --half --img 640 --source --device 0 197 | ``` 198 | 199 | Speed: 0.8ms pre-process, 4.4ms inference, 2.2ms NMS per image at shape (1, 3, 640, 640) 200 | 201 | - batchedNMSPlugin 202 | 203 | ```shell 204 | python trt_infer.py --model yolov5s_opt1_max1_fp16.engine --input_images_folder --output_images_folder --input_size 640 205 | ``` 206 | 207 | infer + nms: 208 | Inference: 5.4 ms per image at shape (1, 3, 640, 640) 209 | 210 | 211 | ### 2.2 trtexec 测试 212 | 213 | trtexec是本地测试结果,batch为1的情况下,整体差别不太大,将nms集成到trt engine后,Output的张量变小了很多,可以降低Device to Host的数据传输时间,代价是GPU Compute的时间增加 214 | 215 | metrics|BatchedNMSDynamic_TRT egine
infer+nms|ultralytics engine
only infer| 216 | :-:|:-:|:-:| 217 | Latency|3.97021 ms|4.08145 ms| 218 | End-to-End Host Latency|6.70715 ms|4.73285 ms| 219 | Enqueue Time|1.27597 ms|0.95929 ms| 220 | H2D Latency|0.563791 ms|0.316406 ms| 221 | GPU Compute Time|3.45068 ms| 2.41992 ms| 222 | D2H Latency|0.0100889 ms|1.34198 ms| 223 | 224 | --- 225 | ## REFERENCES 226 | 227 | 228 | - [Ultralytics Yolov5](https://github.com/ultralytics/yolov5.git) 229 | - [Yolov5 GPU Optimization](https://github.com/NVIDIA-AI-IOT/yolov5_gpu_optimization.git) 230 | - [TensorRT BatchedNMSPlugin ](https://github.com/NVIDIA/TensorRT/tree/main/plugin/batchedNMSPlugin) -------------------------------------------------------------------------------- /docs/batchedNMS_EN.md: -------------------------------------------------------------------------------- 1 | # YOLOV5 TensorRT BatchedNMS 2 | 3 | In [Modifying the yolov5 detect layer](custom_yolov5_detect_layer_EN.md), the lightweight optimization of the detect layer is introduced to improve model serving efficiency. In [Deploying yolov5 Triton Pipelines](pipelines_EN.md), Triton Pipelines are deployed through BLS and Ensemble respectively. However, the infer engine and NMS in Pipelines are two relatively independent steps, where NMS is completed through the python backend. Both BLS and Ensemble have some limitations in data transfer. 4 | 5 | This article utilizes onnx_graphsurgeon to modify the output tensor of the original detect layer, connects it to the TensorRT batchedNMSPlugin implemented by cuda, and integrates yolov5 NMS into the tensorrt engine, avoiding device to host data copies in some scenarios and improving overall computational performance. 6 | 7 | ## 0. Prerequisites 8 | 9 | ```shell 10 | # clone ultralytics repo 11 | git clone -b v6.1 https://github.com/ultralytics/yolov5.git 12 | # clone this repo 13 | git clone 14 | cp -r /* yolov5/ 15 | ``` 16 | 17 | --- 18 | 19 | ## 1. Specific Steps 20 | 21 | It is similar to [Modifying the yolov5 detect layer](custom_yolov5_detect_layer_EN.md#3-specific-steps), following the steps of: 22 | 23 | - Modify the detect layer's forward function 24 | - Export the .onnx file 25 | - Convert to trt engine 26 | 27 | The difference is that some modifications need to be made to the exported onnx file here. A `BatchedNMSDynamic_TRT` node is added to the end of the original graph, and the node attributes are adjusted according to the [TensorRT batchedNMSPlugin input format](https://github.com/NVIDIA/TensorRT/tree/main/plugin/batchedNMSPlugin#structure). 28 | 29 | ### 1.1 Modify before and after 30 | 31 | - Original forward function output format in infer mode 32 | 33 | - squeezed boxes and classes: 34 | 35 | ``` 36 | [batch_size, number_boxes, box_xywh + c + number_classes] = [batch_size, 25200, 85] 37 | ``` 38 | 39 | ![](../assets/before.png) 40 | 41 | - Modified output format 42 | 43 | - boxes 44 | 45 | ``` 46 | [batch_size, number_boxes, 1, x1y1x2y2] 47 | ``` 48 | 49 | - cls_conf 50 | 51 | ``` 52 | [batch_size, number_boxes, number_classes] 53 | ``` 54 | 55 | ![](../assets/after.png) 56 | 57 | According to the [batchedNMSPlugin.cpp](https://github.com/NVIDIA/TensorRT/blob/main/plugin/batchedNMSPlugin/batchedNMSPlugin.cpp#L193) source code comments, the input shape of boxes should be `[batch_size, num_boxes, num_classes, 4]` or `[batch_size, num_boxes, 1, 4]`. 58 | 59 | The related explanation can be found in the [efficientNMSPlugin](https://github.com/NVIDIA/TensorRT/tree/main/plugin/efficientNMSPlugin#boxes-input) documentation: 60 | 61 | > The boxes input can have 3 dimensions in case a single box prediction is produced for all classes (such as in EfficientDet or SSD), or 4 dimensions when separate box predictions are generated for each class (such as in FasterRCNN), in which case number_classes >= 1 and must match the number of classes in the scores input. The final dimension represents the four coordinates that define the bounding box prediction. 62 | 63 | Since YOLOv5 is used, bounding boxes will not be generated for each category, so the input shape of boxes should be `[batch_size, num_boxes, 1, 4]`. 64 | 65 | ### 1.2 Modify detect layer 66 | 67 | Change the output of the yolov5 Detect layer forward function to the [TensorRT batchedNMSPlugin input format](https://github.com/NVIDIA/TensorRT/tree/main/plugin/batchedNMSPlugin#structure). 68 | 69 | ```python 70 | def forward(self, x): 71 | z = [] # inference output 72 | for i in range(self.nl): 73 | x[i] = self.m[i](x[i]) # conv 74 | bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85) 75 | x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous() 76 | 77 | if not self.training: # inference 78 | if self.onnx_dynamic or self.grid[i].shape[2:4] != x[i].shape[2:4]: 79 | self.grid[i], self.anchor_grid[i] = self._make_grid(nx, ny, i) 80 | 81 | y = x[i].sigmoid() 82 | if self.inplace: 83 | y[..., 0:2] = (y[..., 0:2] * 2 - 0.5 + self.grid[i]) * self.stride[i] # xy 84 | y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh 85 | else: # for YOLOv5 on AWS Inferentia https://github.com/ultralytics/yolov5/pull/2953 86 | xy = (y[..., 0:2] * 2 - 0.5 + self.grid[i]) * self.stride[i] # xy 87 | wh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh 88 | # custom output >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> 89 | conf = y[..., 4:] 90 | xmin = xy[..., 0:1] - wh[..., 0:1] / 2 91 | ymin = xy[..., 1:2] - wh[..., 1:2] / 2 92 | xmax = xy[..., 0:1] + wh[..., 0:1] / 2 93 | ymax = xy[..., 1:2] + wh[..., 1:2] / 2 94 | obj_conf = conf[..., 0:1] 95 | cls_conf = conf[..., 1:] 96 | cls_conf *= obj_conf 97 | # y = torch.cat((xy, wh, y[..., 4:]), -1) 98 | y = torch.cat((xmin, ymin, xmax, ymax, cls_conf), 4) 99 | # z.append(y.view(bs, -1, self.no)) 100 | z.append(y.view(bs, -1, self.no - 1)) 101 | 102 | z = torch.cat(z, 1) 103 | bbox = z[..., 0:4].view(bs, -1, 1, 4) 104 | cls_conf = z[..., 4:] 105 | 106 | return bbox, cls_conf 107 | # return x if self.training else (torch.cat(z, 1), x) 108 | # custom output >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> 109 | ``` 110 | 111 | ### 1.3 Modify export_onnx 112 | 113 | When exporting onnx, modify the output to meet the [TensorRT batchedNMSPlugin input format](https://github.com/NVIDIA/TensorRT/tree/main/plugin/batchedNMSPlugin#structure). 114 | 115 | Key points are introduced here, see [export_onnx function in export.py](../export.py) for detailed code. 116 | 117 | - Avoid exporting as static shape during onnx simplify 118 | 119 | ```python 120 | model_onnx, check = onnxsim.simplify( 121 | model_onnx, 122 | dynamic_input_shape=dynamic 123 | # must comment out 124 | #input_shapes={'images': list(im.shape)} if dynamic else None 125 | ) 126 | ``` 127 | 128 | - Use onnx-graphsurgeon to create a `BatchedNMSDynamic_TRT` node and append it to the end of the original graph 129 | 130 | ```python 131 | # add batch NMS: 132 | yolo_graph = onnx_gs.import_onnx(model_onnx) 133 | box_data = yolo_graph.outputs[0] 134 | cls_data = yolo_graph.outputs[1] 135 | nms_out_0 = onnx_gs.Variable( 136 | "BatchedNMS", 137 | dtype=np.int32 138 | ) 139 | nms_out_1 = onnx_gs.Variable( 140 | "BatchedNMS_1", 141 | dtype=np.float32 142 | ) 143 | nms_out_2 = onnx_gs.Variable( 144 | "BatchedNMS_2", 145 | dtype=np.float32 146 | ) 147 | nms_out_3 = onnx_gs.Variable( 148 | "BatchedNMS_3", 149 | dtype=np.float32 150 | ) 151 | nms_attrs = dict() 152 | # ........ 153 | 154 | nms_plugin = onnx_gs.Node( 155 | op="BatchedNMSDynamic_TRT", 156 | name="BatchedNMS_N", 157 | inputs=[box_data, cls_data], 158 | outputs=[nms_out_0, nms_out_1, nms_out_2, nms_out_3], 159 | attrs=nms_attrs 160 | ) 161 | yolo_graph.nodes.append(nms_plugin) 162 | yolo_graph.outputs = nms_plugin.outputs 163 | yolo_graph.cleanup().toposort() 164 | model_onnx = onnx_gs.export_onnx(yolo_graph) 165 | ``` 166 | 167 | - Export onnx and tensorrt engine sequentially 168 | 169 | ```shell 170 | # export onxx 171 | python export.py --weights yolov5s.pt --include onnx --simplify --dynamic 172 | 173 | # export trt engine 174 | /usr/src/tensorrt/bin/trtexec \ 175 | --onnx=yolov5s.onnx \ 176 | --minShapes=images:1x3x640x640 \ 177 | --optShapes=images:1x3x640x640 \ 178 | --maxShapes=images:1x3x640x640 \ 179 | --workspace=4096 \ 180 | --saveEngine=yolov5s_opt1_max1_fp16.engine \ 181 | --shapes=images:1x3x640x640 \ 182 | --verbose \ 183 | --fp16 \ 184 | > result-FP16-BatchedNMS.txt 185 | ``` 186 | 187 | --- 188 | 189 | ## 2. Performance Testing 190 | 191 | ### 2.1 COCO17 validation dataset 192 | 193 | Compare infer + nms elapsed time 194 | 195 | - Original yolov5 196 | 197 | ```shell 198 | python detect.py --weight original-yolov5s-fp16.engine --half --img 640 --source --device 0 199 | ``` 200 | 201 | Speed: 0.8ms pre-process, 4.4ms inference, 2.2ms NMS per image at shape (1, 3, 640, 640) 202 | 203 | - batchedNMSPlugin 204 | 205 | ```shell 206 | python trt_infer.py --model yolov5s_opt1_max1_fp16.engine --input_images_folder --output_images_folder --input_size 640 207 | ``` 208 | 209 | Inference + NMS: 5.4 ms per image at shape (1, 3, 640, 640) 210 | 211 | ### 2.2 trtexec 212 | 213 | trtexec is local test result. With batch size = 1, the overall difference is not significant. After integrating NMS into trt engine, the output tensor is much smaller, which can reduce device to host data transfer time. The cost is that GPU compute time is increased. 214 | 215 | | Metrics | BatchedNMSDynamic_TRT engine
infer+nms | ultralytics engine
only infer | 216 | | :-: | :-: | :-: | 217 | | Latency | 3.97021 ms | 4.08145 ms | 218 | | End-to-End Host Latency | 6.70715 ms | 4.73285 ms | 219 | | Enqueue Time | 1.27597 ms | 0.95929 ms | 220 | | H2D Latency | 0.563791 ms | 0.316406 ms | 221 | | GPU Compute Time | 3.45068 ms | 2.41992 ms | 222 | | D2H Latency | 0.0100889 ms | 1.34198 ms | 223 | 224 | --- 225 | ## REFERENCES 226 | 227 | - [Ultralytics Yolov5](https://github.com/ultralytics/yolov5.git) 228 | - [Yolov5 GPU Optimization](https://github.com/NVIDIA-AI-IOT/yolov5_gpu_optimization.git) 229 | - [TensorRT BatchedNMSPlugin](https://github.com/NVIDIA/TensorRT/tree/main/plugin/batchedNMSPlugin) -------------------------------------------------------------------------------- /trt_infer.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | 3 | import argparse 4 | import time 5 | import tensorrt as trt 6 | import pycuda.autoinit 7 | import pycuda.driver as cuda 8 | import numpy as np 9 | import cv2 10 | import os 11 | from tqdm import tqdm 12 | from common import (allocate_buffers, do_inference, create_tensorrt_logger) 13 | 14 | 15 | 16 | INPUT_SIZE = 640 17 | CLASSES = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 18 | 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 19 | 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 20 | 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 21 | 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 22 | 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 23 | 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 24 | 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 25 | 'hair drier', 'toothbrush'] # class names 26 | 27 | CLASSES_IDS = [i for i in range(len(CLASSES))] 28 | 29 | CONF_THRESH = 0.25 30 | 31 | 32 | 33 | 34 | 35 | def preprocess_ds_nchw(batch_img): 36 | batch_img_array = np.array([np.array(img) for img in batch_img], dtype=np.float32) 37 | batch_img_array = batch_img_array / 255.0 38 | batch_transpose = np.transpose(batch_img_array, (0, 3, 1, 2)) 39 | 40 | return batch_transpose 41 | 42 | 43 | def decode(keep_k, boxes, scores, cls_id): 44 | results = [] 45 | for idx, k in enumerate(keep_k.reshape(-1)): 46 | bbox = boxes[idx].reshape((-1, 4))[:k] 47 | conf = scores[idx].reshape((-1, 1))[:k] 48 | cid = cls_id[idx].reshape((-1, 1))[:k] 49 | results.append(np.concatenate((cid, conf, bbox), axis=-1)) 50 | 51 | return results 52 | 53 | 54 | def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None): 55 | # Rescale coords (xyxy) from img1_shape to img0_shape 56 | if ratio_pad is None: # calculate from img0_shape 57 | gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new 58 | pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding 59 | else: 60 | gain = ratio_pad[0][0] 61 | pad = ratio_pad[1] 62 | 63 | coords[:, [0, 2]] -= pad[0] # x padding 64 | coords[:, [1, 3]] -= pad[1] # y padding 65 | coords[:, :4] /= gain 66 | clip_coords(coords, img0_shape) 67 | return coords 68 | 69 | 70 | def clip_coords(boxes, shape): 71 | # Clip bounding xyxy bounding boxes to image shape (height, width) 72 | boxes[:, [0, 2]] = boxes[:, [0, 2]].clip(0, shape[1]) # x1, x2 73 | boxes[:, [1, 3]] = boxes[:, [1, 3]].clip(0, shape[0]) # y1, y2 74 | 75 | 76 | def draw_bbox_cv(orig_img, infer_img, output_img_path, labels, ratio_pad=None, image_id=None, jlist=None): 77 | bboxes = labels[:, 2:] 78 | confs = labels[:, 1] 79 | cids = labels[:, 0] 80 | bboxes = scale_coords(infer_img.shape[2:], bboxes, orig_img.shape, ratio_pad=ratio_pad).round() 81 | 82 | for idx in range(len(labels)): 83 | 84 | bbox = bboxes[idx] 85 | p1, p2 = (int(bbox[0]), int(bbox[1])), (int(bbox[2]), int(bbox[3])) 86 | cid = int(cids[idx]) 87 | conf = confs[idx] 88 | # print("{}: {} {}".format(CLASSES[cid], conf, bbox)) 89 | if jlist is not None: 90 | if image_id is not None: 91 | b = [bbox[0], bbox[1], bbox[2] - bbox[0], bbox[3] - bbox[1]] 92 | jlist.append({ 93 | 'image_id': image_id, 94 | 'category_id': CLASSES_IDS[cid], 95 | 'bbox': [round(float(x), 3) for x in b], 96 | 'score': round(float(conf), 5)}) 97 | 98 | if conf < CONF_THRESH: 99 | continue 100 | 101 | cv2.rectangle(orig_img, p1, p2, (255, 0, 0), 2, cv2.LINE_AA) 102 | cv2.putText(orig_img, "{0}: {1:.2f}".format(CLASSES[cid], conf), p1, 0, 0.8, (255, 255, 0), 2) 103 | 104 | cv2.imwrite(output_img_path, orig_img) 105 | 106 | 107 | def letterbox(im, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleFill=False, scaleup=True, stride=32): 108 | # new_shape = (height, width) 109 | # Resize and pad image while meeting stride-multiple constraints 110 | shape = im.shape[:2] # current shape [height, width] 111 | if isinstance(new_shape, int): 112 | new_shape = (new_shape, new_shape) 113 | 114 | # Scale ratio (new / old) 115 | r = min(new_shape[0] / shape[0], new_shape[1] / shape[1]) 116 | if not scaleup: # only scale down, do not scale up (for better val mAP) 117 | r = min(r, 1.0) 118 | 119 | # Compute padding 120 | ratio = r, r # width, height ratios 121 | new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r)) 122 | dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding 123 | if auto: # minimum rectangle 124 | dw, dh = np.mod(dw, stride), np.mod(dh, stride) # wh padding 125 | elif scaleFill: # stretch 126 | dw, dh = 0.0, 0.0 127 | new_unpad = (new_shape[1], new_shape[0]) 128 | ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios 129 | 130 | dw /= 2 # divide padding into 2 sides 131 | dh /= 2 132 | 133 | if shape[::-1] != new_unpad: # resize 134 | im = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR) 135 | top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1)) 136 | left, right = int(round(dw - 0.1)), int(round(dw + 0.1)) 137 | im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border 138 | return im, ratio, (dw, dh) 139 | 140 | 141 | def load_images_cv(img_path, new_shape): 142 | orig_img = cv2.imread(img_path) 143 | img = letterbox(orig_img.copy(), new_shape, auto=False, scaleup=True)[0] 144 | img = img[..., [2, 1, 0]] # BGR -> RGB 145 | images = preprocess_ds_nchw([img]) 146 | 147 | return images, orig_img 148 | 149 | 150 | 151 | 152 | def square_inference(engine, img_root, output_img_root, input_size=640, jlist=None): 153 | with engine.create_execution_context() as context: 154 | context.set_binding_shape(0, (1, 3, input_size, input_size)) 155 | new_shape = (input_size, input_size) 156 | inputs, outputs, bindings, stream = allocate_buffers(engine, context) 157 | 158 | # calculate the speed of preprocess and inference per image 159 | infer_time, seen = 0.0, 0 160 | 161 | for img_name in sorted(os.listdir(img_root)): 162 | if os.path.splitext(img_name)[-1] not in ['.jpg', '.png', '.jpeg']: 163 | continue 164 | img_path = os.path.join(img_root, img_name) 165 | if jlist is not None: 166 | img_id = int(img_name.split(".")[0]) 167 | else: 168 | img_id = None 169 | 170 | images, orig_img = load_images_cv(img_path, new_shape) 171 | ratio_pad = None 172 | batch_images = images 173 | # Hard Coded For explicit_batch and the ONNX model's batch_size = 1 174 | batch_images = batch_images[np.newaxis, :, :, :] 175 | outputs_shape, outputs_data, cost = do_inference(batch=batch_images, context=context, 176 | bindings=bindings, inputs=inputs, 177 | outputs=outputs, stream=stream) 178 | 179 | 180 | results = decode(keep_k = outputs_data["BatchedNMS"], 181 | boxes = outputs_data["BatchedNMS_1"], 182 | scores = outputs_data["BatchedNMS_2"], 183 | cls_id = outputs_data["BatchedNMS_3"]) 184 | 185 | infer_time += cost 186 | seen += 1 187 | # visualize the bbox 188 | draw_bbox_cv(orig_img, images, os.path.join(output_img_root, img_name), 189 | results[0], image_id=img_id, jlist=jlist, ratio_pad=ratio_pad) 190 | t = (infer_time / seen) * 1E3 # speeds per image 191 | print(f'Inference: %.2f ms per image at shape {(1,3, input_size, input_size)}' % t) 192 | 193 | 194 | if __name__ == "__main__": 195 | 196 | parser = argparse.ArgumentParser(description='Do YOLOV5 inference using TRT') 197 | parser.add_argument('--input_images_folder', type=str, help='input images path', required=True) 198 | parser.add_argument('--output_images_folder', type=str, help='output images path', required=True) 199 | parser.add_argument('--input_size', type=int, default=640, help="Input Size") 200 | parser.add_argument('--model', type=str, default="yolov5s.engine", help="Model Path") 201 | 202 | 203 | args = parser.parse_args() 204 | 205 | img_root = args.input_images_folder 206 | output_img_root = args.output_images_folder 207 | input_size=args.input_size 208 | engine_file_path = args.model 209 | 210 | if not os.path.exists(output_img_root): 211 | print("Please create the output images directory: {output_img_root}") 212 | exit(0) 213 | 214 | trt_logger = create_tensorrt_logger(verbose=True) 215 | 216 | trt.init_libnvinfer_plugins(None, '') 217 | with open(engine_file_path, "rb") as f, trt.Runtime(trt_logger) as runtime: 218 | engine = runtime.deserialize_cuda_engine(f.read()) 219 | assert engine 220 | square_inference(engine, img_root, output_img_root, input_size=input_size) 221 | 222 | -------------------------------------------------------------------------------- /docs/custom_yolov5_detect_layer.md: -------------------------------------------------------------------------------- 1 | # 修改yolov5的detect层,提高Triton推理服务的性能 2 | 3 | Infer模式下, yolov5 默认的detect层输出的数据是一个形状为`[batches, 25200, 85]`的张量。如果部署在`Nvidia Triton`中,输出层的张量大小过大,处理输出的时间会变大,造成队列积压。 特别是在`Triton Server`和`Client`不在同一台机器,无法使用`shared memory`的情况下,通过网络将数据传输到client的时间还会变大,影响推理服务的性能。 4 | 5 | --- 6 | 7 | ## 1. 测试方法 8 | 将模型转换为tensorrt engine, 并部署在Triton Inference Server,instance group数量为1,类型为GPU,在其他机器上通过Triton提供的perf_analyzer工具进行性能测试。 9 | 10 | - 将yolov5s.pt转换为onnx格式 11 | - 将onnx转换为tensorrt engine 12 | 13 | ```shell 14 | /usr/src/tensorrt/bin/trtexec \ 15 | --onnx=yolov5s.onnx \ 16 | --minShapes=images:1x3x640x640 \ 17 | --optShapes=images:8x3x640x640 \ 18 | --maxShapes=images:32x3x640x640 \ 19 | --workspace=4096 \ 20 | --saveEngine=yolov5s.engine \ 21 | --shapes=images:1x3x640x640 \ 22 | --verbose \ 23 | --fp16 \ 24 | > result-FP16.txt 25 | ``` 26 | 27 | - 部署在Triton Inference Server 28 | 29 | 模型上传到Triton server 设置的model repository路径,编写[模型服务配置](../triton/model_repository/simple_yolov5/config.pbtxt) 30 | 31 | 32 | - [生成真实数据](https://github.com/triton-inference-server/server/blob/main/docs/user_guide/perf_analyzer.md#real-input-data) 33 | 34 | ```shell 35 | python ./triton/generate_input.py --input_images ----output_file .json 36 | ``` 37 | 38 | - 利用真实数据进行性能测试 39 | 40 | ```shell 41 | perf_analyzer -m -b 1 --input-data .json --concurrency-range 1:10 --measurement-interval 10000 -u -i gRPC -f .csv 42 | ``` 43 | --- 44 | ## 2. 修改前的性能指标 45 | 46 | 如下为使用默认detect层的yolov5 trt engine, 部署在triton的性能测试结果,可以看到,使用默认的detect层,大量时间消耗在队列积压(`Server Queue`)和输出数据的处理(`Server Compute Output`),吞吐量甚至达不到 `1 infer/sec` 47 | 48 | > 除了吞吐,其余指标的单位均为us, 其中Client Send和Client Recv分别为gRPC序列化、反序列化数据的时间 49 | 50 | 51 | | Concurrency | Inferences/Second | Client Send | Network+Server Send/Recv | Server Queue | Server Compute Input | Server Compute Infer | Server Compute Output | p90 latency | 52 | | ----------- | ----------------- | ----------- | ------------------------ | ------------ | -------------------- | -------------------- | --------------------- | ----------- | 53 | | 1 | 0.7 | 1683 | 1517232 | 466 | 8003 | 4412 | 9311 | 1592936 | 54 | | 2 | 0.8 | 1464 | 1514475 | 393 | 10659 | 4616 | 956736 | 2583025 | 55 | | 3 | 0.7 | 2613 | 1485868 | 1013992 | 7370 | 4396 | 1268070 | 3879331 | 56 | | 4 | 0.7 | 2268 | 1463386 | 2230040 | 9933 | 5734 | 1250245 | 4983687 | 57 | | 5 | 0.6 | 2064 | 1540583 | 3512025 | 11057 | 4843 | 1226058 | 6512305 | 58 | | 6 | 0.6 | 2819 | 1573869 | 4802885 | 10134 | 4320 | 1234644 | 7888080 | 59 | | 7 | 0.5 | 1664 | 1507386 | 6007235 | 11197 | 4899 | 1244482 | 8854777 | 60 | | | | | | | | | | | 61 | 62 | 63 | 因此,改造的一个方案就是将数据层进行精简,在送入nms之前根据conf对bbox进行粗略的筛选, 最后参考tensorrtx中对detect层的处理,将输出改造成形状为`[batches, num_bboxes, 6]`的向量, 其中`num_bboxes=1000` 64 | > `6 = [cx,cy,w,h,conf,cls_id]`, 其中`conf = obj_conf * cls_prob` 65 | 66 | 67 | --- 68 | ## 3. 具体步骤 69 | 70 | ### 3.1 clone ultralytics yolov5 repo 71 | `git clone -b v6.1 https://github.com/ultralytics/yolov5.git` 72 | 73 | 74 | ### 3.2 改造detect层 75 | 将detect的forward函数修改为 76 | 77 | ```python 78 | def forward(self, x): 79 | z = [] # inference output 80 | for i in range(self.nl): 81 | x[i] = self.m[i](x[i]) # conv 82 | bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85) 83 | x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous() 84 | 85 | if not self.training: # inference 86 | if self.onnx_dynamic or self.grid[i].shape[2:4] != x[i].shape[2:4]: 87 | self.grid[i], self.anchor_grid[i] = self._make_grid(nx, ny, i) 88 | 89 | y = x[i].sigmoid() 90 | if self.inplace: 91 | y[..., 0:2] = (y[..., 0:2] * 2 - 0.5 + self.grid[i]) * self.stride[i] # xy 92 | y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh 93 | else: # for YOLOv5 on AWS Inferentia https://github.com/ultralytics/yolov5/pull/2953 94 | xy = (y[..., 0:2] * 2 - 0.5 + self.grid[i]) * self.stride[i] # xy 95 | wh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh 96 | y = torch.cat((xy, wh, y[..., 4:]), -1) 97 | z.append(y.view(bs, -1, self.no)) 98 | 99 | # custom output >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> 100 | # [bs, 25200, 85] 101 | origin_output = torch.cat(z, 1) 102 | output_bboxes_nums = 1000 103 | # operator argsort to ONNX opset version 12 is not supported. 104 | # top_conf_index = origin_output[..., 4].argsort(descending=True)[:,:output_bboxes_nums] 105 | 106 | # [bs, 1000] 107 | top_conf_index =origin_output[..., 4].topk(k=output_bboxes_nums)[1] 108 | 109 | # torch.Size([bs, 1000, 85]) 110 | filter_output = origin_output.gather(1, top_conf_index.unsqueeze(-1).expand(-1, -1, 85)) 111 | 112 | filter_output[...,5:] *= filter_output[..., 4].unsqueeze(-1) # conf = obj_conf * cls_conf 113 | bboxes = filter_output[..., :4] 114 | conf, cls_id = filter_output[..., 5:].max(2, keepdim=True) 115 | # [bs, 1000, 6] 116 | filter_output = torch.cat((bboxes, conf, cls_id.float()), 2) 117 | 118 | return x if self.training else filter_output 119 | # custom output >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> 120 | 121 | # return x if self.training else (torch.cat(z, 1), x) 122 | ``` 123 | 124 | 125 | ### 3.3 导出onnx 126 | 127 | `onnx simplify`的时候,必须注释掉[下面的代码](https://github.com/ultralytics/yolov5/blob/v6.1/export.py#L145),否则导出的onnx模型仍然为`static shape` 128 | ```python 129 | model_onnx, check = onnxsim.simplify( 130 | model_onnx, 131 | dynamic_input_shape=dynamic 132 | # 必须注释 133 | #input_shapes={'images': list(im.shape)} if dynamic else None 134 | ) 135 | ``` 136 | 137 | 运行`python export.py --weight yolov5s.pt --dynamic --simplify --include onnx`导出onnx模型,导出的onnx结构如下: 138 | ![](../assets/simple_output.png) 139 | 140 | 141 | ### [3.4 导出tensorrt engine](#1-测试方法) 142 | 143 | 144 | --- 145 | ## 4. 修改后的性能 146 | 147 | - batch size = 1 148 | 149 | 吞吐量提升了25倍以上,`Server Queue`和`Server Compute Output`的时间也显著降低 150 | 151 | | Concurrency | Inferences/Second | Client Send | Network+Server Send/Recv | Server Queue | Server Compute Input | Server Compute Infer | Server Compute Output | Client Recv | p90 latency | 152 | | ----------- | ----------------- | ----------- | ------------------------ | ------------ | -------------------- | -------------------- | --------------------- | ----------- | ----------- | 153 | | 1 | 11.9 | 1245 | 69472 | 286 | 7359 | 5022 | 340 | 3 | 93457 | 154 | | 2 | 19.2 | 1376 | 89804 | 341 | 7538 | 4997 | 161 | 3 | 118114 | 155 | | 3 | 20.2 | 1406 | 131265 | 1500 | 8240 | 4881 | 500 | 3 | 171370 | 156 | | 4 | 20 | 1382 | 180621 | 2769 | 9051 | 5184 | 496 | 3 | 235043 | 157 | | 5 | 20.5 | 1362 | 226046 | 2404 | 8112 | 5068 | 622 | 3 | 286810 | 158 | | 6 | 20.8 | 1487 | 271714 | 2034 | 8331 | 5076 | 506 | 3 | 406248 | 159 | | 7 | 20.1 | 1535 | 328144 | 2626 | 8444 | 5122 | 405 | 3 | 430850 | 160 | | 8 | 19.9 | 1512 | 384690 | 3511 | 8168 | 5018 | 581 | 5 | 465658 | 161 | | 9 | 20.2 | 1433 | 420893 | 3499 | 9034 | 5180 | 389 | 3 | 522285 | 162 | | 10 | 20.5 | 1476 | 469029 | 3369 | 8280 | 5165 | 442 | 3 | 622745 | 163 | | | | | | | | | | | | 164 | 165 | 166 | - batch size = 8 167 | 168 | 相对 batch size = 1, `Server Compute Input、Server Compute Infer, Server Compute Output`速度分别提升了约1.4倍、2倍、4倍,代价是随着batch增大,数据传输的耗时增大 169 | 170 | | Concurrency | Inferences/Second | Client Send | Network+Server Send/Recv | Server Queue | Server Compute Input | Server Compute Infer | Server Compute Output | Client Recv | p90 latency | 171 | | ----------- | ----------------- | ----------- | ------------------------ | ------------ | -------------------- | -------------------- | --------------------- | ----------- | ----------- | 172 | | 1 | 15.2 | 11202 | 527075 | 360 | 5386 | 2488 | 43 | 5 | 570189 | 173 | | 2 | 18.4 | 10424 | 829927 | 124 | 5780 | 2491 | 33 | 4 | 901743 | 174 | | 3 | 20 | 10203 | 1178111 | 2290 | 5640 | 2570 | 20 | 4 | 1267145 | 175 | | 4 | 20 | 10097 | 1595614 | 4843 | 5998 | 2454 | 104 | 5 | 1716309 | 176 | | 5 | 19.2 | 9117 | 1971608 | 2397 | 5376 | 2480 | 203 | 4 | 2518530 | 177 | | 6 | 20 | 8728 | 2338066 | 2914 | 6304 | 2496 | 96 | 4 | 2706257 | 178 | | 7 | 20 | 14785 | 2708292 | 6581 | 5556 | 2489 | 160 | 5 | 3170047 | 179 | | 8 | 20 | 13035 | 3052707 | 5067 | 6353 | 2492 | 62 | 4 | 3235293 | 180 | | 9 | 17.6 | 10870 | 3535601 | 7037 | 6307 | 2480 | 136 | 5 | 3856391 | 181 | | 10 | 18.4 | 9357 | 3953830 | 8044 | 5629 | 2520 | 64 | 3 | 4531638 | 182 | | | | | | | | | | | | 183 | 184 | 185 | 186 | --- 187 | 188 | ## REFERENCES 189 | 190 | 191 | - [Ultralytics Yolov5](https://github.com/ultralytics/yolov5.git) 192 | - [Perf Analyzer](https://github.com/triton-inference-server/server/blob/main/docs/user_guide/perf_analyzer.md) -------------------------------------------------------------------------------- /docs/custom_yolov5_detect_layer_EN.md: -------------------------------------------------------------------------------- 1 | # Modify yolov5 detect layer to improve Triton inference serving performance 2 | 3 | In infer mode, the default yolov5 detect layer outputs a tensor with shape `[batches, 25200, 85]`. When deployed in `Nvidia Triton`, the large output tensor size increases output processing time and causes queue backlogs. Especially when `Triton Server` and `Client` are not on the same machine and cannot use `shared memory`, the network transfer time to client is even longer, affecting inference serving performance. 4 | 5 | --- 6 | 7 | ## 1. Test method 8 | 9 | Convert the model to tensorrt engine and deploy on Triton Inference Server with 1 GPU instance group. Use the perf_analyzer tool provided by Triton to test performance on another machine. 10 | 11 | - Convert yolov5s.pt to onnx format 12 | - Convert onnx to tensorrt engine 13 | 14 | ```shell 15 | /usr/src/tensorrt/bin/trtexec \ 16 | --onnx=yolov5s.onnx \ 17 | --minShapes=images:1x3x640x640 \ 18 | --optShapes=images:8x3x640x640 \ 19 | --maxShapes=images:32x3x640x640 \ 20 | --workspace=4096 \ 21 | --saveEngine=yolov5s.engine \ 22 | --shapes=images:1x3x640x640 \ 23 | --verbose \ 24 | --fp16 \ 25 | > result-FP16.txt 26 | ``` 27 | 28 | - Deploy on Triton Inference Server 29 | 30 | Upload model to Triton server's configured model repository path and write [model config](../triton/model_repository/simple_yolov5/config.pbtxt) 31 | 32 | - [Generate real data](https://github.com/triton-inference-server/server/blob/main/docs/user_guide/perf_analyzer.md#real-input-data) 33 | 34 | ```shell 35 | python ./triton/generate_input.py --input_images ----output_file .json 36 | ``` 37 | 38 | - Performance testing with real data 39 | 40 | ```shell 41 | perf_analyzer -m -b 1 --input-data .json --concurrency-range 1:10 --measurement-interval 10000 -u -i gRPC -f .csv 42 | ``` 43 | 44 | --- 45 | 46 | ## 2. Performance metrics before modification 47 | 48 | Below are the performance test results of deploying default yolov5 trt engine on triton. It can be seen that with the default detect layer, a lot of time is spent on queue backlogs (`Server Queue`) and output data processing (`Server Compute Output`). The throughput is even less than `1 infer/sec`. 49 | 50 | > Except for throughput, the units for other metrics are us. Client Send and Client Recv are the times for gRPC serialization and deserialization. 51 | 52 | | Concurrency | Inferences/Second | Client Send | Network+Server Send/Recv | Server Queue | Server Compute Input | Server Compute Infer | Server Compute Output | p90 latency | 53 | | ----------- | ----------------- | ----------- | ------------------------ | ------------ | -------------------- | -------------------- | --------------------- | ----------- | 54 | | 1 | 0.7 | 1683 | 1517232 | 466 | 8003 | 4412 | 9311 | 1592936 | 55 | | 2 | 0.8 | 1464 | 1514475 | 393 | 10659 | 4616 | 956736 | 2583025 | 56 | | 3 | 0.7 | 2613 | 1485868 | 1013992 | 7370 | 4396 | 1268070 | 3879331 | 57 | | 4 | 0.7 | 2268 | 1463386 | 2230040 | 9933 | 5734 | 1250245 | 4983687 | 58 | | 5 | 0.6 | 2064 | 1540583 | 3512025 | 11057 | 4843 | 1226058 | 6512305 | 59 | | 6 | 0.6 | 2819 | 1573869 | 4802885 | 10134 | 4320 | 1234644 | 7888080 | 60 | | 7 | 0.5 | 1664 | 1507386 | 6007235 | 11197 | 4899 | 1244482 | 8854777 | 61 | | | | | | | | | | | 62 | 63 | Therefore, one optimization approach is to streamline the data layer and coarsely filter bboxes by conf before nms. Finally, refer to the processing of the detect layer in tensorrtx to change the output to a vector with shape `[batches, num_bboxes, 6]`, where `num_bboxes=1000`. 64 | 65 | > `6 = [cx,cy,w,h,conf,cls_id]`, where `conf = obj_conf * cls_prob` 66 | 67 | --- 68 | 69 | ## 3. Specific steps 70 | 71 | ### 3.1 Clone ultralytics yolov5 repo 72 | 73 | `git clone -b v6.1 https://github.com/ultralytics/yolov5.git` 74 | 75 | ### 3.2 Modify detect layer 76 | 77 | Change the detect forward function to: 78 | 79 | ```python 80 | def forward(self, x): 81 | z = [] # inference output 82 | for i in range(self.nl): 83 | x[i] = self.m[i](x[i]) # conv 84 | bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85) 85 | x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous() 86 | 87 | if not self.training: # inference 88 | if self.onnx_dynamic or self.grid[i].shape[2:4] != x[i].shape[2:4]: 89 | self.grid[i], self.anchor_grid[i] = self._make_grid(nx, ny, i) 90 | 91 | y = x[i].sigmoid() 92 | if self.inplace: 93 | y[..., 0:2] = (y[..., 0:2] * 2 - 0.5 + self.grid[i]) * self.stride[i] # xy 94 | y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh 95 | else: # for YOLOv5 on AWS Inferentia https://github.com/ultralytics/yolov5/pull/2953 96 | xy = (y[..., 0:2] * 2 - 0.5 + self.grid[i]) * self.stride[i] # xy 97 | wh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh 98 | y = torch.cat((xy, wh, y[..., 4:]), -1) 99 | z.append(y.view(bs, -1, self.no)) 100 | 101 | # custom output >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> 102 | # [bs, 25200, 85] 103 | origin_output = torch.cat(z, 1) 104 | output_bboxes_nums = 1000 105 | # operator argsort to ONNX opset version 12 is not supported. 106 | # top_conf_index = origin_output[..., 4].argsort(descending=True)[:,:output_bboxes_nums] 107 | 108 | # [bs, 1000] 109 | top_conf_index =origin_output[..., 4].topk(k=output_bboxes_nums)[1] 110 | 111 | # torch.Size([bs, 1000, 85]) 112 | filter_output = origin_output.gather(1, top_conf_index.unsqueeze(-1).expand(-1, -1, 85)) 113 | 114 | filter_output[...,5:] *= filter_output[..., 4].unsqueeze(-1) # conf = obj_conf * cls_conf 115 | bboxes = filter_output[..., :4] 116 | conf, cls_id = filter_output[..., 5:].max(2, keepdim=True) 117 | # [bs, 1000, 6] 118 | filter_output = torch.cat((bboxes, conf, cls_id.float()), 2) 119 | 120 | return x if self.training else filter_output 121 | # custom output >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> 122 | 123 | # return x if self.training else (torch.cat(z, 1), x) 124 | ``` 125 | 126 | ### 3.3 Export onnx 127 | 128 | When exporting onnx, comment out the following code in `onnx simplify` to avoid exporting static shape onnx model: 129 | 130 | ```python 131 | model_onnx, check = onnxsim.simplify( 132 | model_onnx, 133 | dynamic_input_shape=dynamic 134 | # must comment out 135 | #input_shapes={'images': list(im.shape)} if dynamic else None 136 | ) 137 | ``` 138 | 139 | Run `python export.py --weight yolov5s.pt --dynamic --simplify --include onnx` to export onnx model. The exported onnx structure is: 140 | 141 | ![](../assets/simple_output.png) 142 | 143 | ### [3.4 Export tensorrt engine](#1-test-method) 144 | 145 | --- 146 | 147 | ## 4. Performance after modification 148 | 149 | - Batch size = 1 150 | 151 | Throughput increased by more than 25 times. `Server Queue` and `Server Compute Output` times were also significantly reduced. 152 | 153 | | Concurrency | Inferences/Second | Client Send | Network+Server Send/Recv | Server Queue | Server Compute Input | Server Compute Infer | Server Compute Output | Client Recv | p90 latency | 154 | | ----------- | ----------------- | ----------- | ------------------------ | ------------ | -------------------- | -------------------- | --------------------- | ----------- | ----------- | 155 | | 1 | 11.9 | 1245 | 69472 | 286 | 7359 | 5022 | 340 | 3 | 93457 | 156 | | 2 | 19.2 | 1376 | 89804 | 341 | 7538 | 4997 | 161 | 3 | 118114 | 157 | | 3 | 20.2 | 1406 | 131265 | 1500 | 8240 | 4881 | 500 | 3 | 171370 | 158 | | 4 | 20 | 1382 | 180621 | 2769 | 9051 | 5184 | 496 | 3 | 235043 | 159 | | 5 | 20.5 | 1362 | 226046 | 2404 | 8112 | 5068 | 622 | 3 | 286810 | 160 | | 6 | 20.8 | 1487 | 271714 | 2034 | 8331 | 5076 | 506 | 3 | 406248 | 161 | | 7 | 20.1 | 1535 | 328144 | 2626 | 8444 | 5122 | 405 | 3 | 430850 | 162 | | 8 | 19.9 | 1512 | 384690 | 3511 | 8168 | 5018 | 581 | 5 | 465658 | 163 | | 9 | 20.2 | 1433 | 420893 | 3499 | 9034 | 5180 | 389 | 3 | 522285 | 164 | | 10 | 20.5 | 1476 | 469029 | 3369 | 8280 | 5165 | 442 | 3 | 622745 | 165 | | | | | | | | | | | | 166 | 167 | - Batch size = 8 168 | 169 | Compared to batch size = 1, `Server Compute Input`, `Server Compute Infer`, and `Server Compute Output` speeds improved by about 1.4x, 2x, and 4x respectively. The cost is that data transfer time increases as batch size grows. 170 | 171 | | Concurrency | Inferences/Second | Client Send | Network+Server Send/Recv | Server Queue | Server Compute Input | Server Compute Infer | Server Compute Output | Client Recv | p90 latency | 172 | | ----------- | ----------------- | ----------- | ------------------------ | ------------ | -------------------- | -------------------- | --------------------- | ----------- | ----------- | 173 | | 1 | 15.2 | 11202 | 527075 | 360 | 5386 | 2488 | 43 | 5 | 570189 | 174 | | 2 | 18.4 | 10424 | 829927 | 124 | 5780 | 2491 | 33 | 4 | 901743 | 175 | | 3 | 20 | 10203 | 1178111 | 2290 | 5640 | 2570 | 20 | 4 | 1267145 | 176 | | 4 | 20 | 10097 | 1595614 | 4843 | 5998 | 2454 | 104 | 5 | 1716309 | 177 | | 5 | 19.2 | 9117 | 1971608 | 2397 | 5376 | 2480 | 203 | 4 | 2518530 | 178 | | 6 | 20 | 8728 | 2338066 | 2914 | 6304 | 2496 | 96 | 4 | 2706257 | 179 | | 7 | 20 | 14785 | 2708292 | 6581 | 5556 | 2489 | 160 | 5 | 3170047 | 180 | | 8 | 20 | 13035 | 3052707 | 5067 | 6353 | 2492 | 62 | 4 | 3235293 | 181 | | 9 | 17.6 | 10870 | 3535601 | 7037 | 6307 | 2480 | 136 | 5 | 3856391 | 182 | | 10 | 18.4 | 9357 | 3953830 | 8044 | 5629 | 2520 | 64 | 3 | 4531638 | 183 | | | | | | | | | | | | 184 | 185 | 186 | --- 187 | 188 | ## REFERENCES 189 | 190 | - [Ultralytics Yolov5](https://github.com/ultralytics/yolov5.git) 191 | - [Perf Analyzer](https://github.com/triton-inference-server/server/blob/main/docs/user_guide/perf_analyzer.md) -------------------------------------------------------------------------------- /triton/model_repository/nms/1/triton_python_backend_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # Redistribution and use in source and binary forms, with or without 4 | # modification, are permitted provided that the following conditions 5 | # are met: 6 | # * Redistributions of source code must retain the above copyright 7 | # notice, this list of conditions and the following disclaimer. 8 | # * Redistributions in binary form must reproduce the above copyright 9 | # notice, this list of conditions and the following disclaimer in the 10 | # documentation and/or other materials provided with the distribution. 11 | # * Neither the name of NVIDIA CORPORATION nor the names of its 12 | # contributors may be used to endorse or promote products derived 13 | # from this software without specific prior written permission. 14 | # 15 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY 16 | # EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 17 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 18 | # PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR 19 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 20 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 21 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 22 | # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY 23 | # OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 24 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | 27 | import numpy as np 28 | import struct 29 | import json 30 | 31 | TRITON_STRING_TO_NUMPY = { 32 | 'TYPE_BOOL': bool, 33 | 'TYPE_UINT8': np.uint8, 34 | 'TYPE_UINT16': np.uint16, 35 | 'TYPE_UINT32': np.uint32, 36 | 'TYPE_UINT64': np.uint64, 37 | 'TYPE_INT8': np.int8, 38 | 'TYPE_INT16': np.int16, 39 | 'TYPE_INT32': np.int32, 40 | 'TYPE_INT64': np.int64, 41 | 'TYPE_FP16': np.float16, 42 | 'TYPE_FP32': np.float32, 43 | 'TYPE_FP64': np.float64, 44 | 'TYPE_STRING': np.object_ 45 | } 46 | 47 | 48 | def serialize_byte_tensor(input_tensor): 49 | """ 50 | Serializes a bytes tensor into a flat numpy array of length prepended 51 | bytes. The numpy array should use dtype of np.object_. For np.bytes_, 52 | numpy will remove trailing zeros at the end of byte sequence and because 53 | of this it should be avoided. 54 | Parameters 55 | ---------- 56 | input_tensor : np.array 57 | The bytes tensor to serialize. 58 | Returns 59 | ------- 60 | serialized_bytes_tensor : np.array 61 | The 1-D numpy array of type uint8 containing the serialized bytes in 'C' order. 62 | Raises 63 | ------ 64 | InferenceServerException 65 | If unable to serialize the given tensor. 66 | """ 67 | 68 | if input_tensor.size == 0: 69 | return () 70 | 71 | # If the input is a tensor of string/bytes objects, then must flatten those 72 | # into a 1-dimensional array containing the 4-byte byte size followed by the 73 | # actual element bytes. All elements are concatenated together in "C" order. 74 | if (input_tensor.dtype == np.object_) or (input_tensor.dtype.type 75 | == np.bytes_): 76 | flattened_ls = [] 77 | for obj in np.nditer(input_tensor, flags=["refs_ok"], order='C'): 78 | # If directly passing bytes to BYTES type, 79 | # don't convert it to str as Python will encode the 80 | # bytes which may distort the meaning 81 | if input_tensor.dtype == np.object_: 82 | if type(obj.item()) == bytes: 83 | s = obj.item() 84 | else: 85 | s = str(obj.item()).encode('utf-8') 86 | else: 87 | s = obj.item() 88 | flattened_ls.append(struct.pack(" max_batch_size: 334 | raise ValueError( 335 | "configuration specified max_batch_size " + 336 | str(self._model_config["max_batch_size"]) + 337 | ", but in auto-complete-config function for model '" + 338 | self._model_config["name"] + "' specified max_batch_size " + 339 | str(max_batch_size)) 340 | else: 341 | self._model_config["max_batch_size"] = max_batch_size 342 | 343 | def set_dynamic_batching(self): 344 | """Set dynamic_batching as the scheduler for the model if no scheduler 345 | is set. If dynamic_batching is set in the model configuration, then no 346 | action is taken and return success. 347 | Raises 348 | ------ 349 | ValueError 350 | If the 'sequence_batching' or 'ensemble_scheduling' scheduler is 351 | set for this model configuration. 352 | """ 353 | found_scheduler = None 354 | if "sequence_batching" in self._model_config: 355 | found_scheduler = "sequence_batching" 356 | elif "ensemble_scheduling" in self._model_config: 357 | found_scheduler = "ensemble_scheduling" 358 | 359 | if found_scheduler != None: 360 | raise ValueError( 361 | "Configuration specified scheduling_choice as '" 362 | + found_scheduler + "', but auto-complete-config " 363 | "function for model '" + self._model_config["name"] 364 | + "' tries to set scheduling_choice as 'dynamic_batching'") 365 | 366 | if "dynamic_batching" not in self._model_config: 367 | self._model_config["dynamic_batching"] = {} 368 | 369 | def add_input(self, input): 370 | """Add the input for the model. 371 | Parameters 372 | ---------- 373 | input : dict 374 | The input to be added. 375 | Raises 376 | ------ 377 | ValueError 378 | If input contains property other than 'name', 'data_type' 379 | and 'dims' or any of the properties are not set, or if an 380 | input with the same name already exists in the configuration 381 | but has different data_type or dims property 382 | """ 383 | valid_properties = ['name', 'data_type', 'dims'] 384 | for current_property in input: 385 | if current_property not in valid_properties: 386 | raise ValueError( 387 | "input '" + input['name'] + 388 | "' in auto-complete-config function for model '" + 389 | self._model_config["name"] + 390 | "' contains property other than 'name', 'data_type' and 'dims'." 391 | ) 392 | 393 | if 'name' not in input: 394 | raise ValueError( 395 | "input in auto-complete-config function for model '" + 396 | self._model_config["name"] + "' is missing 'name' property.") 397 | elif 'data_type' not in input: 398 | raise ValueError("input '" + input['name'] + 399 | "' in auto-complete-config function for model '" + 400 | self._model_config["name"] + 401 | "' is missing 'data_type' property.") 402 | elif 'dims' not in input: 403 | raise ValueError("input '" + input['name'] + 404 | "' in auto-complete-config function for model '" + 405 | self._model_config["name"] + 406 | "' is missing 'dims' property.") 407 | 408 | for current_input in self._model_config["input"]: 409 | if input['name'] == current_input['name']: 410 | if current_input[ 411 | 'data_type'] != "TYPE_INVALID" and current_input[ 412 | 'data_type'] != input['data_type']: 413 | raise ValueError("unable to load model '" + 414 | self._model_config["name"] + 415 | "', configuration expects datatype " + 416 | current_input['data_type'] + 417 | " for input '" + input['name'] + 418 | "', model provides " + input['data_type']) 419 | elif current_input[ 420 | 'dims'] and current_input['dims'] != input['dims']: 421 | raise ValueError( 422 | "model '" + self._model_config["name"] + "', tensor '" + 423 | input['name'] + "': the model expects dims " + 424 | str(input['dims']) + 425 | " but the model configuration specifies dims " + 426 | str(current_input['dims'])) 427 | else: 428 | current_input['data_type'] = input['data_type'] 429 | current_input['dims'] = input['dims'] 430 | return 431 | 432 | self._model_config["input"].append(input) 433 | 434 | def add_output(self, output): 435 | """Add the output for the model. 436 | Parameters 437 | ---------- 438 | output : dict 439 | The output to be added. 440 | Raises 441 | ------ 442 | ValueError 443 | If output contains property other than 'name', 'data_type' 444 | and 'dims' or any of the properties are not set, or if an 445 | output with the same name already exists in the configuration 446 | but has different data_type or dims property 447 | """ 448 | valid_properties = ['name', 'data_type', 'dims'] 449 | for current_property in output: 450 | if current_property not in valid_properties: 451 | raise ValueError( 452 | "output '" + output['name'] + 453 | "' in auto-complete-config function for model '" + 454 | self._model_config["name"] + 455 | "' contains property other than 'name', 'data_type' and 'dims'." 456 | ) 457 | 458 | if 'name' not in output: 459 | raise ValueError( 460 | "output in auto-complete-config function for model '" + 461 | self._model_config["name"] + "' is missing 'name' property.") 462 | elif 'data_type' not in output: 463 | raise ValueError("output '" + output['name'] + 464 | "' in auto-complete-config function for model '" + 465 | self._model_config["name"] + 466 | "' is missing 'data_type' property.") 467 | elif 'dims' not in output: 468 | raise ValueError("output '" + output['name'] + 469 | "' in auto-complete-config function for model '" + 470 | self._model_config["name"] + 471 | "' is missing 'dims' property.") 472 | 473 | for current_output in self._model_config["output"]: 474 | if output['name'] == current_output['name']: 475 | if current_output[ 476 | 'data_type'] != "TYPE_INVALID" and current_output[ 477 | 'data_type'] != output['data_type']: 478 | raise ValueError("unable to load model '" + 479 | self._model_config["name"] + 480 | "', configuration expects datatype " + 481 | current_output['data_type'] + 482 | " for output '" + output['name'] + 483 | "', model provides " + output['data_type']) 484 | elif current_output[ 485 | 'dims'] and current_output['dims'] != output['dims']: 486 | raise ValueError( 487 | "model '" + self._model_config["name"] + "', tensor '" + 488 | output['name'] + "': the model expects dims " + 489 | str(output['dims']) + 490 | " but the model configuration specifies dims " + 491 | str(current_output['dims'])) 492 | else: 493 | current_output['data_type'] = output['data_type'] 494 | current_output['dims'] = output['dims'] 495 | return 496 | 497 | self._model_config["output"].append(output) 498 | 499 | 500 | TRITONSERVER_REQUEST_FLAG_SEQUENCE_START = 1 501 | TRITONSERVER_REQUEST_FLAG_SEQUENCE_END = 2 502 | TRITONSERVER_RESPONSE_COMPLETE_FINAL = 1 503 | -------------------------------------------------------------------------------- /triton/model_repository/simple_yolov5_bls/1/triton_python_backend_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # Redistribution and use in source and binary forms, with or without 4 | # modification, are permitted provided that the following conditions 5 | # are met: 6 | # * Redistributions of source code must retain the above copyright 7 | # notice, this list of conditions and the following disclaimer. 8 | # * Redistributions in binary form must reproduce the above copyright 9 | # notice, this list of conditions and the following disclaimer in the 10 | # documentation and/or other materials provided with the distribution. 11 | # * Neither the name of NVIDIA CORPORATION nor the names of its 12 | # contributors may be used to endorse or promote products derived 13 | # from this software without specific prior written permission. 14 | # 15 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY 16 | # EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 17 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 18 | # PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR 19 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 20 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 21 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 22 | # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY 23 | # OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 24 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | 27 | import numpy as np 28 | import struct 29 | import json 30 | 31 | TRITON_STRING_TO_NUMPY = { 32 | 'TYPE_BOOL': bool, 33 | 'TYPE_UINT8': np.uint8, 34 | 'TYPE_UINT16': np.uint16, 35 | 'TYPE_UINT32': np.uint32, 36 | 'TYPE_UINT64': np.uint64, 37 | 'TYPE_INT8': np.int8, 38 | 'TYPE_INT16': np.int16, 39 | 'TYPE_INT32': np.int32, 40 | 'TYPE_INT64': np.int64, 41 | 'TYPE_FP16': np.float16, 42 | 'TYPE_FP32': np.float32, 43 | 'TYPE_FP64': np.float64, 44 | 'TYPE_STRING': np.object_ 45 | } 46 | 47 | 48 | def serialize_byte_tensor(input_tensor): 49 | """ 50 | Serializes a bytes tensor into a flat numpy array of length prepended 51 | bytes. The numpy array should use dtype of np.object_. For np.bytes_, 52 | numpy will remove trailing zeros at the end of byte sequence and because 53 | of this it should be avoided. 54 | Parameters 55 | ---------- 56 | input_tensor : np.array 57 | The bytes tensor to serialize. 58 | Returns 59 | ------- 60 | serialized_bytes_tensor : np.array 61 | The 1-D numpy array of type uint8 containing the serialized bytes in 'C' order. 62 | Raises 63 | ------ 64 | InferenceServerException 65 | If unable to serialize the given tensor. 66 | """ 67 | 68 | if input_tensor.size == 0: 69 | return () 70 | 71 | # If the input is a tensor of string/bytes objects, then must flatten those 72 | # into a 1-dimensional array containing the 4-byte byte size followed by the 73 | # actual element bytes. All elements are concatenated together in "C" order. 74 | if (input_tensor.dtype == np.object_) or (input_tensor.dtype.type 75 | == np.bytes_): 76 | flattened_ls = [] 77 | for obj in np.nditer(input_tensor, flags=["refs_ok"], order='C'): 78 | # If directly passing bytes to BYTES type, 79 | # don't convert it to str as Python will encode the 80 | # bytes which may distort the meaning 81 | if input_tensor.dtype == np.object_: 82 | if type(obj.item()) == bytes: 83 | s = obj.item() 84 | else: 85 | s = str(obj.item()).encode('utf-8') 86 | else: 87 | s = obj.item() 88 | flattened_ls.append(struct.pack(" max_batch_size: 334 | raise ValueError( 335 | "configuration specified max_batch_size " + 336 | str(self._model_config["max_batch_size"]) + 337 | ", but in auto-complete-config function for model '" + 338 | self._model_config["name"] + "' specified max_batch_size " + 339 | str(max_batch_size)) 340 | else: 341 | self._model_config["max_batch_size"] = max_batch_size 342 | 343 | def set_dynamic_batching(self): 344 | """Set dynamic_batching as the scheduler for the model if no scheduler 345 | is set. If dynamic_batching is set in the model configuration, then no 346 | action is taken and return success. 347 | Raises 348 | ------ 349 | ValueError 350 | If the 'sequence_batching' or 'ensemble_scheduling' scheduler is 351 | set for this model configuration. 352 | """ 353 | found_scheduler = None 354 | if "sequence_batching" in self._model_config: 355 | found_scheduler = "sequence_batching" 356 | elif "ensemble_scheduling" in self._model_config: 357 | found_scheduler = "ensemble_scheduling" 358 | 359 | if found_scheduler != None: 360 | raise ValueError( 361 | "Configuration specified scheduling_choice as '" 362 | + found_scheduler + "', but auto-complete-config " 363 | "function for model '" + self._model_config["name"] 364 | + "' tries to set scheduling_choice as 'dynamic_batching'") 365 | 366 | if "dynamic_batching" not in self._model_config: 367 | self._model_config["dynamic_batching"] = {} 368 | 369 | def add_input(self, input): 370 | """Add the input for the model. 371 | Parameters 372 | ---------- 373 | input : dict 374 | The input to be added. 375 | Raises 376 | ------ 377 | ValueError 378 | If input contains property other than 'name', 'data_type' 379 | and 'dims' or any of the properties are not set, or if an 380 | input with the same name already exists in the configuration 381 | but has different data_type or dims property 382 | """ 383 | valid_properties = ['name', 'data_type', 'dims'] 384 | for current_property in input: 385 | if current_property not in valid_properties: 386 | raise ValueError( 387 | "input '" + input['name'] + 388 | "' in auto-complete-config function for model '" + 389 | self._model_config["name"] + 390 | "' contains property other than 'name', 'data_type' and 'dims'." 391 | ) 392 | 393 | if 'name' not in input: 394 | raise ValueError( 395 | "input in auto-complete-config function for model '" + 396 | self._model_config["name"] + "' is missing 'name' property.") 397 | elif 'data_type' not in input: 398 | raise ValueError("input '" + input['name'] + 399 | "' in auto-complete-config function for model '" + 400 | self._model_config["name"] + 401 | "' is missing 'data_type' property.") 402 | elif 'dims' not in input: 403 | raise ValueError("input '" + input['name'] + 404 | "' in auto-complete-config function for model '" + 405 | self._model_config["name"] + 406 | "' is missing 'dims' property.") 407 | 408 | for current_input in self._model_config["input"]: 409 | if input['name'] == current_input['name']: 410 | if current_input[ 411 | 'data_type'] != "TYPE_INVALID" and current_input[ 412 | 'data_type'] != input['data_type']: 413 | raise ValueError("unable to load model '" + 414 | self._model_config["name"] + 415 | "', configuration expects datatype " + 416 | current_input['data_type'] + 417 | " for input '" + input['name'] + 418 | "', model provides " + input['data_type']) 419 | elif current_input[ 420 | 'dims'] and current_input['dims'] != input['dims']: 421 | raise ValueError( 422 | "model '" + self._model_config["name"] + "', tensor '" + 423 | input['name'] + "': the model expects dims " + 424 | str(input['dims']) + 425 | " but the model configuration specifies dims " + 426 | str(current_input['dims'])) 427 | else: 428 | current_input['data_type'] = input['data_type'] 429 | current_input['dims'] = input['dims'] 430 | return 431 | 432 | self._model_config["input"].append(input) 433 | 434 | def add_output(self, output): 435 | """Add the output for the model. 436 | Parameters 437 | ---------- 438 | output : dict 439 | The output to be added. 440 | Raises 441 | ------ 442 | ValueError 443 | If output contains property other than 'name', 'data_type' 444 | and 'dims' or any of the properties are not set, or if an 445 | output with the same name already exists in the configuration 446 | but has different data_type or dims property 447 | """ 448 | valid_properties = ['name', 'data_type', 'dims'] 449 | for current_property in output: 450 | if current_property not in valid_properties: 451 | raise ValueError( 452 | "output '" + output['name'] + 453 | "' in auto-complete-config function for model '" + 454 | self._model_config["name"] + 455 | "' contains property other than 'name', 'data_type' and 'dims'." 456 | ) 457 | 458 | if 'name' not in output: 459 | raise ValueError( 460 | "output in auto-complete-config function for model '" + 461 | self._model_config["name"] + "' is missing 'name' property.") 462 | elif 'data_type' not in output: 463 | raise ValueError("output '" + output['name'] + 464 | "' in auto-complete-config function for model '" + 465 | self._model_config["name"] + 466 | "' is missing 'data_type' property.") 467 | elif 'dims' not in output: 468 | raise ValueError("output '" + output['name'] + 469 | "' in auto-complete-config function for model '" + 470 | self._model_config["name"] + 471 | "' is missing 'dims' property.") 472 | 473 | for current_output in self._model_config["output"]: 474 | if output['name'] == current_output['name']: 475 | if current_output[ 476 | 'data_type'] != "TYPE_INVALID" and current_output[ 477 | 'data_type'] != output['data_type']: 478 | raise ValueError("unable to load model '" + 479 | self._model_config["name"] + 480 | "', configuration expects datatype " + 481 | current_output['data_type'] + 482 | " for output '" + output['name'] + 483 | "', model provides " + output['data_type']) 484 | elif current_output[ 485 | 'dims'] and current_output['dims'] != output['dims']: 486 | raise ValueError( 487 | "model '" + self._model_config["name"] + "', tensor '" + 488 | output['name'] + "': the model expects dims " + 489 | str(output['dims']) + 490 | " but the model configuration specifies dims " + 491 | str(current_output['dims'])) 492 | else: 493 | current_output['data_type'] = output['data_type'] 494 | current_output['dims'] = output['dims'] 495 | return 496 | 497 | self._model_config["output"].append(output) 498 | 499 | 500 | TRITONSERVER_REQUEST_FLAG_SEQUENCE_START = 1 501 | TRITONSERVER_REQUEST_FLAG_SEQUENCE_END = 2 502 | TRITONSERVER_RESPONSE_COMPLETE_FINAL = 1 503 | -------------------------------------------------------------------------------- /export.py: -------------------------------------------------------------------------------- 1 | # YOLOv5 🚀 by Ultralytics, GPL-3.0 license 2 | """ 3 | Export a YOLOv5 PyTorch model to other formats. TensorFlow exports authored by https://github.com/zldrobit 4 | 5 | Format | `export.py --include` | Model 6 | --- | --- | --- 7 | PyTorch | - | yolov5s.pt 8 | TorchScript | `torchscript` | yolov5s.torchscript 9 | ONNX | `onnx` | yolov5s.onnx 10 | OpenVINO | `openvino` | yolov5s_openvino_model/ 11 | TensorRT | `engine` | yolov5s.engine 12 | CoreML | `coreml` | yolov5s.mlmodel 13 | TensorFlow SavedModel | `saved_model` | yolov5s_saved_model/ 14 | TensorFlow GraphDef | `pb` | yolov5s.pb 15 | TensorFlow Lite | `tflite` | yolov5s.tflite 16 | TensorFlow Edge TPU | `edgetpu` | yolov5s_edgetpu.tflite 17 | TensorFlow.js | `tfjs` | yolov5s_web_model/ 18 | 19 | Requirements: 20 | $ pip install -r requirements.txt coremltools onnx onnx-simplifier onnxruntime openvino-dev tensorflow-cpu # CPU 21 | $ pip install -r requirements.txt coremltools onnx onnx-simplifier onnxruntime-gpu openvino-dev tensorflow # GPU 22 | 23 | Usage: 24 | $ python path/to/export.py --weights yolov5s.pt --include torchscript onnx openvino engine coreml tflite ... 25 | 26 | Inference: 27 | $ python path/to/detect.py --weights yolov5s.pt # PyTorch 28 | yolov5s.torchscript # TorchScript 29 | yolov5s.onnx # ONNX Runtime or OpenCV DNN with --dnn 30 | yolov5s.xml # OpenVINO 31 | yolov5s.engine # TensorRT 32 | yolov5s.mlmodel # CoreML (macOS-only) 33 | yolov5s_saved_model # TensorFlow SavedModel 34 | yolov5s.pb # TensorFlow GraphDef 35 | yolov5s.tflite # TensorFlow Lite 36 | yolov5s_edgetpu.tflite # TensorFlow Edge TPU 37 | 38 | TensorFlow.js: 39 | $ cd .. && git clone https://github.com/zldrobit/tfjs-yolov5-example.git && cd tfjs-yolov5-example 40 | $ npm install 41 | $ ln -s ../../yolov5/yolov5s_web_model public/yolov5s_web_model 42 | $ npm start 43 | """ 44 | 45 | import argparse 46 | import json 47 | import os 48 | import platform 49 | import subprocess 50 | import sys 51 | import time 52 | import warnings 53 | from pathlib import Path 54 | 55 | import pandas as pd 56 | import torch 57 | import yaml 58 | from torch.utils.mobile_optimizer import optimize_for_mobile 59 | 60 | FILE = Path(__file__).resolve() 61 | ROOT = FILE.parents[0] # YOLOv5 root directory 62 | if str(ROOT) not in sys.path: 63 | sys.path.append(str(ROOT)) # add ROOT to PATH 64 | if platform.system() != 'Windows': 65 | ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative 66 | 67 | from models.experimental import attempt_load 68 | from models.yolo import Detect 69 | from utils.dataloaders import LoadImages 70 | from utils.general import (LOGGER, check_dataset, check_img_size, check_requirements, check_version, colorstr, 71 | file_size, print_args, url2file) 72 | from utils.torch_utils import select_device 73 | 74 | 75 | def export_formats(): 76 | # YOLOv5 export formats 77 | x = [ 78 | ['PyTorch', '-', '.pt', True], 79 | ['TorchScript', 'torchscript', '.torchscript', True], 80 | ['ONNX', 'onnx', '.onnx', True], 81 | ['OpenVINO', 'openvino', '_openvino_model', False], 82 | ['TensorRT', 'engine', '.engine', True], 83 | ['CoreML', 'coreml', '.mlmodel', False], 84 | ['TensorFlow SavedModel', 'saved_model', '_saved_model', True], 85 | ['TensorFlow GraphDef', 'pb', '.pb', True], 86 | ['TensorFlow Lite', 'tflite', '.tflite', False], 87 | ['TensorFlow Edge TPU', 'edgetpu', '_edgetpu.tflite', False], 88 | ['TensorFlow.js', 'tfjs', '_web_model', False],] 89 | return pd.DataFrame(x, columns=['Format', 'Argument', 'Suffix', 'GPU']) 90 | 91 | 92 | def export_torchscript(model, im, file, optimize, prefix=colorstr('TorchScript:')): 93 | # YOLOv5 TorchScript model export 94 | try: 95 | LOGGER.info(f'\n{prefix} starting export with torch {torch.__version__}...') 96 | f = file.with_suffix('.torchscript') 97 | 98 | ts = torch.jit.trace(model, im, strict=False) 99 | d = {"shape": im.shape, "stride": int(max(model.stride)), "names": model.names} 100 | extra_files = {'config.txt': json.dumps(d)} # torch._C.ExtraFilesMap() 101 | if optimize: # https://pytorch.org/tutorials/recipes/mobile_interpreter.html 102 | optimize_for_mobile(ts)._save_for_lite_interpreter(str(f), _extra_files=extra_files) 103 | else: 104 | ts.save(str(f), _extra_files=extra_files) 105 | 106 | LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)') 107 | return f 108 | except Exception as e: 109 | LOGGER.info(f'{prefix} export failure: {e}') 110 | 111 | 112 | def export_onnx(model, im, file, opset, train, dynamic, simplify, prefix=colorstr('ONNX:')): 113 | # YOLOv5 ONNX export 114 | # try: 115 | check_requirements(('onnx',)) 116 | import onnx 117 | 118 | LOGGER.info(f'\n{prefix} starting export with onnx {onnx.__version__}...') 119 | f = file.with_suffix('.onnx') 120 | print(f'training mode: {train}') 121 | torch.onnx.export( 122 | model, 123 | im, 124 | f, 125 | verbose=False, 126 | opset_version=opset, 127 | training=torch.onnx.TrainingMode.TRAINING if train else torch.onnx.TrainingMode.EVAL, 128 | do_constant_folding=not train, 129 | input_names=['images'], 130 | output_names=['bbox', 'cls_score'], 131 | dynamic_axes={ 132 | 'images': { 133 | 0: 'batch', 134 | 2: 'height', 135 | 3: 'width' 136 | }, # shape(1,3,640,640) 137 | 'bbox': { 138 | 0: 'batch', 139 | 1: 'anchors'}, # shape(1,25200, 1, 4) 140 | 'cls_score': { 141 | 0: 'batch', 142 | 1: 'anchors'} # shape(1,25200, 80) 143 | } if dynamic else None) 144 | 145 | # Checks 146 | model_onnx = onnx.load(f) # load onnx model 147 | onnx.checker.check_model(model_onnx) # check onnx model 148 | 149 | # Simplify 150 | if simplify: 151 | # try: 152 | check_requirements(('onnx-simplifier',)) 153 | import onnxsim 154 | 155 | LOGGER.info(f'{prefix} simplifying with onnx-simplifier {onnxsim.__version__}...') 156 | model_onnx, check = onnxsim.simplify(model_onnx, 157 | # not needed any more, onnxsim can now support dynamic input shapes natively 158 | #dynamic_input_shape=dynamic, 159 | # overwrite_input_shapes={'images': list(im.shape)} if dynamic else None 160 | ) 161 | assert check, 'assert check failed' 162 | onnx.save(model_onnx, f) 163 | # except Exception as e: 164 | # LOGGER.info(f'{prefix} simplifier failure: {e}') 165 | 166 | # add batch NMS: 167 | import onnx_graphsurgeon as onnx_gs 168 | import numpy as np 169 | yolo_graph = onnx_gs.import_onnx(model_onnx) 170 | box_data = yolo_graph.outputs[0] 171 | cls_data = yolo_graph.outputs[1] 172 | 173 | nms_out_0 = onnx_gs.Variable( 174 | "BatchedNMS", 175 | dtype=np.int32 176 | ) 177 | nms_out_1 = onnx_gs.Variable( 178 | "BatchedNMS_1", 179 | dtype=np.float32 180 | ) 181 | nms_out_2 = onnx_gs.Variable( 182 | "BatchedNMS_2", 183 | dtype=np.float32 184 | ) 185 | nms_out_3 = onnx_gs.Variable( 186 | "BatchedNMS_3", 187 | dtype=np.float32 188 | ) 189 | 190 | nms_attrs = dict() 191 | 192 | # If set to true, the boxes input are shared across all classes. 193 | # If set to false, the boxes input should account for per-class box data. 194 | nms_attrs["shareLocation"] = 1 195 | nms_attrs["backgroundLabelId"] = -1 196 | # nms参数 197 | nms_attrs["scoreThreshold"] = 0.25 198 | nms_attrs["iouThreshold"] = 0.45 199 | # 控制plugin输入输出的bbox数量 200 | nms_attrs["topK"] = 2*300 201 | nms_attrs["keepTopK"] = 300 202 | nms_attrs["numClasses"] = 80 203 | # yolov5 detect层输出的bbox坐标已经scale到feature map尺寸的所以这里要设置为0 204 | nms_attrs["clipBoxes"] = 0 205 | nms_attrs["isNormalized"] = 0 206 | # nms_attrs["scoreBits"] = 16 207 | 208 | nms_plugin = onnx_gs.Node( 209 | op="BatchedNMSDynamic_TRT", 210 | name="BatchedNMS_N", 211 | inputs=[box_data, cls_data], 212 | outputs=[nms_out_0, nms_out_1, nms_out_2, nms_out_3], 213 | attrs=nms_attrs 214 | ) 215 | 216 | yolo_graph.nodes.append(nms_plugin) 217 | yolo_graph.outputs = nms_plugin.outputs 218 | yolo_graph.cleanup().toposort() 219 | model_onnx = onnx_gs.export_onnx(yolo_graph) 220 | # Metadata 221 | d = {'stride': int(max(model.stride)), 'names': model.names} 222 | for k, v in d.items(): 223 | meta = model_onnx.metadata_props.add() 224 | meta.key, meta.value = k, str(v) 225 | 226 | onnx.save(model_onnx, f) 227 | LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)') 228 | return f 229 | # except Exception as e: 230 | # LOGGER.info(f'{prefix} export failure: {e}') 231 | 232 | 233 | def export_openvino(model, file, half, prefix=colorstr('OpenVINO:')): 234 | # YOLOv5 OpenVINO export 235 | try: 236 | check_requirements(('openvino-dev',)) # requires openvino-dev: https://pypi.org/project/openvino-dev/ 237 | import openvino.inference_engine as ie 238 | 239 | LOGGER.info(f'\n{prefix} starting export with openvino {ie.__version__}...') 240 | f = str(file).replace('.pt', f'_openvino_model{os.sep}') 241 | 242 | cmd = f"mo --input_model {file.with_suffix('.onnx')} --output_dir {f} --data_type {'FP16' if half else 'FP32'}" 243 | subprocess.check_output(cmd.split()) # export 244 | with open(Path(f) / file.with_suffix('.yaml').name, 'w') as g: 245 | yaml.dump({'stride': int(max(model.stride)), 'names': model.names}, g) # add metadata.yaml 246 | 247 | LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)') 248 | return f 249 | except Exception as e: 250 | LOGGER.info(f'\n{prefix} export failure: {e}') 251 | 252 | 253 | def export_coreml(model, im, file, int8, half, prefix=colorstr('CoreML:')): 254 | # YOLOv5 CoreML export 255 | try: 256 | check_requirements(('coremltools',)) 257 | import coremltools as ct 258 | 259 | LOGGER.info(f'\n{prefix} starting export with coremltools {ct.__version__}...') 260 | f = file.with_suffix('.mlmodel') 261 | 262 | ts = torch.jit.trace(model, im, strict=False) # TorchScript model 263 | ct_model = ct.convert(ts, inputs=[ct.ImageType('image', shape=im.shape, scale=1 / 255, bias=[0, 0, 0])]) 264 | bits, mode = (8, 'kmeans_lut') if int8 else (16, 'linear') if half else (32, None) 265 | if bits < 32: 266 | if platform.system() == 'Darwin': # quantization only supported on macOS 267 | with warnings.catch_warnings(): 268 | warnings.filterwarnings("ignore", category=DeprecationWarning) # suppress numpy==1.20 float warning 269 | ct_model = ct.models.neural_network.quantization_utils.quantize_weights(ct_model, bits, mode) 270 | else: 271 | print(f'{prefix} quantization only supported on macOS, skipping...') 272 | ct_model.save(f) 273 | 274 | LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)') 275 | return ct_model, f 276 | except Exception as e: 277 | LOGGER.info(f'\n{prefix} export failure: {e}') 278 | return None, None 279 | 280 | 281 | def export_engine(model, im, file, train, half, simplify, workspace=4, verbose=False, prefix=colorstr('TensorRT:')): 282 | # YOLOv5 TensorRT export https://developer.nvidia.com/tensorrt 283 | try: 284 | assert im.device.type != 'cpu', 'export running on CPU but must be on GPU, i.e. `python export.py --device 0`' 285 | try: 286 | import tensorrt as trt 287 | except Exception: 288 | if platform.system() == 'Linux': 289 | check_requirements(('nvidia-tensorrt',), cmds=('-U --index-url https://pypi.ngc.nvidia.com',)) 290 | import tensorrt as trt 291 | 292 | if trt.__version__[0] == '7': # TensorRT 7 handling https://github.com/ultralytics/yolov5/issues/6012 293 | grid = model.model[-1].anchor_grid 294 | model.model[-1].anchor_grid = [a[..., :1, :1, :] for a in grid] 295 | export_onnx(model, im, file, 12, train, False, simplify) # opset 12 296 | model.model[-1].anchor_grid = grid 297 | else: # TensorRT >= 8 298 | check_version(trt.__version__, '8.0.0', hard=True) # require tensorrt>=8.0.0 299 | export_onnx(model, im, file, 13, train, False, simplify) # opset 13 300 | onnx = file.with_suffix('.onnx') 301 | 302 | LOGGER.info(f'\n{prefix} starting export with TensorRT {trt.__version__}...') 303 | assert onnx.exists(), f'failed to export ONNX file: {onnx}' 304 | f = file.with_suffix('.engine') # TensorRT engine file 305 | logger = trt.Logger(trt.Logger.INFO) 306 | if verbose: 307 | logger.min_severity = trt.Logger.Severity.VERBOSE 308 | 309 | builder = trt.Builder(logger) 310 | config = builder.create_builder_config() 311 | config.max_workspace_size = workspace * 1 << 30 312 | # config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace << 30) # fix TRT 8.4 deprecation notice 313 | 314 | flag = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)) 315 | network = builder.create_network(flag) 316 | parser = trt.OnnxParser(network, logger) 317 | if not parser.parse_from_file(str(onnx)): 318 | raise RuntimeError(f'failed to load ONNX file: {onnx}') 319 | 320 | inputs = [network.get_input(i) for i in range(network.num_inputs)] 321 | outputs = [network.get_output(i) for i in range(network.num_outputs)] 322 | LOGGER.info(f'{prefix} Network Description:') 323 | for inp in inputs: 324 | LOGGER.info(f'{prefix}\tinput "{inp.name}" with shape {inp.shape} and dtype {inp.dtype}') 325 | for out in outputs: 326 | LOGGER.info(f'{prefix}\toutput "{out.name}" with shape {out.shape} and dtype {out.dtype}') 327 | 328 | LOGGER.info(f'{prefix} building FP{16 if builder.platform_has_fast_fp16 and half else 32} engine in {f}') 329 | if builder.platform_has_fast_fp16 and half: 330 | config.set_flag(trt.BuilderFlag.FP16) 331 | with builder.build_engine(network, config) as engine, open(f, 'wb') as t: 332 | t.write(engine.serialize()) 333 | LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)') 334 | return f 335 | except Exception as e: 336 | LOGGER.info(f'\n{prefix} export failure: {e}') 337 | 338 | 339 | def export_saved_model(model, 340 | im, 341 | file, 342 | dynamic, 343 | tf_nms=False, 344 | agnostic_nms=False, 345 | topk_per_class=100, 346 | topk_all=100, 347 | iou_thres=0.45, 348 | conf_thres=0.25, 349 | keras=False, 350 | prefix=colorstr('TensorFlow SavedModel:')): 351 | # YOLOv5 TensorFlow SavedModel export 352 | try: 353 | import tensorflow as tf 354 | from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2 355 | 356 | from models.tf import TFDetect, TFModel 357 | 358 | LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...') 359 | f = str(file).replace('.pt', '_saved_model') 360 | batch_size, ch, *imgsz = list(im.shape) # BCHW 361 | 362 | tf_model = TFModel(cfg=model.yaml, model=model, nc=model.nc, imgsz=imgsz) 363 | im = tf.zeros((batch_size, *imgsz, ch)) # BHWC order for TensorFlow 364 | _ = tf_model.predict(im, tf_nms, agnostic_nms, topk_per_class, topk_all, iou_thres, conf_thres) 365 | inputs = tf.keras.Input(shape=(*imgsz, ch), batch_size=None if dynamic else batch_size) 366 | outputs = tf_model.predict(inputs, tf_nms, agnostic_nms, topk_per_class, topk_all, iou_thres, conf_thres) 367 | keras_model = tf.keras.Model(inputs=inputs, outputs=outputs) 368 | keras_model.trainable = False 369 | keras_model.summary() 370 | if keras: 371 | keras_model.save(f, save_format='tf') 372 | else: 373 | spec = tf.TensorSpec(keras_model.inputs[0].shape, keras_model.inputs[0].dtype) 374 | m = tf.function(lambda x: keras_model(x)) # full model 375 | m = m.get_concrete_function(spec) 376 | frozen_func = convert_variables_to_constants_v2(m) 377 | tfm = tf.Module() 378 | tfm.__call__ = tf.function(lambda x: frozen_func(x)[:4] if tf_nms else frozen_func(x)[0], [spec]) 379 | tfm.__call__(im) 380 | tf.saved_model.save(tfm, 381 | f, 382 | options=tf.saved_model.SaveOptions(experimental_custom_gradients=False) 383 | if check_version(tf.__version__, '2.6') else tf.saved_model.SaveOptions()) 384 | LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)') 385 | return keras_model, f 386 | except Exception as e: 387 | LOGGER.info(f'\n{prefix} export failure: {e}') 388 | return None, None 389 | 390 | 391 | def export_pb(keras_model, file, prefix=colorstr('TensorFlow GraphDef:')): 392 | # YOLOv5 TensorFlow GraphDef *.pb export https://github.com/leimao/Frozen_Graph_TensorFlow 393 | try: 394 | import tensorflow as tf 395 | from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2 396 | 397 | LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...') 398 | f = file.with_suffix('.pb') 399 | 400 | m = tf.function(lambda x: keras_model(x)) # full model 401 | m = m.get_concrete_function(tf.TensorSpec(keras_model.inputs[0].shape, keras_model.inputs[0].dtype)) 402 | frozen_func = convert_variables_to_constants_v2(m) 403 | frozen_func.graph.as_graph_def() 404 | tf.io.write_graph(graph_or_graph_def=frozen_func.graph, logdir=str(f.parent), name=f.name, as_text=False) 405 | 406 | LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)') 407 | return f 408 | except Exception as e: 409 | LOGGER.info(f'\n{prefix} export failure: {e}') 410 | 411 | 412 | def export_tflite(keras_model, im, file, int8, data, nms, agnostic_nms, prefix=colorstr('TensorFlow Lite:')): 413 | # YOLOv5 TensorFlow Lite export 414 | try: 415 | import tensorflow as tf 416 | 417 | LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...') 418 | batch_size, ch, *imgsz = list(im.shape) # BCHW 419 | f = str(file).replace('.pt', '-fp16.tflite') 420 | 421 | converter = tf.lite.TFLiteConverter.from_keras_model(keras_model) 422 | converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS] 423 | converter.target_spec.supported_types = [tf.float16] 424 | converter.optimizations = [tf.lite.Optimize.DEFAULT] 425 | if int8: 426 | from models.tf import representative_dataset_gen 427 | dataset = LoadImages(check_dataset(data)['train'], img_size=imgsz, auto=False) # representative data 428 | converter.representative_dataset = lambda: representative_dataset_gen(dataset, ncalib=100) 429 | converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] 430 | converter.target_spec.supported_types = [] 431 | converter.inference_input_type = tf.uint8 # or tf.int8 432 | converter.inference_output_type = tf.uint8 # or tf.int8 433 | converter.experimental_new_quantizer = True 434 | f = str(file).replace('.pt', '-int8.tflite') 435 | if nms or agnostic_nms: 436 | converter.target_spec.supported_ops.append(tf.lite.OpsSet.SELECT_TF_OPS) 437 | 438 | tflite_model = converter.convert() 439 | open(f, "wb").write(tflite_model) 440 | LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)') 441 | return f 442 | except Exception as e: 443 | LOGGER.info(f'\n{prefix} export failure: {e}') 444 | 445 | 446 | def export_edgetpu(file, prefix=colorstr('Edge TPU:')): 447 | # YOLOv5 Edge TPU export https://coral.ai/docs/edgetpu/models-intro/ 448 | try: 449 | cmd = 'edgetpu_compiler --version' 450 | help_url = 'https://coral.ai/docs/edgetpu/compiler/' 451 | assert platform.system() == 'Linux', f'export only supported on Linux. See {help_url}' 452 | if subprocess.run(f'{cmd} >/dev/null', shell=True).returncode != 0: 453 | LOGGER.info(f'\n{prefix} export requires Edge TPU compiler. Attempting install from {help_url}') 454 | sudo = subprocess.run('sudo --version >/dev/null', shell=True).returncode == 0 # sudo installed on system 455 | for c in ( 456 | 'curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo apt-key add -', 457 | 'echo "deb https://packages.cloud.google.com/apt coral-edgetpu-stable main" | sudo tee /etc/apt/sources.list.d/coral-edgetpu.list', 458 | 'sudo apt-get update', 'sudo apt-get install edgetpu-compiler'): 459 | subprocess.run(c if sudo else c.replace('sudo ', ''), shell=True, check=True) 460 | ver = subprocess.run(cmd, shell=True, capture_output=True, check=True).stdout.decode().split()[-1] 461 | 462 | LOGGER.info(f'\n{prefix} starting export with Edge TPU compiler {ver}...') 463 | f = str(file).replace('.pt', '-int8_edgetpu.tflite') # Edge TPU model 464 | f_tfl = str(file).replace('.pt', '-int8.tflite') # TFLite model 465 | 466 | cmd = f"edgetpu_compiler -s -o {file.parent} {f_tfl}" 467 | subprocess.run(cmd.split(), check=True) 468 | 469 | LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)') 470 | return f 471 | except Exception as e: 472 | LOGGER.info(f'\n{prefix} export failure: {e}') 473 | 474 | 475 | def export_tfjs(file, prefix=colorstr('TensorFlow.js:')): 476 | # YOLOv5 TensorFlow.js export 477 | try: 478 | check_requirements(('tensorflowjs',)) 479 | import re 480 | 481 | import tensorflowjs as tfjs 482 | 483 | LOGGER.info(f'\n{prefix} starting export with tensorflowjs {tfjs.__version__}...') 484 | f = str(file).replace('.pt', '_web_model') # js dir 485 | f_pb = file.with_suffix('.pb') # *.pb path 486 | f_json = f'{f}/model.json' # *.json path 487 | 488 | cmd = f'tensorflowjs_converter --input_format=tf_frozen_model ' \ 489 | f'--output_node_names=Identity,Identity_1,Identity_2,Identity_3 {f_pb} {f}' 490 | subprocess.run(cmd.split()) 491 | 492 | with open(f_json) as j: 493 | json = j.read() 494 | with open(f_json, 'w') as j: # sort JSON Identity_* in ascending order 495 | subst = re.sub( 496 | r'{"outputs": {"Identity.?.?": {"name": "Identity.?.?"}, ' 497 | r'"Identity.?.?": {"name": "Identity.?.?"}, ' 498 | r'"Identity.?.?": {"name": "Identity.?.?"}, ' 499 | r'"Identity.?.?": {"name": "Identity.?.?"}}}', r'{"outputs": {"Identity": {"name": "Identity"}, ' 500 | r'"Identity_1": {"name": "Identity_1"}, ' 501 | r'"Identity_2": {"name": "Identity_2"}, ' 502 | r'"Identity_3": {"name": "Identity_3"}}}', json) 503 | j.write(subst) 504 | 505 | LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)') 506 | return f 507 | except Exception as e: 508 | LOGGER.info(f'\n{prefix} export failure: {e}') 509 | 510 | 511 | @torch.no_grad() 512 | def run( 513 | data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path' 514 | weights=ROOT / 'yolov5s.pt', # weights path 515 | imgsz=(640, 640), # image (height, width) 516 | batch_size=1, # batch size 517 | device='cpu', # cuda device, i.e. 0 or 0,1,2,3 or cpu 518 | include=('torchscript', 'onnx'), # include formats 519 | half=False, # FP16 half-precision export 520 | inplace=False, # set YOLOv5 Detect() inplace=True 521 | train=False, # model.train() mode 522 | keras=False, # use Keras 523 | optimize=False, # TorchScript: optimize for mobile 524 | int8=False, # CoreML/TF INT8 quantization 525 | dynamic=False, # ONNX/TF: dynamic axes 526 | simplify=False, # ONNX: simplify model 527 | opset=12, # ONNX: opset version 528 | verbose=False, # TensorRT: verbose log 529 | workspace=4, # TensorRT: workspace size (GB) 530 | nms=False, # TF: add NMS to model 531 | agnostic_nms=False, # TF: add agnostic NMS to model 532 | topk_per_class=100, # TF.js NMS: topk per class to keep 533 | topk_all=100, # TF.js NMS: topk for all classes to keep 534 | iou_thres=0.45, # TF.js NMS: IoU threshold 535 | conf_thres=0.25, # TF.js NMS: confidence threshold 536 | ): 537 | t = time.time() 538 | include = [x.lower() for x in include] # to lowercase 539 | fmts = tuple(export_formats()['Argument'][1:]) # --include arguments 540 | flags = [x in include for x in fmts] 541 | assert sum(flags) == len(include), f'ERROR: Invalid --include {include}, valid --include arguments are {fmts}' 542 | jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs = flags # export booleans 543 | file = Path(url2file(weights) if str(weights).startswith(('http:/', 'https:/')) else weights) # PyTorch weights 544 | 545 | # Load PyTorch model 546 | device = select_device(device) 547 | if half: 548 | assert device.type != 'cpu' or coreml or xml, '--half only compatible with GPU export, i.e. use --device 0' 549 | assert not dynamic, '--half not compatible with --dynamic, i.e. use either --half or --dynamic but not both' 550 | model = attempt_load(weights, device=device, inplace=True, fuse=True) # load FP32 model 551 | nc, names = model.nc, model.names # number of classes, class names 552 | 553 | # Checks 554 | imgsz *= 2 if len(imgsz) == 1 else 1 # expand 555 | assert nc == len(names), f'Model class count {nc} != len(names) {len(names)}' 556 | 557 | # Input 558 | gs = int(max(model.stride)) # grid size (max stride) 559 | imgsz = [check_img_size(x, gs) for x in imgsz] # verify img_size are gs-multiples 560 | im = torch.zeros(batch_size, 3, *imgsz).to(device) # image size(1,3,320,192) BCHW iDetection 561 | 562 | # Update model 563 | import torch.nn as nn 564 | if half and not coreml and not xml: 565 | im, model = im.half(), model.half() # to FP16 566 | model.train() if train else model.eval() # training mode = no Detect() layer grid construction 567 | for k, m in model.named_modules(): 568 | if isinstance(m, Detect): 569 | m.inplace = inplace 570 | m.onnx_dynamic = dynamic 571 | m.export = True 572 | elif isinstance(m, nn.Upsample): 573 | print(m) 574 | 575 | for _ in range(2): 576 | y = model(im) # dry runs 577 | shape = tuple(y[0].shape) # model output shape 578 | LOGGER.info(f"\n{colorstr('PyTorch:')} starting from {file} with output shape {shape} ({file_size(file):.1f} MB)") 579 | 580 | # Exports 581 | f = [''] * 10 # exported filenames 582 | warnings.filterwarnings(action='ignore', category=torch.jit.TracerWarning) # suppress TracerWarning 583 | if jit: 584 | f[0] = export_torchscript(model, im, file, optimize) 585 | if engine: # TensorRT required before ONNX 586 | f[1] = export_engine(model, im, file, train, half, simplify, workspace, verbose) 587 | if onnx or xml: # OpenVINO requires ONNX 588 | f[2] = export_onnx(model, im, file, opset, train, dynamic, simplify) 589 | if xml: # OpenVINO 590 | f[3] = export_openvino(model, file, half) 591 | if coreml: 592 | _, f[4] = export_coreml(model, im, file, int8, half) 593 | 594 | # TensorFlow Exports 595 | if any((saved_model, pb, tflite, edgetpu, tfjs)): 596 | if int8 or edgetpu: # TFLite --int8 bug https://github.com/ultralytics/yolov5/issues/5707 597 | check_requirements(('flatbuffers==1.12',)) # required before `import tensorflow` 598 | assert not tflite or not tfjs, 'TFLite and TF.js models must be exported separately, please pass only one type.' 599 | model, f[5] = export_saved_model(model.cpu(), 600 | im, 601 | file, 602 | dynamic, 603 | tf_nms=nms or agnostic_nms or tfjs, 604 | agnostic_nms=agnostic_nms or tfjs, 605 | topk_per_class=topk_per_class, 606 | topk_all=topk_all, 607 | iou_thres=iou_thres, 608 | conf_thres=conf_thres, 609 | keras=keras) 610 | if pb or tfjs: # pb prerequisite to tfjs 611 | f[6] = export_pb(model, file) 612 | if tflite or edgetpu: 613 | f[7] = export_tflite(model, im, file, int8=int8 or edgetpu, data=data, nms=nms, agnostic_nms=agnostic_nms) 614 | if edgetpu: 615 | f[8] = export_edgetpu(file) 616 | if tfjs: 617 | f[9] = export_tfjs(file) 618 | 619 | # Finish 620 | f = [str(x) for x in f if x] # filter out '' and None 621 | if any(f): 622 | LOGGER.info(f'\nExport complete ({time.time() - t:.2f}s)' 623 | f"\nResults saved to {colorstr('bold', file.parent.resolve())}" 624 | f"\nDetect: python detect.py --weights {f[-1]}" 625 | f"\nPyTorch Hub: model = torch.hub.load('ultralytics/yolov5', 'custom', '{f[-1]}')" 626 | f"\nValidate: python val.py --weights {f[-1]}" 627 | f"\nVisualize: https://netron.app") 628 | return f # return list of exported files/dirs 629 | 630 | 631 | def parse_opt(): 632 | parser = argparse.ArgumentParser() 633 | parser.add_argument('--data', type=str, default=ROOT / 'data/coco128.yaml', help='dataset.yaml path') 634 | parser.add_argument('--weights', nargs='+', type=str, default=ROOT / 'yolov5s.pt', help='model.pt path(s)') 635 | parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[640, 640], help='image (h, w)') 636 | parser.add_argument('--batch-size', type=int, default=1, help='batch size') 637 | parser.add_argument('--device', default='cpu', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') 638 | parser.add_argument('--half', action='store_true', help='FP16 half-precision export') 639 | parser.add_argument('--inplace', action='store_true', help='set YOLOv5 Detect() inplace=True') 640 | parser.add_argument('--train', action='store_true', help='model.train() mode') 641 | parser.add_argument('--keras', action='store_true', help='TF: use Keras') 642 | parser.add_argument('--optimize', action='store_true', help='TorchScript: optimize for mobile') 643 | parser.add_argument('--int8', action='store_true', help='CoreML/TF INT8 quantization') 644 | parser.add_argument('--dynamic', action='store_true', help='ONNX/TF: dynamic axes') 645 | parser.add_argument('--simplify', action='store_true', help='ONNX: simplify model') 646 | parser.add_argument('--opset', type=int, default=12, help='ONNX: opset version') 647 | parser.add_argument('--verbose', action='store_true', help='TensorRT: verbose log') 648 | parser.add_argument('--workspace', type=int, default=4, help='TensorRT: workspace size (GB)') 649 | parser.add_argument('--nms', action='store_true', help='TF: add NMS to model') 650 | parser.add_argument('--agnostic-nms', action='store_true', help='TF: add agnostic NMS to model') 651 | parser.add_argument('--topk-per-class', type=int, default=100, help='TF.js NMS: topk per class to keep') 652 | parser.add_argument('--topk-all', type=int, default=100, help='TF.js NMS: topk for all classes to keep') 653 | parser.add_argument('--iou-thres', type=float, default=0.45, help='TF.js NMS: IoU threshold') 654 | parser.add_argument('--conf-thres', type=float, default=0.25, help='TF.js NMS: confidence threshold') 655 | parser.add_argument('--include', 656 | nargs='+', 657 | default=['torchscript', 'onnx'], 658 | help='torchscript, onnx, openvino, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs') 659 | opt = parser.parse_args() 660 | print_args(vars(opt)) 661 | return opt 662 | 663 | 664 | def main(opt): 665 | for opt.weights in (opt.weights if isinstance(opt.weights, list) else [opt.weights]): 666 | run(**vars(opt)) 667 | 668 | 669 | if __name__ == "__main__": 670 | opt = parse_opt() 671 | main(opt) 672 | --------------------------------------------------------------------------------