├── .gitignore
├── centerface
├── centerface
│ ├── 1
│ │ ├── run.sh
│ │ └── change_dim.py
│ ├── centerface_labels.txt
│ └── config.pbtxt
├── centerface_output.png
├── customparser
│ ├── libnvds_infercustomparser_centernet.so
│ ├── Makefile
│ └── customparserbbox_centernet.cpp
├── config
│ ├── centerface.txt
│ └── source1_primary_detector.txt
└── README.md
├── CONTRIBUTE.pdf
├── faster_rcnn_inception_v2
├── faster_rcnn_output.png
├── config
│ ├── labels.txt
│ ├── source1_primary_faster_rcnn_inception_v2.txt
│ └── config_infer_primary_faster_rcnn_inception_v2.txt
├── config.pbtxt
├── README.md
└── export_nms_only.py
├── README.md
└── LICENSE
/.gitignore:
--------------------------------------------------------------------------------
1 | *.mp4
2 | *.pb
3 | *.graphdef
4 |
--------------------------------------------------------------------------------
/centerface/centerface/centerface_labels.txt:
--------------------------------------------------------------------------------
1 | face
--------------------------------------------------------------------------------
/CONTRIBUTE.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVIDIA-AI-IOT/deepstream_triton_model_deploy/HEAD/CONTRIBUTE.pdf
--------------------------------------------------------------------------------
/centerface/centerface_output.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVIDIA-AI-IOT/deepstream_triton_model_deploy/HEAD/centerface/centerface_output.png
--------------------------------------------------------------------------------
/faster_rcnn_inception_v2/faster_rcnn_output.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVIDIA-AI-IOT/deepstream_triton_model_deploy/HEAD/faster_rcnn_inception_v2/faster_rcnn_output.png
--------------------------------------------------------------------------------
/centerface/customparser/libnvds_infercustomparser_centernet.so:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVIDIA-AI-IOT/deepstream_triton_model_deploy/HEAD/centerface/customparser/libnvds_infercustomparser_centernet.so
--------------------------------------------------------------------------------
/centerface/centerface/1/run.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | echo "Downloading centernet.onnx model"
17 | wget https://github.com/Star-Clouds/CenterFace/raw/master/models/onnx/centerface.onnx
18 |
19 | echo "Changing input and output node dimensions"
20 | python3 change_dim.py
21 |
22 | # Remove the original onnx model *
23 |
24 | rm -r centerface.onnx
25 |
--------------------------------------------------------------------------------
/faster_rcnn_inception_v2/config/labels.txt:
--------------------------------------------------------------------------------
1 | unlabeled
2 | person
3 | bicycle
4 | car
5 | motorcycle
6 | airplane
7 | bus
8 | train
9 | truck
10 | boat
11 | traffic light
12 | fire hydrant
13 | street sign
14 | stop sign
15 | parking meter
16 | bench
17 | bird
18 | cat
19 | dog
20 | horse
21 | sheep
22 | cow
23 | elephant
24 | bear
25 | zebra
26 | giraffe
27 | hat
28 | backpack
29 | umbrella
30 | shoe
31 | eye glasses
32 | handbag
33 | tie
34 | suitcase
35 | frisbee
36 | skis
37 | snowboard
38 | sports ball
39 | kite
40 | baseball bat
41 | baseball glove
42 | skateboard
43 | surfboard
44 | tennis racket
45 | bottle
46 | plate
47 | wine glass
48 | cup
49 | fork
50 | knife
51 | spoon
52 | bowl
53 | banana
54 | apple
55 | sandwich
56 | orange
57 | broccoli
58 | carrot
59 | hot dog
60 | pizza
61 | donut
62 | cake
63 | chair
64 | couch
65 | potted plant
66 | bed
67 | mirror
68 | dining table
69 | window
70 | desk
71 | toilet
72 | door
73 | tv
74 | laptop
75 | mouse
76 | remote
77 | keyboard
78 | cell phone
79 | microwave
80 | oven
81 | toaster
82 | sink
83 | refrigerator
84 | blender
85 | book
86 | clock
87 | vase
88 | scissors
89 | teddy bear
90 | hair drier
91 | toothbrush
92 |
--------------------------------------------------------------------------------
/centerface/customparser/Makefile:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | CC:= g++
16 |
17 | CFLAGS:= -Wall -std=c++11
18 |
19 | CFLAGS+= -shared -fPIC
20 |
21 | CFLAGS+= -I/opt/nvidia/deepstream/deepstream-5.1/sources/includes
22 |
23 | LIBS:= -lnvinfer -lnvparsers
24 | LFLAGS:= -Wl,--start-group $(LIBS) -Wl,--end-group
25 |
26 | SRCFILES:= customparserbbox_centernet.cpp
27 | TARGET_LIB:= libnvds_infercustomparser_centernet.so
28 |
29 | all: $(TARGET_LIB)
30 |
31 | $(TARGET_LIB) : $(SRCFILES)
32 | $(CC) -o $@ $^ $(CFLAGS) $(LFLAGS)
33 |
34 | install: $(TARGET_LIB)
35 | cp $(TARGET_LIB) ../../../lib
36 |
37 | clean:
38 | rm -rf $(TARGET_LIB)
39 |
--------------------------------------------------------------------------------
/centerface/centerface/1/change_dim.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import onnx
16 |
17 | def update_dim(model):
18 |
19 | # New width and height. Using -1,-1 so that we can use variable input size in model while using triton inference server.
20 | value = -1
21 | inputs = model.graph.input
22 | outputs = model.graph.output
23 |
24 | inputs[0].type.tensor_type.shape.dim[0].dim_value = -1
25 | inputs[0].type.tensor_type.shape.dim[2].dim_value = value
26 | inputs[0].type.tensor_type.shape.dim[3].dim_value = value
27 |
28 | for output in outputs:
29 | output.type.tensor_type.shape.dim[0].dim_value = -1
30 | output.type.tensor_type.shape.dim[2].dim_value = value #
31 | output.type.tensor_type.shape.dim[3].dim_value = value
32 |
33 | def change(update_dim, infile, outfile):
34 | model = onnx.load(infile)
35 | update_dim(model)
36 | onnx.save(model, outfile) # Save the new model with updated dimension
37 |
38 | ## Update the input and output dimension of model layers ##
39 | change(update_dim, "centerface.onnx", "model.onnx")
40 |
--------------------------------------------------------------------------------
/centerface/centerface/config.pbtxt:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | name: "centerface"
16 | platform: "onnxruntime_onnx"
17 | max_batch_size: 0
18 | input [
19 | {
20 | name: "input.1"
21 | data_type: TYPE_FP32
22 | # format: FORMAT_NCHW
23 | dims: [ -1, 3, 480, 640]
24 | # reshape { shape: [ 1, 3, 480, 640 ] }
25 | }
26 | ]
27 |
28 | output [
29 | {
30 | name: "537"
31 | data_type: TYPE_FP32
32 | dims: [ -1, 1, -1, -1 ]
33 | # reshape { shape: [ 1, 1, 1, 1 ] }
34 | label_filename: "centerface_labels.txt"
35 | },
36 | {
37 | name: "538"
38 | data_type: TYPE_FP32
39 | dims: [ -1, 2, -1, -1]
40 | label_filename: "centerface_labels.txt"
41 | },
42 |
43 | {
44 | name: "539"
45 | data_type: TYPE_FP32
46 | dims: [-1, 2, -1, -1]
47 | label_filename: "centerface_labels.txt"
48 | },
49 | {
50 | name: "540"
51 | data_type: TYPE_FP32
52 | dims: [-1, 10 , -1, -1]
53 | label_filename: "centerface_labels.txt"
54 | }
55 | ]
56 |
57 | instance_group {
58 | count: 1
59 | gpus: 0
60 | kind: KIND_GPU
61 | }
62 |
63 | # Enable TensorRT acceleration running in gpu instance. It might take several
64 | # minutes during intialization to generate tensorrt online caches.
65 |
66 | #optimization { execution_accelerators {
67 | # gpu_execution_accelerator : [ { name : "tensorrt" } ]
68 | # }}
69 |
--------------------------------------------------------------------------------
/faster_rcnn_inception_v2/config.pbtxt:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # tf_gpu_memory_fraction: 0.2 is specified for device with limited memory
16 | # resource such as Nano. Smaller value can limit Tensorflow GPU usage;
17 | # and larger value may increase performance but may also cause Out-Of-Memory
18 | # issues. Please tune a proper value.
19 |
20 | name: "faster_rcnn_inception_v2"
21 | platform: "tensorflow_graphdef"
22 | max_batch_size: 8
23 | input [
24 | {
25 | name: "image_tensor"
26 | data_type: TYPE_UINT8
27 | format: FORMAT_NHWC
28 | dims: [ 1920, 1080, 3 ]
29 | }
30 | ]
31 | output [
32 | {
33 | name: "detection_boxes"
34 | data_type: TYPE_FP32
35 | dims: [ 100, 4]
36 | reshape { shape: [100,4] }
37 | },
38 | {
39 | name: "detection_classes"
40 | data_type: TYPE_FP32
41 | dims: [ 100 ]
42 | },
43 | {
44 | name: "detection_scores"
45 | data_type: TYPE_FP32
46 | dims: [ 100 ]
47 | },
48 | {
49 | name: "num_detections"
50 | data_type: TYPE_FP32
51 | dims: [ 1 ]
52 | reshape { shape: [] }
53 | }
54 | ]
55 | version_policy: { specific {versions: 1}}
56 | instance_group [
57 | {
58 | kind: KIND_GPU
59 | count: 1
60 | gpus: [ 0 ]
61 | }
62 | ]
63 | #optimization { execution_accelerators {
64 | # gpu_execution_accelerator : [ {
65 | # name : "tensorrt"
66 | # parameters { key: "precision_mode" value: "FP16" }}]
67 | #}}
68 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------
2 | # This sample application is no longer maintained
3 | # ------------------------------------------------------
4 |
5 | # Deploying an open source model using NVIDIA DeepStream and Triton Inference Server
6 |
7 | This repository contains contains the the code and configuration files required to deploy sample open source models video analytics using Triton Inference Server and DeepStream SDK 5.0.
8 |
9 | ## Getting Started ##
10 |
11 | ### Prerequisites: ###
12 |
13 | [DeepStream SDK 5.0](https://developer.nvidia.com/deepstream-sdk) or use docker image (nvcr.io/nvidia/deepstream:5.0.1-20.09-triton) for x86 and (nvcr.io/nvidia/deepstream-l4t:5.0-20.07-samples) for NVIDIA Jetson.
14 |
15 | The following models have been deployed on DeepStream using Triton Inference Server.
16 |
17 | For further details, please see each project's README.
18 |
19 | ### TensorFlow Faster RCNN Inception V2 : [README](faster_rcnn_inception_v2/README.md) ###
20 | The project shows how to deploy [TensorFlow Faster RCNN Inception V2 network trained on MSCOCO dataset](https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/tf1_detection_zoo.md) for object detection.
21 | 
22 |
23 | ### ONNX CenterFace : [README](centerface/README.md) ###
24 | The project shows how to deploy [ONNX CenterFace](https://github.com/Star-Clouds/CenterFace) network for face detection and alignment.
25 | 
26 |
27 | Additional resources:
28 |
29 | Developer blog: [Building Intelligent Video Analytics Apps Using NVIDIA DeepStream 5.0](https://developer.nvidia.com/blog/building-iva-apps-using-deepstream-5-0-updated-for-ga/)
30 |
31 | Learn more about [Triton Inference Server](https://developer.nvidia.com/nvidia-triton-inference-server)
32 |
33 | Post your questions or feedback in the [DeepStream SDK developer forums](https://forums.developer.nvidia.com/c/accelerated-computing/intelligent-video-analytics/deepstream-sdk/15)
34 |
--------------------------------------------------------------------------------
/centerface/config/centerface.txt:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # tf_gpu_memory_fraction: 0.2 is specified for device with limited memory
16 | # resource such as Nano. Smaller value can limit Tensorflow GPU usage;
17 | # and larger value may increase performance but may also cause Out-Of-Memory
18 | # issues. Please tune a proper value.
19 |
20 | infer_config {
21 | unique_id: 1
22 | gpu_ids: 0
23 | max_batch_size: 1
24 | backend {
25 | inputs [
26 | {
27 | name: "input.1"
28 | dims: [3, 480, 640]
29 | }
30 | ]
31 | trt_is {
32 | model_name: "centerface"
33 | version: -1
34 | model_repo {
35 | root: "../"
36 | log_level: 1
37 | tf_gpu_memory_fraction: 0.2
38 | tf_disable_soft_placement: 0
39 | }
40 | }
41 | }
42 |
43 | preprocess {
44 | network_format: IMAGE_FORMAT_RGB
45 | tensor_order: TENSOR_ORDER_LINEAR
46 | maintain_aspect_ratio: 0
47 | normalize {
48 | scale_factor: 1.0
49 | channel_offsets: [0, 0, 0]
50 | }
51 | }
52 |
53 | postprocess {
54 | labelfile_path: "../centerface/centerface_labels.txt"
55 | detection {
56 | num_detected_classes: 1
57 | custom_parse_bbox_func: "NvDsInferParseCustomCenterNetFace"
58 | simple_cluster {
59 | threshold: 0.3
60 | }
61 | }
62 | }
63 |
64 | custom_lib {
65 | path: "../customparser/libnvds_infercustomparser_centernet.so"
66 | }
67 |
68 | extra {
69 | copy_input_to_host_buffers: false
70 | }
71 | }
72 | input_control {
73 | process_mode: PROCESS_MODE_FULL_FRAME
74 | interval: 0
75 | }
76 |
--------------------------------------------------------------------------------
/faster_rcnn_inception_v2/config/source1_primary_faster_rcnn_inception_v2.txt:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # tf_gpu_memory_fraction: 0.2 is specified for device with limited memory
16 | # resource such as Nano. Smaller value can limit Tensorflow GPU usage;
17 | # and larger value may increase performance but may also cause Out-Of-Memory
18 | # issues. Please tune a proper value.
19 |
20 | [application]
21 | enable-perf-measurement=1
22 | perf-measurement-interval-sec=1
23 | gie-kitti-output-dir=streamscl
24 |
25 | [tiled-display]
26 | enable=1
27 | rows=2
28 | columns=2
29 | width=1920
30 | height=1080
31 | gpu-id=0
32 | nvbuf-memory-type=0
33 |
34 | [source0]
35 | enable=1
36 | #Type - 1=CameraV4L2 2=URI 3=MultiURI 4=RTSP
37 | type=3
38 | num-sources=4
39 | uri=file:/opt/nvidia/deepstream/deepstream-5.0/models/vid.mp4
40 | gpu-id=0
41 | cudadec-memtype=0
42 |
43 | [streammux]
44 | gpu-id=0
45 | batch-size=4
46 | batched-push-timeout=40000
47 | enable-padding=0
48 | ## Set muxer output width and height
49 | width=1920
50 | height=1080
51 | nvbuf-memory-type=0
52 |
53 | [sink0]
54 | enable=1
55 | #Type - 1=FakeSink 2=EglSink 3=File
56 | type=2
57 | sync=0
58 | source-id=0
59 | gpu-id=0
60 | nvbuf-memory-type=0
61 | container=1
62 | bitrate=4000000
63 | output-file=/opt/nvidia/deepstream/deepstream-5.0/models/output.mp4
64 | codec=1
65 |
66 | [osd]
67 | enable=1
68 | gpu-id=0
69 | border-width=1
70 | text-size=15
71 | text-color=1;1;1;1;
72 | text-bg-color=0.3;0.3;0.3;1
73 | font=Serif
74 | show-clock=0
75 | clock-x-offset=800
76 | clock-y-offset=820
77 | clock-text-size=12
78 | clock-color=1;0;0;0
79 | nvbuf-memory-type=0
80 |
81 | [primary-gie]
82 | enable=1
83 | #(0): nvinfer; (1): nvinferserver
84 | plugin-type=1
85 | #infer-raw-output-dir=trtis-output
86 | batch-size=4
87 | interval=0
88 | gie-unique-id=1
89 | config-file=config_infer_primary_faster_rcnn_inception_v2.txt
90 |
--------------------------------------------------------------------------------
/faster_rcnn_inception_v2/config/config_infer_primary_faster_rcnn_inception_v2.txt:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # tf_gpu_memory_fraction: 0.2 is specified for device with limited memory
16 | # resource such as Nano. Smaller value can limit Tensorflow GPU usage;
17 | # and larger value may increase performance but may also cause Out-Of-Memory
18 | # issues. Please tune a proper value.
19 |
20 | infer_config {
21 | unique_id: 1
22 | gpu_ids: [0]
23 | backend {
24 | trt_is {
25 | model_name: "faster_rcnn_inception_v2"
26 | version: -1
27 | model_repo {
28 | root: "../../trtis_model_repo"
29 | log_level: 2
30 | tf_gpu_memory_fraction: 0
31 | tf_disable_soft_placement: 0
32 | }
33 | }
34 | }
35 |
36 | preprocess {
37 | network_format: IMAGE_FORMAT_RGB
38 | tensor_order: TENSOR_ORDER_NONE
39 | maintain_aspect_ratio: 0
40 | frame_scaling_hw: FRAME_SCALING_HW_DEFAULT
41 | frame_scaling_filter: 1
42 | normalize {
43 | scale_factor: 1.0
44 | channel_offsets: [0, 0, 0]
45 | }
46 | }
47 |
48 | postprocess {
49 | labelfile_path: "../../trtis_model_repo/faster_rcnn_inception_v2/labels.txt"
50 | detection {
51 | num_detected_classes: 91
52 | custom_parse_bbox_func: "NvDsInferParseCustomTfSSD"
53 | nms {
54 | confidence_threshold: 0.3
55 | iou_threshold: 0.6
56 | topk : 100
57 | }
58 | }
59 | }
60 |
61 | extra {
62 | copy_input_to_host_buffers: false
63 | }
64 |
65 | custom_lib {
66 | path: "/opt/nvidia/deepstream/deepstream-5.0/lib/libnvds_infercustomparser.so"
67 | }
68 | }
69 | input_control {
70 | process_mode: PROCESS_MODE_FULL_FRAME
71 | interval: 0
72 | }
73 |
74 | output_control {
75 | detect_control {
76 | default_filter { bbox_filter { min_width: 32, min_height: 32 } }
77 | }
78 | }
79 |
80 |
--------------------------------------------------------------------------------
/faster_rcnn_inception_v2/README.md:
--------------------------------------------------------------------------------
1 | # TensorFlow FasterRCNN Inception V2 Model with Deepstream #
2 |
3 | We are using Deepstream-5.0 with Triton Inference Server to deploy the FasterRCNN with Inception V2 model trained on the MSCOCO dataset for object detection.
4 |
5 | ### Prerequisites: ###
6 |
7 | [DeepStream SDK 5.0](https://developer.nvidia.com/deepstream-sdk)
8 |
9 | Download and install DeepStream SDK or use DeepStream docker image (nvcr.io/nvidia/deepstream:5.0.1-20.09-triton) for x86 and (nvcr.io/nvidia/deepstream-l4t:5.0-20.07-samples) for NVIDIA Jetson.
10 |
11 | Follow the instructions mentioned in the quick start guide: (https://docs.nvidia.com/metropolis/deepstream/dev-guide/index.html#page/DeepStream_Development_Guide/deepstream_quick_start.html)
12 |
13 | ### Obtaining the model ###
14 |
15 | ```bash
16 | $wget http://download.tensorflow.org/models/object_detection/faster_rcnn_inception_v2_coco_2018_01_28.tar.gz
17 | $tar xvf faster_rcnn_inception_v2_coco_2018_01_28.tar.gz
18 | ```
19 |
20 | ### Optimizing the model with TF-TRT ###
21 |
22 | ```
23 | $docker pull nvcr.io/nvidia/tensorflow:20.03-tf1-py3
24 | $docker pull nvcr.io/nvidia/l4t-tensorflow:r32.4.3-tf1.15-py3
25 | $docker run --gpus all -it --rm --shm-size=1g --ulimit memlock=-1 --ulimit stack=67108864 -v /home/$USER/triton_blog/:/workspace/triton_blog nvcr.io/nvidia/tensorflow:20.03-tf1-py3
26 | $docker run --runtime=nvidia -it --rm --shm-size=1g --ulimit memlock=-1 --ulimit stack=67108864 -v /home/$USER/triton_blog/:/workspace/triton_blog nvcr.io/nvidia/l4t-tensorflow:r32.4.3-tf1.15-py3
27 | $python3 export_nms_only.py --modelPath faster_rcnn_inception_v2_coco_2018_01_28/frozen_inference_graph.pb --gpu_mem_fraction 0.6 --nms True --precision FP16 --max_batch_size 8 --min_segment_size 5
28 | ```
29 |
30 | ### Deepstream Configuration Files ###
31 |
32 | There are two configuration file:
33 | 1. Inference configuration file
34 | * Sets the parameters for inference. This file takes the model configuration file sets the parameters for pre/post-processing
35 | 2. Application configuration file
36 | * Sets the configuration group to create a DeepStream pipeline. In this file you can set different configuration groups like source, sink, primary-gie, osd etc. Each group is calling a gstreamer-plugin. For more information on these plugins and configuration please check (https://docs.nvidia.com/metropolis/deepstream/plugin-manual/index.html#page/DeepStream%20Plugins%20Development%20Guide/deepstream_plugin_details.html) (https://docs.nvidia.com/metropolis/deepstream/dev-guide/index.html)
37 |
38 | These files are located at faster_rcnn_inception_v2/config
39 |
40 | ### Run the Application ###
41 |
42 | To run the application, make sure that the paths to the configuration files and input video stream are correct, then launch the reference app with the application configuration file
43 |
44 | `cd $DEEPSTREAM_DIR/samples/configs/deepstream-app-trtis`
45 | `deepstream-app -c source1_primary_faster_rcnn_inception_v2.txt`
46 |
47 | ## Performance ##
48 |
49 | Performance across 4 1080p streams with FP16 and TF-TRT optimizations
50 |
51 | | Model | WxH | Perf | Hardware | # Streams | # Batch size |
52 | |----------------------------|-----------|-------|------------------|-----------|--------------|
53 | | TF FasterRCNN Inception V2 | 1920x1080 | 32.36 | NVIDIA T4 | 4 | 4 |
54 | | TF FasterRCNN Inception V2 | 1920x1080 | 14.92 | NVIDIA Jetson NX | 4 | 4 |
55 |
56 |
57 |
58 |
--------------------------------------------------------------------------------
/centerface/config/source1_primary_detector.txt:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | [application]
16 | enable-perf-measurement=1
17 | perf-measurement-interval-sec=5
18 | #gie-kitti-output-dir=kitti-trtis
19 |
20 | [tiled-display]
21 | enable=1
22 | rows=1
23 | columns=1
24 | width=1280
25 | height=720
26 | gpu-id=0
27 | #(0): nvbuf-mem-default - Default memory allocated, specific to particular platform
28 | #(1): nvbuf-mem-cuda-pinned - Allocate Pinned/Host cuda memory applicable for Tesla
29 | #(2): nvbuf-mem-cuda-device - Allocate Device cuda memory applicable for Tesla
30 | #(3): nvbuf-mem-cuda-unified - Allocate Unified cuda memory applicable for Tesla
31 | #(4): nvbuf-mem-surface-array - Allocate Surface Array memory, applicable for Jetson
32 | nvbuf-memory-type=0
33 |
34 | [source0]
35 | enable=1
36 | #Type - 1=CameraV4L2 2=URI 3=MultiURI 4=RTSP
37 | type=2
38 | #uri=file://../../../samples/configs/tlt_pretrained_models/Redaction-A_1.mp4
39 | uri=file:///opt/nvidia/deepstream/deepstream-5.1/samples/streams/sample_1080p_h264.mp4
40 | num-sources=1
41 | #drop-frame-interval=2
42 | gpu-id=0
43 | # (0): memtype_device - Memory type Device
44 | # (1): memtype_pinned - Memory type Host Pinned
45 | # (2): memtype_unified - Memory type Unified
46 | cudadec-memtype=0
47 |
48 | [sink0]
49 | enable=0
50 | #Type - 1=FakeSink 2=EglSink 3=File
51 | type=1
52 | sync=0
53 | source-id=0
54 | gpu-id=0
55 | nvbuf-memory-type=0
56 |
57 | [sink1]
58 | enable=1
59 | type=3
60 | #1=mp4 2=mkv
61 | container=1
62 | #1=h264 2=h265
63 | codec=1
64 | sync=0
65 | #iframeinterval=10
66 | bitrate=2000000
67 | output-file=out.mp4
68 | source-id=0
69 |
70 | [sink2]
71 | enable=0
72 | #Type - 1=FakeSink 2=EglSink 3=File 4=RTSPStreaming
73 | type=4
74 | #1=h264 2=h265
75 | codec=1
76 | sync=0
77 | bitrate=4000000
78 | # set below properties in case of RTSPStreaming
79 | rtsp-port=8554
80 | udp-port=5400
81 |
82 | [osd]
83 | enable=1
84 | gpu-id=0
85 | border-width=1
86 | text-size=15
87 | text-color=1;1;1;1;
88 | text-bg-color=0.3;0.3;0.3;1
89 | font=Serif
90 | show-clock=0
91 | clock-x-offset=800
92 | clock-y-offset=820
93 | clock-text-size=12
94 | clock-color=1;0;0;0
95 | nvbuf-memory-type=0
96 |
97 | [streammux]
98 | gpu-id=0
99 | ##Boolean property to inform muxer that sources are live
100 | live-source=0
101 | batch-size=1
102 | ##time out in usec, to wait after the first buffer is available
103 | ##to push the batch even if the complete batch is not formed
104 | batched-push-timeout=40000
105 | ## Set muxer output width and height
106 | width=1920
107 | height=1080
108 | ##Enable to maintain aspect ratio wrt source, and allow black borders, works
109 | ##along with width, height properties
110 | enable-padding=0
111 | nvbuf-memory-type=0
112 |
113 | # config-file property is mandatory for any gie section.
114 | # Other properties are optional and if set will override the properties set in
115 | # the infer config file.
116 | [primary-gie]
117 | enable=1
118 | #(0): nvinfer; (1): nvinferserver
119 | plugin-type=1
120 | #infer-raw-output-dir=trtis-output
121 | batch-size=1
122 | interval=0
123 | gie-unique-id=1
124 | bbox-border-color0=1;0;0;1
125 | bbox-border-color1=0;1;1;1
126 | #bbox-border-color2=0;0;1;1
127 | #bbox-border-color3=0;1;0;1
128 | config-file=centerface.txt
129 |
130 | [tests]
131 | file-loop=0
132 |
--------------------------------------------------------------------------------
/centerface/README.md:
--------------------------------------------------------------------------------
1 | # ONNX Centerface Model with Deepstream #
2 |
3 |
4 |
5 |
6 |
7 | We are using Deepstream-5.0 with Triton Inference Server to deploy the Centerface network for face detection and alignment. For more information about the network please read : (https://arxiv.org/ftp/arxiv/papers/1911/1911.03599.pdf). This example shows a step by step process to deploy the Centerface network.
8 |
9 | Currently ONNX on Triton Inference Server with DeepStream is supported only on x86.
10 |
11 | ---
12 |
13 | ### Prerequisites: ###
14 |
15 | [DeepStream SDK 5.0](https://developer.nvidia.com/deepstream-sdk)
16 |
17 | Download and install DeepStream SDK or use DeepStream docker image(nvcr.io/nvidia/deepstream:5.0.1-20.09-triton).
18 |
19 | Follow the instructions mentioned in the quick start guide: (https://docs.nvidia.com/metropolis/deepstream/dev-guide/index.html#page/DeepStream_Development_Guide/deepstream_quick_start.html)
20 |
21 | ### Running the model with DeepStream ###
22 |
23 | `cd centerface/1 && ./run.sh`
24 |
25 | * centernet_labels.txt: This is the label file for centerface network. There is only one label, "face". If you are using a different model with different classes then you will have to update this file.
26 |
27 | * config.pbtxt: This is a model configuration file that provides the information about the model. This file must specify the name, platform, max_batch_size, input, output. To get more information on this file please check: (https://docs.nvidia.com/deeplearning/triton-inference-server/user-guide/docs/model_configuration.html)
28 |
29 | * 1/change_dim.py: In Triton Inference Server, if you want the input and output nodes to have variable size then relevant dimensions should be specified as -1. change_dim.py reads the input ONNX model, updates the height and width dimensions to -1, and saves the resulting model.
30 |
31 | * 1/run.sh: This script downloads the model and updates the dimension of input and output nodes.
32 |
33 | ## Custom Parser ##
34 |
35 | `cd customparser`
36 |
37 | In the custom parser, we take the outputlayer, apply the post-processing algorithm, and then attach the bounding boxes to NvDsInferObjectDetectionInfo. For more information on NvDsInferObjectDetectionInfo please check (https://docs.nvidia.com/metropolis/deepstream/4.0/dev-guide/DeepStream_Development_Guide/baggage/nvdsinfer_8h_source.html#l00126)
38 |
39 | If you need to update the custom parser for your own model then you can update customparserbbox_centernet.cpp and rebuilt it with 'make'.
40 |
41 | ## Deepstream Configuration Files ##
42 |
43 | There are two configuration file:
44 | 1. Inference configuration file
45 | * Sets the parameters for inference. This file takes the model configuration file sets the parameters for pre/post-processing
46 | 2. Application configuration file
47 | * Sets the configuration group to create a DeepStream pipeline. In this file you can set different configuration groups like source, sink, primary-gie, osd etc. Each group is calling a gstreamer-plugin. For more information on these plugins and configuration please check (https://docs.nvidia.com/metropolis/deepstream/plugin-manual/index.html#page/DeepStream%20Plugins%20Development%20Guide/deepstream_plugin_details.html) (https://docs.nvidia.com/metropolis/deepstream/dev-guide/index.html)
48 |
49 | These files are located at centerface/config
50 | ## Run the Application ##
51 |
52 | To run the application:
53 |
54 | `cd centerface/config`
55 |
56 | `deepstream-app -c source1_primary_detector.txt`
57 |
58 | ## Performance ##
59 | | Model | WxH | Perf. | Hardware | # Streams | # Batch Size |
60 | | ------ | ------ | ------ | ------ | ------ | ------ |
61 | | Centerface | 640x480 | 136 fps | T4 | 20 | 20 |
62 |
63 | ## FAQ ##
64 | Getting an error related to model dimensions (failed to load 'CenterNet' version 1: Invalid argument: model 'CenterNet', tensor 'input.1': the model expects 4 dimensions (shape [10,3,32,32]) but the model configuration specifies 4 dimensions (shape [1,3,480,640]))
65 |
66 | Answer: Please make sure you have updated the input node dimensions to -1 as mentioned in the pre-processing step and are using the correct ONNX model.
67 |
68 |
69 |
70 |
--------------------------------------------------------------------------------
/faster_rcnn_inception_v2/export_nms_only.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import tensorflow as tf
16 | from PIL import Image
17 | import numpy as np
18 |
19 | tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
20 | import tensorflow.compat.v1 as tf1
21 |
22 | tf1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
23 | from tensorflow.python.compiler.tensorrt import trt_convert as trt
24 | import argparse
25 | from tensorflow.python.util import deprecation
26 |
27 | deprecation._PRINT_DEPRECATION_WARNINGS = False
28 |
29 | DEFAULT_FROZEN_GRAPH_NAME = "frozen_inference_graph.pb"
30 | DEFAULT_MAX_BATCHSIZE = 1
31 | DEFAULT_INPUT_NAME = "image_tensor"
32 | DEFAULT_BOXES_NAME = "detection_boxes"
33 | DEFAULT_CLASSES_NAME = "detection_classes"
34 | DEFAULT_SCORES_NAME = "detection_scores"
35 | DEFAULT_NUM_DETECTIONS_NAME = "num_detections"
36 | DEFAULT_PRECISION = "FP32"
37 | DEFAULT_NMS = False
38 | # Default workspace size : 512MB
39 | DEFAULT_MAX_WORKSPACE_SIZE = 1 << 29
40 | DEFAULT_MIN_SEGMENT_SIZE = 10
41 | DEFAULT_GPU_MEMORY_FRACTION = 0.6
42 |
43 | TfConfig = tf.ConfigProto()
44 | # TfConfig.gpu_options.allow_growth=True
45 | TfConfig.gpu_options.allow_growth = False
46 | TfConfig.gpu_options.per_process_gpu_memory_fraction = DEFAULT_GPU_MEMORY_FRACTION
47 |
48 |
49 | def loadGraphDef(modelFile):
50 | graphDef = tf.GraphDef()
51 | with open(modelFile, "rb") as f:
52 | graphDef.ParseFromString(f.read())
53 | return graphDef
54 |
55 |
56 | def saveGraphDef(graphDef, outputFilePath):
57 | with open(outputFilePath, "wb") as f:
58 | f.write(graphDef.SerializeToString())
59 | print("---------saved graphdef to {}".format(outputFilePath))
60 |
61 |
62 | def updateNmsCpu(graphDef):
63 | for node in graphDef.node:
64 | # if 'NonMaxSuppressionV' in node.name and not node.device:
65 | if "NonMaxSuppression" in node.name and "TRTEngineOp" not in node.name:
66 | # node.device = '/device:CPU:0'
67 | node.device = "/job:localhost/replica:0/task:0/device:CPU:0"
68 |
69 |
70 | def main():
71 |
72 | parser = argparse.ArgumentParser(description="Offline tf-trt GraphDef")
73 | parser.add_argument(
74 | "--modelPath",
75 | type=str,
76 | default=DEFAULT_FROZEN_GRAPH_NAME,
77 | help="path to frozen model",
78 | required=True,
79 | )
80 | parser.add_argument(
81 | "--gpu_mem_fraction",
82 | type=float,
83 | default=DEFAULT_GPU_MEMORY_FRACTION,
84 | help="Tensorflow gpu memory fraction, suggested value [0.2, 0.6]",
85 | )
86 | parser.add_argument(
87 | "--nms", type=bool, default=DEFAULT_NMS, help="to offload NMS operation to CPU"
88 | ),
89 | parser.add_argument(
90 | "--precision", type=str, default=DEFAULT_PRECISION, help="Precision mode to use"
91 | )
92 | parser.add_argument(
93 | "--max_batch_size",
94 | type=int,
95 | default=DEFAULT_MAX_BATCHSIZE,
96 | help="Specify max batch size",
97 | )
98 | parser.add_argument(
99 | "--save_graph", type=str, default=None, help="TF-TRT optimized model file"
100 | )
101 | parser.add_argument(
102 | "--min_segment_size",
103 | type=int,
104 | default=DEFAULT_MIN_SEGMENT_SIZE,
105 | help="the minimum number of nodes required for a subgraph to be replaced by TRTEngineOp",
106 | )
107 | args = parser.parse_args()
108 | saveGraphPath = args.save_graph
109 | if not saveGraphPath:
110 | saveGraphPath = (
111 | "frozen_tfrtr_"
112 | + args.precision.lower()
113 | + "_bs"
114 | + str(args.max_batch_size)
115 | + "_mss"
116 | + str(args.min_segment_size)
117 | + ".pb"
118 | )
119 | TfConfig.gpu_options.per_process_gpu_memory_fraction = args.gpu_mem_fraction
120 | outputNames = [
121 | DEFAULT_BOXES_NAME,
122 | DEFAULT_CLASSES_NAME,
123 | DEFAULT_SCORES_NAME,
124 | DEFAULT_NUM_DETECTIONS_NAME,
125 | ]
126 | nnGraphDef = loadGraphDef(args.modelPath)
127 | converter = trt.TrtGraphConverter(
128 | is_dynamic_op=True,
129 | input_graph_def=nnGraphDef,
130 | nodes_blacklist=outputNames,
131 | max_batch_size=args.max_batch_size,
132 | max_workspace_size_bytes=DEFAULT_MAX_WORKSPACE_SIZE,
133 | precision_mode=args.precision,
134 | minimum_segment_size=args.min_segment_size,
135 | )
136 | trtGraphDef = converter.convert()
137 | print("-------tf-trt model has been rebuilt.")
138 | if args.nms == True:
139 | # Update NMS to CPU and save the model
140 | print("-------updateNMS to CPU.")
141 | updateNmsCpu(trtGraphDef)
142 | saveGraphPath = "nms_" + saveGraphPath
143 | saveGraphDef(trtGraphDef, saveGraphPath)
144 |
145 |
146 | if __name__ == "__main__":
147 | main()
148 |
--------------------------------------------------------------------------------
/centerface/customparser/customparserbbox_centernet.cpp:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | /* This custom post processing parser is for centernet face detection model */
18 | #include
19 | #include
20 | #include "nvdsinfer_custom_impl.h"
21 | #include
22 | #include
23 | #include
24 | #include
25 | #include
26 |
27 | #define CLIP(a, min, max) (MAX(MIN(a, max), min))
28 |
29 | /* C-linkage to prevent name-mangling */
30 | extern "C" bool NvDsInferParseCustomTfSSD(std::vector const &outputLayersInfo,
31 | NvDsInferNetworkInfo const &networkInfo,
32 | NvDsInferParseDetectionParams const &detectionParams,
33 | std::vector &objectList);
34 |
35 | /* This is a smaple bbox parsing function for the centernet face detection onnx model*/
36 | struct FrcnnParams
37 | {
38 | int inputHeight;
39 | int inputWidth;
40 | int outputClassSize;
41 | float visualizeThreshold;
42 | int postNmsTopN;
43 | int outputBboxSize;
44 | std::vector classifierRegressorStd;
45 | };
46 |
47 | struct FaceInfo
48 | {
49 | float x1;
50 | float y1;
51 | float x2;
52 | float y2;
53 | float score;
54 | float landmarks[10];
55 | };
56 |
57 | /* NMS for centernet */
58 | static void nms(std::vector &input, std::vector &output, float nmsthreshold)
59 | {
60 | std::sort(input.begin(), input.end(),
61 | [](const FaceInfo &a, const FaceInfo &b) {
62 | return a.score > b.score;
63 | });
64 |
65 | int box_num = input.size();
66 |
67 | std::vector merged(box_num, 0);
68 |
69 | for (int i = 0; i < box_num; i++)
70 | {
71 | if (merged[i])
72 | continue;
73 |
74 | output.push_back(input[i]);
75 |
76 | float h0 = input[i].y2 - input[i].y1 + 1;
77 | float w0 = input[i].x2 - input[i].x1 + 1;
78 |
79 | float area0 = h0 * w0;
80 |
81 | for (int j = i + 1; j < box_num; j++)
82 | {
83 | if (merged[j])
84 | continue;
85 |
86 | float inner_x0 = input[i].x1 > input[j].x1 ? input[i].x1 : input[j].x1; //std::max(input[i].x1, input[j].x1);
87 | float inner_y0 = input[i].y1 > input[j].y1 ? input[i].y1 : input[j].y1;
88 |
89 | float inner_x1 = input[i].x2 < input[j].x2 ? input[i].x2 : input[j].x2; //bug fixed ,sorry
90 | float inner_y1 = input[i].y2 < input[j].y2 ? input[i].y2 : input[j].y2;
91 |
92 | float inner_h = inner_y1 - inner_y0 + 1;
93 | float inner_w = inner_x1 - inner_x0 + 1;
94 |
95 | if (inner_h <= 0 || inner_w <= 0)
96 | continue;
97 |
98 | float inner_area = inner_h * inner_w;
99 |
100 | float h1 = input[j].y2 - input[j].y1 + 1;
101 | float w1 = input[j].x2 - input[j].x1 + 1;
102 |
103 | float area1 = h1 * w1;
104 |
105 | float score;
106 |
107 | score = inner_area / (area0 + area1 - inner_area);
108 |
109 | if (score > nmsthreshold)
110 | merged[j] = 1;
111 | }
112 | }
113 | }
114 | /* For CenterNetFacedetection */
115 | //extern "C"
116 | static std::vector getIds(float *heatmap, int h, int w, float thresh)
117 | {
118 | std::vector ids;
119 | for (int i = 0; i < h; i++)
120 | {
121 | for (int j = 0; j < w; j++)
122 | {
123 |
124 | // std::cout<<"ids"< thresh)
126 | {
127 | // std::array id = { i,j };
128 | ids.push_back(i);
129 | ids.push_back(j);
130 | // std::cout<<"print ids"< const &outputLayersInfo,
139 | NvDsInferNetworkInfo const &networkInfo,
140 | NvDsInferParseDetectionParams const &detectionParams,
141 | std::vector &objectList)
142 | {
143 | auto layerFinder = [&outputLayersInfo](const std::string &name)
144 | -> const NvDsInferLayerInfo * {
145 | for (auto &layer : outputLayersInfo)
146 | {
147 |
148 | if (layer.dataType == FLOAT &&
149 | (layer.layerName && name == layer.layerName))
150 | {
151 | return &layer;
152 | }
153 | }
154 | return nullptr;
155 | };
156 | objectList.clear();
157 | const NvDsInferLayerInfo *heatmap = layerFinder("537");
158 | const NvDsInferLayerInfo *scale = layerFinder("538");
159 | const NvDsInferLayerInfo *offset = layerFinder("539");
160 | const NvDsInferLayerInfo *landmarks = layerFinder("540");
161 | // std::cout<<"width"<<&networkInfo.width<inferDims.d[1]; //#heatmap.size[2];
171 | int fea_w = heatmap->inferDims.d[2]; //heatmap.size[3];
172 | int spacial_size = fea_w * fea_h;
173 | // std::cout<<"features"<buffer);
175 |
176 | float *scale0 = (float *)(scale->buffer);
177 | float *scale1 = scale0 + spacial_size;
178 |
179 | float *offset0 = (float *)(offset->buffer);
180 | float *offset1 = offset0 + spacial_size;
181 | float *lm = (float *)landmarks->buffer;
182 |
183 | float scoreThresh = 0.5;
184 | std::vector ids = getIds(heatmap_, fea_h, fea_w, scoreThresh);
185 | //?? d_w, d_h
186 | int width = networkInfo.width;
187 | int height = networkInfo.height;
188 | int d_h = (int)(std::ceil(height / 32) * 32);
189 | int d_w = (int)(std::ceil(width / 32) * 32);
190 | // int d_scale_h = height/d_h ;
191 | // int d_scale_w = width/d_w ;
192 | // float scale_w = (float)width / (float)d_w;
193 | // float scale_h = (float)height / (float)d_h;
194 | std::vector faces_tmp;
195 | std::vector faces;
196 | for (int i = 0; i < ids.size() / 2; i++)
197 | {
198 | int id_h = ids[2 * i];
199 | int id_w = ids[2 * i + 1];
200 | int index = id_h * fea_w + id_w;
201 |
202 | float s0 = std::exp(scale0[index]) * 4;
203 | float s1 = std::exp(scale1[index]) * 4;
204 | float o0 = offset0[index];
205 | float o1 = offset1[index];
206 | float x1 = std::max(0., (id_w + o1 + 0.5) * 4 - s1 / 2);
207 | float y1 = std::max(0., (id_h + o0 + 0.5) * 4 - s0 / 2);
208 | float x2 = 0, y2 = 0;
209 | x1 = std::min(x1, (float)d_w);
210 | y1 = std::min(y1, (float)d_h);
211 | x2 = std::min(x1 + s1, (float)d_w);
212 | y2 = std::min(y1 + s0, (float)d_h);
213 |
214 | FaceInfo facebox;
215 | facebox.x1 = x1;
216 | facebox.y1 = y1;
217 | facebox.x2 = x2;
218 | facebox.y2 = y2;
219 | facebox.score = heatmap_[index];
220 | for (int j = 0; j < 5; j++)
221 | {
222 | facebox.landmarks[2 * j] = x1 + lm[(2 * j + 1) * spacial_size + index] * s1;
223 | facebox.landmarks[2 * j + 1] = y1 + lm[(2 * j) * spacial_size + index] * s0;
224 | }
225 | faces_tmp.push_back(facebox);
226 | }
227 |
228 | const float threshold = 0.3;
229 | nms(faces_tmp, faces, threshold);
230 | for (int k = 0; k < faces.size(); k++)
231 | {
232 | NvDsInferObjectDetectionInfo object;
233 | /* Clip object box co-ordinates to network resolution */
234 | object.left = CLIP(faces[k].x1, 0, networkInfo.width - 1);
235 | object.top = CLIP(faces[k].y1, 0, networkInfo.height - 1);
236 | object.width = CLIP((faces[k].x2 - faces[k].x1), 0, networkInfo.width - 1);
237 | object.height = CLIP((faces[k].y2 - faces[k].y1), 0, networkInfo.height - 1);
238 |
239 | if (object.width && object.height)
240 | {
241 | object.detectionConfidence = 0.99;
242 | object.classId = 0;
243 | objectList.push_back(object);
244 | }
245 | }
246 | return true;
247 | }
248 | /* Check that the custom function has been defined correctly */
249 | CHECK_CUSTOM_PARSE_FUNC_PROTOTYPE(NvDsInferParseCustomCenterNetFace);
250 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "{}"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright 2020 Dhruv Singal
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------