├── lib └── libflattenconcat.so ├── coco.py ├── config ├── model_ssd_inception_v2_coco_2017_11_17.py ├── model_ssd_mobilenet_v2_coco_2018_03_29.py └── model_ssd_mobilenet_v1_coco_2018_01_28.py ├── README.md └── main.py /lib/libflattenconcat.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AastaNV/TRT_object_detection/HEAD/lib/libflattenconcat.so -------------------------------------------------------------------------------- /coco.py: -------------------------------------------------------------------------------- 1 | 2 | COCO_CLASSES_LIST = [ 3 | 'unlabeled', 4 | 'person', 5 | 'bicycle', 6 | 'car', 7 | 'motorcycle', 8 | 'airplane', 9 | 'bus', 10 | 'train', 11 | 'truck', 12 | 'boat', 13 | 'traffic light', 14 | 'fire hydrant', 15 | 'street sign', 16 | 'stop sign', 17 | 'parking meter', 18 | 'bench', 19 | 'bird', 20 | 'cat', 21 | 'dog', 22 | 'horse', 23 | 'sheep', 24 | 'cow', 25 | 'elephant', 26 | 'bear', 27 | 'zebra', 28 | 'giraffe', 29 | 'hat', 30 | 'backpack', 31 | 'umbrella', 32 | 'shoe', 33 | 'eye glasses', 34 | 'handbag', 35 | 'tie', 36 | 'suitcase', 37 | 'frisbee', 38 | 'skis', 39 | 'snowboard', 40 | 'sports ball', 41 | 'kite', 42 | 'baseball bat', 43 | 'baseball glove', 44 | 'skateboard', 45 | 'surfboard', 46 | 'tennis racket', 47 | 'bottle', 48 | 'plate', 49 | 'wine glass', 50 | 'cup', 51 | 'fork', 52 | 'knife', 53 | 'spoon', 54 | 'bowl', 55 | 'banana', 56 | 'apple', 57 | 'sandwich', 58 | 'orange', 59 | 'broccoli', 60 | 'carrot', 61 | 'hot dog', 62 | 'pizza', 63 | 'donut', 64 | 'cake', 65 | 'chair', 66 | 'couch', 67 | 'potted plant', 68 | 'bed', 69 | 'mirror', 70 | 'dining table', 71 | 'window', 72 | 'desk', 73 | 'toilet', 74 | 'door', 75 | 'tv', 76 | 'laptop', 77 | 'mouse', 78 | 'remote', 79 | 'keyboard', 80 | 'cell phone', 81 | 'microwave', 82 | 'oven', 83 | 'toaster', 84 | 'sink', 85 | 'refrigerator', 86 | 'blender', 87 | 'book', 88 | 'clock', 89 | 'vase', 90 | 'scissors', 91 | 'teddy bear', 92 | 'hair drier', 93 | 'toothbrush', 94 | ] 95 | 96 | COCO_CLASSES_SET = set(COCO_CLASSES_LIST) 97 | -------------------------------------------------------------------------------- /config/model_ssd_inception_v2_coco_2017_11_17.py: -------------------------------------------------------------------------------- 1 | import graphsurgeon as gs 2 | 3 | path = 'model/ssd_inception_v2_coco_2017_11_17/frozen_inference_graph.pb' 4 | TRTbin = 'TRT_ssd_inception_v2_coco_2017_11_17.bin' 5 | output_name = ['NMS'] 6 | dims = [3,300,300] 7 | layout = 7 8 | 9 | def add_plugin(graph): 10 | all_assert_nodes = graph.find_nodes_by_op("Assert") 11 | graph.remove(all_assert_nodes, remove_exclusive_dependencies=True) 12 | 13 | all_identity_nodes = graph.find_nodes_by_op("Identity") 14 | graph.forward_inputs(all_identity_nodes) 15 | 16 | Input = gs.create_plugin_node( 17 | name="Input", 18 | op="Placeholder", 19 | shape=[1, 3, 300, 300] 20 | ) 21 | 22 | PriorBox = gs.create_plugin_node( 23 | name="GridAnchor", 24 | op="GridAnchor_TRT", 25 | minSize=0.2, 26 | maxSize=0.95, 27 | aspectRatios=[1.0, 2.0, 0.5, 3.0, 0.33], 28 | variance=[0.1,0.1,0.2,0.2], 29 | featureMapShapes=[19, 10, 5, 3, 2, 1], 30 | numLayers=6 31 | ) 32 | 33 | NMS = gs.create_plugin_node( 34 | name="NMS", 35 | op="NMS_TRT", 36 | shareLocation=1, 37 | varianceEncodedInTarget=0, 38 | backgroundLabelId=0, 39 | confidenceThreshold=1e-8, 40 | nmsThreshold=0.6, 41 | topK=100, 42 | keepTopK=100, 43 | numClasses=91, 44 | inputOrder=[0, 2, 1], 45 | confSigmoid=1, 46 | isNormalized=1, 47 | scoreConverter="SIGMOID" 48 | ) 49 | 50 | concat_priorbox = gs.create_node( 51 | "concat_priorbox", 52 | op="ConcatV2", 53 | axis=2 54 | ) 55 | 56 | concat_box_loc = gs.create_plugin_node( 57 | "concat_box_loc", 58 | op="FlattenConcat_TRT", 59 | ) 60 | 61 | concat_box_conf = gs.create_plugin_node( 62 | "concat_box_conf", 63 | op="FlattenConcat_TRT", 64 | ) 65 | 66 | namespace_plugin_map = { 67 | "MultipleGridAnchorGenerator": PriorBox, 68 | "Postprocessor": NMS, 69 | "Preprocessor": Input, 70 | "ToFloat": Input, 71 | "image_tensor": Input, 72 | "MultipleGridAnchorGenerator/Concatenate": concat_priorbox, 73 | "concat": concat_box_loc, 74 | "concat_1": concat_box_conf 75 | } 76 | 77 | graph.collapse_namespaces(namespace_plugin_map) 78 | graph.remove(graph.graph_outputs, remove_exclusive_dependencies=False) 79 | 80 | return graph 81 | -------------------------------------------------------------------------------- /config/model_ssd_mobilenet_v2_coco_2018_03_29.py: -------------------------------------------------------------------------------- 1 | import graphsurgeon as gs 2 | 3 | path = 'model/ssd_mobilenet_v2_coco_2018_03_29/frozen_inference_graph.pb' 4 | TRTbin = 'TRT_ssd_mobilenet_v2_coco_2018_03_29.bin' 5 | output_name = ['NMS'] 6 | dims = [3,300,300] 7 | layout = 7 8 | 9 | def add_plugin(graph): 10 | all_assert_nodes = graph.find_nodes_by_op("Assert") 11 | graph.remove(all_assert_nodes, remove_exclusive_dependencies=True) 12 | 13 | all_identity_nodes = graph.find_nodes_by_op("Identity") 14 | graph.forward_inputs(all_identity_nodes) 15 | 16 | Input = gs.create_plugin_node( 17 | name="Input", 18 | op="Placeholder", 19 | shape=[1, 3, 300, 300] 20 | ) 21 | 22 | PriorBox = gs.create_plugin_node( 23 | name="GridAnchor", 24 | op="GridAnchor_TRT", 25 | minSize=0.2, 26 | maxSize=0.95, 27 | aspectRatios=[1.0, 2.0, 0.5, 3.0, 0.33], 28 | variance=[0.1,0.1,0.2,0.2], 29 | featureMapShapes=[19, 10, 5, 3, 2, 1], 30 | numLayers=6 31 | ) 32 | 33 | NMS = gs.create_plugin_node( 34 | name="NMS", 35 | op="NMS_TRT", 36 | shareLocation=1, 37 | varianceEncodedInTarget=0, 38 | backgroundLabelId=0, 39 | confidenceThreshold=1e-8, 40 | nmsThreshold=0.6, 41 | topK=100, 42 | keepTopK=100, 43 | numClasses=91, 44 | inputOrder=[1, 0, 2], 45 | confSigmoid=1, 46 | isNormalized=1 47 | ) 48 | 49 | concat_priorbox = gs.create_node( 50 | "concat_priorbox", 51 | op="ConcatV2", 52 | axis=2 53 | ) 54 | 55 | concat_box_loc = gs.create_plugin_node( 56 | "concat_box_loc", 57 | op="FlattenConcat_TRT", 58 | ) 59 | 60 | concat_box_conf = gs.create_plugin_node( 61 | "concat_box_conf", 62 | op="FlattenConcat_TRT", 63 | ) 64 | 65 | namespace_plugin_map = { 66 | "MultipleGridAnchorGenerator": PriorBox, 67 | "Postprocessor": NMS, 68 | "Preprocessor": Input, 69 | "ToFloat": Input, 70 | "image_tensor": Input, 71 | "Concatenate": concat_priorbox, 72 | "concat": concat_box_loc, 73 | "concat_1": concat_box_conf 74 | } 75 | 76 | graph.collapse_namespaces(namespace_plugin_map) 77 | graph.remove(graph.graph_outputs, remove_exclusive_dependencies=False) 78 | graph.find_nodes_by_op("NMS_TRT")[0].input.remove("Input") 79 | 80 | return graph 81 | -------------------------------------------------------------------------------- /config/model_ssd_mobilenet_v1_coco_2018_01_28.py: -------------------------------------------------------------------------------- 1 | import graphsurgeon as gs 2 | 3 | path = 'model/ssd_mobilenet_v1_coco_2018_01_28/frozen_inference_graph.pb' 4 | TRTbin = 'TRT_ssd_mobilenet_v1_coco_2018_01_28.bin' 5 | output_name = ['Postprocessor'] 6 | dims = [3,300,300] 7 | layout = 7 8 | 9 | def add_plugin(graph): 10 | all_assert_nodes = graph.find_nodes_by_op("Assert") 11 | graph.remove(all_assert_nodes, remove_exclusive_dependencies=True) 12 | 13 | all_identity_nodes = graph.find_nodes_by_op("Identity") 14 | graph.forward_inputs(all_identity_nodes) 15 | 16 | Input = gs.create_plugin_node( 17 | name="Input", 18 | op="Placeholder", 19 | shape=[1, 3, 300, 300] 20 | ) 21 | 22 | PriorBox = gs.create_plugin_node( 23 | name="MultipleGridAnchorGenerator", 24 | op="GridAnchor_TRT", 25 | minSize=0.2, 26 | maxSize=0.95, 27 | aspectRatios=[1.0, 2.0, 0.5, 3.0, 0.33], 28 | variance=[0.1,0.1,0.2,0.2], 29 | featureMapShapes=[19, 10, 5, 3, 2, 1], 30 | numLayers=6 31 | ) 32 | 33 | Postprocessor = gs.create_plugin_node( 34 | name="Postprocessor", 35 | op="NMS_TRT", 36 | shareLocation=1, 37 | varianceEncodedInTarget=0, 38 | backgroundLabelId=0, 39 | confidenceThreshold=1e-8, 40 | nmsThreshold=0.6, 41 | topK=100, 42 | keepTopK=100, 43 | numClasses=91, 44 | inputOrder=[0, 2, 1], 45 | confSigmoid=1, 46 | isNormalized=1 47 | ) 48 | 49 | concat_priorbox = gs.create_node( 50 | "concat_priorbox", 51 | op="ConcatV2", 52 | axis=2 53 | ) 54 | 55 | concat_box_loc = gs.create_plugin_node( 56 | "concat_box_loc", 57 | op="FlattenConcat_TRT", 58 | ) 59 | 60 | concat_box_conf = gs.create_plugin_node( 61 | "concat_box_conf", 62 | op="FlattenConcat_TRT", 63 | ) 64 | 65 | namespace_plugin_map = { 66 | "MultipleGridAnchorGenerator": PriorBox, 67 | "Postprocessor": Postprocessor, 68 | "Preprocessor": Input, 69 | "ToFloat": Input, 70 | "image_tensor": Input, 71 | "MultipleGridAnchorGenerator/Concatenate": concat_priorbox, 72 | "concat": concat_box_loc, 73 | "concat_1": concat_box_conf 74 | } 75 | 76 | graph.collapse_namespaces(namespace_plugin_map) 77 | graph.remove(graph.graph_outputs, remove_exclusive_dependencies=False) 78 | graph.find_nodes_by_op("NMS_TRT")[0].input.remove("Input") 79 | graph.find_nodes_by_name("Input")[0].input.remove("image_tensor:0") 80 | 81 | return graph 82 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | TensorRT Python Sample for Object Detection 2 | ====================================== 3 | 4 | Performance includes memcpy and inference. 5 |
6 | 7 | | Model | Input Size | TRT Nano | 8 | |:------|:----------:|-----------:| 9 | | ssd_inception_v2_coco(2017) | 300x300 | 49ms | 10 | | ssd_mobilenet_v1_coco | 300x300 | 36ms | 11 | | ssd_mobilenet_v2_coco | 300x300 | 46ms | 12 | 13 | Since the optimization of preprocessing is not ready yet, we don't include image read/write time here. 14 |
15 |
16 | 17 | ## Install dependencies 18 | 19 | ```C 20 | $ sudo apt-get install python3-pip libhdf5-serial-dev hdf5-tools 21 | $ pip3 install --extra-index-url https://developer.download.nvidia.com/compute/redist/jp/v42 tensorflow-gpu==1.13.1+nv19.5 --user 22 | $ pip3 install numpy pycuda --user 23 | ``` 24 | 25 |
26 |
27 | 28 | ## Download model 29 | 30 | Please download the object detection model from TensorFlow model zoo. 31 |
32 | 33 | ```C 34 | $ git clone https://github.com/AastaNV/TRT_object_detection.git 35 | $ cd TRT_object_detection 36 | $ mkdir model 37 | $ cp [model].tar.gz model/ 38 | $ tar zxvf model/[model].tar.gz -C model/ 39 | ``` 40 | 41 | ##### Supported models: 42 | 43 | - ssd_inception_v2_coco_2017_11_17 44 | - ssd_mobilenet_v1_coco 45 | - ssd_mobilenet_v2_coco 46 | 47 | We will keep adding new model into our supported list. 48 | 49 |
50 |
51 | 52 | ## Update graphsurgeon converter 53 | 54 | Edit /usr/lib/python3.6/dist-packages/graphsurgeon/node_manipulation.py 55 | 56 | ```C 57 | diff --git a/node_manipulation.py b/node_manipulation.py 58 | index d2d012a..1ef30a0 100644 59 | --- a/node_manipulation.py 60 | +++ b/node_manipulation.py 61 | @@ -30,6 +30,7 @@ def create_node(name, op=None, _do_suffix=False, **kwargs): 62 | node = NodeDef() 63 | node.name = name 64 | node.op = op if op else name 65 | + node.attr["dtype"].type = 1 66 | for key, val in kwargs.items(): 67 | if key == "dtype": 68 | node.attr["dtype"].type = val.as_datatype_enum 69 | ``` 70 |
71 |
72 | 73 | ## RUN 74 | 75 | **1. Maximize the Nano performance** 76 | ```C 77 | $ sudo nvpmodel -m 0 78 | $ sudo jetson_clocks 79 | ``` 80 |
81 | 82 | **2. Update main.py based on the model you used** 83 | ```C 84 | from config import model_ssd_inception_v2_coco_2017_11_17 as model 85 | from config import model_ssd_mobilenet_v1_coco_2018_01_28 as model 86 | from config import model_ssd_mobilenet_v2_coco_2018_03_29 as model 87 | ``` 88 |
89 | 90 | **3. Execute** 91 | ```C 92 | $ python3 main.py [image] 93 | ``` 94 | 95 | It takes some time to compile a TensorRT model when the first launching. 96 |
97 | After that, TensorRT engine can be created directly with the serialized .bin file 98 |
99 |
100 | @ To get more memory, it's recommended to turn-off X-server. 101 |
102 |
103 |
104 |
105 |
106 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import cv2 4 | import time 5 | import ctypes 6 | import numpy as np 7 | import pycuda.autoinit 8 | import pycuda.driver as cuda 9 | 10 | import coco 11 | import uff 12 | import tensorrt as trt 13 | import graphsurgeon as gs 14 | #from config import model_ssd_inception_v2_coco_2017_11_17 as model 15 | #from config import model_ssd_mobilenet_v1_coco_2018_01_28 as model 16 | from config import model_ssd_mobilenet_v2_coco_2018_03_29 as model 17 | 18 | ctypes.CDLL("lib/libflattenconcat.so") 19 | COCO_LABELS = coco.COCO_CLASSES_LIST 20 | 21 | 22 | # initialize 23 | TRT_LOGGER = trt.Logger(trt.Logger.INFO) 24 | trt.init_libnvinfer_plugins(TRT_LOGGER, '') 25 | runtime = trt.Runtime(TRT_LOGGER) 26 | 27 | 28 | # compile model into TensorRT 29 | if not os.path.isfile(model.TRTbin): 30 | dynamic_graph = model.add_plugin(gs.DynamicGraph(model.path)) 31 | uff_model = uff.from_tensorflow(dynamic_graph.as_graph_def(), model.output_name, output_filename='tmp.uff') 32 | 33 | with trt.Builder(TRT_LOGGER) as builder, builder.create_network() as network, trt.UffParser() as parser: 34 | builder.max_workspace_size = 1 << 28 35 | builder.max_batch_size = 1 36 | builder.fp16_mode = True 37 | 38 | parser.register_input('Input', model.dims) 39 | parser.register_output('MarkOutput_0') 40 | parser.parse('tmp.uff', network) 41 | engine = builder.build_cuda_engine(network) 42 | 43 | buf = engine.serialize() 44 | with open(model.TRTbin, 'wb') as f: 45 | f.write(buf) 46 | 47 | 48 | # create engine 49 | with open(model.TRTbin, 'rb') as f: 50 | buf = f.read() 51 | engine = runtime.deserialize_cuda_engine(buf) 52 | 53 | 54 | # create buffer 55 | host_inputs = [] 56 | cuda_inputs = [] 57 | host_outputs = [] 58 | cuda_outputs = [] 59 | bindings = [] 60 | stream = cuda.Stream() 61 | 62 | for binding in engine: 63 | size = trt.volume(engine.get_binding_shape(binding)) * engine.max_batch_size 64 | host_mem = cuda.pagelocked_empty(size, np.float32) 65 | cuda_mem = cuda.mem_alloc(host_mem.nbytes) 66 | 67 | bindings.append(int(cuda_mem)) 68 | if engine.binding_is_input(binding): 69 | host_inputs.append(host_mem) 70 | cuda_inputs.append(cuda_mem) 71 | else: 72 | host_outputs.append(host_mem) 73 | cuda_outputs.append(cuda_mem) 74 | context = engine.create_execution_context() 75 | 76 | 77 | # inference 78 | #TODO enable video pipeline 79 | #TODO using pyCUDA for preprocess 80 | ori = cv2.imread(sys.argv[1]) 81 | image = cv2.cvtColor(ori, cv2.COLOR_BGR2RGB) 82 | image = cv2.resize(image, (model.dims[2],model.dims[1])) 83 | image = (2.0/255.0) * image - 1.0 84 | image = image.transpose((2, 0, 1)) 85 | np.copyto(host_inputs[0], image.ravel()) 86 | 87 | start_time = time.time() 88 | cuda.memcpy_htod_async(cuda_inputs[0], host_inputs[0], stream) 89 | context.execute_async(bindings=bindings, stream_handle=stream.handle) 90 | cuda.memcpy_dtoh_async(host_outputs[1], cuda_outputs[1], stream) 91 | cuda.memcpy_dtoh_async(host_outputs[0], cuda_outputs[0], stream) 92 | stream.synchronize() 93 | print("execute times "+str(time.time()-start_time)) 94 | 95 | output = host_outputs[0] 96 | height, width, channels = ori.shape 97 | for i in range(int(len(output)/model.layout)): 98 | prefix = i*model.layout 99 | index = int(output[prefix+0]) 100 | label = int(output[prefix+1]) 101 | conf = output[prefix+2] 102 | xmin = int(output[prefix+3]*width) 103 | ymin = int(output[prefix+4]*height) 104 | xmax = int(output[prefix+5]*width) 105 | ymax = int(output[prefix+6]*height) 106 | 107 | if conf > 0.7: 108 | print("Detected {} with confidence {}".format(COCO_LABELS[label], "{0:.0%}".format(conf))) 109 | cv2.rectangle(ori, (xmin,ymin), (xmax, ymax), (0,0,255),3) 110 | cv2.putText(ori, COCO_LABELS[label],(xmin+10,ymin+10), cv2.FONT_HERSHEY_SIMPLEX, 1, (255,255,255), 2, cv2.LINE_AA) 111 | 112 | cv2.imwrite("result.jpg", ori) 113 | cv2.imshow("result", ori) 114 | cv2.waitKey(0) 115 | --------------------------------------------------------------------------------