├── .dockerignore ├── .gitignore ├── CHANGELOG.md ├── CONTRIBUTING.md ├── Dockerfile ├── Dockerfile.deepstream ├── INFERENCE.md ├── LICENSE ├── README.md ├── TRAINING.md ├── csrc ├── calibrator.h ├── cuda │ ├── decode.cu │ ├── decode.h │ ├── decode_rotate.cu │ ├── decode_rotate.h │ ├── nms.cu │ ├── nms.h │ ├── nms_iou.cu │ ├── nms_iou.h │ └── utils.h ├── engine.cpp ├── engine.h ├── extensions.cpp └── plugins │ ├── DecodePlugin.h │ ├── DecodeRotatePlugin.h │ ├── NMSPlugin.h │ └── NMSRotatePlugin.h ├── extras ├── cppapi │ ├── CMakeLists.txt │ ├── README.md │ ├── export.cpp │ ├── generate_anchors.py │ ├── infer.cpp │ └── infervideo.cpp ├── deepstream │ ├── README.md │ └── deepstream-sample │ │ ├── CMakeLists.txt │ │ ├── ds_config_1vid.txt │ │ ├── ds_config_8vid.txt │ │ ├── infer_config_batch1.txt │ │ ├── infer_config_batch8.txt │ │ ├── labels_coco.txt │ │ └── nvdsparsebbox_retinanet.cpp └── test.sh ├── odtk ├── __init__.py ├── backbones │ ├── __init__.py │ ├── fpn.py │ ├── layers.py │ ├── mobilenet.py │ ├── resnet.py │ └── utils.py ├── box.py ├── dali.py ├── data.py ├── infer.py ├── loss.py ├── main.py ├── model.py ├── train.py └── utils.py └── setup.py /.dockerignore: -------------------------------------------------------------------------------- 1 | .git 2 | .DS_Store 3 | __pycache__ 4 | *.pyc 5 | *.o 6 | *.so 7 | *.egg-info 8 | build 9 | dist 10 | .vscode 11 | *.jpg 12 | !tests/*.jpg 13 | *.pkl 14 | *.torch 15 | *.plan 16 | 17 | venv/ -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | __pycache__ 3 | *.pyc 4 | *.o 5 | *.so 6 | odtk/tensorrt/src/*.py 7 | odtk/tensorrt/src/*.cxx 8 | *.egg-info 9 | build 10 | dist 11 | models 12 | data_tools 13 | .vscode 14 | *.jpg 15 | !tests/*.jpg 16 | *.pkl 17 | *.torch 18 | *.plan 19 | 20 | venv/ 21 | .idea -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # NVIDIA ODTK change log 2 | 3 | ## Version 0.2.6 -- 2021-08-11 4 | 5 | ### Added 6 | * `--with-apex` option to `odtk train` and `odtk infer`. 7 | * This parameter allows you to switch to NVIDIA APEX AMP and DistributedDataParallel. 8 | * Adding validation stats to TensorBoard. 9 | 10 | ### Changed 11 | * Pytorch Docker container 21.09 from 20.06 12 | * Added training and inference support for PyTorch native AMP, and torch.nn.parallel.DistributedDataParallel (default). 13 | * Switched the Pytorch Model and Data Memory Format to Channels Last. (see [Memory Format Tutorial](https://pytorch.org/tutorials/intermediate/memory_format_tutorial.html)) 14 | * Bug fixes: 15 | * Workaround for `'No detections!'` during vlidation added. (see [#52663](https://github.com/pytorch/pytorch/issues/52663)) 16 | * Freeze unused parameters from torchvision models from autograd gradient calculations. 17 | * Make tensorboard writer exclusive to the master process to prevent race conditions. 18 | * Renamed instances of `retinanet` to `odtk` (folder, C++ namepsaces, etc.) 19 | 20 | 21 | ## Version 0.2.5 -- 2020-06-27 22 | 23 | ### Added 24 | * `--dynamic-batch-opts` option to `odtk export`. 25 | * This parameter allows you to provide TensorRT Optimiation Profile batch sizes for engine export (min, opt, max). 26 | 27 | ### Changed 28 | * Updated TensorRT plugins to allow for dynamic batch sizes (see https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#work_dynamic_shapes and https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/classnvinfer1_1_1_i_plugin_v2_dynamic_ext.html). 29 | 30 | 31 | ## Version 0.2.4 -- 2020-04-20 32 | 33 | ### Added 34 | * `--anchor-ious` option to `odtk train`. 35 | * This parameter allows you to adjust the background and foreground anchor IoU threshold. The default values are `[0.4, 0.5].` 36 | * Example `--anchor-ious 0.3 0.5`. This would mean that any anchor with an IoU of less than 0.3 is assigned to background, 37 | and that any anchor with an IoU of greater than 0.5 is assigned to the foreground object, which is atmost one. 38 | 39 | ## Version 0.2.3 -- 2020-04-14 40 | 41 | ### Added 42 | * `MobileNetV2FPN` backbone 43 | 44 | ## Version 0.2.2 -- 2020-04-01 45 | 46 | ### Added 47 | * Rotated bounding box detections models can now be exported to ONNX and TensorRT using `odtk export model.pth model.plan --rotated-bbox` 48 | * The `--rotated-bbox` flag is automatically applied when running `odtk infer` or `odtk export` _on a model trained with ODTK version 0.2.2 or later_. 49 | 50 | ### Changed 51 | 52 | * Improvements to the rotated IoU calculations. 53 | 54 | ### Limitations 55 | 56 | * The C++ API cannot currently infer rotated bounding box models. 57 | 58 | ## Version 0.2.1 -- 2020-03-18 59 | 60 | ### Added 61 | * The DALI dataloader (flag `--with-dali`) now supports image augmentation using: 62 | * `--augment-brightness` : Randomly adjusts brightness of image 63 | * `--augment-contrast` : Randomly adjusts contrast of image 64 | * `--augment-hue` : Randomly adjusts hue of image 65 | * `--augment-saturation` : Randomly adjusts saturation of image 66 | 67 | ### Changed 68 | * The code in `box.py` for generating anchors has been improved. 69 | 70 | ## Version 0.2.0 -- 2020-03-13 71 | 72 | Version 0.2.0 introduces rotated detections. 73 | 74 | ### Added 75 | * `train arguments`: 76 | * `--rotated-bbox`: Trains a model is predict rotated bounding boxes `[x, y, w, h, theta]` instead of axis aligned boxes `[x, y, w, h]`. 77 | * `infer arguments`: 78 | * `--rotated-bbox`: Infer a rotated model. 79 | 80 | ### Changed 81 | The project has reverted to the name **Object Detection Toolkit** (ODTK), to better reflect the multi-network nature of the repo. 82 | * `retinanet` has been replaced with `odtk`. All subcommands remain the same. 83 | 84 | ### Limitations 85 | * Models trained using the `--rotated-bbox` flag cannot be exported to ONNX or a TensorRT Engine. 86 | * PyTorch raises two warnings which can be ignored: 87 | 88 | Warning 1: NCCL watchdog 89 | ``` 90 | [E ProcessGroupNCCL.cpp:284] NCCL watchdog thread terminated 91 | ``` 92 | 93 | Warning 2: Save state warning 94 | ``` 95 | /opt/conda/lib/python3.6/site-packages/torch/optim/lr_scheduler.py:201: UserWarning: Please also save or load the state of the optimzer when saving or loading the scheduler. 96 | warnings.warn(SAVE_STATE_WARNING, UserWarning) 97 | ``` 98 | 99 | ## Version 0.1.1 -- 2020-03-06 100 | 101 | ### Added 102 | * `train` arguments 103 | * `--augment-rotate`: Randomly rotates the training images by 0°, 90°, 180° or 270°. 104 | * `--augment-brightness` : Randomly adjusts brightness of image 105 | * `--augment-contrast` : Randomly adjusts contrast of image 106 | * `--augment-hue` : Randomly adjusts hue of image 107 | * `--augment-saturation` : Randomly adjusts saturation of image 108 | * `--regularization-l2` : Sets the L2 regularization of the optimizer. 109 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | Reporting problems, asking questions 2 | ------------------------------------ 3 | 4 | 5 | We appreciate feedback, questions or bug reports. When you need help with the code, try to follow the process outlined in the Stack Overflow (https://stackoverflow.com/help/mcve) document. 6 | 7 | At a minimum, your issues should describe the following: 8 | 9 | * What command you ran 10 | * The hardware and container that you are using 11 | * The version of ODTK you are using 12 | * What was the result you observed 13 | * What was the result you expected 14 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvcr.io/nvidia/pytorch:21.09-py3 2 | 3 | COPY . odtk/ 4 | RUN pip install --no-cache-dir -e odtk/ 5 | -------------------------------------------------------------------------------- /Dockerfile.deepstream: -------------------------------------------------------------------------------- 1 | FROM nvcr.io/nvidia/pytorch:20.03-py3 2 | COPY . /workspace/retinanet-examples/ 3 | RUN apt-get update && apt-get install -y libssl1.0.0 libgstreamer1.0-0 gstreamer1.0-tools gstreamer1.0-plugins-good gstreamer1.0-plugins-bad gstreamer1.0-plugins-ugly gstreamer1.0-libav libgstrtspserver-1.0-0 libjansson4 ffmpeg libjson-glib-1.0 libgles2-mesa 4 | RUN git clone https://github.com/edenhill/librdkafka.git /librdkafka && \ 5 | cd /librdkafka && ./configure && make -j && make -j install && \ 6 | mkdir -p /opt/nvidia/deepstream/deepstream-4.0/lib && \ 7 | cp /usr/local/lib/librdkafka* /opt/nvidia/deepstream/deepstream-4.0/lib && \ 8 | rm -rf /librdkafka 9 | WORKDIR /workspace/retinanet-examples/extras/deepstream/DeepStream_Release/deepstream_sdk_v4.0.2_x86_64 10 | RUN tar -xvf binaries.tbz2 -C / && \ 11 | ./install.sh 12 | # config files + sample apps 13 | RUN chmod u+x ./sources/tools/nvds_logger/setup_nvds_logger.sh 14 | 15 | WORKDIR /usr/lib/x86_64-linux-gnu 16 | RUN ln -sf libnvcuvid.so.1 libnvcuvid.so 17 | 18 | WORKDIR /workspace/retinanet-examples 19 | RUN pip install --no-cache-dir -e . 20 | RUN mkdir extras/deepstream/deepstream-sample/build && \ 21 | cd extras/deepstream/deepstream-sample/build && \ 22 | cmake -DDeepStream_DIR=/workspace/retinanet-examples/extras/deepstream/DeepStream_Release/deepstream_sdk_v4.0.2_x86_64 .. && make -j 23 | WORKDIR /workspace/retinanet-examples/extras/deepstream 24 | -------------------------------------------------------------------------------- /INFERENCE.md: -------------------------------------------------------------------------------- 1 | # Inference 2 | 3 | We provide two ways inferring using `odtk`: 4 | * PyTorch inference using a trained model (FP32 or FP16 precision) 5 | * Export trained pytorch model to TensorRT for optimized inference (FP32, FP16 or INT8 precision) 6 | 7 | `odtk infer` will run distributed inference across all available GPUs. When using PyTorch, the default behavior is to run inference with mixed precision. The precision used when running inference with a TensorRT engine will correspond to the precision chosen when the model was exported to TensorRT (see [TensorRT section](#exporting-trained-pytorch-model-to-tensorrt) below). 8 | 9 | **NOTE**: Availability of HW support for fast FP16 and INT8 precision like [NVIDIA Tensor Cores](https://www.nvidia.com/en-us/data-center/tensorcore/) depends on your GPU architecture: Volta or newer GPUs support both FP16 and INT8, and Pascal GPUs can support either FP16 or INT8. 10 | 11 | ## PyTorch Inference 12 | 13 | Evaluate trained PyTorch detection model on COCO 2017 (mixed precision): 14 | 15 | ```bash 16 | odtk infer model.pth --images=/data/coco/val2017 --annotations=instances_val2017.json --batch 8 17 | ``` 18 | **NOTE**: `--batch N` specifies *global* batch size to be used for inference. The batch size per GPU will be `N // num_gpus`. 19 | 20 | Use full precision (FP32) during evaluation: 21 | 22 | ```bash 23 | odtk infer model.pth --images=/data/coco/val2017 --annotations=instances_val2017.json --full-precision 24 | ``` 25 | 26 | Evaluate PyTorch detection model with a small input image size: 27 | 28 | ```bash 29 | odtk infer model.pth --images=/data/coco/val2017 --annotations=instances_val2017.json --resize 400 --max-size 640 30 | ``` 31 | Here, the shorter side of the input images will be resized to `resize` as long as the longer side doesn't get larger than `max-size`, otherwise the longer side of the input image will be resized to `max-size`. 32 | 33 | **NOTE**: To get best accuracy, training the model at the preferred export size is encouraged. 34 | 35 | Run inference using your own dataset: 36 | 37 | ```bash 38 | odtk infer model.pth --images=/data/your_images --output=detections.json 39 | ``` 40 | 41 | ## Exporting trained PyTorch model to TensorRT 42 | 43 | `odtk` provides an simple workflow to optimize a trained PyTorch model for inference deployment using TensorRT. The PyTorch model is exported to [ONNX](https://github.com/onnx/onnx), and then the ONNX model is consumed and optimized by TensorRT. 44 | To learn more about TensorRT optimization, refer here: https://developer.nvidia.com/tensorrt 45 | 46 | **NOTE**: When a model is optimized with TensorRT, the output is a TensorRT engine (.plan file) that can be used for deployment. This TensorRT engine has several fixed properties that are specified during the export process. 47 | * Input image size: TensorRT engines only support a fixed input size. 48 | * Precision: TensorRT supports FP32, FP16, or INT8 precision. 49 | * Target GPU: TensorRT optimizations are tied to the type of GPU on the system where optimization is performed. They are not transferable across different types of GPUs. Put another way, if you aim to deploy your TensorRT engine on a Tesla T4 GPU, you must run the optimization on a system with a T4 GPU. 50 | 51 | The workflow for exporting a trained PyTorch detection model to TensorRT is as simple as: 52 | 53 | ```bash 54 | odtk export model.pth model_fp16.plan --size 1280 55 | ``` 56 | This will create a TensorRT engine optimized for batch size 1, using an input size of 1280x1280. By default, the engine will be created to run in FP16 precision. 57 | 58 | Export your model to use full precision using a non-square input size: 59 | ```bash 60 | odtk export model.pth model_fp32.plan --full-precision --size 800 1280 61 | ``` 62 | 63 | In order to use INT8 precision with TensorRT, you need to provide calibration images (images that are representative of what will be seen at runtime) that will be used to rescale the network. 64 | ```bash 65 | odtk export model.pth model_int8.plan --int8 --calibration-images /data/val/ --calibration-batches 2 --calibration-table model_calibration_table 66 | ``` 67 | 68 | This will randomly select 16 images from `/data/val/` to calibrate the network for INT8 precision. The results from calibration will be saved to `model_calibration_table` that can be used to create subsequent INT8 engines for this model without needed to recalibrate. 69 | 70 | **NOTE:** Number of images in `/data/val/` must be greater than or equal to the kOPT(middle) optimization profile from `--dynamic-batch-opts`. Here, the default kOPT is 8. 71 | 72 | Build an INT8 engine for a previously calibrated model: 73 | ```bash 74 | odtk export model.pth model_int8.plan --int8 --calibration-table model_calibration_table 75 | ``` 76 | 77 | ## Deployment with TensorRT on NVIDIA Jetson AGX Xavier 78 | 79 | We provide a path for deploying trained models with TensorRT onto embedded platforms like [NVIDIA Jetson AGX Xavier](https://developer.nvidia.com/embedded/buy/jetson-agx-xavier-devkit), where PyTorch is not readily available. 80 | 81 | You will need to export your trained PyTorch model to ONNX representation on your host system, and copy the resulting ONNX model to your Jetson AGX Xavier: 82 | ```bash 83 | odtk export model.pth model.onnx --size 800 1280 84 | ``` 85 | 86 | Refer to additional documentation on using the example cppapi code to build the TensorRT engine and run inference here: [cppapi example code](extras/cppapi/README.md) 87 | 88 | ## Rotated detections 89 | 90 | *Rotated ODTK* allows users to train and infer rotated bounding boxes in imagery. 91 | 92 | ### Inference 93 | 94 | An example command: 95 | ``` 96 | odtk infer model.pth --images /data/val --annotations /data/val_rotated.json --output /data/detections.json \ 97 | --resize 768 --rotated-bbox 98 | ``` 99 | 100 | ### Export 101 | 102 | Rotated bounding box models can be exported to create TensorRT engines by using the axis aligned command with the addition of `--rotated-bbox`. -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 2 | 3 | Redistribution and use in source and binary forms, with or without 4 | modification, are permitted provided that the following conditions 5 | are met: 6 | * Redistributions of source code must retain the above copyright 7 | notice, this list of conditions and the following disclaimer. 8 | * Redistributions in binary form must reproduce the above copyright 9 | notice, this list of conditions and the following disclaimer in the 10 | documentation and/or other materials provided with the distribution. 11 | * Neither the name of NVIDIA CORPORATION nor the names of its 12 | contributors may be used to endorse or promote products derived 13 | from this software without specific prior written permission. 14 | 15 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY 16 | EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 17 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 18 | PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR 19 | CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 20 | EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 21 | PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 22 | PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY 23 | OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 24 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # NVIDIA Object Detection Toolkit (ODTK) 2 | 3 | **Fast** and **accurate** single stage object detection with end-to-end GPU optimization. 4 | 5 | ## Description 6 | 7 | ODTK is a single shot object detector with various backbones and detection heads. This allows performance/accuracy trade-offs. 8 | 9 | It is optimized for end-to-end GPU processing using: 10 | * The [PyTorch](https://pytorch.org) deep learning framework with [ONNX](https://onnx.ai) support 11 | * NVIDIA [Apex](https://github.com/NVIDIA/apex) for mixed precision and distributed training 12 | * NVIDIA [DALI](https://github.com/NVIDIA/DALI) for optimized data pre-processing 13 | * NVIDIA [TensorRT](https://developer.nvidia.com/tensorrt) for high-performance inference 14 | * NVIDIA [DeepStream](https://developer.nvidia.com/deepstream-sdk) for optimized real-time video streams support 15 | 16 | ## Rotated bounding box detections 17 | 18 | This repo now supports rotated bounding box detections. See [rotated detections training](TRAINING.md#rotated-detections) and [rotated detections inference](INFERENCE.md#rotated-detections) documents for more information on how to use the `--rotated-bbox` command. 19 | 20 | Bounding box annotations are described by `[x, y, w, h, theta]`. 21 | 22 | ## Performance 23 | 24 | The detection pipeline allows the user to select a specific backbone depending on the latency-accuracy trade-off preferred. 25 | 26 | ODTK **RetinaNet** model accuracy and inference latency & FPS (frames per seconds) for [COCO 2017](http://cocodataset.org/#detection-2017) (train/val) after full training schedule. Inference results include bounding boxes post-processing for a batch size of 1. Inference measured at `--resize 800` using `--with-dali` on a FP16 TensorRT engine. 27 | 28 | Backbone | mAP @[IoU=0.50:0.95] | Training Time on [DGX1v](https://www.nvidia.com/en-us/data-center/dgx-1/) | Inference latency FP16 on [V100](https://www.nvidia.com/en-us/data-center/tesla-v100/) | Inference latency INT8 on [T4](https://www.nvidia.com/en-us/data-center/tesla-t4/) | Inference latency FP16 on [A100](https://www.nvidia.com/en-us/data-center/a100/) | Inference latency INT8 on [A100](https://www.nvidia.com/en-us/data-center/a100/) 29 | --- | :---: | :---: | :---: | :---: | :---: | :---: 30 | [ResNet18FPN](https://github.com/NVIDIA/retinanet-examples/releases/download/19.04/retinanet_rn18fpn.zip) | 0.318 | 5 hrs | 14 ms;
71 FPS | 18 ms;
56 FPS | 9 ms;
110 FPS | 7 ms;
141 FPS 31 | [MobileNetV2FPN](https://github.com/NVIDIA/retinanet-examples/releases/download/v0.2.3/retinanet_mobilenetv2fpn.pth) | 0.333 | | 14 ms;
74 FPS | 18 ms;
56 FPS | 9 ms;
114 FPS | 7 ms;
138 FPS 32 | [ResNet34FPN](https://github.com/NVIDIA/retinanet-examples/releases/download/19.04/retinanet_rn34fpn.zip) | 0.343 | 6 hrs | 16 ms;
64 FPS | 20 ms;
50 FPS | 10 ms;
103 FPS | 7 ms;
142 FPS 33 | [ResNet50FPN](https://github.com/NVIDIA/retinanet-examples/releases/download/19.04/retinanet_rn50fpn.zip) | 0.358 | 7 hrs | 18 ms;
56 FPS | 22 ms;
45 FPS | 11 ms;
93 FPS | 8 ms;
129 FPS 34 | [ResNet101FPN](https://github.com/NVIDIA/retinanet-examples/releases/download/19.04/retinanet_rn101fpn.zip) | 0.376 | 10 hrs | 22 ms;
46 FPS | 27 ms;
37 FPS | 13 ms;
78 FPS | 9 ms;
117 FPS 35 | [ResNet152FPN](https://github.com/NVIDIA/retinanet-examples/releases/download/19.04/retinanet_rn152fpn.zip) | 0.393 | 12 hrs | 26 ms;
38 FPS | 33 ms;
31 FPS | 15 ms;
66 FPS | 10 ms;
103 FPS 36 | 37 | ## Installation 38 | 39 | For best performance, use the latest [PyTorch NGC docker container](https://ngc.nvidia.com/catalog/containers/nvidia:pytorch). Clone this repository, build and run your own image: 40 | 41 | ```bash 42 | git clone https://github.com/nvidia/retinanet-examples 43 | docker build -t odtk:latest retinanet-examples/ 44 | docker run --gpus all --rm --ipc=host -it odtk:latest 45 | ``` 46 | 47 | ## Usage 48 | 49 | Training, inference, evaluation and model export can be done through the `odtk` utility. 50 | For more details, including a list of parameters, please refer to the [TRAINING](TRAINING.md) and [INFERENCE](INFERENCE.md) documentation. 51 | 52 | ### Training 53 | 54 | Train a detection model on [COCO 2017](http://cocodataset.org/#download) from pre-trained backbone: 55 | ```bash 56 | odtk train retinanet_rn50fpn.pth --backbone ResNet50FPN \ 57 | --images /coco/images/train2017/ --annotations /coco/annotations/instances_train2017.json \ 58 | --val-images /coco/images/val2017/ --val-annotations /coco/annotations/instances_val2017.json 59 | ``` 60 | 61 | ### Fine Tuning 62 | 63 | Fine-tune a pre-trained model on your dataset. In the example below we use [Pascal VOC](http://host.robots.ox.ac.uk/pascal/VOC/voc2012/index.html) with [JSON annotations](https://storage.googleapis.com/coco-dataset/external/PASCAL_VOC.zip): 64 | ```bash 65 | odtk train model_mydataset.pth --backbone ResNet50FPN \ 66 | --fine-tune retinanet_rn50fpn.pth \ 67 | --classes 20 --iters 10000 --val-iters 1000 --lr 0.0005 \ 68 | --resize 512 --jitter 480 640 --images /voc/JPEGImages/ \ 69 | --annotations /voc/pascal_train2012.json --val-annotations /voc/pascal_val2012.json 70 | ``` 71 | 72 | Note: the shorter side of the input images will be resized to `resize` as long as the longer side doesn't get larger than `max-size`. During training, the images will be randomly randomly resized to a new size within the `jitter` range. 73 | 74 | ### Inference 75 | 76 | Evaluate your detection model on [COCO 2017](http://cocodataset.org/#download): 77 | ```bash 78 | odtk infer retinanet_rn50fpn.pth --images /coco/images/val2017/ --annotations /coco/annotations/instances_val2017.json 79 | ``` 80 | 81 | Run inference on [your dataset](#datasets): 82 | ```bash 83 | odtk infer retinanet_rn50fpn.pth --images /dataset/val --output detections.json 84 | ``` 85 | 86 | ### Optimized Inference with TensorRT 87 | 88 | For faster inference, export the detection model to an optimized FP16 TensorRT engine: 89 | ```bash 90 | odtk export model.pth engine.plan 91 | ``` 92 | 93 | Evaluate the model with TensorRT backend on [COCO 2017](http://cocodataset.org/#download): 94 | ```bash 95 | odtk infer engine.plan --images /coco/images/val2017/ --annotations /coco/annotations/instances_val2017.json 96 | ``` 97 | 98 | ### INT8 Inference with TensorRT 99 | 100 | For even faster inference, do INT8 calibration to create an optimized INT8 TensorRT engine: 101 | ```bash 102 | odtk export model.pth engine.plan --int8 --calibration-images /coco/images/val2017/ 103 | ``` 104 | This will create an INT8CalibrationTable file that can be used to create INT8 TensorRT engines for the same model later on without needing to do calibration. 105 | 106 | Or create an optimized INT8 TensorRT engine using a cached calibration table: 107 | ```bash 108 | odtk export model.pth engine.plan --int8 --calibration-table /path/to/INT8CalibrationTable 109 | ``` 110 | 111 | ## Datasets 112 | 113 | RetinaNet supports annotations in the [COCO JSON format](http://cocodataset.org/#format-data). 114 | When converting the annotations from your own dataset into JSON, the following entries are required: 115 | ``` 116 | { 117 | "images": [{ 118 | "id" : int, 119 | "file_name" : str 120 | }], 121 | "annotations": [{ 122 | "id" : int, 123 | "image_id" : int, 124 | "category_id" : int, 125 | "bbox" : [x, y, w, h] # all floats 126 | "area": float # w * h. Required for validation scores 127 | "iscrowd": 0 # Required for validation scores 128 | }], 129 | "categories": [{ 130 | "id" : int 131 | ]} 132 | } 133 | ``` 134 | 135 | If using the `--rotated-bbox` flag for rotated detections, add an additional float `theta` to the annotations. To get validation scores you also need to fill the `segmentation` section. 136 | ``` 137 | "bbox" : [x, y, w, h, theta] # all floats, where theta is measured in radians anti-clockwise from the x-axis. 138 | "segmentation" : [[x1, y1, x2, y2, x3, y3, x4, y4]] 139 | # Required for validation scores. 140 | ``` 141 | 142 | ## Disclaimer 143 | 144 | This is a research project, not an official NVIDIA product. 145 | 146 | ## Jetpack compatibility 147 | 148 | This branch uses TensorRT 7. If you are training and inferring models using PyTorch, or are creating TensorRT engines on Tesla GPUs (eg V100, T4), then you should use this branch. 149 | 150 | If you wish to deploy your model to a Jetson device (eg - Jetson AGX Xavier) running Jetpack version 4.3, then you should use the `19.10` branch of this repo. 151 | 152 | ## References 153 | 154 | - [Focal Loss for Dense Object Detection](https://arxiv.org/abs/1708.02002). 155 | Tsung-Yi Lin, Priya Goyal, Ross Girshick, Kaiming He, Piotr Dollár. 156 | ICCV, 2017. 157 | - [Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour](https://arxiv.org/abs/1706.02677). 158 | Priya Goyal, Piotr Dollár, Ross Girshick, Pieter Noordhuis, Lukasz Wesolowski, Aapo Kyrola, Andrew Tulloch, Yangqing Jia, Kaiming He. 159 | June 2017. 160 | - [Feature Pyramid Networks for Object Detection](https://arxiv.org/abs/1612.03144). 161 | Tsung-Yi Lin, Piotr Dollár, Ross Girshick, Kaiming He, Bharath Hariharan, Serge Belongie. 162 | CVPR, 2017. 163 | - [Deep Residual Learning for Image Recognition](http://arxiv.org/abs/1512.03385). 164 | Kaiming He, Xiangyu Zhang, Shaoqing Renm Jian Sun. 165 | CVPR, 2016. 166 | -------------------------------------------------------------------------------- /TRAINING.md: -------------------------------------------------------------------------------- 1 | # Training 2 | 3 | There are two main ways to train a model with `odtk`: 4 | * Fine-tuning the detection model using a model already trained on a large dataset (like MS-COCO) 5 | * Fully training the detection model from random initialization using a pre-trained backbone (usually ImageNet) 6 | 7 | ## Fine-tuning 8 | 9 | Fine-tuning an existing model trained on COCO allows you to use transfer learning to get a accurate model for your own dataset with minimal training. 10 | When fine-tuning, we re-initialize the last layer of the classification head so the network will re-learn how to map features to classes scores regardless of the number of classes in your own dataset. 11 | 12 | You can fine-tune a pre-trained model on your dataset. In the example below we take a model trained on COCO, and then fine-tune using [Pascal VOC](http://host.robots.ox.ac.uk/pascal/VOC/voc2012/index.html) with [JSON annotations](https://storage.googleapis.com/coco-dataset/external/PASCAL_VOC.zip): 13 | ```bash 14 | odtk train model_mydataset.pth \ 15 | --fine-tune retinanet_rn50fpn.pth \ 16 | --classes 20 --iters 10000 --val-iters 1000 --lr 0.0005 \ 17 | --resize 512 --jitter 480 640 --images /voc/JPEGImages/ \ 18 | --annotations /voc/pascal_train2012.json --val-annotations /voc/pascal_val2012.json 19 | ``` 20 | 21 | Even though the COCO model was trained on 80 classes, we can easily use tranfer learning to fine-tune it on the Pascal VOC model representing only 20 classes. 22 | 23 | The shorter side of the input images will be resized to `resize` as long as the longer side doesn't get larger than `max-size`. 24 | During training the images will be randomly resized to a new size within the `jitter` range. 25 | 26 | We usually want to fine-tune the model with a lower learning rate `lr` than during full training and for less iterations `iters`. 27 | 28 | ## Full Training 29 | 30 | If you do not have a pre-trained model, if your dataset is substantially large, or if you have written your own backbone, then you should fully train the detection model. 31 | 32 | Full training usually starts from a pre-trained backbone (automatically downloaded with the current backbones we offer) that has been pre-trained on a classification task with a large dataset like [ImageNet](http://www.image-net.org). 33 | This is especially necessary for backbones using batch normalization as they require large batch sizes during training that cannot be provided when training on the detection task as the input images have to be relatively large. 34 | 35 | Train a detection model on [COCO 2017](http://cocodataset.org/#download) from pre-trained backbone: 36 | ```bash 37 | odtk train retinanet_rn50fpn.pth --backbone ResNet50FPN \ 38 | --images /coco/images/train2017/ --annotations /coco/annotations/instances_train2017.json \ 39 | --val-images /coco/images/val2017/ --val-annotations /coco/annotations/instances_val2017.json 40 | ``` 41 | 42 | ## Training arguments 43 | 44 | ### Positional arguments 45 | * The only positional argument is the name of the model. This can be a full path, or relative to the current directory. 46 | ```bash 47 | odtk train model.pth 48 | ``` 49 | 50 | ### Other arguments 51 | The following arguments are available during training: 52 | 53 | * `--annotations` (str): Path to COCO style annotations (required). 54 | * `--images` (str): Path to a directory of images (required). 55 | * `--lr` (float): Sets the learning rate. Default: 0.01. 56 | * `--full-precision`: By default we train using mixed precision. Include this argument to instead train in full precision. 57 | * `--warmup` (int): The number of initial iterations during which we want to linearly ramp-up the learning rate to avoid early divergence of the loss. Default: 1000 58 | * `--backbone` (str): Specify one of the supported backbones. Default: `ResNet50FPN` 59 | * `--classes` (int): The number of classes in your dataset. Default: 80 60 | * `--batch` (int): The size of each training batch. Default: 2 x number of GPUs. 61 | * `--max-size` (int): The longest edge of your training image will be resized, so that it is always less than or equal to `max-size`. Default: 1333. 62 | * `--jitter` (int int): The shortest edge of your training images will be resized to int1 >= shortest edge >= int2, unless the longest edge exceeds `max-size`, in which case the longest edge will be resized to `max-size` and the shortest length will be sized to keep the aspect ratio constant. Default: 640 1024. 63 | * `--resize` (int): During validation inference, the shortest edge of your training images will be resized to int, unless the longest edge exceeds `max-size`, in which case the longest edge will be resized to `max-size` and the shortest length will be sized to keep the aspect ratio constant. Default: 800. 64 | * `--iters` (int): The number of iterations to process. An iteration is the processing (forward and backward pass) of one batch. Number of epochs is (`iters` x `batch`) / `len(data)`. Default: 90000. 65 | * `--milestones` (int int): The learning rate is multiplied by `--gamma` every time it reaches a milestone. Default: 60000 80000. 66 | * `--gamma` (float): The learning rate is multiplied by `--gamma` every time it reaches a milestone. Default: 0.1. 67 | * `--override`: Do not continue training from `model.pth`, instead overwrite it. 68 | * `--val-annotations` (str): Path to COCO style annotations. If supplied, `pycocotools` will be used to give validation mAP. 69 | * `--val-images` (str): Path to directory of validation images. 70 | * `--val-iters` (int): Run inference on the validation set every int iterations. 71 | * `--fine-tune` (str): Fine tune from a model at path str. 72 | * `--with-dali`: Load data using DALI. 73 | * `--augment-rotate`: Randomly rotates the training images by 0°, 90°, 180° or 270°. 74 | * `--augment-brightness` (float): Randomly adjusts brightness of image. The value sets the standard deviation of a Gaussian distribution. The degree of augmentation is selected from this distribution. Default: 0.002 75 | * `--augment-contrast` (float): Randomly adjusts contrast of image. The value sets the standard deviation of a Gaussian distribution. The degree of augmentation is selected from this distribution. Default: 0.002 76 | * `--augment-hue` (float): Randomly adjusts hue of image. The value sets the standard deviation of a Gaussian distribution. The degree of augmentation is selected from this distribution. Default: 0.0002 77 | * `--augment-saturation` (float): Randomly adjusts saturation of image. The value sets the standard deviation of a Gaussian distribution. The degree of augmentation is selected from this distribution. Default: 0.002 78 | * `--regularization-l2` (float): Sets the L2 regularization of the optimizer. Default: 0.0001 79 | 80 | You can also monitor the loss and learning rate schedule of the training using TensorBoard bu specifying a `logdir` path. 81 | 82 | ## Rotated detections 83 | 84 | *Rotated ODTK* allows users to train and infer rotated bounding boxes in imagery. 85 | 86 | ### Dataset 87 | Annotations need to conform to the COCO standard, with the addition of an angle (radians) in the bounding box (bbox) entry `[xmin, ymin, width, height, **theta**]`. `xmin`, `ymin`, `width` and `height` are in the axis aligned coordinates, ie floats, measured from the top left of the image. `theta` is in radians, measured anti-clockwise from the x-axis. We constrain theta between - \pi/4 and \pi/4. 88 | 89 | In order for the validation metrics to calculate, you also need to fill the `segmentation` entry with the coordinates of the corners of your bounding box. 90 | 91 | If using the `--rotated-bbox` flag for rotated detections, add an additional float `theta` to the annotations. To get validation scores you also need to fill the `segmentation` section. 92 | ``` 93 | "bbox" : [x, y, w, h, theta] # all floats, where theta is measured in radians anti-clockwise from the x-axis. 94 | "segmentation" : [[x1, y1], [x2, y2], [x3, y3], [x4, y4]] 95 | # Required for validation scores. 96 | ``` 97 | 98 | ### Anchors 99 | 100 | As with all single shot detectors, the anchor boxes may need to be adjusted to suit your dataset. You may need to adjust the anchors in `odtk/model.py` 101 | 102 | The default anchors are: 103 | 104 | ```python 105 | self.ratios = [0.5, 1.0, 2.0] 106 | self.scales = [4 * 2**(i/3) for i in range(3)] 107 | self.angles = [-np.pi/6, 0, np.pi/6] 108 | ``` 109 | 110 | ### Training 111 | 112 | We recommend reducing your learning rate, for example using `--lr 0.0005`. 113 | 114 | An example training command for training remote sensing imagery. Note that `--augment-rotate` has been used to randomly rotated the imagery during training. 115 | ``` 116 | odtk train model.pth --images /data/train --annotations /data/train_rotated.json --backbone ResNet50FPN \ 117 | --lr 0.00005 --fine-tune /data/saved_models/retinanet_rn50fpn.pth \ 118 | --val-images /data/val --val-annotations /data/val_rotated.json --classes 1 \ 119 | --jitter 688 848 --resize 768 \ 120 | --augment-rotate --augment-brightness 0.01 --augment-contrast 0.01 --augment-hue 0.002 \ 121 | --augment-saturation 0.01 --batch 16 --regularization-l2 0.0001 --val-iters 20000 --rotated-bbox 122 | ``` 123 | 124 | -------------------------------------------------------------------------------- /csrc/calibrator.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Permission is hereby granted, free of charge, to any person obtaining a 5 | * copy of this software and associated documentation files (the "Software"), 6 | * to deal in the Software without restriction, including without limitation 7 | * the rights to use, copy, modify, merge, publish, distribute, sublicense, 8 | * and/or sell copies of the Software, and to permit persons to whom the 9 | * Software is furnished to do so, subject to the following conditions: 10 | * 11 | * The above copyright notice and this permission notice shall be included in 12 | * all copies or substantial portions of the Software. 13 | * 14 | * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 | * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 17 | * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 | * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 19 | * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 20 | * DEALINGS IN THE SOFTWARE. 21 | */ 22 | 23 | #pragma once 24 | 25 | #include 26 | #include 27 | #include 28 | #include 29 | #include 30 | #include 31 | #include 32 | #include 33 | #include "NvInfer.h" 34 | 35 | using namespace std; 36 | using namespace cv; 37 | 38 | class ImageStream { 39 | public: 40 | ImageStream(int batchSize, Dims inputDims, const vector calibrationImages) 41 | : _batchSize(batchSize) 42 | , _calibrationImages(calibrationImages) 43 | , _currentBatch(0) 44 | , _maxBatches(_calibrationImages.size() / _batchSize) 45 | , _inputDims(inputDims) { 46 | _batch.resize(_batchSize * _inputDims.d[1] * _inputDims.d[2] * _inputDims.d[3]); 47 | } 48 | 49 | int getBatchSize() const noexcept { return _batchSize;} 50 | 51 | int getMaxBatches() const { return _maxBatches;} 52 | 53 | float* getBatch() noexcept { return &_batch[0];} 54 | 55 | Dims getInputDims() { return _inputDims;} 56 | 57 | bool next() { 58 | 59 | if (_currentBatch == _maxBatches) 60 | return false; 61 | 62 | for (int i = 0; i < _batchSize; i++) { 63 | auto image = imread(_calibrationImages[_batchSize * _currentBatch + i].c_str(), IMREAD_COLOR); 64 | cv::resize(image, image, Size(_inputDims.d[3], _inputDims.d[2])); 65 | cv::Mat pixels; 66 | image.convertTo(pixels, CV_32FC3, 1.0 / 255, 0); 67 | 68 | vector img; 69 | 70 | if (pixels.isContinuous()) 71 | img.assign((float*)pixels.datastart, (float*)pixels.dataend); 72 | else 73 | return false; 74 | 75 | auto hw = _inputDims.d[2] * _inputDims.d[3]; 76 | auto channels = _inputDims.d[1]; 77 | auto vol = channels * hw; 78 | 79 | for (int c = 0; c < channels; c++) { 80 | for (int j = 0; j < hw; j++) { 81 | _batch[i * vol + c * hw + j] = (img[channels * j + 2 - c] - _mean[c]) / _std[c]; 82 | } 83 | } 84 | } 85 | 86 | _currentBatch++; 87 | return true; 88 | } 89 | 90 | void reset() { 91 | _currentBatch = 0; 92 | } 93 | 94 | private: 95 | int _batchSize; 96 | vector _calibrationImages; 97 | int _currentBatch; 98 | int _maxBatches; 99 | Dims _inputDims; 100 | 101 | vector _mean {0.485, 0.456, 0.406}; 102 | vector _std {0.229, 0.224, 0.225}; 103 | vector _batch; 104 | 105 | }; 106 | 107 | class Int8EntropyCalibrator: public IInt8EntropyCalibrator2 { 108 | public: 109 | Int8EntropyCalibrator(ImageStream& stream, const string networkName, const string calibrationCacheName, bool readCache = true) 110 | : _stream(stream) 111 | , _networkName(networkName) 112 | , _calibrationCacheName(calibrationCacheName) 113 | , _readCache(readCache) { 114 | Dims d = _stream.getInputDims(); 115 | _inputCount = _stream.getBatchSize() * d.d[1] * d.d[2] * d.d[3]; 116 | cudaMalloc(&_deviceInput, _inputCount * sizeof(float)); 117 | } 118 | 119 | int getBatchSize() const noexcept override {return _stream.getBatchSize();} 120 | 121 | virtual ~Int8EntropyCalibrator() {cudaFree(_deviceInput);} 122 | 123 | bool getBatch(void* bindings[], const char* names[], int nbBindings) noexcept override { 124 | 125 | if (!_stream.next()) 126 | return false; 127 | 128 | cudaMemcpy(_deviceInput, _stream.getBatch(), _inputCount * sizeof(float), cudaMemcpyHostToDevice); 129 | bindings[0] = _deviceInput; 130 | return true; 131 | } 132 | 133 | const void* readCalibrationCache(size_t& length) noexcept { 134 | _calibrationCache.clear(); 135 | ifstream input(calibrationTableName(), ios::binary); 136 | input >> noskipws; 137 | if (_readCache && input.good()) 138 | copy(istream_iterator(input), istream_iterator(), back_inserter(_calibrationCache)); 139 | 140 | length = _calibrationCache.size(); 141 | return length ? &_calibrationCache[0] : nullptr; 142 | } 143 | 144 | void writeCalibrationCache(const void* cache, size_t length) noexcept { 145 | std::ofstream output(calibrationTableName(), std::ios::binary); 146 | output.write(reinterpret_cast(cache), length); 147 | } 148 | 149 | private: 150 | std::string calibrationTableName() { 151 | // Use calibration cache if provided 152 | if(_calibrationCacheName.length() > 0) 153 | return _calibrationCacheName; 154 | 155 | assert(_networkName.length() > 0); 156 | Dims d = _stream.getInputDims(); 157 | return std::string("Int8CalibrationTable_") + _networkName + to_string(d.d[2]) + "x" + to_string(d.d[3]) + "_" + to_string(_stream.getMaxBatches()); 158 | } 159 | 160 | ImageStream _stream; 161 | const string _networkName; 162 | const string _calibrationCacheName; 163 | bool _readCache {true}; 164 | size_t _inputCount; 165 | void* _deviceInput {nullptr}; 166 | vector _calibrationCache; 167 | 168 | }; 169 | -------------------------------------------------------------------------------- /csrc/cuda/decode.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Permission is hereby granted, free of charge, to any person obtaining a 5 | * copy of this software and associated documentation files (the "Software"), 6 | * to deal in the Software without restriction, including without limitation 7 | * the rights to use, copy, modify, merge, publish, distribute, sublicense, 8 | * and/or sell copies of the Software, and to permit persons to whom the 9 | * Software is furnished to do so, subject to the following conditions: 10 | * 11 | * The above copyright notice and this permission notice shall be included in 12 | * all copies or substantial portions of the Software. 13 | * 14 | * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 | * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 17 | * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 | * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 19 | * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 20 | * DEALINGS IN THE SOFTWARE. 21 | */ 22 | 23 | #include "decode.h" 24 | #include "utils.h" 25 | 26 | #include 27 | #include 28 | 29 | #include 30 | #include 31 | #include 32 | #include 33 | #include 34 | #include 35 | #include 36 | #include 37 | #include 38 | 39 | #include 40 | 41 | namespace odtk { 42 | namespace cuda { 43 | 44 | int decode(int batch_size, 45 | const void *const *inputs, void *const *outputs, 46 | size_t height, size_t width, size_t scale, 47 | size_t num_anchors, size_t num_classes, 48 | const std::vector &anchors, float score_thresh, int top_n, 49 | void *workspace, size_t workspace_size, cudaStream_t stream) { 50 | 51 | int scores_size = num_anchors * num_classes * height * width; 52 | 53 | if (!workspace || !workspace_size) { 54 | // Return required scratch space size cub style 55 | workspace_size = get_size_aligned(anchors.size()); // anchors 56 | workspace_size += get_size_aligned(scores_size); // flags 57 | workspace_size += get_size_aligned(scores_size); // indices 58 | workspace_size += get_size_aligned(scores_size); // indices_sorted 59 | workspace_size += get_size_aligned(scores_size); // scores 60 | workspace_size += get_size_aligned(scores_size); // scores_sorted 61 | 62 | size_t temp_size_flag = 0; 63 | cub::DeviceSelect::Flagged((void *)nullptr, temp_size_flag, 64 | cub::CountingInputIterator(scores_size), 65 | (bool *)nullptr, (int *)nullptr, (int *)nullptr, scores_size); 66 | size_t temp_size_sort = 0; 67 | cub::DeviceRadixSort::SortPairsDescending((void *)nullptr, temp_size_sort, 68 | (float *)nullptr, (float *)nullptr, (int *)nullptr, (int *)nullptr, scores_size); 69 | workspace_size += std::max(temp_size_flag, temp_size_sort); 70 | 71 | return workspace_size; 72 | } 73 | 74 | auto anchors_d = get_next_ptr(anchors.size(), workspace, workspace_size); 75 | cudaMemcpyAsync(anchors_d, anchors.data(), anchors.size() * sizeof *anchors_d, cudaMemcpyHostToDevice, stream); 76 | 77 | auto on_stream = thrust::cuda::par.on(stream); 78 | 79 | auto flags = get_next_ptr(scores_size, workspace, workspace_size); 80 | auto indices = get_next_ptr(scores_size, workspace, workspace_size); 81 | auto indices_sorted = get_next_ptr(scores_size, workspace, workspace_size); 82 | auto scores = get_next_ptr(scores_size, workspace, workspace_size); 83 | auto scores_sorted = get_next_ptr(scores_size, workspace, workspace_size); 84 | 85 | 86 | for (int batch = 0; batch < batch_size; batch++) { 87 | 88 | auto in_scores = static_cast(inputs[0]) + batch * scores_size; 89 | auto in_boxes = static_cast(inputs[1]) + batch * (scores_size / num_classes) * 4; 90 | 91 | auto out_scores = static_cast(outputs[0]) + batch * top_n; 92 | auto out_boxes = static_cast(outputs[1]) + batch * top_n; 93 | auto out_classes = static_cast(outputs[2]) + batch * top_n; 94 | 95 | // Discard scores below threshold 96 | thrust::transform(on_stream, in_scores, in_scores + scores_size, 97 | flags, thrust::placeholders::_1 > score_thresh); 98 | 99 | int *num_selected = reinterpret_cast(indices_sorted); 100 | cub::DeviceSelect::Flagged(workspace, workspace_size, 101 | cub::CountingInputIterator(0), 102 | flags, indices, num_selected, scores_size, stream); 103 | cudaStreamSynchronize(stream); 104 | int num_detections = *thrust::device_pointer_cast(num_selected); 105 | 106 | // Only keep top n scores 107 | auto indices_filtered = indices; 108 | if (num_detections > top_n) { 109 | thrust::gather(on_stream, indices, indices + num_detections, 110 | in_scores, scores); 111 | cub::DeviceRadixSort::SortPairsDescending(workspace, workspace_size, 112 | scores, scores_sorted, indices, indices_sorted, num_detections, 0, sizeof(*scores)*8, stream); 113 | indices_filtered = indices_sorted; 114 | num_detections = top_n; 115 | } 116 | 117 | // Gather boxes 118 | bool has_anchors = !anchors.empty(); 119 | thrust::transform(on_stream, indices_filtered, indices_filtered + num_detections, 120 | thrust::make_zip_iterator(thrust::make_tuple(out_scores, out_boxes, out_classes)), 121 | [=] __device__ (int i) { 122 | int x = i % width; 123 | int y = (i / width) % height; 124 | int a = (i / num_classes / height / width) % num_anchors; 125 | int cls = (i / height / width) % num_classes; 126 | float4 box = float4{ 127 | in_boxes[((a * 4 + 0) * height + y) * width + x], 128 | in_boxes[((a * 4 + 1) * height + y) * width + x], 129 | in_boxes[((a * 4 + 2) * height + y) * width + x], 130 | in_boxes[((a * 4 + 3) * height + y) * width + x] 131 | }; 132 | 133 | if (has_anchors) { 134 | // Add anchors offsets to deltas 135 | float x = (i % width) * scale; 136 | float y = ((i / width) % height) * scale; 137 | float *d = anchors_d + 4*a; 138 | 139 | float x1 = x + d[0]; 140 | float y1 = y + d[1]; 141 | float x2 = x + d[2]; 142 | float y2 = y + d[3]; 143 | float w = x2 - x1 + 1.0f; 144 | float h = y2 - y1 + 1.0f; 145 | float pred_ctr_x = box.x * w + x1 + 0.5f * w; 146 | float pred_ctr_y = box.y * h + y1 + 0.5f * h; 147 | float pred_w = exp(box.z) * w; 148 | float pred_h = exp(box.w) * h; 149 | 150 | box = float4{ 151 | max(0.0f, pred_ctr_x - 0.5f * pred_w), 152 | max(0.0f, pred_ctr_y - 0.5f * pred_h), 153 | min(pred_ctr_x + 0.5f * pred_w - 1.0f, width * scale - 1.0f), 154 | min(pred_ctr_y + 0.5f * pred_h - 1.0f, height * scale - 1.0f) 155 | }; 156 | } 157 | 158 | return thrust::make_tuple(in_scores[i], box, cls); 159 | }); 160 | 161 | // Zero-out unused scores 162 | if (num_detections < top_n) { 163 | thrust::fill(on_stream, out_scores + num_detections, 164 | out_scores + top_n, 0.0f); 165 | thrust::fill(on_stream, out_classes + num_detections, 166 | out_classes + top_n, 0.0f); 167 | } 168 | } 169 | 170 | return 0; 171 | } 172 | 173 | } 174 | } 175 | -------------------------------------------------------------------------------- /csrc/cuda/decode.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Permission is hereby granted, free of charge, to any person obtaining a 5 | * copy of this software and associated documentation files (the "Software"), 6 | * to deal in the Software without restriction, including without limitation 7 | * the rights to use, copy, modify, merge, publish, distribute, sublicense, 8 | * and/or sell copies of the Software, and to permit persons to whom the 9 | * Software is furnished to do so, subject to the following conditions: 10 | * 11 | * The above copyright notice and this permission notice shall be included in 12 | * all copies or substantial portions of the Software. 13 | * 14 | * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 | * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 17 | * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 | * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 19 | * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 20 | * DEALINGS IN THE SOFTWARE. 21 | */ 22 | 23 | #pragma once 24 | 25 | #include 26 | 27 | namespace odtk { 28 | namespace cuda { 29 | 30 | int decode(int batch_size, 31 | const void *const *inputs, void *const *outputs, 32 | size_t height, size_t width, size_t scale, 33 | size_t num_anchors, size_t num_classes, 34 | const std::vector &anchors, float score_thresh, int top_n, 35 | void *workspace, size_t workspace_size, cudaStream_t stream); 36 | 37 | 38 | } 39 | } -------------------------------------------------------------------------------- /csrc/cuda/decode_rotate.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Permission is hereby granted, free of charge, to any person obtaining a 5 | * copy of this software and associated documentation files (the "Software"), 6 | * to deal in the Software without restriction, including without limitation 7 | * the rights to use, copy, modify, merge, publish, distribute, sublicense, 8 | * and/or sell copies of the Software, and to permit persons to whom the 9 | * Software is furnished to do so, subject to the following conditions: 10 | * 11 | * The above copyright notice and this permission notice shall be included in 12 | * all copies or substantial portions of the Software. 13 | * 14 | * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 | * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 17 | * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 | * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 19 | * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 20 | * DEALINGS IN THE SOFTWARE. 21 | */ 22 | 23 | #include "decode_rotate.h" 24 | #include "utils.h" 25 | 26 | #include 27 | #include 28 | 29 | #include 30 | #include 31 | #include 32 | #include 33 | #include 34 | #include 35 | #include 36 | #include 37 | #include 38 | 39 | namespace odtk { 40 | namespace cuda { 41 | 42 | int decode_rotate(int batch_size, 43 | const void *const *inputs, void *const *outputs, 44 | size_t height, size_t width, size_t scale, 45 | size_t num_anchors, size_t num_classes, 46 | const std::vector &anchors, float score_thresh, int top_n, 47 | void *workspace, size_t workspace_size, cudaStream_t stream) { 48 | 49 | int scores_size = num_anchors * num_classes * height * width; 50 | 51 | if (!workspace || !workspace_size) { 52 | // Return required scratch space size cub style 53 | workspace_size = get_size_aligned(anchors.size()); // anchors 54 | workspace_size += get_size_aligned(scores_size); // flags 55 | workspace_size += get_size_aligned(scores_size); // indices 56 | workspace_size += get_size_aligned(scores_size); // indices_sorted 57 | workspace_size += get_size_aligned(scores_size); // scores 58 | workspace_size += get_size_aligned(scores_size); // scores_sorted 59 | 60 | size_t temp_size_flag = 0; 61 | cub::DeviceSelect::Flagged((void *)nullptr, temp_size_flag, 62 | cub::CountingInputIterator(scores_size), 63 | (bool *)nullptr, (int *)nullptr, (int *)nullptr, scores_size); 64 | size_t temp_size_sort = 0; 65 | cub::DeviceRadixSort::SortPairsDescending((void *)nullptr, temp_size_sort, 66 | (float *)nullptr, (float *)nullptr, (int *)nullptr, (int *)nullptr, scores_size); 67 | workspace_size += std::max(temp_size_flag, temp_size_sort); 68 | 69 | return workspace_size; 70 | } 71 | 72 | auto anchors_d = get_next_ptr(anchors.size(), workspace, workspace_size); 73 | cudaMemcpyAsync(anchors_d, anchors.data(), anchors.size() * sizeof *anchors_d, cudaMemcpyHostToDevice, stream); 74 | 75 | auto on_stream = thrust::cuda::par.on(stream); 76 | 77 | auto flags = get_next_ptr(scores_size, workspace, workspace_size); 78 | auto indices = get_next_ptr(scores_size, workspace, workspace_size); 79 | auto indices_sorted = get_next_ptr(scores_size, workspace, workspace_size); 80 | auto scores = get_next_ptr(scores_size, workspace, workspace_size); 81 | auto scores_sorted = get_next_ptr(scores_size, workspace, workspace_size); 82 | 83 | for (int batch = 0; batch < batch_size; batch++) { 84 | auto in_scores = static_cast(inputs[0]) + batch * scores_size; 85 | auto in_boxes = static_cast(inputs[1]) + batch * (scores_size / num_classes) * 6; //From 4 86 | 87 | auto out_scores = static_cast(outputs[0]) + batch * top_n; 88 | auto out_boxes = static_cast(outputs[1]) + batch * top_n; // From float4 89 | auto out_classes = static_cast(outputs[2]) + batch * top_n; 90 | 91 | // Discard scores below threshold 92 | thrust::transform(on_stream, in_scores, in_scores + scores_size, 93 | flags, thrust::placeholders::_1 > score_thresh); 94 | 95 | int *num_selected = reinterpret_cast(indices_sorted); 96 | cub::DeviceSelect::Flagged(workspace, workspace_size, cub::CountingInputIterator(0), 97 | flags, indices, num_selected, scores_size, stream); 98 | cudaStreamSynchronize(stream); 99 | int num_detections = *thrust::device_pointer_cast(num_selected); 100 | 101 | // Only keep top n scores 102 | auto indices_filtered = indices; 103 | if (num_detections > top_n) { 104 | thrust::gather(on_stream, indices, indices + num_detections, 105 | in_scores, scores); 106 | cub::DeviceRadixSort::SortPairsDescending(workspace, workspace_size, 107 | scores, scores_sorted, indices, indices_sorted, num_detections, 0, sizeof(*scores)*8, stream); 108 | indices_filtered = indices_sorted; 109 | num_detections = top_n; 110 | } 111 | 112 | // Gather boxes 113 | bool has_anchors = !anchors.empty(); 114 | thrust::transform(on_stream, indices_filtered, indices_filtered + num_detections, 115 | thrust::make_zip_iterator(thrust::make_tuple(out_scores, out_boxes, out_classes)), 116 | [=] __device__ (int i) { 117 | int x = i % width; 118 | int y = (i / width) % height; 119 | int a = (i / num_classes / height / width) % num_anchors; 120 | int cls = (i / height / width) % num_classes; 121 | 122 | float6 box = make_float6( 123 | make_float4( 124 | in_boxes[((a * 6 + 0) * height + y) * width + x], 125 | in_boxes[((a * 6 + 1) * height + y) * width + x], 126 | in_boxes[((a * 6 + 2) * height + y) * width + x], 127 | in_boxes[((a * 6 + 3) * height + y) * width + x] 128 | ), 129 | make_float2( 130 | in_boxes[((a * 6 + 4) * height + y) * width + x], 131 | in_boxes[((a * 6 + 5) * height + y) * width + x] 132 | ) 133 | ); 134 | 135 | if (has_anchors) { 136 | // Add anchors offsets to deltas 137 | float x = (i % width) * scale; 138 | float y = ((i / width) % height) * scale; 139 | float *d = anchors_d + 4*a; 140 | 141 | float x1 = x + d[0]; 142 | float y1 = y + d[1]; 143 | float x2 = x + d[2]; 144 | float y2 = y + d[3]; 145 | 146 | float w = x2 - x1 + 1.0f; 147 | float h = y2 - y1 + 1.0f; 148 | float pred_ctr_x = box.x1 * w + x1 + 0.5f * w; 149 | float pred_ctr_y = box.y1 * h + y1 + 0.5f * h; 150 | float pred_w = exp(box.x2) * w; 151 | float pred_h = exp(box.y2) * h; 152 | float pred_sin = box.s; 153 | float pred_cos = box.c; 154 | 155 | box = make_float6( 156 | make_float4( 157 | max(0.0f, pred_ctr_x - 0.5f * pred_w), 158 | max(0.0f, pred_ctr_y - 0.5f * pred_h), 159 | min(pred_ctr_x + 0.5f * pred_w - 1.0f, width * scale - 1.0f), 160 | min(pred_ctr_y + 0.5f * pred_h - 1.0f, height * scale - 1.0f) 161 | ), 162 | make_float2(pred_sin, pred_cos) 163 | ); 164 | } 165 | 166 | return thrust::make_tuple(in_scores[i], box, cls); 167 | }); 168 | 169 | // Zero-out unused scores 170 | if (num_detections < top_n) { 171 | thrust::fill(on_stream, out_scores + num_detections, 172 | out_scores + top_n, 0.0f); 173 | thrust::fill(on_stream, out_classes + num_detections, 174 | out_classes + top_n, 0.0f); 175 | } 176 | } 177 | 178 | return 0; 179 | } 180 | 181 | } 182 | } 183 | -------------------------------------------------------------------------------- /csrc/cuda/decode_rotate.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Permission is hereby granted, free of charge, to any person obtaining a 5 | * copy of this software and associated documentation files (the "Software"), 6 | * to deal in the Software without restriction, including without limitation 7 | * the rights to use, copy, modify, merge, publish, distribute, sublicense, 8 | * and/or sell copies of the Software, and to permit persons to whom the 9 | * Software is furnished to do so, subject to the following conditions: 10 | * 11 | * The above copyright notice and this permission notice shall be included in 12 | * all copies or substantial portions of the Software. 13 | * 14 | * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 | * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 17 | * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 | * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 19 | * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 20 | * DEALINGS IN THE SOFTWARE. 21 | */ 22 | 23 | #pragma once 24 | 25 | #include 26 | 27 | namespace odtk { 28 | namespace cuda { 29 | 30 | int decode_rotate(int batchSize, 31 | const void *const *inputs, void *const *outputs, 32 | size_t height, size_t width, size_t scale, 33 | size_t num_anchors, size_t num_classes, 34 | const std::vector &anchors, float score_thresh, int top_n, 35 | void *workspace, size_t workspace_size, cudaStream_t stream); 36 | 37 | } 38 | } -------------------------------------------------------------------------------- /csrc/cuda/nms.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Permission is hereby granted, free of charge, to any person obtaining a 5 | * copy of this software and associated documentation files (the "Software"), 6 | * to deal in the Software without restriction, including without limitation 7 | * the rights to use, copy, modify, merge, publish, distribute, sublicense, 8 | * and/or sell copies of the Software, and to permit persons to whom the 9 | * Software is furnished to do so, subject to the following conditions: 10 | * 11 | * The above copyright notice and this permission notice shall be included in 12 | * all copies or substantial portions of the Software. 13 | * 14 | * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 | * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 17 | * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 | * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 19 | * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 20 | * DEALINGS IN THE SOFTWARE. 21 | */ 22 | 23 | #include "nms.h" 24 | #include "utils.h" 25 | 26 | #include 27 | #include 28 | #include 29 | #include 30 | #include 31 | #include 32 | 33 | #include 34 | #include 35 | #include 36 | #include 37 | #include 38 | #include 39 | #include 40 | 41 | namespace odtk { 42 | namespace cuda { 43 | 44 | __global__ void nms_kernel( 45 | const int num_per_thread, const float threshold, const int num_detections, 46 | const int *indices, float *scores, const float *classes, const float4 *boxes) { 47 | 48 | // Go through detections by descending score 49 | for (int m = 0; m < num_detections; m++) { 50 | for (int n = 0; n < num_per_thread; n++) { 51 | int i = threadIdx.x * num_per_thread + n; 52 | if (i < num_detections && m < i && scores[m] > 0.0f) { 53 | int idx = indices[i]; 54 | int max_idx = indices[m]; 55 | int icls = classes[idx]; 56 | int mcls = classes[max_idx]; 57 | if (mcls == icls) { 58 | float4 ibox = boxes[idx]; 59 | float4 mbox = boxes[max_idx]; 60 | float x1 = max(ibox.x, mbox.x); 61 | float y1 = max(ibox.y, mbox.y); 62 | float x2 = min(ibox.z, mbox.z); 63 | float y2 = min(ibox.w, mbox.w); 64 | float w = max(0.0f, x2 - x1 + 1); 65 | float h = max(0.0f, y2 - y1 + 1); 66 | float iarea = (ibox.z - ibox.x + 1) * (ibox.w - ibox.y + 1); 67 | float marea = (mbox.z - mbox.x + 1) * (mbox.w - mbox.y + 1); 68 | float inter = w * h; 69 | float overlap = inter / (iarea + marea - inter); 70 | if (overlap > threshold) { 71 | scores[i] = 0.0f; 72 | } 73 | } 74 | } 75 | } 76 | 77 | // Sync discarded detections 78 | __syncthreads(); 79 | } 80 | } 81 | 82 | int nms(int batch_size, 83 | const void *const *inputs, void *const *outputs, 84 | size_t count, int detections_per_im, float nms_thresh, 85 | void *workspace, size_t workspace_size, cudaStream_t stream) { 86 | 87 | if (!workspace || !workspace_size) { 88 | // Return required scratch space size cub style 89 | workspace_size = get_size_aligned(count); // flags 90 | workspace_size += get_size_aligned(count); // indices 91 | workspace_size += get_size_aligned(count); // indices_sorted 92 | workspace_size += get_size_aligned(count); // scores 93 | workspace_size += get_size_aligned(count); // scores_sorted 94 | 95 | size_t temp_size_flag = 0; 96 | cub::DeviceSelect::Flagged((void *)nullptr, temp_size_flag, 97 | cub::CountingInputIterator(count), 98 | (bool *)nullptr, (int *)nullptr, (int *)nullptr, count); 99 | size_t temp_size_sort = 0; 100 | cub::DeviceRadixSort::SortPairsDescending((void *)nullptr, temp_size_sort, 101 | (float *)nullptr, (float *)nullptr, (int *)nullptr, (int *)nullptr, count); 102 | workspace_size += std::max(temp_size_flag, temp_size_sort); 103 | 104 | return workspace_size; 105 | } 106 | 107 | auto on_stream = thrust::cuda::par.on(stream); 108 | 109 | auto flags = get_next_ptr(count, workspace, workspace_size); 110 | auto indices = get_next_ptr(count, workspace, workspace_size); 111 | auto indices_sorted = get_next_ptr(count, workspace, workspace_size); 112 | auto scores = get_next_ptr(count, workspace, workspace_size); 113 | auto scores_sorted = get_next_ptr(count, workspace, workspace_size); 114 | 115 | for (int batch = 0; batch < batch_size; batch++) { 116 | auto in_scores = static_cast(inputs[0]) + batch * count; 117 | auto in_boxes = static_cast(inputs[1]) + batch * count; 118 | auto in_classes = static_cast(inputs[2]) + batch * count; 119 | 120 | auto out_scores = static_cast(outputs[0]) + batch * detections_per_im; 121 | auto out_boxes = static_cast(outputs[1]) + batch * detections_per_im; 122 | auto out_classes = static_cast(outputs[2]) + batch * detections_per_im; 123 | 124 | // Discard null scores 125 | thrust::transform(on_stream, in_scores, in_scores + count, 126 | flags, thrust::placeholders::_1 > 0.0f); 127 | 128 | int *num_selected = reinterpret_cast(indices_sorted); 129 | cub::DeviceSelect::Flagged(workspace, workspace_size, cub::CountingInputIterator(0), 130 | flags, indices, num_selected, count, stream); 131 | cudaStreamSynchronize(stream); 132 | int num_detections = *thrust::device_pointer_cast(num_selected); 133 | 134 | // Sort scores and corresponding indices 135 | thrust::gather(on_stream, indices, indices + num_detections, in_scores, scores); 136 | cub::DeviceRadixSort::SortPairsDescending(workspace, workspace_size, 137 | scores, scores_sorted, indices, indices_sorted, num_detections, 0, sizeof(*scores)*8, stream); 138 | 139 | // Launch actual NMS kernel - 1 block with each thread handling n detections 140 | const int max_threads = 1024; 141 | int num_per_thread = ceil((float)num_detections / max_threads); 142 | nms_kernel<<<1, max_threads, 0, stream>>>(num_per_thread, nms_thresh, num_detections, 143 | indices_sorted, scores_sorted, in_classes, in_boxes); 144 | 145 | // Re-sort with updated scores 146 | cub::DeviceRadixSort::SortPairsDescending(workspace, workspace_size, 147 | scores_sorted, scores, indices_sorted, indices, num_detections, 0, sizeof(*scores)*8, stream); 148 | 149 | // Gather filtered scores, boxes, classes 150 | num_detections = min(detections_per_im, num_detections); 151 | cudaMemcpyAsync(out_scores, scores, num_detections * sizeof *scores, cudaMemcpyDeviceToDevice, stream); 152 | if (num_detections < detections_per_im) { 153 | thrust::fill_n(on_stream, out_scores + num_detections, detections_per_im - num_detections, 0); 154 | } 155 | thrust::gather(on_stream, indices, indices + num_detections, in_boxes, out_boxes); 156 | thrust::gather(on_stream, indices, indices + num_detections, in_classes, out_classes); 157 | } 158 | 159 | return 0; 160 | } 161 | 162 | } 163 | } 164 | -------------------------------------------------------------------------------- /csrc/cuda/nms.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Permission is hereby granted, free of charge, to any person obtaining a 5 | * copy of this software and associated documentation files (the "Software"), 6 | * to deal in the Software without restriction, including without limitation 7 | * the rights to use, copy, modify, merge, publish, distribute, sublicense, 8 | * and/or sell copies of the Software, and to permit persons to whom the 9 | * Software is furnished to do so, subject to the following conditions: 10 | * 11 | * The above copyright notice and this permission notice shall be included in 12 | * all copies or substantial portions of the Software. 13 | * 14 | * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 | * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 17 | * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 | * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 19 | * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 20 | * DEALINGS IN THE SOFTWARE. 21 | */ 22 | 23 | #pragma once 24 | 25 | namespace odtk { 26 | namespace cuda { 27 | 28 | int nms(int batchSize, 29 | const void *const *inputs, void *const *outputs, 30 | size_t count, int detections_per_im, float nms_thresh, 31 | void *workspace, size_t workspace_size, cudaStream_t stream); 32 | 33 | } 34 | } -------------------------------------------------------------------------------- /csrc/cuda/nms_iou.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Permission is hereby granted, free of charge, to any person obtaining a 5 | * copy of this software and associated documentation files (the "Software"), 6 | * to deal in the Software without restriction, including without limitation 7 | * the rights to use, copy, modify, merge, publish, distribute, sublicense, 8 | * and/or sell copies of the Software, and to permit persons to whom the 9 | * Software is furnished to do so, subject to the following conditions: 10 | * 11 | * The above copyright notice and this permission notice shall be included in 12 | * all copies or substantial portions of the Software. 13 | * 14 | * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 | * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 17 | * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 | * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 19 | * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 20 | * DEALINGS IN THE SOFTWARE. 21 | */ 22 | 23 | #pragma once 24 | 25 | namespace odtk { 26 | namespace cuda { 27 | 28 | int nms_rotate(int batchSize, 29 | const void *const *inputs, void *const *outputs, 30 | size_t count, int detections_per_im, float nms_thresh, 31 | void *workspace, size_t workspace_size, cudaStream_t stream); 32 | 33 | int iou( 34 | const void *const *inputs, void *const *outputs, 35 | int num_boxes, int num_anchors, cudaStream_t stream); 36 | } 37 | } -------------------------------------------------------------------------------- /csrc/cuda/utils.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Permission is hereby granted, free of charge, to any person obtaining a 5 | * copy of this software and associated documentation files (the "Software"), 6 | * to deal in the Software without restriction, including without limitation 7 | * the rights to use, copy, modify, merge, publish, distribute, sublicense, 8 | * and/or sell copies of the Software, and to permit persons to whom the 9 | * Software is furnished to do so, subject to the following conditions: 10 | * 11 | * The above copyright notice and this permission notice shall be included in 12 | * all copies or substantial portions of the Software. 13 | * 14 | * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 | * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 17 | * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 | * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 19 | * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 20 | * DEALINGS IN THE SOFTWARE. 21 | */ 22 | 23 | #pragma once 24 | #include 25 | #include 26 | #include 27 | 28 | #define CUDA_ALIGN 256 29 | 30 | struct float6 31 | { 32 | float x1, y1, x2, y2, s, c; 33 | }; 34 | 35 | inline __host__ __device__ float6 make_float6(float4 f, float2 t) 36 | { 37 | float6 fs; 38 | fs.x1 = f.x; fs.y1 = f.y; fs.x2 = f.z; fs.y2 = f.w; fs.s = t.x; fs.c = t.y; 39 | return fs; 40 | } 41 | 42 | template 43 | inline size_t get_size_aligned(size_t num_elem) { 44 | size_t size = num_elem * sizeof(T); 45 | size_t extra_align = 0; 46 | if (size % CUDA_ALIGN != 0) { 47 | extra_align = CUDA_ALIGN - size % CUDA_ALIGN; 48 | } 49 | return size + extra_align; 50 | } 51 | 52 | template 53 | inline T *get_next_ptr(size_t num_elem, void *&workspace, size_t &workspace_size) { 54 | size_t size = get_size_aligned(num_elem); 55 | if (size > workspace_size) { 56 | throw std::runtime_error("Workspace is too small!"); 57 | } 58 | workspace_size -= size; 59 | T *ptr = reinterpret_cast(workspace); 60 | workspace = reinterpret_cast(reinterpret_cast(workspace) + size); 61 | return ptr; 62 | } 63 | -------------------------------------------------------------------------------- /csrc/engine.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Permission is hereby granted, free of charge, to any person obtaining a 5 | * copy of this software and associated documentation files (the "Software"), 6 | * to deal in the Software without restriction, including without limitation 7 | * the rights to use, copy, modify, merge, publish, distribute, sublicense, 8 | * and/or sell copies of the Software, and to permit persons to whom the 9 | * Software is furnished to do so, subject to the following conditions: 10 | * 11 | * The above copyright notice and this permission notice shall be included in 12 | * all copies or substantial portions of the Software. 13 | * 14 | * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 | * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 17 | * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 | * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 19 | * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 20 | * DEALINGS IN THE SOFTWARE. 21 | */ 22 | 23 | #include "engine.h" 24 | 25 | #include 26 | #include 27 | 28 | #include 29 | #include 30 | 31 | #include "plugins/DecodePlugin.h" 32 | #include "plugins/NMSPlugin.h" 33 | #include "plugins/DecodeRotatePlugin.h" 34 | #include "plugins/NMSRotatePlugin.h" 35 | #include "calibrator.h" 36 | 37 | #include 38 | #include 39 | 40 | using namespace nvinfer1; 41 | using namespace nvonnxparser; 42 | 43 | namespace odtk { 44 | 45 | class Logger : public ILogger { 46 | public: 47 | Logger(bool verbose) 48 | : _verbose(verbose) { 49 | } 50 | 51 | void log(Severity severity, const char *msg) noexcept override { 52 | if (_verbose || ((severity != Severity::kINFO) && (severity != Severity::kVERBOSE))) 53 | cout << msg << endl; 54 | } 55 | 56 | private: 57 | bool _verbose{false}; 58 | }; 59 | 60 | void Engine::_load(const string &path) { 61 | ifstream file(path, ios::in | ios::binary); 62 | file.seekg (0, file.end); 63 | size_t size = file.tellg(); 64 | file.seekg (0, file.beg); 65 | 66 | auto buffer = std::unique_ptr(new char[size]); 67 | file.read(buffer.get(), size); 68 | file.close(); 69 | 70 | _engine = std::unique_ptr(_runtime->deserializeCudaEngine(buffer.get(), size)); 71 | } 72 | 73 | void Engine::_prepare() { 74 | _context = std::unique_ptr(_engine->createExecutionContext()); 75 | _context->setOptimizationProfileAsync(0, _stream); 76 | cudaStreamCreate(&_stream); 77 | } 78 | 79 | Engine::Engine(const string &engine_path, bool verbose) { 80 | Logger logger(verbose); 81 | _runtime = std::unique_ptr(createInferRuntime(logger)); 82 | _load(engine_path); 83 | _prepare(); 84 | } 85 | 86 | Engine::~Engine() { 87 | if (_stream) cudaStreamDestroy(_stream); 88 | } 89 | 90 | Engine::Engine(const char *onnx_model, size_t onnx_size, const vector& dynamic_batch_opts, 91 | string precision, float score_thresh, int top_n, const vector>& anchors, 92 | bool rotated, float nms_thresh, int detections_per_im, const vector& calibration_images, 93 | string model_name, string calibration_table, bool verbose, size_t workspace_size) { 94 | 95 | Logger logger(verbose); 96 | _runtime = std::unique_ptr(createInferRuntime(logger)); 97 | 98 | bool fp16 = precision.compare("FP16") == 0; 99 | bool int8 = precision.compare("INT8") == 0; 100 | 101 | // Create builder 102 | auto builder = std::unique_ptr(createInferBuilder(logger)); 103 | auto builderConfig = std::unique_ptr(builder->createBuilderConfig()); 104 | // Allow use of FP16 layers when running in INT8 105 | if(fp16 || int8) builderConfig->setFlag(BuilderFlag::kFP16); 106 | builderConfig->setMaxWorkspaceSize(workspace_size); 107 | 108 | // Parse ONNX FCN 109 | cout << "Building " << precision << " core model..." << endl; 110 | const auto flags = 1U << static_cast(NetworkDefinitionCreationFlag::kEXPLICIT_BATCH); 111 | auto network = std::unique_ptr(builder->createNetworkV2(flags)); 112 | auto parser = std::unique_ptr(createParser(*network, logger)); 113 | parser->parse(onnx_model, onnx_size); 114 | 115 | auto input = network->getInput(0); 116 | auto inputDims = input->getDimensions(); 117 | auto profile = builder->createOptimizationProfile(); 118 | auto inputName = input->getName(); 119 | auto profileDimsmin = Dims4{dynamic_batch_opts[0], inputDims.d[1], inputDims.d[2], inputDims.d[3]}; 120 | auto profileDimsopt = Dims4{dynamic_batch_opts[1], inputDims.d[1], inputDims.d[2], inputDims.d[3]}; 121 | auto profileDimsmax = Dims4{dynamic_batch_opts[2], inputDims.d[1], inputDims.d[2], inputDims.d[3]}; 122 | 123 | profile->setDimensions(inputName, nvinfer1::OptProfileSelector::kMIN, profileDimsmin); 124 | profile->setDimensions(inputName, nvinfer1::OptProfileSelector::kOPT, profileDimsopt); 125 | profile->setDimensions(inputName, nvinfer1::OptProfileSelector::kMAX, profileDimsmax); 126 | 127 | if(profile->isValid()) 128 | builderConfig->addOptimizationProfile(profile); 129 | 130 | std::unique_ptr calib; 131 | if (int8) { 132 | builderConfig->setFlag(BuilderFlag::kINT8); 133 | // Calibration is performed using kOPT values of the profile. 134 | // Calibration batch size must match this profile. 135 | builderConfig->setCalibrationProfile(profile); 136 | ImageStream stream(dynamic_batch_opts[1], inputDims, calibration_images); 137 | calib = std::unique_ptr(new Int8EntropyCalibrator(stream, model_name, calibration_table)); 138 | builderConfig->setInt8Calibrator(calib.get()); 139 | } 140 | 141 | // Add decode plugins 142 | cout << "Building accelerated plugins..." << endl; 143 | vector decodePlugins; 144 | vector decodeRotatePlugins; 145 | vector scores, boxes, classes; 146 | auto nbOutputs = network->getNbOutputs(); 147 | 148 | for (int i = 0; i < nbOutputs / 2; i++) { 149 | auto classOutput = network->getOutput(i); 150 | auto boxOutput = network->getOutput(nbOutputs / 2 + i); 151 | auto outputDims = classOutput->getDimensions(); 152 | int scale = inputDims.d[2] / outputDims.d[2]; 153 | auto decodePlugin = DecodePlugin(score_thresh, top_n, anchors[i], scale); 154 | auto decodeRotatePlugin = DecodeRotatePlugin(score_thresh, top_n, anchors[i], scale); 155 | decodePlugins.push_back(decodePlugin); 156 | decodeRotatePlugins.push_back(decodeRotatePlugin); 157 | vector inputs = {classOutput, boxOutput}; 158 | auto layer = (!rotated) ? network->addPluginV2(inputs.data(), inputs.size(), decodePlugin) \ 159 | : network->addPluginV2(inputs.data(), inputs.size(), decodeRotatePlugin); 160 | scores.push_back(layer->getOutput(0)); 161 | boxes.push_back(layer->getOutput(1)); 162 | classes.push_back(layer->getOutput(2)); 163 | } 164 | 165 | // Cleanup outputs 166 | for (int i = 0; i < nbOutputs; i++) { 167 | auto output = network->getOutput(0); 168 | network->unmarkOutput(*output); 169 | } 170 | 171 | // Concat tensors from each feature map 172 | vector concat; 173 | for (auto tensors : {scores, boxes, classes}) { 174 | auto layer = network->addConcatenation(tensors.data(), tensors.size()); 175 | concat.push_back(layer->getOutput(0)); 176 | } 177 | 178 | // Add NMS plugin 179 | auto nmsPlugin = NMSPlugin(nms_thresh, detections_per_im); 180 | auto nmsRotatePlugin = NMSRotatePlugin(nms_thresh, detections_per_im); 181 | auto layer = (!rotated) ? network->addPluginV2(concat.data(), concat.size(), nmsPlugin) \ 182 | : network->addPluginV2(concat.data(), concat.size(), nmsRotatePlugin); 183 | vector names = {"scores", "boxes", "classes"}; 184 | for (int i = 0; i < layer->getNbOutputs(); i++) { 185 | auto output = layer->getOutput(i); 186 | network->markOutput(*output); 187 | output->setName(names[i].c_str()); 188 | } 189 | 190 | // Build engine 191 | cout << "Applying optimizations and building TRT CUDA engine..." << endl; 192 | _plan = std::unique_ptr(builder->buildSerializedNetwork(*network, *builderConfig)); 193 | } 194 | 195 | void Engine::save(const string &path) { 196 | cout << "Writing to " << path << "..." << endl; 197 | ofstream file(path, ios::out | ios::binary); 198 | file.write(reinterpret_cast(_plan->data()), _plan->size()); 199 | } 200 | 201 | void Engine::infer(vector &buffers, int batch){ 202 | auto dims = _engine->getBindingDimensions(0); 203 | _context->setBindingDimensions(0, Dims4(batch, dims.d[1], dims.d[2], dims.d[3])); 204 | _context->enqueueV2(buffers.data(), _stream, nullptr); 205 | cudaStreamSynchronize(_stream); 206 | } 207 | 208 | vector Engine::getInputSize() { 209 | auto dims = _engine->getBindingDimensions(0); 210 | return {dims.d[2], dims.d[3]}; 211 | } 212 | 213 | int Engine::getMaxBatchSize() { 214 | return _engine->getMaxBatchSize(); 215 | } 216 | 217 | int Engine::getMaxDetections() { 218 | return _engine->getBindingDimensions(1).d[1]; 219 | } 220 | 221 | int Engine::getStride() { 222 | return 1; 223 | } 224 | 225 | } 226 | -------------------------------------------------------------------------------- /csrc/engine.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Permission is hereby granted, free of charge, to any person obtaining a 5 | * copy of this software and associated documentation files (the "Software"), 6 | * to deal in the Software without restriction, including without limitation 7 | * the rights to use, copy, modify, merge, publish, distribute, sublicense, 8 | * and/or sell copies of the Software, and to permit persons to whom the 9 | * Software is furnished to do so, subject to the following conditions: 10 | * 11 | * The above copyright notice and this permission notice shall be included in 12 | * all copies or substantial portions of the Software. 13 | * 14 | * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 | * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 17 | * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 | * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 19 | * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 20 | * DEALINGS IN THE SOFTWARE. 21 | */ 22 | 23 | #pragma once 24 | 25 | #include 26 | #include 27 | #include 28 | 29 | #include 30 | 31 | #include 32 | 33 | using namespace std; 34 | using namespace nvinfer1; 35 | 36 | namespace odtk { 37 | 38 | // RetinaNet wrapper around TensorRT CUDA engine 39 | class Engine { 40 | public: 41 | // Create engine from engine path 42 | Engine(const string &engine_path, bool verbose=false); 43 | 44 | // Create engine from serialized onnx model 45 | 46 | Engine(const char *onnx_model, size_t onnx_size, const vector& dynamic_batch_opts, 47 | string precision, float score_thresh, int top_n, const vector>& anchors, 48 | bool rotated, float nms_thresh, int detections_per_im, const vector& calibration_images, 49 | string model_name, string calibration_table, bool verbose, size_t workspace_size=(1ULL << 30)); 50 | 51 | ~Engine(); 52 | 53 | // Save model to path 54 | void save(const string &path); 55 | 56 | // Infer using pre-allocated GPU buffers {data, scores, boxes, classes} 57 | void infer(vector &buffers, int batch); 58 | 59 | // Get (h, w) size of the fixed input 60 | vector getInputSize(); 61 | 62 | // Get max allowed batch size 63 | int getMaxBatchSize(); 64 | 65 | // Get max number of detections 66 | int getMaxDetections(); 67 | 68 | // Get stride 69 | int getStride(); 70 | 71 | private: 72 | std::unique_ptr _runtime; 73 | std::unique_ptr _engine; 74 | std::unique_ptr _plan; 75 | std::unique_ptr _context; 76 | cudaStream_t _stream = nullptr; 77 | 78 | void _load(const string &path); 79 | void _prepare(); 80 | 81 | }; 82 | 83 | } 84 | -------------------------------------------------------------------------------- /csrc/extensions.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Permission is hereby granted, free of charge, to any person obtaining a 5 | * copy of this software and associated documentation files (the "Software"), 6 | * to deal in the Software without restriction, including without limitation 7 | * the rights to use, copy, modify, merge, publish, distribute, sublicense, 8 | * and/or sell copies of the Software, and to permit persons to whom the 9 | * Software is furnished to do so, subject to the following conditions: 10 | * 11 | * The above copyright notice and this permission notice shall be included in 12 | * all copies or substantial portions of the Software. 13 | * 14 | * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 | * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 17 | * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 | * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 19 | * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 20 | * DEALINGS IN THE SOFTWARE. 21 | */ 22 | 23 | #include 24 | #include 25 | #include 26 | #include 27 | #include 28 | 29 | #include 30 | #include 31 | 32 | #include 33 | #include 34 | 35 | #include "engine.h" 36 | #include "cuda/decode.h" 37 | #include "cuda/decode_rotate.h" 38 | #include "cuda/nms.h" 39 | #include "cuda/nms_iou.h" 40 | #include 41 | 42 | #define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x " must be a CUDA tensor") 43 | #define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") 44 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 45 | 46 | 47 | vector iou(torch::Tensor boxes, torch::Tensor anchors) { 48 | 49 | CHECK_INPUT(boxes); 50 | CHECK_INPUT(anchors); 51 | 52 | int num_boxes = boxes.numel() / 8; 53 | int num_anchors = anchors.numel() / 8; 54 | auto options = boxes.options(); 55 | 56 | auto iou_vals = torch::zeros({num_boxes*num_anchors}, options); 57 | 58 | // Calculate Polygon IOU 59 | vector inputs = {boxes.data_ptr(), anchors.data_ptr()}; 60 | vector outputs = {iou_vals.data_ptr()}; 61 | 62 | odtk::cuda::iou(inputs.data(), outputs.data(), num_boxes, num_anchors, c10::cuda::getCurrentCUDAStream()); 63 | 64 | auto shape = std::vector{num_anchors, num_boxes}; 65 | 66 | return {iou_vals.reshape(shape)}; 67 | } 68 | 69 | vector decode(torch::Tensor cls_head, torch::Tensor box_head, 70 | vector &anchors, int scale, float score_thresh, int top_n, bool rotated=false) { 71 | 72 | CHECK_INPUT(cls_head); 73 | CHECK_INPUT(box_head); 74 | 75 | int num_boxes = (!rotated) ? 4 : 6; 76 | int batch = cls_head.size(0); 77 | int num_anchors = anchors.size() / 4; 78 | int num_classes = cls_head.size(1) / num_anchors; 79 | int height = cls_head.size(2); 80 | int width = cls_head.size(3); 81 | auto options = cls_head.options(); 82 | 83 | auto scores = torch::zeros({batch, top_n}, options); 84 | auto boxes = torch::zeros({batch, top_n, num_boxes}, options); 85 | auto classes = torch::zeros({batch, top_n}, options); 86 | 87 | vector inputs = {cls_head.data_ptr(), box_head.data_ptr()}; 88 | vector outputs = {scores.data_ptr(), boxes.data_ptr(), classes.data_ptr()}; 89 | 90 | if(!rotated) { 91 | // Create scratch buffer 92 | int size = odtk::cuda::decode(batch, nullptr, nullptr, height, width, scale, 93 | num_anchors, num_classes, anchors, score_thresh, top_n, nullptr, 0, nullptr); 94 | auto scratch = torch::zeros({size}, options.dtype(torch::kUInt8)); 95 | 96 | // Decode boxes 97 | odtk::cuda::decode(batch, inputs.data(), outputs.data(), height, width, scale, 98 | num_anchors, num_classes, anchors, score_thresh, top_n, 99 | scratch.data_ptr(), size, c10::cuda::getCurrentCUDAStream()); 100 | 101 | } 102 | else { 103 | // Create scratch buffer 104 | int size = odtk::cuda::decode_rotate(batch, nullptr, nullptr, height, width, scale, 105 | num_anchors, num_classes, anchors, score_thresh, top_n, nullptr, 0, nullptr); 106 | auto scratch = torch::zeros({size}, options.dtype(torch::kUInt8)); 107 | 108 | // Decode boxes 109 | odtk::cuda::decode_rotate(batch, inputs.data(), outputs.data(), height, width, scale, 110 | num_anchors, num_classes, anchors, score_thresh, top_n, 111 | scratch.data_ptr(), size, c10::cuda::getCurrentCUDAStream()); 112 | } 113 | 114 | return {scores, boxes, classes}; 115 | } 116 | 117 | vector nms(torch::Tensor scores, torch::Tensor boxes, torch::Tensor classes, 118 | float nms_thresh, int detections_per_im, bool rotated=false) { 119 | 120 | CHECK_INPUT(scores); 121 | CHECK_INPUT(boxes); 122 | CHECK_INPUT(classes); 123 | 124 | int num_boxes = (!rotated) ? 4 : 6; 125 | int batch = scores.size(0); 126 | int count = scores.size(1); 127 | auto options = scores.options(); 128 | auto nms_scores = torch::zeros({batch, detections_per_im}, scores.options()); 129 | auto nms_boxes = torch::zeros({batch, detections_per_im, num_boxes}, boxes.options()); 130 | auto nms_classes = torch::zeros({batch, detections_per_im}, classes.options()); 131 | 132 | vector inputs = {scores.data_ptr(), boxes.data_ptr(), classes.data_ptr()}; 133 | vector outputs = {nms_scores.data_ptr(), nms_boxes.data_ptr(), nms_classes.data_ptr()}; 134 | 135 | if(!rotated) { 136 | // Create scratch buffer 137 | int size = odtk::cuda::nms(batch, nullptr, nullptr, count, 138 | detections_per_im, nms_thresh, nullptr, 0, nullptr); 139 | auto scratch = torch::zeros({size}, options.dtype(torch::kUInt8)); 140 | 141 | // Perform NMS 142 | odtk::cuda::nms(batch, inputs.data(), outputs.data(), count, detections_per_im, 143 | nms_thresh, scratch.data_ptr(), size, c10::cuda::getCurrentCUDAStream()); 144 | } 145 | else { 146 | // Create scratch buffer 147 | int size = odtk::cuda::nms_rotate(batch, nullptr, nullptr, count, 148 | detections_per_im, nms_thresh, nullptr, 0, nullptr); 149 | auto scratch = torch::zeros({size}, options.dtype(torch::kUInt8)); 150 | 151 | // Perform NMS 152 | odtk::cuda::nms_rotate(batch, inputs.data(), outputs.data(), count, 153 | detections_per_im, nms_thresh, scratch.data_ptr(), size, c10::cuda::getCurrentCUDAStream()); 154 | } 155 | 156 | 157 | return {nms_scores, nms_boxes, nms_classes}; 158 | } 159 | 160 | vector infer(odtk::Engine &engine, torch::Tensor data, bool rotated=false) { 161 | CHECK_INPUT(data); 162 | 163 | int num_boxes = (!rotated) ? 4 : 6; 164 | int batch = data.size(0); 165 | auto input_size = engine.getInputSize(); 166 | data = torch::constant_pad_nd(data, {0, input_size[1] - data.size(3), 0, input_size[0] - data.size(2)}); 167 | 168 | int num_detections = engine.getMaxDetections(); 169 | auto scores = torch::zeros({batch, num_detections}, data.options()); 170 | auto boxes = torch::zeros({batch, num_detections, num_boxes}, data.options()); 171 | auto classes = torch::zeros({batch, num_detections}, data.options()); 172 | 173 | vector buffers; 174 | for (auto buffer : {data, scores, boxes, classes}) { 175 | buffers.push_back(buffer.data_ptr()); 176 | } 177 | 178 | engine.infer(buffers, batch); 179 | 180 | return {scores, boxes, classes}; 181 | } 182 | 183 | 184 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 185 | pybind11::class_(m, "Engine") 186 | .def(pybind11::init&, string, float, int, 187 | const vector>&, bool, float, int, const vector&, string, string, bool>()) 188 | .def("save", &odtk::Engine::save) 189 | .def("infer", &odtk::Engine::infer) 190 | .def_property_readonly("stride", &odtk::Engine::getStride) 191 | .def_property_readonly("input_size", &odtk::Engine::getInputSize) 192 | .def_static("load", [](const string &path) { 193 | return new odtk::Engine(path); 194 | }) 195 | .def("__call__", [](odtk::Engine &engine, torch::Tensor data, bool rotated=false) { 196 | return infer(engine, data, rotated); 197 | }); 198 | m.def("decode", &decode); 199 | m.def("nms", &nms); 200 | m.def("iou", &iou); 201 | } 202 | -------------------------------------------------------------------------------- /csrc/plugins/DecodePlugin.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Permission is hereby granted, free of charge, to any person obtaining a 5 | * copy of this software and associated documentation files (the "Software"), 6 | * to deal in the Software without restriction, including without limitation 7 | * the rights to use, copy, modify, merge, publish, distribute, sublicense, 8 | * and/or sell copies of the Software, and to permit persons to whom the 9 | * Software is furnished to do so, subject to the following conditions: 10 | * 11 | * The above copyright notice and this permission notice shall be included in 12 | * all copies or substantial portions of the Software. 13 | * 14 | * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 | * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 17 | * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 | * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 19 | * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 20 | * DEALINGS IN THE SOFTWARE. 21 | */ 22 | 23 | #pragma once 24 | 25 | #include 26 | 27 | #include 28 | #include 29 | 30 | #include "../cuda/decode.h" 31 | 32 | using namespace nvinfer1; 33 | 34 | #define RETINANET_PLUGIN_NAME "RetinaNetDecode" 35 | #define RETINANET_PLUGIN_VERSION "1" 36 | #define RETINANET_PLUGIN_NAMESPACE "" 37 | 38 | namespace odtk { 39 | 40 | class DecodePlugin : public IPluginV2DynamicExt { 41 | float _score_thresh; 42 | int _top_n; 43 | std::vector _anchors; 44 | float _scale; 45 | 46 | size_t _height; 47 | size_t _width; 48 | size_t _num_anchors; 49 | size_t _num_classes; 50 | mutable int size = -1; 51 | 52 | protected: 53 | void deserialize(void const* data, size_t length) { 54 | const char* d = static_cast(data); 55 | read(d, _score_thresh); 56 | read(d, _top_n); 57 | size_t anchors_size; 58 | read(d, anchors_size); 59 | while( anchors_size-- ) { 60 | float val; 61 | read(d, val); 62 | _anchors.push_back(val); 63 | } 64 | read(d, _scale); 65 | read(d, _height); 66 | read(d, _width); 67 | read(d, _num_anchors); 68 | read(d, _num_classes); 69 | } 70 | 71 | size_t getSerializationSize() const noexcept override { 72 | return sizeof(_score_thresh) + sizeof(_top_n) 73 | + sizeof(size_t) + sizeof(float) * _anchors.size() + sizeof(_scale) 74 | + sizeof(_height) + sizeof(_width) + sizeof(_num_anchors) + sizeof(_num_classes); 75 | } 76 | 77 | void serialize(void *buffer) const noexcept override { 78 | char* d = static_cast(buffer); 79 | write(d, _score_thresh); 80 | write(d, _top_n); 81 | write(d, _anchors.size()); 82 | for( auto &val : _anchors ) { 83 | write(d, val); 84 | } 85 | write(d, _scale); 86 | write(d, _height); 87 | write(d, _width); 88 | write(d, _num_anchors); 89 | write(d, _num_classes); 90 | } 91 | 92 | public: 93 | DecodePlugin(float score_thresh, int top_n, std::vector const& anchors, int scale) 94 | : _score_thresh(score_thresh), _top_n(top_n), _anchors(anchors), _scale(scale) {} 95 | 96 | DecodePlugin(float score_thresh, int top_n, std::vector const& anchors, int scale, 97 | size_t height, size_t width, size_t num_anchors, size_t num_classes) 98 | : _score_thresh(score_thresh), _top_n(top_n), _anchors(anchors), _scale(scale), 99 | _height(height), _width(width), _num_anchors(num_anchors), _num_classes(num_classes) {} 100 | 101 | DecodePlugin(void const* data, size_t length) { 102 | this->deserialize(data, length); 103 | } 104 | 105 | const char *getPluginType() const noexcept override { 106 | return RETINANET_PLUGIN_NAME; 107 | } 108 | 109 | const char *getPluginVersion() const noexcept override { 110 | return RETINANET_PLUGIN_VERSION; 111 | } 112 | 113 | int getNbOutputs() const noexcept override { 114 | return 3; 115 | } 116 | 117 | DimsExprs getOutputDimensions(int outputIndex, const DimsExprs *inputs, 118 | int nbInputs, IExprBuilder &exprBuilder) noexcept override 119 | { 120 | DimsExprs output(inputs[0]); 121 | output.d[1] = exprBuilder.constant(_top_n * (outputIndex == 1 ? 4 : 1)); 122 | output.d[2] = exprBuilder.constant(1); 123 | output.d[3] = exprBuilder.constant(1); 124 | 125 | return output; 126 | } 127 | 128 | bool supportsFormatCombination(int pos, const PluginTensorDesc *inOut, 129 | int nbInputs, int nbOutputs) noexcept override 130 | { 131 | assert(nbInputs == 2); 132 | assert(nbOutputs == 3); 133 | assert(pos < 5); 134 | return inOut[pos].type == DataType::kFLOAT && inOut[pos].format == nvinfer1::PluginFormat::kLINEAR; 135 | } 136 | 137 | int initialize() noexcept override { return 0; } 138 | 139 | void terminate() noexcept override {} 140 | 141 | size_t getWorkspaceSize(const PluginTensorDesc *inputs, 142 | int nbInputs, const PluginTensorDesc *outputs, int nbOutputs) const noexcept override 143 | { 144 | if (size < 0) { 145 | size = cuda::decode(inputs->dims.d[0], nullptr, nullptr, _height, _width, _scale, 146 | _num_anchors, _num_classes, _anchors, _score_thresh, _top_n, 147 | nullptr, 0, nullptr); 148 | } 149 | return size; 150 | } 151 | 152 | int enqueue(const PluginTensorDesc *inputDesc, 153 | const PluginTensorDesc *outputDesc, const void *const *inputs, 154 | void *const *outputs, void *workspace, cudaStream_t stream) noexcept override 155 | { 156 | 157 | return cuda::decode(inputDesc->dims.d[0], inputs, outputs, _height, _width, _scale, 158 | _num_anchors, _num_classes, _anchors, _score_thresh, _top_n, 159 | workspace, getWorkspaceSize(inputDesc, 2, outputDesc, 3), stream); 160 | 161 | } 162 | 163 | void destroy() noexcept override { 164 | delete this; 165 | }; 166 | 167 | const char *getPluginNamespace() const noexcept override { 168 | return RETINANET_PLUGIN_NAMESPACE; 169 | } 170 | 171 | void setPluginNamespace(const char *N) noexcept override {} 172 | 173 | DataType getOutputDataType(int index, const DataType* inputTypes, int nbInputs) const noexcept 174 | { 175 | assert(index < 3); 176 | return DataType::kFLOAT; 177 | } 178 | 179 | void configurePlugin(const DynamicPluginTensorDesc *in, int nbInputs, 180 | const DynamicPluginTensorDesc *out, int nbOutputs) noexcept 181 | { 182 | assert(nbInputs == 2); 183 | assert(nbOutputs == 3); 184 | auto const& scores_dims = in[0].desc.dims; 185 | auto const& boxes_dims = in[1].desc.dims; 186 | assert(scores_dims.d[2] == boxes_dims.d[2]); 187 | assert(scores_dims.d[3] == boxes_dims.d[3]); 188 | _height = scores_dims.d[2]; 189 | _width = scores_dims.d[3]; 190 | _num_anchors = boxes_dims.d[1] / 4; 191 | _num_classes = scores_dims.d[1] / _num_anchors; 192 | } 193 | 194 | IPluginV2DynamicExt *clone() const noexcept override { 195 | return new DecodePlugin(_score_thresh, _top_n, _anchors, _scale, _height, _width, 196 | _num_anchors, _num_classes); 197 | } 198 | 199 | private: 200 | template void write(char*& buffer, const T& val) const { 201 | *reinterpret_cast(buffer) = val; 202 | buffer += sizeof(T); 203 | } 204 | 205 | template void read(const char*& buffer, T& val) { 206 | val = *reinterpret_cast(buffer); 207 | buffer += sizeof(T); 208 | } 209 | }; 210 | 211 | class DecodePluginCreator : public IPluginCreator { 212 | public: 213 | DecodePluginCreator() {} 214 | 215 | const char *getPluginName () const noexcept override { 216 | return RETINANET_PLUGIN_NAME; 217 | } 218 | 219 | const char *getPluginVersion () const noexcept override { 220 | return RETINANET_PLUGIN_VERSION; 221 | } 222 | 223 | const char *getPluginNamespace() const noexcept override { 224 | return RETINANET_PLUGIN_NAMESPACE; 225 | } 226 | 227 | 228 | IPluginV2DynamicExt *deserializePlugin (const char *name, const void *serialData, size_t serialLength) noexcept override { 229 | return new DecodePlugin(serialData, serialLength); 230 | } 231 | 232 | void setPluginNamespace(const char *N) noexcept override {} 233 | const PluginFieldCollection *getFieldNames() noexcept override { return nullptr; } 234 | IPluginV2DynamicExt *createPlugin (const char *name, const PluginFieldCollection *fc) noexcept override { return nullptr; } 235 | }; 236 | 237 | REGISTER_TENSORRT_PLUGIN(DecodePluginCreator); 238 | 239 | } 240 | 241 | #undef RETINANET_PLUGIN_NAME 242 | #undef RETINANET_PLUGIN_VERSION 243 | #undef RETINANET_PLUGIN_NAMESPACE -------------------------------------------------------------------------------- /csrc/plugins/DecodeRotatePlugin.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Permission is hereby granted, free of charge, to any person obtaining a 5 | * copy of this software and associated documentation files (the "Software"), 6 | * to deal in the Software without restriction, including without limitation 7 | * the rights to use, copy, modify, merge, publish, distribute, sublicense, 8 | * and/or sell copies of the Software, and to permit persons to whom the 9 | * Software is furnished to do so, subject to the following conditions: 10 | * 11 | * The above copyright notice and this permission notice shall be included in 12 | * all copies or substantial portions of the Software. 13 | * 14 | * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 | * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 17 | * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 | * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 19 | * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 20 | * DEALINGS IN THE SOFTWARE. 21 | */ 22 | 23 | #pragma once 24 | 25 | #include 26 | 27 | #include 28 | #include 29 | 30 | #include "../cuda/decode_rotate.h" 31 | 32 | using namespace nvinfer1; 33 | 34 | #define RETINANET_PLUGIN_NAME "RetinaNetDecodeRotate" 35 | #define RETINANET_PLUGIN_VERSION "1" 36 | #define RETINANET_PLUGIN_NAMESPACE "" 37 | 38 | namespace odtk { 39 | 40 | class DecodeRotatePlugin : public IPluginV2DynamicExt { 41 | float _score_thresh; 42 | int _top_n; 43 | std::vector _anchors; 44 | float _scale; 45 | 46 | size_t _height; 47 | size_t _width; 48 | size_t _num_anchors; 49 | size_t _num_classes; 50 | mutable int size = -1; 51 | 52 | protected: 53 | void deserialize(void const* data, size_t length) { 54 | const char* d = static_cast(data); 55 | read(d, _score_thresh); 56 | read(d, _top_n); 57 | size_t anchors_size; 58 | read(d, anchors_size); 59 | while( anchors_size-- ) { 60 | float val; 61 | read(d, val); 62 | _anchors.push_back(val); 63 | } 64 | read(d, _scale); 65 | read(d, _height); 66 | read(d, _width); 67 | read(d, _num_anchors); 68 | read(d, _num_classes); 69 | } 70 | 71 | size_t getSerializationSize() const noexcept override { 72 | return sizeof(_score_thresh) + sizeof(_top_n) 73 | + sizeof(size_t) + sizeof(float) * _anchors.size() + sizeof(_scale) 74 | + sizeof(_height) + sizeof(_width) + sizeof(_num_anchors) + sizeof(_num_classes); 75 | } 76 | 77 | void serialize(void *buffer) const noexcept override { 78 | char* d = static_cast(buffer); 79 | write(d, _score_thresh); 80 | write(d, _top_n); 81 | write(d, _anchors.size()); 82 | for( auto &val : _anchors ) { 83 | write(d, val); 84 | } 85 | write(d, _scale); 86 | write(d, _height); 87 | write(d, _width); 88 | write(d, _num_anchors); 89 | write(d, _num_classes); 90 | } 91 | 92 | public: 93 | DecodeRotatePlugin(float score_thresh, int top_n, std::vector const& anchors, int scale) 94 | : _score_thresh(score_thresh), _top_n(top_n), _anchors(anchors), _scale(scale) {} 95 | 96 | DecodeRotatePlugin(float score_thresh, int top_n, std::vector const& anchors, int scale, 97 | size_t height, size_t width, size_t num_anchors, size_t num_classes) 98 | : _score_thresh(score_thresh), _top_n(top_n), _anchors(anchors), _scale(scale), 99 | _height(height), _width(width), _num_anchors(num_anchors), _num_classes(num_classes) {} 100 | 101 | DecodeRotatePlugin(void const* data, size_t length) { 102 | this->deserialize(data, length); 103 | } 104 | 105 | const char *getPluginType() const noexcept override { 106 | return RETINANET_PLUGIN_NAME; 107 | } 108 | 109 | const char *getPluginVersion() const noexcept override { 110 | return RETINANET_PLUGIN_VERSION; 111 | } 112 | 113 | int getNbOutputs() const noexcept override { 114 | return 3; 115 | } 116 | 117 | DimsExprs getOutputDimensions(int outputIndex, const DimsExprs *inputs, 118 | int nbInputs, IExprBuilder &exprBuilder) noexcept override 119 | { 120 | DimsExprs output(inputs[0]); 121 | output.d[1] = exprBuilder.constant(_top_n * (outputIndex == 1 ? 6 : 1)); 122 | output.d[2] = exprBuilder.constant(1); 123 | output.d[3] = exprBuilder.constant(1); 124 | 125 | return output; 126 | } 127 | 128 | 129 | bool supportsFormatCombination(int pos, const PluginTensorDesc *inOut, 130 | int nbInputs, int nbOutputs) noexcept override 131 | { 132 | assert(nbInputs == 2); 133 | assert(nbOutputs == 3); 134 | assert(pos < 5); 135 | return inOut[pos].type == DataType::kFLOAT && inOut[pos].format == nvinfer1::PluginFormat::kLINEAR; 136 | } 137 | 138 | 139 | int initialize() noexcept override { return 0; } 140 | 141 | void terminate() noexcept override {} 142 | 143 | size_t getWorkspaceSize(const PluginTensorDesc *inputs, 144 | int nbInputs, const PluginTensorDesc *outputs, int nbOutputs) const noexcept override 145 | { 146 | if (size < 0) { 147 | size = cuda::decode_rotate(inputs->dims.d[0], nullptr, nullptr, _height, _width, _scale, 148 | _num_anchors, _num_classes, _anchors, _score_thresh, _top_n, 149 | nullptr, 0, nullptr); 150 | } 151 | return size; 152 | } 153 | 154 | int enqueue(const PluginTensorDesc *inputDesc, 155 | const PluginTensorDesc *outputDesc, const void *const *inputs, 156 | void *const *outputs, void *workspace, cudaStream_t stream) noexcept override 157 | { 158 | return cuda::decode_rotate(inputDesc->dims.d[0], inputs, outputs, _height, _width, _scale, 159 | _num_anchors, _num_classes, _anchors, _score_thresh, _top_n, 160 | workspace, getWorkspaceSize(inputDesc, 2, outputDesc, 3), stream); 161 | } 162 | 163 | void destroy() noexcept override { 164 | delete this; 165 | }; 166 | 167 | const char *getPluginNamespace() const noexcept override { 168 | return RETINANET_PLUGIN_NAMESPACE; 169 | } 170 | 171 | void setPluginNamespace(const char *N) noexcept override {} 172 | 173 | DataType getOutputDataType(int index, const DataType* inputTypes, int nbInputs) const noexcept 174 | { 175 | assert(index < 3); 176 | return DataType::kFLOAT; 177 | } 178 | 179 | void configurePlugin(const DynamicPluginTensorDesc *in, int nbInputs, 180 | const DynamicPluginTensorDesc *out, int nbOutputs) noexcept 181 | { 182 | assert(nbInputs == 2); 183 | assert(nbOutputs == 3); 184 | auto const& scores_dims = in[0].desc.dims; 185 | auto const& boxes_dims = in[1].desc.dims; 186 | assert(scores_dims.d[2] == boxes_dims.d[2]); 187 | assert(scores_dims.d[3] == boxes_dims.d[3]); 188 | _height = scores_dims.d[2]; 189 | _width = scores_dims.d[3]; 190 | _num_anchors = boxes_dims.d[1] / 6; 191 | _num_classes = scores_dims.d[1] / _num_anchors; 192 | } 193 | 194 | IPluginV2DynamicExt *clone() const noexcept override { 195 | return new DecodeRotatePlugin(_score_thresh, _top_n, _anchors, _scale, _height, _width, 196 | _num_anchors, _num_classes); 197 | } 198 | 199 | private: 200 | template void write(char*& buffer, const T& val) const { 201 | *reinterpret_cast(buffer) = val; 202 | buffer += sizeof(T); 203 | } 204 | 205 | template void read(const char*& buffer, T& val) { 206 | val = *reinterpret_cast(buffer); 207 | buffer += sizeof(T); 208 | } 209 | }; 210 | 211 | class DecodeRotatePluginCreator : public IPluginCreator { 212 | public: 213 | DecodeRotatePluginCreator() {} 214 | 215 | const char *getPluginName () const noexcept override { 216 | return RETINANET_PLUGIN_NAME; 217 | } 218 | 219 | const char *getPluginVersion () const noexcept override { 220 | return RETINANET_PLUGIN_VERSION; 221 | } 222 | 223 | const char *getPluginNamespace() const noexcept override { 224 | return RETINANET_PLUGIN_NAMESPACE; 225 | } 226 | 227 | IPluginV2DynamicExt *deserializePlugin (const char *name, const void *serialData, size_t serialLength) noexcept override { 228 | return new DecodeRotatePlugin(serialData, serialLength); 229 | } 230 | 231 | void setPluginNamespace(const char *N) noexcept override {} 232 | const PluginFieldCollection *getFieldNames() noexcept override { return nullptr; } 233 | IPluginV2DynamicExt *createPlugin (const char *name, const PluginFieldCollection *fc) noexcept override { return nullptr; } 234 | }; 235 | 236 | REGISTER_TENSORRT_PLUGIN(DecodeRotatePluginCreator); 237 | 238 | } 239 | 240 | #undef RETINANET_PLUGIN_NAME 241 | #undef RETINANET_PLUGIN_VERSION 242 | #undef RETINANET_PLUGIN_NAMESPACE 243 | -------------------------------------------------------------------------------- /csrc/plugins/NMSPlugin.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Permission is hereby granted, free of charge, to any person obtaining a 5 | * copy of this software and associated documentation files (the "Software"), 6 | * to deal in the Software without restriction, including without limitation 7 | * the rights to use, copy, modify, merge, publish, distribute, sublicense, 8 | * and/or sell copies of the Software, and to permit persons to whom the 9 | * Software is furnished to do so, subject to the following conditions: 10 | * 11 | * The above copyright notice and this permission notice shall be included in 12 | * all copies or substantial portions of the Software. 13 | * 14 | * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 | * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 17 | * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 | * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 19 | * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 20 | * DEALINGS IN THE SOFTWARE. 21 | */ 22 | 23 | #pragma once 24 | 25 | #include 26 | 27 | #include 28 | #include 29 | 30 | #include "../cuda/nms.h" 31 | 32 | using namespace nvinfer1; 33 | 34 | #define RETINANET_PLUGIN_NAME "RetinaNetNMS" 35 | #define RETINANET_PLUGIN_VERSION "1" 36 | #define RETINANET_PLUGIN_NAMESPACE "" 37 | 38 | namespace odtk { 39 | 40 | class NMSPlugin : public IPluginV2DynamicExt { 41 | float _nms_thresh; 42 | int _detections_per_im; 43 | 44 | size_t _count; 45 | mutable int size = -1; 46 | 47 | protected: 48 | void deserialize(void const* data, size_t length) { 49 | const char* d = static_cast(data); 50 | read(d, _nms_thresh); 51 | read(d, _detections_per_im); 52 | read(d, _count); 53 | } 54 | 55 | size_t getSerializationSize() const noexcept override { 56 | return sizeof(_nms_thresh) + sizeof(_detections_per_im) 57 | + sizeof(_count); 58 | } 59 | 60 | void serialize(void *buffer) const noexcept override { 61 | char* d = static_cast(buffer); 62 | write(d, _nms_thresh); 63 | write(d, _detections_per_im); 64 | write(d, _count); 65 | } 66 | 67 | public: 68 | NMSPlugin(float nms_thresh, int detections_per_im) 69 | : _nms_thresh(nms_thresh), _detections_per_im(detections_per_im) { 70 | assert(nms_thresh > 0); 71 | assert(detections_per_im > 0); 72 | } 73 | 74 | NMSPlugin(float nms_thresh, int detections_per_im, size_t count) 75 | : _nms_thresh(nms_thresh), _detections_per_im(detections_per_im), _count(count) { 76 | assert(nms_thresh > 0); 77 | assert(detections_per_im > 0); 78 | assert(count > 0); 79 | } 80 | 81 | NMSPlugin(void const* data, size_t length) { 82 | this->deserialize(data, length); 83 | } 84 | 85 | const char *getPluginType() const noexcept override { 86 | return RETINANET_PLUGIN_NAME; 87 | } 88 | 89 | const char *getPluginVersion() const noexcept override { 90 | return RETINANET_PLUGIN_VERSION; 91 | } 92 | 93 | int getNbOutputs() const noexcept override { 94 | return 3; 95 | } 96 | 97 | DimsExprs getOutputDimensions(int outputIndex, const DimsExprs *inputs, 98 | int nbInputs, IExprBuilder &exprBuilder) noexcept override 99 | { 100 | DimsExprs output(inputs[0]); 101 | output.d[1] = exprBuilder.constant(_detections_per_im * (outputIndex == 1 ? 4 : 1)); 102 | output.d[2] = exprBuilder.constant(1); 103 | output.d[3] = exprBuilder.constant(1); 104 | return output; 105 | } 106 | 107 | bool supportsFormatCombination(int pos, const PluginTensorDesc *inOut, 108 | int nbInputs, int nbOutputs) noexcept override 109 | { 110 | assert(nbInputs == 3); 111 | assert(nbOutputs == 3); 112 | assert(pos < 6); 113 | return inOut[pos].type == DataType::kFLOAT && inOut[pos].format == nvinfer1::PluginFormat::kLINEAR; 114 | } 115 | 116 | int initialize() noexcept override { return 0; } 117 | 118 | void terminate() noexcept override {} 119 | 120 | size_t getWorkspaceSize(const PluginTensorDesc *inputs, 121 | int nbInputs, const PluginTensorDesc *outputs, int nbOutputs) const noexcept override 122 | { 123 | if (size < 0) { 124 | size = cuda::nms(inputs->dims.d[0], nullptr, nullptr, _count, 125 | _detections_per_im, _nms_thresh, 126 | nullptr, 0, nullptr); 127 | } 128 | return size; 129 | } 130 | 131 | int enqueue(const PluginTensorDesc *inputDesc, 132 | const PluginTensorDesc *outputDesc, const void *const *inputs, 133 | void *const *outputs, void *workspace, cudaStream_t stream) noexcept override 134 | { 135 | return cuda::nms(inputDesc->dims.d[0], inputs, outputs, _count, 136 | _detections_per_im, _nms_thresh, 137 | workspace, getWorkspaceSize(inputDesc, 3, outputDesc, 3), stream); 138 | } 139 | 140 | void destroy() noexcept override { 141 | delete this; 142 | } 143 | 144 | const char *getPluginNamespace() const noexcept override { 145 | return RETINANET_PLUGIN_NAMESPACE; 146 | } 147 | 148 | void setPluginNamespace(const char *N) noexcept override {} 149 | 150 | DataType getOutputDataType(int index, const DataType* inputTypes, int nbInputs) const noexcept 151 | { 152 | assert(index < 3); 153 | return DataType::kFLOAT; 154 | } 155 | 156 | void configurePlugin(const DynamicPluginTensorDesc *in, int nbInputs, 157 | const DynamicPluginTensorDesc *out, int nbOutputs) noexcept 158 | { 159 | assert(nbInputs == 3); 160 | assert(in[0].desc.dims.d[1] == in[2].desc.dims.d[1]); 161 | assert(in[1].desc.dims.d[1] == in[2].desc.dims.d[1] * 4); 162 | _count = in[0].desc.dims.d[1]; 163 | } 164 | 165 | IPluginV2DynamicExt *clone() const noexcept override { 166 | return new NMSPlugin(_nms_thresh, _detections_per_im, _count); 167 | } 168 | 169 | 170 | private: 171 | template void write(char*& buffer, const T& val) const { 172 | *reinterpret_cast(buffer) = val; 173 | buffer += sizeof(T); 174 | } 175 | 176 | template void read(const char*& buffer, T& val) { 177 | val = *reinterpret_cast(buffer); 178 | buffer += sizeof(T); 179 | } 180 | }; 181 | 182 | class NMSPluginCreator : public IPluginCreator { 183 | public: 184 | NMSPluginCreator() {} 185 | 186 | const char *getPluginNamespace() const noexcept override { 187 | return RETINANET_PLUGIN_NAMESPACE; 188 | } 189 | const char *getPluginName () const noexcept override { 190 | return RETINANET_PLUGIN_NAME; 191 | } 192 | 193 | const char *getPluginVersion () const noexcept override { 194 | return RETINANET_PLUGIN_VERSION; 195 | } 196 | 197 | //Was IPluginV2 198 | IPluginV2DynamicExt *deserializePlugin (const char *name, const void *serialData, size_t serialLength) noexcept override { 199 | return new NMSPlugin(serialData, serialLength); 200 | } 201 | 202 | //Was IPluginV2 203 | void setPluginNamespace(const char *N) noexcept override {} 204 | const PluginFieldCollection *getFieldNames() noexcept override { return nullptr; } 205 | IPluginV2DynamicExt *createPlugin (const char *name, const PluginFieldCollection *fc) noexcept override { return nullptr; } 206 | }; 207 | 208 | REGISTER_TENSORRT_PLUGIN(NMSPluginCreator); 209 | 210 | } 211 | 212 | #undef RETINANET_PLUGIN_NAME 213 | #undef RETINANET_PLUGIN_VERSION 214 | #undef RETINANET_PLUGIN_NAMESPACE -------------------------------------------------------------------------------- /csrc/plugins/NMSRotatePlugin.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Permission is hereby granted, free of charge, to any person obtaining a 5 | * copy of this software and associated documentation files (the "Software"), 6 | * to deal in the Software without restriction, including without limitation 7 | * the rights to use, copy, modify, merge, publish, distribute, sublicense, 8 | * and/or sell copies of the Software, and to permit persons to whom the 9 | * Software is furnished to do so, subject to the following conditions: 10 | * 11 | * The above copyright notice and this permission notice shall be included in 12 | * all copies or substantial portions of the Software. 13 | * 14 | * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 | * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 17 | * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 | * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 19 | * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 20 | * DEALINGS IN THE SOFTWARE. 21 | */ 22 | 23 | #pragma once 24 | 25 | #include 26 | 27 | #include 28 | #include 29 | 30 | #include "../cuda/nms_iou.h" 31 | 32 | using namespace nvinfer1; 33 | 34 | #define RETINANET_PLUGIN_NAME "RetinaNetNMSRotate" 35 | #define RETINANET_PLUGIN_VERSION "1" 36 | #define RETINANET_PLUGIN_NAMESPACE "" 37 | 38 | namespace odtk { 39 | 40 | class NMSRotatePlugin : public IPluginV2DynamicExt { 41 | float _nms_thresh; 42 | int _detections_per_im; 43 | 44 | size_t _count; 45 | mutable int size = -1; 46 | 47 | protected: 48 | void deserialize(void const* data, size_t length) { 49 | const char* d = static_cast(data); 50 | read(d, _nms_thresh); 51 | read(d, _detections_per_im); 52 | read(d, _count); 53 | } 54 | 55 | size_t getSerializationSize() const noexcept override { 56 | return sizeof(_nms_thresh) + sizeof(_detections_per_im) 57 | + sizeof(_count); 58 | } 59 | 60 | void serialize(void *buffer) const noexcept override { 61 | char* d = static_cast(buffer); 62 | write(d, _nms_thresh); 63 | write(d, _detections_per_im); 64 | write(d, _count); 65 | } 66 | 67 | public: 68 | NMSRotatePlugin(float nms_thresh, int detections_per_im) 69 | : _nms_thresh(nms_thresh), _detections_per_im(detections_per_im) { 70 | assert(nms_thresh > 0); 71 | assert(detections_per_im > 0); 72 | } 73 | 74 | NMSRotatePlugin(float nms_thresh, int detections_per_im, size_t count) 75 | : _nms_thresh(nms_thresh), _detections_per_im(detections_per_im), _count(count) { 76 | assert(nms_thresh > 0); 77 | assert(detections_per_im > 0); 78 | assert(count > 0); 79 | } 80 | 81 | NMSRotatePlugin(void const* data, size_t length) { 82 | this->deserialize(data, length); 83 | } 84 | 85 | const char *getPluginType() const noexcept override { 86 | return RETINANET_PLUGIN_NAME; 87 | } 88 | 89 | const char *getPluginVersion() const noexcept override { 90 | return RETINANET_PLUGIN_VERSION; 91 | } 92 | 93 | int getNbOutputs() const noexcept override { 94 | return 3; 95 | } 96 | 97 | DimsExprs getOutputDimensions(int outputIndex, const DimsExprs *inputs, 98 | int nbInputs, IExprBuilder &exprBuilder) noexcept override 99 | { 100 | DimsExprs output(inputs[0]); 101 | output.d[1] = exprBuilder.constant(_detections_per_im * (outputIndex == 1 ? 6 : 1)); 102 | output.d[2] = exprBuilder.constant(1); 103 | output.d[3] = exprBuilder.constant(1); 104 | return output; 105 | } 106 | 107 | bool supportsFormatCombination(int pos, const PluginTensorDesc *inOut, 108 | int nbInputs, int nbOutputs) noexcept override 109 | { 110 | assert(nbInputs == 3); 111 | assert(nbOutputs == 3); 112 | assert(pos < 6); 113 | return inOut[pos].type == DataType::kFLOAT && inOut[pos].format == nvinfer1::PluginFormat::kLINEAR; 114 | } 115 | 116 | int initialize() noexcept override { return 0; } 117 | 118 | void terminate() noexcept override {} 119 | 120 | size_t getWorkspaceSize(const PluginTensorDesc *inputs, 121 | int nbInputs, const PluginTensorDesc *outputs, int nbOutputs) const noexcept override 122 | { 123 | if (size < 0) { 124 | size = cuda::nms_rotate(inputs->dims.d[0], nullptr, nullptr, _count, 125 | _detections_per_im, _nms_thresh, 126 | nullptr, 0, nullptr); 127 | } 128 | return size; 129 | } 130 | 131 | int enqueue(const PluginTensorDesc *inputDesc, 132 | const PluginTensorDesc *outputDesc, const void *const *inputs, 133 | void *const *outputs, void *workspace, cudaStream_t stream) noexcept override 134 | { 135 | return cuda::nms_rotate(inputDesc->dims.d[0], inputs, outputs, _count, 136 | _detections_per_im, _nms_thresh, 137 | workspace, getWorkspaceSize(inputDesc, 3, outputDesc, 3), stream); 138 | } 139 | 140 | void destroy() noexcept override { 141 | delete this; 142 | } 143 | 144 | const char *getPluginNamespace() const noexcept override { 145 | return RETINANET_PLUGIN_NAMESPACE; 146 | } 147 | 148 | void setPluginNamespace(const char *N) noexcept override {} 149 | 150 | DataType getOutputDataType(int index, const DataType* inputTypes, int nbInputs) const noexcept 151 | { 152 | assert(index < 3); 153 | return DataType::kFLOAT; 154 | } 155 | 156 | 157 | void configurePlugin(const DynamicPluginTensorDesc *in, int nbInputs, 158 | const DynamicPluginTensorDesc *out, int nbOutputs) noexcept 159 | { 160 | assert(nbInputs == 3); 161 | assert(in[0].desc.dims.d[1] == in[2].desc.dims.d[1]); 162 | assert(in[1].desc.dims.d[1] == in[2].desc.dims.d[1] * 6); 163 | _count = in[0].desc.dims.d[1]; 164 | } 165 | 166 | IPluginV2DynamicExt *clone() const noexcept override { 167 | return new NMSRotatePlugin(_nms_thresh, _detections_per_im, _count); 168 | } 169 | 170 | private: 171 | template void write(char*& buffer, const T& val) const { 172 | *reinterpret_cast(buffer) = val; 173 | buffer += sizeof(T); 174 | } 175 | 176 | template void read(const char*& buffer, T& val) { 177 | val = *reinterpret_cast(buffer); 178 | buffer += sizeof(T); 179 | } 180 | }; 181 | 182 | class NMSRotatePluginCreator : public IPluginCreator { 183 | public: 184 | NMSRotatePluginCreator() {} 185 | 186 | const char *getPluginNamespace() const noexcept override { 187 | return RETINANET_PLUGIN_NAMESPACE; 188 | } 189 | const char *getPluginName () const noexcept override { 190 | return RETINANET_PLUGIN_NAME; 191 | } 192 | 193 | const char *getPluginVersion () const noexcept override { 194 | return RETINANET_PLUGIN_VERSION; 195 | } 196 | 197 | IPluginV2DynamicExt *deserializePlugin (const char *name, const void *serialData, size_t serialLength) noexcept override { 198 | return new NMSRotatePlugin(serialData, serialLength); 199 | } 200 | 201 | void setPluginNamespace(const char *N) noexcept override {} 202 | const PluginFieldCollection *getFieldNames() noexcept override { return nullptr; } 203 | IPluginV2DynamicExt *createPlugin (const char *name, const PluginFieldCollection *fc) noexcept override { return nullptr; } 204 | }; 205 | 206 | REGISTER_TENSORRT_PLUGIN(NMSRotatePluginCreator); 207 | 208 | } 209 | 210 | #undef RETINANET_PLUGIN_NAME 211 | #undef RETINANET_PLUGIN_VERSION 212 | #undef RETINANET_PLUGIN_NAMESPACE 213 | -------------------------------------------------------------------------------- /extras/cppapi/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.9 FATAL_ERROR) 2 | 3 | project(odtk_infer LANGUAGES CXX) 4 | set(CMAKE_CXX_STANDARD 14) 5 | find_package(CUDA REQUIRED) 6 | enable_language(CUDA) 7 | find_package(OpenCV REQUIRED) 8 | 9 | if(DEFINED TensorRT_DIR) 10 | include_directories("${TensorRT_DIR}/include") 11 | link_directories("${TensorRT_DIR}/lib") 12 | endif(DEFINED TensorRT_DIR) 13 | include_directories(${CUDA_INCLUDE_DIRS}) 14 | 15 | add_library(odtk SHARED 16 | ../../csrc/cuda/decode.h 17 | ../../csrc/cuda/decode.cu 18 | ../../csrc/cuda/nms.h 19 | ../../csrc/cuda/nms.cu 20 | ../../csrc/cuda/decode_rotate.h 21 | ../../csrc/cuda/decode_rotate.cu 22 | ../../csrc/cuda/nms_iou.h 23 | ../../csrc/cuda/nms_iou.cu 24 | ../../csrc/cuda/utils.h 25 | ../../csrc/engine.h 26 | ../../csrc/engine.cpp 27 | ../../csrc/calibrator.h 28 | ) 29 | set_target_properties(odtk PROPERTIES 30 | CUDA_RESOLVE_DEVICE_SYMBOLS ON 31 | CUDA_ARCHITECTURES 60 61 70 72 75 80 86 32 | ) 33 | include_directories(${OpenCV_INCLUDE_DIRS}) 34 | target_link_libraries(odtk PUBLIC nvinfer nvonnxparser ${OpenCV_LIBS}) 35 | 36 | add_executable(export export.cpp) 37 | include_directories(${OpenCV_INCLUDE_DIRS}) 38 | target_link_libraries(export PRIVATE odtk ${OpenCV_LIBS}) 39 | 40 | add_executable(infer infer.cpp) 41 | include_directories(${OpenCV_INCLUDE_DIRS}) 42 | target_link_libraries(infer PRIVATE odtk ${OpenCV_LIBS} cuda ${CUDA_LIBRARIES}) 43 | 44 | if(CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64") 45 | add_executable(infervideo infervideo.cpp) 46 | include_directories(${OpenCV_INCLUDE_DIRS}) 47 | target_link_libraries(infervideo PRIVATE odtk ${OpenCV_LIBS} cuda ${CUDA_LIBRARIES}) 48 | endif() 49 | -------------------------------------------------------------------------------- /extras/cppapi/README.md: -------------------------------------------------------------------------------- 1 | # RetinaNet C++ Inference API - Sample Code 2 | 3 | The C++ API allows you to build a TensorRT engine for inference using the ONNX export of a core model. 4 | 5 | The following shows how to build and run code samples for exporting an ONNX core model (from RetinaNet or other toolkit supporting the same sort of core model structure) to a TensorRT engine and doing inference on images. 6 | 7 | ## Building 8 | 9 | Building the example requires the following toolkits and libraries to be set up properly on your system: 10 | * A proper C++ toolchain (MSVS on Windows) 11 | * [CMake](https://cmake.org/download/) version 3.9 or later 12 | * NVIDIA [CUDA](https://developer.nvidia.com/cuda-toolkit) 13 | * NVIDIA [CuDNN](https://developer.nvidia.com/cudnn) 14 | * NVIDIA [TensorRT](https://developer.nvidia.com/tensorrt) 15 | * [OpenCV](https://opencv.org/releases.html) 16 | 17 | ### Linux 18 | ```bash 19 | mkdir build && cd build 20 | cmake -DCMAKE_CUDA_FLAGS="--expt-extended-lambda -std=c++14" .. 21 | make 22 | ``` 23 | 24 | ### Windows 25 | ```bash 26 | mkdir build && cd build 27 | cmake -G "Visual Studio 15 2017" -A x64 -T host=x64,cuda=10.0 -DTensorRT_DIR="C:\path\to\tensorrt" -DOpenCV_DIR="C:\path\to\opencv\build" .. 28 | msbuild odtk_infer.sln 29 | ``` 30 | 31 | ## Running 32 | 33 | If you don't have an ONNX core model, generate one from your RetinaNet model: 34 | ```bash 35 | odtk export model.pth model.onnx 36 | ``` 37 | 38 | Load the ONNX core model and export it to a RetinaNet TensorRT engine (using FP16 precision): 39 | ```bash 40 | export{.exe} model.onnx engine.plan 41 | ``` 42 | 43 | You can also export the ONNX core model to an INT8 TensorRT engine if you have already done INT8 calibration: 44 | ```bash 45 | export{.exe} model.onnx engine.plan INT8CalibrationTable 46 | ``` 47 | 48 | Run a test inference (default output if none provided: "detections.png"): 49 | ```bash 50 | infer{.exe} engine.plan image.jpg [.png] 51 | ``` 52 | 53 | Note: make sure the TensorRT, CuDNN and OpenCV libraries are available in your environment and path. 54 | 55 | We have verified these steps with the following configurations: 56 | * DGX-1V using the provided Docker container (CUDA 10, cuDNN 7.4.2, TensorRT 5.0.2, OpenCV 3.4.3) 57 | * Jetson AGX Xavier with JetPack 4.1.1 Developer Preview (CUDA 10, cuDNN 7.3.1, TensorRT 5.0.3, OpenCV 3.3.1) 58 | 59 | 60 | 61 | 62 | -------------------------------------------------------------------------------- /extras/cppapi/export.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | 7 | #include "../../csrc/engine.h" 8 | 9 | #define ROTATED false // Change to true for Rotated Bounding Box export 10 | #define COCO_PATH "/coco/coco2017/val2017" // Path to calibration images 11 | 12 | using namespace std; 13 | 14 | // Sample program to build a TensorRT Engine from an ONNX model from RetinaNet 15 | // 16 | // By default TensorRT will target FP16 precision (supported on Pascal, Volta, and Turing GPUs) 17 | // 18 | // You can optionally provide an INT8CalibrationTable file created during RetinaNet INT8 calibration 19 | // to build a TensorRT engine with INT8 precision 20 | 21 | inline vector glob(int batch){ 22 | glob_t glob_result; 23 | string path = string(COCO_PATH); 24 | if(path.back()!='/') path+="/"; 25 | glob((path+"*").c_str(), (GLOB_TILDE | GLOB_NOSORT), NULL, &glob_result); 26 | vector calibration_files; 27 | for(int i=0; i(new char[size]); 53 | onnxFile.read(buffer.get(), size); 54 | onnxFile.close(); 55 | 56 | // Define default RetinaNet parameters to use for TRT export 57 | const vector dynamic_batch_opts{1, 8, 16}; 58 | int calibration_batches = 2; // must be >= 1 59 | float score_thresh = 0.05f; 60 | int top_n = 1000; 61 | size_t workspace_size =(1ULL << 30); 62 | float nms_thresh = 0.5; 63 | int detections_per_im = 100; 64 | bool verbose = false; 65 | // Generated from generate_anchors.py 66 | vector> anchors; 67 | if(!ROTATED) { 68 | // Axis-aligned 69 | anchors = { 70 | {-12.0, -12.0, 20.0, 20.0, -7.31, -18.63, 15.31, 26.63, -18.63, -7.31, 26.63, 15.31, -16.16, -16.16, 24.16, 24.16, -10.25, -24.51, 18.25, 32.51, -24.51, -10.25, 32.51, 18.25, -21.4, -21.4, 29.4, 29.4, -13.96, -31.92, 21.96, 39.92, -31.92, -13.96, 39.92, 21.96}, 71 | {-24.0, -24.0, 40.0, 40.0, -14.63, -37.25, 30.63, 53.25, -37.25, -14.63, 53.25, 30.63, -32.32, -32.32, 48.32, 48.32, -20.51, -49.02, 36.51, 65.02, -49.02, -20.51, 65.02, 36.51, -42.8, -42.8, 58.8, 58.8, -27.92, -63.84, 43.92, 79.84, -63.84, -27.92, 79.84, 43.92}, 72 | {-48.0, -48.0, 80.0, 80.0, -29.25, -74.51, 61.25, 106.51, -74.51, -29.25, 106.51, 61.25, -64.63, -64.63, 96.63, 96.63, -41.02, -98.04, 73.02, 130.04, -98.04, -41.02, 130.04, 73.02, -85.59, -85.59, 117.59, 117.59, -55.84, -127.68, 87.84, 159.68, -127.68, -55.84, 159.68, 87.84}, 73 | {-96.0, -96.0, 160.0, 160.0, -58.51, -149.02, 122.51, 213.02, -149.02, -58.51, 213.02, 122.51, -129.27, -129.27, 193.27, 193.27, -82.04, -196.07, 146.04, 260.07, -196.07, -82.04, 260.07, 146.04, -171.19, -171.19, 235.19, 235.19, -111.68, -255.35, 175.68, 319.35, -255.35, -111.68, 319.35, 175.68}, 74 | {-192.0, -192.0, 320.0, 320.0, -117.02, -298.04, 245.02, 426.04, -298.04, -117.02, 426.04, 245.02, -258.54, -258.54, 386.54, 386.54, -164.07, -392.14, 292.07, 520.14, -392.14, -164.07, 520.14, 292.07, -342.37, -342.37, 470.37, 470.37, -223.35, -510.7, 351.35, 638.7, -510.7, -223.35, 638.7, 351.35} 75 | }; 76 | } 77 | else { 78 | // Rotated-bboxes 79 | anchors = { 80 | {-12.0, 0.0, 19.0, 7.0, -7.0, -2.0, 14.0, 9.0, -4.0, -4.0, 11.0, 11.0, -2.0, -8.0, 9.0, 15.0, 0.0, -12.0, 7.0, 19.0, -21.4, -2.35, 28.4, 9.35, -13.46, -5.52, 20.46, 12.52, -8.7, -8.7, 15.7, 15.7, -5.52, -15.05, 12.52, 22.05, -2.35, -21.4, 9.35, 28.4, -36.32, -6.08, 43.32, 13.08, -23.72, -11.12, 30.72, 18.12, -16.16, -16.16, 23.16, 23.16, -11.12, -26.24, 18.12, 33.24, -6.08, -36.32, 13.08, 43.32, -12.0, 0.0, 19.0, 7.0, -7.0, -2.0, 14.0, 9.0, -4.0, -4.0, 11.0, 11.0, -2.0, -8.0, 9.0, 15.0, 0.0, -12.0, 7.0, 19.0, -21.4, -2.35, 28.4, 9.35, -13.46, -5.52, 20.46, 12.52, -8.7, -8.7, 15.7, 15.7, -5.52, -15.05, 12.52, 22.05, -2.35, -21.4, 9.35, 28.4, -36.32, -6.08, 43.32, 13.08, -23.72, -11.12, 30.72, 18.12, -16.16, -16.16, 23.16, 23.16, -11.12, -26.24, 18.12, 33.24, -6.08, -36.32, 13.08, 43.32, -12.0, 0.0, 19.0, 7.0, -7.0, -2.0, 14.0, 9.0, -4.0, -4.0, 11.0, 11.0, -2.0, -8.0, 9.0, 15.0, 0.0, -12.0, 7.0, 19.0, -21.4, -2.35, 28.4, 9.35, -13.46, -5.52, 20.46, 12.52, -8.7, -8.7, 15.7, 15.7, -5.52, -15.05, 12.52, 22.05, -2.35, -21.4, 9.35, 28.4, -36.32, -6.08, 43.32, 13.08, -23.72, -11.12, 30.72, 18.12, -16.16, -16.16, 23.16, 23.16, -11.12, -26.24, 18.12, 33.24, -6.08, -36.32, 13.08, 43.32}, 81 | {-24.0, 0.0, 39.0, 15.0, -15.0, -4.0, 30.0, 19.0, -8.0, -8.0, 23.0, 23.0, -3.0, -14.0, 18.0, 29.0, 0.0, -24.0, 15.0, 39.0, -42.8, -4.7, 57.8, 19.7, -28.51, -11.05, 43.51, 26.05, -17.4, -17.4, 32.4, 32.4, -9.46, -26.92, 24.46, 41.92, -4.7, -42.8, 19.7, 57.8, -72.63, -12.16, 87.63, 27.16, -49.96, -22.24, 64.96, 37.24, -32.32, -32.32, 47.32, 47.32, -19.72, -47.44, 34.72, 62.44, -12.16, -72.63, 27.16, 87.63, -24.0, 0.0, 39.0, 15.0, -15.0, -4.0, 30.0, 19.0, -8.0, -8.0, 23.0, 23.0, -3.0, -14.0, 18.0, 29.0, 0.0, -24.0, 15.0, 39.0, -42.8, -4.7, 57.8, 19.7, -28.51, -11.05, 43.51, 26.05, -17.4, -17.4, 32.4, 32.4, -9.46, -26.92, 24.46, 41.92, -4.7, -42.8, 19.7, 57.8, -72.63, -12.16, 87.63, 27.16, -49.96, -22.24, 64.96, 37.24, -32.32, -32.32, 47.32, 47.32, -19.72, -47.44, 34.72, 62.44, -12.16, -72.63, 27.16, 87.63, -24.0, 0.0, 39.0, 15.0, -15.0, -4.0, 30.0, 19.0, -8.0, -8.0, 23.0, 23.0, -3.0, -14.0, 18.0, 29.0, 0.0, -24.0, 15.0, 39.0, -42.8, -4.7, 57.8, 19.7, -28.51, -11.05, 43.51, 26.05, -17.4, -17.4, 32.4, 32.4, -9.46, -26.92, 24.46, 41.92, -4.7, -42.8, 19.7, 57.8, -72.63, -12.16, 87.63, 27.16, -49.96, -22.24, 64.96, 37.24, -32.32, -32.32, 47.32, 47.32, -19.72, -47.44, 34.72, 62.44, -12.16, -72.63, 27.16, 87.63}, 82 | {-48.0, 0.0, 79.0, 31.0, -29.0, -6.0, 60.0, 37.0, -16.0, -16.0, 47.0, 47.0, -7.0, -30.0, 38.0, 61.0, 0.0, -48.0, 31.0, 79.0, -85.59, -9.4, 116.59, 40.4, -55.43, -18.92, 86.43, 49.92, -34.8, -34.8, 65.8, 65.8, -20.51, -57.02, 51.51, 88.02, -9.4, -85.59, 40.4, 116.59, -145.27, -24.32, 176.27, 55.32, -97.39, -39.44, 128.39, 70.44, -64.63, -64.63, 95.63, 95.63, -41.96, -99.91, 72.96, 130.91, -24.32, -145.27, 55.32, 176.27, -48.0, 0.0, 79.0, 31.0, -29.0, -6.0, 60.0, 37.0, -16.0, -16.0, 47.0, 47.0, -7.0, -30.0, 38.0, 61.0, 0.0, -48.0, 31.0, 79.0, -85.59, -9.4, 116.59, 40.4, -55.43, -18.92, 86.43, 49.92, -34.8, -34.8, 65.8, 65.8, -20.51, -57.02, 51.51, 88.02, -9.4, -85.59, 40.4, 116.59, -145.27, -24.32, 176.27, 55.32, -97.39, -39.44, 128.39, 70.44, -64.63, -64.63, 95.63, 95.63, -41.96, -99.91, 72.96, 130.91, -24.32, -145.27, 55.32, 176.27, -48.0, 0.0, 79.0, 31.0, -29.0, -6.0, 60.0, 37.0, -16.0, -16.0, 47.0, 47.0, -7.0, -30.0, 38.0, 61.0, 0.0, -48.0, 31.0, 79.0, -85.59, -9.4, 116.59, 40.4, -55.43, -18.92, 86.43, 49.92, -34.8, -34.8, 65.8, 65.8, -20.51, -57.02, 51.51, 88.02, -9.4, -85.59, 40.4, 116.59, -145.27, -24.32, 176.27, 55.32, -97.39, -39.44, 128.39, 70.44, -64.63, -64.63, 95.63, 95.63, -41.96, -99.91, 72.96, 130.91, -24.32, -145.27, 55.32, 176.27}, 83 | {-96.0, 0.0, 159.0, 63.0, -59.0, -14.0, 122.0, 77.0, -32.0, -32.0, 95.0, 95.0, -13.0, -58.0, 76.0, 121.0, 0.0, -96.0, 63.0, 159.0, -171.19, -18.8, 234.19, 81.8, -112.45, -41.02, 175.45, 104.02, -69.59, -69.59, 132.59, 132.59, -39.43, -110.87, 102.43, 173.87, -18.8, -171.19, 81.8, 234.19, -290.54, -48.63, 353.54, 111.63, -197.31, -83.91, 260.31, 146.91, -129.27, -129.27, 192.27, 192.27, -81.39, -194.79, 144.39, 257.79, -48.63, -290.54, 111.63, 353.54, -96.0, 0.0, 159.0, 63.0, -59.0, -14.0, 122.0, 77.0, -32.0, -32.0, 95.0, 95.0, -13.0, -58.0, 76.0, 121.0, 0.0, -96.0, 63.0, 159.0, -171.19, -18.8, 234.19, 81.8, -112.45, -41.02, 175.45, 104.02, -69.59, -69.59, 132.59, 132.59, -39.43, -110.87, 102.43, 173.87, -18.8, -171.19, 81.8, 234.19, -290.54, -48.63, 353.54, 111.63, -197.31, -83.91, 260.31, 146.91, -129.27, -129.27, 192.27, 192.27, -81.39, -194.79, 144.39, 257.79, -48.63, -290.54, 111.63, 353.54, -96.0, 0.0, 159.0, 63.0, -59.0, -14.0, 122.0, 77.0, -32.0, -32.0, 95.0, 95.0, -13.0, -58.0, 76.0, 121.0, 0.0, -96.0, 63.0, 159.0, -171.19, -18.8, 234.19, 81.8, -112.45, -41.02, 175.45, 104.02, -69.59, -69.59, 132.59, 132.59, -39.43, -110.87, 102.43, 173.87, -18.8, -171.19, 81.8, 234.19, -290.54, -48.63, 353.54, 111.63, -197.31, -83.91, 260.31, 146.91, -129.27, -129.27, 192.27, 192.27, -81.39, -194.79, 144.39, 257.79, -48.63, -290.54, 111.63, 353.54}, 84 | {-192.0, 0.0, 319.0, 127.0, -117.0, -26.0, 244.0, 153.0, -64.0, -64.0, 191.0, 191.0, -27.0, -118.0, 154.0, 245.0, 0.0, -192.0, 127.0, 319.0, -342.37, -37.59, 469.37, 164.59, -223.32, -78.87, 350.32, 205.87, -139.19, -139.19, 266.19, 266.19, -80.45, -224.91, 207.45, 351.91, -37.59, -342.37, 164.59, 469.37, -581.08, -97.27, 708.08, 224.27, -392.09, -162.79, 519.09, 289.79, -258.54, -258.54, 385.54, 385.54, -165.31, -394.61, 292.31, 521.61, -97.27, -581.08, 224.27, 708.08, -192.0, 0.0, 319.0, 127.0, -117.0, -26.0, 244.0, 153.0, -64.0, -64.0, 191.0, 191.0, -27.0, -118.0, 154.0, 245.0, 0.0, -192.0, 127.0, 319.0, -342.37, -37.59, 469.37, 164.59, -223.32, -78.87, 350.32, 205.87, -139.19, -139.19, 266.19, 266.19, -80.45, -224.91, 207.45, 351.91, -37.59, -342.37, 164.59, 469.37, -581.08, -97.27, 708.08, 224.27, -392.09, -162.79, 519.09, 289.79, -258.54, -258.54, 385.54, 385.54, -165.31, -394.61, 292.31, 521.61, -97.27, -581.08, 224.27, 708.08, -192.0, 0.0, 319.0, 127.0, -117.0, -26.0, 244.0, 153.0, -64.0, -64.0, 191.0, 191.0, -27.0, -118.0, 154.0, 245.0, 0.0, -192.0, 127.0, 319.0, -342.37, -37.59, 469.37, 164.59, -223.32, -78.87, 350.32, 205.87, -139.19, -139.19, 266.19, 266.19, -80.45, -224.91, 207.45, 351.91, -37.59, -342.37, 164.59, 469.37, -581.08, -97.27, 708.08, 224.27, -392.09, -162.79, 519.09, 289.79, -258.54, -258.54, 385.54, 385.54, -165.31, -394.61, 292.31, 521.61, -97.27, -581.08, 224.27, 708.08} 85 | }; 86 | } 87 | 88 | // For INT8 calibration, after setting COCO_PATH on line 10: 89 | // const vector calibration_files = glob(calibration_batches*dynamic_batch_opts[1]); 90 | const vector calibration_files; 91 | string model_name = ""; 92 | string calibration_table = argc == 4 ? string(argv[3]) : ""; 93 | 94 | // Use FP16 precision by default, use INT8 if calibration table is provided 95 | string precision = "FP16"; 96 | if (argc == 4) 97 | precision = "INT8"; 98 | 99 | cout << "Building engine..." << endl; 100 | auto engine = std::unique_ptr(new odtk::Engine(buffer.get(), size, dynamic_batch_opts, precision, score_thresh, top_n, 101 | anchors, ROTATED, nms_thresh, detections_per_im, calibration_files, model_name, calibration_table, verbose, workspace_size)); 102 | engine->save(string(argv[2])); 103 | 104 | return 0; 105 | } 106 | -------------------------------------------------------------------------------- /extras/cppapi/generate_anchors.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from odtk.box import generate_anchors, generate_anchors_rotated 3 | 4 | # Generates anchors for export.cpp 5 | 6 | # ratios = [1.0, 2.0, 0.5] 7 | # scales = [4 * 2 ** (i / 3) for i in range(3)] 8 | ratios = [0.25, 0.5, 1.0, 2.0, 4.0] 9 | scales = [2 * 2**(2 * i/3) for i in range(3)] 10 | angles = [-np.pi / 6, 0, np.pi / 6] 11 | strides = [2**i for i in range(3,8)] 12 | 13 | axis = str(np.round([generate_anchors(stride, ratios, scales, 14 | angles).view(-1).tolist() for stride in strides], decimals=2).tolist() 15 | ).replace('[', '{').replace(']', '}').replace('}, ', '},\n') 16 | 17 | rot = str(np.round([generate_anchors_rotated(stride, ratios, scales, 18 | angles)[0].view(-1).tolist() for stride in strides], decimals=2).tolist() 19 | ).replace('[', '{').replace(']', '}').replace('}, ', '},\n') 20 | 21 | print("Axis-aligned:\n"+axis+'\n') 22 | print("Rotated:\n"+rot) 23 | -------------------------------------------------------------------------------- /extras/cppapi/infer.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | 7 | #include 8 | #include 9 | #include 10 | #include 11 | 12 | #include 13 | 14 | #include "../../csrc/engine.h" 15 | 16 | using namespace std; 17 | using namespace cv; 18 | 19 | int main(int argc, char *argv[]) { 20 | if (argc<3 || argc>4) { 21 | cerr << "Usage: " << argv[0] << " engine.plan image.jpg [.png]" << endl; 22 | return 1; 23 | } 24 | 25 | cout << "Loading engine..." << endl; 26 | auto engine = std::unique_ptr(new odtk::Engine(argv[1])); 27 | 28 | cout << "Preparing data..." << endl; 29 | auto image = imread(argv[2], IMREAD_COLOR); 30 | auto inputSize = engine->getInputSize(); 31 | cv::resize(image, image, Size(inputSize[1], inputSize[0])); 32 | cv::Mat pixels; 33 | image.convertTo(pixels, CV_32FC3, 1.0 / 255, 0); 34 | 35 | int channels = 3; 36 | vector img; 37 | vector data (channels * inputSize[0] * inputSize[1]); 38 | 39 | if (pixels.isContinuous()) 40 | img.assign((float*)pixels.datastart, (float*)pixels.dataend); 41 | else { 42 | cerr << "Error reading image " << argv[2] << endl; 43 | return -1; 44 | } 45 | 46 | vector mean {0.485, 0.456, 0.406}; 47 | vector std {0.229, 0.224, 0.225}; 48 | 49 | for (int c = 0; c < channels; c++) { 50 | for (int j = 0, hw = inputSize[0] * inputSize[1]; j < hw; j++) { 51 | data[c * hw + j] = (img[channels * j + 2 - c] - mean[c]) / std[c]; 52 | } 53 | } 54 | 55 | // Create device buffers 56 | void *data_d, *scores_d, *boxes_d, *classes_d; 57 | auto num_det = engine->getMaxDetections(); 58 | cudaMalloc(&data_d, 3 * inputSize[0] * inputSize[1] * sizeof(float)); 59 | cudaMalloc(&scores_d, num_det * sizeof(float)); 60 | cudaMalloc(&boxes_d, num_det * 4 * sizeof(float)); 61 | cudaMalloc(&classes_d, num_det * sizeof(float)); 62 | 63 | // Copy image to device 64 | size_t dataSize = data.size() * sizeof(float); 65 | cudaMemcpy(data_d, data.data(), dataSize, cudaMemcpyHostToDevice); 66 | 67 | // Run inference n times 68 | cout << "Running inference..." << endl; 69 | const int count = 100; 70 | auto start = chrono::steady_clock::now(); 71 | vector buffers = { data_d, scores_d, boxes_d, classes_d }; 72 | for (int i = 0; i < count; i++) { 73 | engine->infer(buffers, 1); 74 | } 75 | auto stop = chrono::steady_clock::now(); 76 | auto timing = chrono::duration_cast>(stop - start); 77 | cout << "Took " << timing.count() / count << " seconds per inference." << endl; 78 | 79 | cudaFree(data_d); 80 | 81 | // Get back the bounding boxes 82 | unique_ptr scores(new float[num_det]); 83 | unique_ptr boxes(new float[num_det * 4]); 84 | unique_ptr classes(new float[num_det]); 85 | cudaMemcpy(scores.get(), scores_d, sizeof(float) * num_det, cudaMemcpyDeviceToHost); 86 | cudaMemcpy(boxes.get(), boxes_d, sizeof(float) * num_det * 4, cudaMemcpyDeviceToHost); 87 | cudaMemcpy(classes.get(), classes_d, sizeof(float) * num_det, cudaMemcpyDeviceToHost); 88 | 89 | cudaFree(scores_d); 90 | cudaFree(boxes_d); 91 | cudaFree(classes_d); 92 | 93 | for (int i = 0; i < num_det; i++) { 94 | // Show results over confidence threshold 95 | if (scores[i] >= 0.3f) { 96 | float x1 = boxes[i*4+0]; 97 | float y1 = boxes[i*4+1]; 98 | float x2 = boxes[i*4+2]; 99 | float y2 = boxes[i*4+3]; 100 | cout << "Found box {" << x1 << ", " << y1 << ", " << x2 << ", " << y2 101 | << "} with score " << scores[i] << " and class " << classes[i] << endl; 102 | 103 | // Draw bounding box on image 104 | cv::rectangle(image, Point(x1, y1), Point(x2, y2), cv::Scalar(0, 255, 0)); 105 | } 106 | } 107 | 108 | // Write image 109 | string out_file = argc == 4 ? string(argv[3]) : "detections.png"; 110 | cout << "Saving result to " << out_file << endl; 111 | imwrite(out_file, image); 112 | 113 | return 0; 114 | } 115 | -------------------------------------------------------------------------------- /extras/cppapi/infervideo.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | 7 | #include 8 | #include 9 | #include 10 | #include 11 | 12 | #include 13 | 14 | #include "../../csrc/engine.h" 15 | 16 | using namespace std; 17 | using namespace cv; 18 | 19 | int main(int argc, char *argv[]) { 20 | if (argc != 4) { 21 | cerr << "Usage: " << argv[0] << " engine.plan input.mov output.mp4" << endl; 22 | return 1; 23 | } 24 | 25 | cout << "Loading engine..." << endl; 26 | auto engine = std::unique_ptr(new odtk::Engine(argv[1])); 27 | VideoCapture src(argv[2]); 28 | 29 | if (!src.isOpened()){ 30 | cerr << "Could not read " << argv[2] << endl; 31 | return 1; 32 | } 33 | 34 | auto fh=src.get(CAP_PROP_FRAME_HEIGHT); 35 | auto fw=src.get(CAP_PROP_FRAME_WIDTH); 36 | auto fps=src.get(CAP_PROP_FPS); 37 | auto nframes=src.get(CAP_PROP_FRAME_COUNT); 38 | 39 | VideoWriter sink; 40 | sink.open(argv[3], 0x31637661, fps, Size(fw, fh)); 41 | Mat frame; 42 | Mat resized_frame; 43 | Mat inferred_frame; 44 | int count=1; 45 | 46 | auto inputSize = engine->getInputSize(); 47 | // Create device buffers 48 | void *data_d, *scores_d, *boxes_d, *classes_d; 49 | auto num_det = engine->getMaxDetections(); 50 | cudaMalloc(&data_d, 3 * inputSize[0] * inputSize[1] * sizeof(float)); 51 | cudaMalloc(&scores_d, num_det * sizeof(float)); 52 | cudaMalloc(&boxes_d, num_det * 4 * sizeof(float)); 53 | cudaMalloc(&classes_d, num_det * sizeof(float)); 54 | 55 | unique_ptr scores(new float[num_det]); 56 | unique_ptr boxes(new float[num_det * 4]); 57 | unique_ptr classes(new float[num_det]); 58 | 59 | vector mean {0.485, 0.456, 0.406}; 60 | vector std {0.229, 0.224, 0.225}; 61 | 62 | vector blues {0,63,127,191,255,0}; //colors for bounding boxes 63 | vector greens {0,255,191,127,63,0}; 64 | vector reds {191,255,0,0,63,127}; 65 | 66 | int channels = 3; 67 | vector img; 68 | vector data (channels * inputSize[0] * inputSize[1]); 69 | 70 | while(1) 71 | { 72 | src >> frame; 73 | if (frame.empty()){ 74 | cout << "Finished inference!" << endl; 75 | break; 76 | } 77 | 78 | cv::resize(frame, resized_frame, Size(inputSize[1], inputSize[0])); 79 | cv::Mat pixels; 80 | resized_frame.convertTo(pixels, CV_32FC3, 1.0 / 255, 0); 81 | 82 | img.assign((float*)pixels.datastart, (float*)pixels.dataend); 83 | 84 | for (int c = 0; c < channels; c++) { 85 | for (int j = 0, hw = inputSize[0] * inputSize[1]; j < hw; j++) { 86 | data[c * hw + j] = (img[channels * j + 2 - c] - mean[c]) / std[c]; 87 | } 88 | } 89 | 90 | // Copy image to device 91 | size_t dataSize = data.size() * sizeof(float); 92 | cudaMemcpy(data_d, data.data(), dataSize, cudaMemcpyHostToDevice); 93 | 94 | //Do inference 95 | cout << "Inferring on frame: " << count <<"/" << nframes << endl; 96 | count++; 97 | vector buffers = { data_d, scores_d, boxes_d, classes_d }; 98 | engine->infer(buffers, 1); 99 | 100 | cudaMemcpy(scores.get(), scores_d, sizeof(float) * num_det, cudaMemcpyDeviceToHost); 101 | cudaMemcpy(boxes.get(), boxes_d, sizeof(float) * num_det * 4, cudaMemcpyDeviceToHost); 102 | cudaMemcpy(classes.get(), classes_d, sizeof(float) * num_det, cudaMemcpyDeviceToHost); 103 | 104 | // Get back the bounding boxes 105 | for (int i = 0; i < num_det; i++) { 106 | // Show results over confidence threshold 107 | if (scores[i] >= 0.2f) { 108 | float x1 = boxes[i*4+0]; 109 | float y1 = boxes[i*4+1]; 110 | float x2 = boxes[i*4+2]; 111 | float y2 = boxes[i*4+3]; 112 | int cls=classes[i]; 113 | // Draw bounding box on image 114 | cv::rectangle(resized_frame, Point(x1, y1), Point(x2, y2), cv::Scalar(blues[cls], greens[cls], reds[cls])); 115 | } 116 | } 117 | cv::resize(resized_frame, inferred_frame, Size(fw, fh)); 118 | sink.write(inferred_frame); 119 | } 120 | src.release(); 121 | sink.release(); 122 | cudaFree(data_d); 123 | cudaFree(scores_d); 124 | cudaFree(boxes_d); 125 | cudaFree(classes_d); 126 | return 0; 127 | } 128 | -------------------------------------------------------------------------------- /extras/deepstream/README.md: -------------------------------------------------------------------------------- 1 | # Deploying RetinaNet in DeepStream 4.0 2 | 3 | This shows how to export a trained RetinaNet model to TensorRT and deploy it in a video analytics application using NVIDIA DeepStream 4.0. 4 | 5 | ## Prerequisites 6 | * A GPU supported by DeepStream: Jetson Xavier, Tesla P4/P40/V100/T4 7 | * A trained PyTorch RetinaNet model. 8 | * A video source, either `.mp4` files or a webcam. 9 | 10 | ## Tesla GPUs 11 | Setup instructions: 12 | 13 | #### 1. Download DeepStream 4.0 14 | Download DeepStream 4.0 SDK for Tesla "Download .tar" from [https://developer.nvidia.com/deepstream-download](https://developer.nvidia.com/deepstream-download) and place in the `extras/deepstream` directory. 15 | 16 | This file should be called `deepstream_sdk_v4.0.2_x86_64.tbz2`. 17 | 18 | #### 2. Unpack DeepStream 19 | You may need to adjust the permissions on the `.tbz2` file before you can extract it. 20 | 21 | ``` 22 | cd extras/deepstream 23 | mkdir DeepStream_Release 24 | tar -xvf deepstream_sdk_v4.0.2_x86_64.tbz2 -C DeepStream_Release/ 25 | ``` 26 | 27 | #### 3. Build and enter the DeepStream docker container 28 | ``` 29 | docker build -f /retinanet-examples/Dockerfile.deepstream -t ds_odtk:latest /retinanet-examples 30 | docker run --gpus all -it --rm --ipc=host -v :/data ds_odtk:latest 31 | ``` 32 | 33 | #### 4. Export your trained PyTorch RetinaNet model to TensorRT per the [INFERENCE](https://github.com/NVIDIA/retinanet-examples/blob/master/INFERENCE.md) instructions: 34 | ``` 35 | odtk export --batch n 36 | 37 | OR 38 | 39 | odtk export --int8 --calibration-images --batch n 40 | ``` 41 | 42 | #### 5. Run deepstream-app 43 | Once all of the config files have been modified, launch the DeepStream application: 44 | ``` 45 | cd /workspace/retinanet-examples/extras/deepstream/deepstream-sample/ 46 | LD_PRELOAD=build/libnvdsparsebbox_odtk.so deepstream-app -c 47 | ``` 48 | 49 | ## Jetson AGX Xavier 50 | Setup instructions. 51 | 52 | #### 1. Flash Jetson Xavier with [Jetpack 4.3](https://developer.nvidia.com/embedded/jetpack) 53 | 54 | **Ensure that you tick the DeepStream box, under Additional SDKs** 55 | 56 | #### 2. (on host) Covert PyTorch model to ONNX. 57 | 58 | ```bash 59 | odtk export model.pth model.onnx 60 | ``` 61 | 62 | #### 3. Copy ONNX RetinaNet model and config files to Jetson Xavier 63 | 64 | Use `scp` or a memory card. 65 | 66 | #### 4. (on Jetson) Make the C++ API 67 | 68 | ```bash 69 | cd extras/cppapi 70 | mkdir build && cd build 71 | cmake -DCMAKE_CUDA_FLAGS="--expt-extended-lambda -std=c++14" .. 72 | make 73 | ``` 74 | 75 | #### 5. (on Jetson) Make the RetinaNet plugin 76 | 77 | ```bash 78 | cd extras/deepstream/deepstream-sample 79 | mkdir build && cd build 80 | cmake -DDeepStream_DIR=/opt/nvidia/deepstream/deepstream-4.0 .. && make -j 81 | ``` 82 | 83 | #### 6. (on Jetson) Build the TensorRT Engine 84 | 85 | ```bash 86 | cd extras/cppapi/build 87 | ./export model.onnx engine.plan 88 | ``` 89 | 90 | #### 7. (on Jetson) Modify the DeepStream config files 91 | As described in the "preparing the DeepStream config file" section below. 92 | 93 | #### 8. (on Jetson) Run deepstream-app 94 | Once all of the config files have been modified, launch the DeepStream application: 95 | ``` 96 | cd extras/deepstream/deepstream-sample 97 | LD_PRELOAD=build/libnvdsparsebbox_odtk.so deepstream-app -c 98 | ``` 99 | 100 | ## Preparing the DeepStream config file: 101 | We have included two example DeepStream config files in `deepstream-sample`. 102 | - `ds_config_1vids.txt`: Performs detection on a single video, using the detector specified by `infer_config_batch1.txt`. 103 | - `ds_config_8vids.txt`: Performs detection on multiple video streams simultaneously, using the detector specified by `infer_config_batch8.txt`. Frames from each video are combined into a single batch and passed to the detector for inference. 104 | 105 | The `ds_config_*` files are DeepStream config files. They describe the overall processing. `infer_config_*` files define the individual detectors, which can be chained in series. 106 | 107 | Before they can be used, these config files must be modified to specify the correct paths to the input and output videos files, and the TensorRT engines. 108 | 109 | * **Input files** are specified in the deepstream config files by the `uri=file://` parameter. 110 | 111 | * **Output files** are specified in the deepstream config files by the `output-file=` parameter. 112 | 113 | * **TensorRT engines** are specified in both the DeepStream config files, and also the detector config files, by the `model-engine-file=` parameters. 114 | 115 | On Xavier, you can optionally set `enable=1` to `[sink1]` in `ds_config_*` files to display the processed video stream. 116 | 117 | 118 | ## Convert output video file to mp4 119 | You can convert the outputted `.mkv` file to `.mp4` using `ffmpeg`. 120 | ``` 121 | ffmpeg -i /data/output/file1.mkv -c copy /data/output/file2.mp4 122 | ``` 123 | -------------------------------------------------------------------------------- /extras/deepstream/deepstream-sample/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.5.1) 2 | 3 | project(deepstream-odtk) 4 | enable_language(CXX) 5 | include(FindCUDA) 6 | 7 | set(CMAKE_CXX_STANDARD 14) 8 | find_package(CUDA REQUIRED) 9 | find_package(OpenCV REQUIRED) 10 | 11 | if(DEFINED TensorRT_DIR) 12 | include_directories("${TensorRT_DIR}/include") 13 | link_directories("${TensorRT_DIR}/lib") 14 | endif(DEFINED TensorRT_DIR) 15 | if(DEFINED DeepStream_DIR) 16 | include_directories("${DeepStream_DIR}/sources/includes") 17 | endif(DEFINED DeepStream_DIR) 18 | include_directories(${CUDA_INCLUDE_DIRS}) 19 | 20 | if(NOT DEFINED ARCH) 21 | set(ARCH "sm_70") 22 | endif(NOT DEFINED ARCH) 23 | 24 | cuda_add_library(nvdsparsebbox_odtk SHARED 25 | ../../../csrc/cuda/decode.h 26 | ../../../csrc/cuda/decode.cu 27 | ../../../csrc/cuda/nms.h 28 | ../../../csrc/cuda/nms.cu 29 | ../../../csrc/cuda/utils.h 30 | ../../../csrc/engine.cpp 31 | nvdsparsebbox_odtk.cpp 32 | OPTIONS -arch ${ARCH} -std=c++14 --expt-extended-lambda 33 | ) 34 | include_directories(${OpenCV_INCLUDE_DIRS}) 35 | target_link_libraries(nvdsparsebbox_odtk ${CUDA_LIBRARIES} nvinfer nvinfer_plugin nvonnxparser ${OpenCV_LIBS}) 36 | -------------------------------------------------------------------------------- /extras/deepstream/deepstream-sample/ds_config_1vid.txt: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018 NVIDIA Corporation. All rights reserved. 2 | # 3 | # NVIDIA Corporation and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA Corporation is strictly prohibited. 8 | 9 | [application] 10 | enable-perf-measurement=1 11 | perf-measurement-interval-sec=1 12 | 13 | [tiled-display] 14 | enable=0 15 | rows=1 16 | columns=1 17 | width=1280 18 | height=720 19 | gpu-id=0 20 | 21 | [source0] 22 | enable=1 23 | type=2 24 | num-sources=1 25 | uri=file:// 26 | gpu-id=0 27 | 28 | [streammux] 29 | gpu-id=0 30 | batch-size=1 31 | batched-push-timeout=-1 32 | ## Set muxer output width and height 33 | width=1280 34 | height=720 35 | cuda-memory-type=1 36 | enable-padding=1 37 | 38 | [sink0] 39 | enable=1 40 | type=3 41 | #1=mp4 2=mkv 42 | container=1 43 | #1=h264 2=h265 3=mpeg4 44 | ## only SW mpeg4 is supported right now. 45 | codec=3 46 | sync=1 47 | bitrate=80000000 48 | output-file= 49 | source-id=0 50 | 51 | [sink1] 52 | enable=0 53 | #Type - 1=FakeSink 2=EglSink 3=File 54 | type=2 55 | sync=1 56 | source-id=0 57 | gpu-id=0 58 | cuda-memory-type=1 59 | 60 | 61 | [osd] 62 | enable=1 63 | gpu-id=0 64 | border-width=2 65 | text-size=12 66 | text-color=1;1;1;1; 67 | text-bg-color=0.3;0.3;0.3;1 68 | font=Arial 69 | show-clock=0 70 | clock-x-offset=800 71 | clock-y-offset=820 72 | clock-text-size=12 73 | clock-color=1;0;0;0 74 | 75 | [primary-gie] 76 | enable=1 77 | gpu-id=0 78 | batch-size=1 79 | gie-unique-id=1 80 | interval=0 81 | labelfile-path=labels_coco.txt 82 | model-engine-file= 83 | config-file=infer_config_batch1.txt 84 | -------------------------------------------------------------------------------- /extras/deepstream/deepstream-sample/ds_config_8vid.txt: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018 NVIDIA Corporation. All rights reserved. 2 | # 3 | # NVIDIA Corporation and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA Corporation is strictly prohibited. 8 | 9 | [application] 10 | enable-perf-measurement=1 11 | perf-measurement-interval-sec=5 12 | 13 | [tiled-display] 14 | enable=1 15 | rows=2 16 | columns=4 17 | width=1280 18 | height=720 19 | gpu-id=0 20 | cuda-memory-type=1 21 | 22 | [source0] 23 | enable=1 24 | type=3 25 | num-sources=4 26 | uri=file:// 27 | gpu-id=0 28 | cuda-memory-type=1 29 | 30 | [source1] 31 | enable=1 32 | type=3 33 | num-sources=4 34 | uri=file:// 35 | gpu-id=0 36 | cuda-memory-type=1 37 | 38 | [streammux] 39 | gpu-id=0 40 | batched-push-timeout=-1 41 | ## Set muxer output width and height 42 | width=1280 43 | height=720 44 | cuda-memory-type=1 45 | enable-padding=1 46 | batch-size=8 47 | 48 | [sink0] 49 | enable=1 50 | type=3 51 | #1=mp4 2=mkv 52 | container=1 53 | #1=h264 2=h265 3=mpeg4 54 | ## only SW mpeg4 is supported right now. 55 | codec=3 56 | sync=0 57 | bitrate=32000000 58 | output-file= 59 | source-id=0 60 | cuda-memory-type=1 61 | 62 | [sink1] 63 | enable=0 64 | #Type - 1=FakeSink 2=EglSink 3=File 65 | type=2 66 | sync=1 67 | source-id=0 68 | gpu-id=0 69 | cuda-memory-type=1 70 | 71 | 72 | [osd] 73 | enable=1 74 | gpu-id=0 75 | border-width=2 76 | text-size=12 77 | text-color=1;1;1;1; 78 | text-bg-color=0.3;0.3;0.3;1 79 | font=Arial 80 | show-clock=0 81 | clock-x-offset=800 82 | clock-y-offset=820 83 | clock-text-size=12 84 | clock-color=1;0;0;0 85 | 86 | [primary-gie] 87 | enable=1 88 | gpu-id=0 89 | batch-size=8 90 | gie-unique-id=1 91 | interval=0 92 | labelfile-path=labels_coco.txt 93 | model-engine-file= 94 | config-file=infer_config_batch8.txt 95 | cuda-memory-type=1 96 | -------------------------------------------------------------------------------- /extras/deepstream/deepstream-sample/infer_config_batch1.txt: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018 NVIDIA Corporation. All rights reserved. 2 | # NVIDIA Corporation and its licensors retain all intellectual property 3 | # and proprietary rights in and to this software, related documentation 4 | # and any modifications thereto. Any use, reproduction, disclosure or 5 | # distribution of this software and related documentation without an express 6 | # license agreement from NVIDIA Corporation is strictly prohibited. 7 | 8 | # Following properties are mandatory when engine files are not specified: 9 | # int8-calib-file(Only in INT8) 10 | # Caffemodel mandatory properties: model-file, proto-file, output-blob-names 11 | # UFF: uff-file, input-dims, uff-input-blob-name, output-blob-names 12 | # ONNX: onnx-file 13 | # 14 | # Mandatory properties for detectors: 15 | # parse-func, num-detected-classes, 16 | # custom-lib-path (when parse-func=0 i.e. custom), 17 | # parse-bbox-func-name (when parse-func=0) 18 | # 19 | # Optional properties for detectors: 20 | # enable-dbscan(Default=false), interval(Primary mode only, Default=0) 21 | # 22 | # Mandatory properties for classifiers: 23 | # classifier-threshold, is-classifier 24 | # 25 | # Optional properties for classifiers: 26 | # classifier-async-mode(Secondary mode only, Default=false) 27 | # 28 | # Optional properties in secondary mode: 29 | # operate-on-gie-id(Default=0), operate-on-class-ids(Defaults to all classes), 30 | # input-object-min-width, input-object-min-height, input-object-max-width, 31 | # input-object-max-height 32 | # 33 | # Following properties are always recommended: 34 | # batch-size(Default=1) 35 | # 36 | # Other optional properties: 37 | # net-scale-factor(Default=1), network-mode(Default=0 i.e FP32), 38 | # model-color-format(Default=0 i.e. RGB) model-engine-file, labelfile-path, 39 | # mean-file, gie-unique-id(Default=0), offsets, gie-mode (Default=1 i.e. primary), 40 | # custom-lib-path, network-mode(Default=0 i.e FP32) 41 | # 42 | # The values in the config file are overridden by values set through GObject 43 | # properties. 44 | 45 | [property] 46 | gpu-id=0 47 | net-scale-factor=0.017352074 48 | offsets=123.675;116.28;103.53 49 | model-engine-file= 50 | labelfile-path=labels_coco.txt 51 | batch-size=1 52 | ## 0=FP32, 1=INT8, 2=FP16 mode 53 | network-mode=2 54 | num-detected-classes=80 55 | interval=0 56 | gie-unique-id=1 57 | parse-func=0 58 | is-classifier=0 59 | output-blob-names=boxes;scores;classes 60 | parse-bbox-func-name=NvDsInferParseRetinaNet 61 | custom-lib-path=build/libnvdsparsebbox_odtk.so 62 | #enable-dbscan=1 63 | 64 | 65 | [class-attrs-all] 66 | threshold=0.5 67 | group-threshold=0 68 | ## Set eps=0.7 and minBoxes for enable-dbscan=1 69 | #eps=0.2 70 | ##minBoxes=3 71 | #roi-top-offset=0 72 | #roi-bottom-offset=0 73 | detected-min-w=4 74 | detected-min-h=4 75 | #detected-max-w=0 76 | #detected-max-h=0 77 | 78 | ## Per class configuration 79 | #[class-attrs-2] 80 | #threshold=0.6 81 | #eps=0.5 82 | #group-threshold=3 83 | #roi-top-offset=20 84 | #roi-bottom-offset=10 85 | #detected-min-w=40 86 | #detected-min-h=40 87 | #detected-max-w=400 88 | #detected-max-h=800 89 | -------------------------------------------------------------------------------- /extras/deepstream/deepstream-sample/infer_config_batch8.txt: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018 NVIDIA Corporation. All rights reserved. 2 | # NVIDIA Corporation and its licensors retain all intellectual property 3 | # and proprietary rights in and to this software, related documentation 4 | # and any modifications thereto. Any use, reproduction, disclosure or 5 | # distribution of this software and related documentation without an express 6 | # license agreement from NVIDIA Corporation is strictly prohibited. 7 | 8 | # Following properties are mandatory when engine files are not specified: 9 | # int8-calib-file(Only in INT8) 10 | # Caffemodel mandatory properties: model-file, proto-file, output-blob-names 11 | # UFF: uff-file, input-dims, uff-input-blob-name, output-blob-names 12 | # ONNX: onnx-file 13 | # 14 | # Mandatory properties for detectors: 15 | # parse-func, num-detected-classes, 16 | # custom-lib-path (when parse-func=0 i.e. custom), 17 | # parse-bbox-func-name (when parse-func=0) 18 | # 19 | # Optional properties for detectors: 20 | # enable-dbscan(Default=false), interval(Primary mode only, Default=0) 21 | # 22 | # Mandatory properties for classifiers: 23 | # classifier-threshold, is-classifier 24 | # 25 | # Optional properties for classifiers: 26 | # classifier-async-mode(Secondary mode only, Default=false) 27 | # 28 | # Optional properties in secondary mode: 29 | # operate-on-gie-id(Default=0), operate-on-class-ids(Defaults to all classes), 30 | # input-object-min-width, input-object-min-height, input-object-max-width, 31 | # input-object-max-height 32 | # 33 | # Following properties are always recommended: 34 | # batch-size(Default=1) 35 | # 36 | # Other optional properties: 37 | # net-scale-factor(Default=1), network-mode(Default=0 i.e FP32), 38 | # model-color-format(Default=0 i.e. RGB) model-engine-file, labelfile-path, 39 | # mean-file, gie-unique-id(Default=0), offsets, gie-mode (Default=1 i.e. primary), 40 | # custom-lib-path, network-mode(Default=0 i.e FP32) 41 | # 42 | # The values in the config file are overridden by values set through GObject 43 | # properties. 44 | 45 | [property] 46 | gpu-id=0 47 | net-scale-factor=0.017352074 48 | offsets=123.675;116.28;103.53 49 | model-engine-file= 50 | labelfile-path=labels_coco.txt 51 | #int8-calib-file=cal_trt4.bin 52 | batch-size=8 53 | ## 0=FP32, 1=INT8, 2=FP16 mode 54 | network-mode=2 55 | num-detected-classes=80 56 | interval=0 57 | gie-unique-id=1 58 | parse-func=0 59 | is-classifier=0 60 | output-blob-names=boxes;scores;classes 61 | parse-bbox-func-name=NvDsInferParseRetinaNet 62 | custom-lib-path=build/libnvdsparsebbox_odtk.so 63 | #enable-dbscan=1 64 | 65 | 66 | [class-attrs-all] 67 | threshold=0.5 68 | group-threshold=0 69 | ## Set eps=0.7 and minBoxes for enable-dbscan=1 70 | #eps=0.2 71 | ##minBoxes=3 72 | #roi-top-offset=0 73 | #roi-bottom-offset=0 74 | detected-min-w=4 75 | detected-min-h=4 76 | #detected-max-w=0 77 | #detected-max-h=0 78 | 79 | ## Per class configuration 80 | #[class-attrs-2] 81 | #threshold=0.6 82 | #eps=0.5 83 | #group-threshold=3 84 | #roi-top-offset=20 85 | #roi-bottom-offset=10 86 | #detected-min-w=40 87 | #detected-min-h=40 88 | #detected-max-w=400 89 | #detected-max-h=800 90 | -------------------------------------------------------------------------------- /extras/deepstream/deepstream-sample/labels_coco.txt: -------------------------------------------------------------------------------- 1 | person 2 | bicycle 3 | car 4 | motorcycle 5 | airplane 6 | bus 7 | train 8 | truck 9 | boat 10 | traffic light 11 | fire hydrant 12 | stop sign 13 | parking meter 14 | bench 15 | bird 16 | cat 17 | dog 18 | horse 19 | sheep 20 | cow 21 | elephant 22 | bear 23 | zebra 24 | giraffe 25 | backpack 26 | umbrella 27 | handbag 28 | tie 29 | suitcase 30 | frisbee 31 | skis 32 | snowboard 33 | sports ball 34 | kite 35 | baseball bat 36 | baseball glove 37 | skateboard 38 | surfboard 39 | tennis racket 40 | bottle 41 | wine glass 42 | cup 43 | fork 44 | knife 45 | spoon 46 | bowl 47 | banana 48 | apple 49 | sandwich 50 | orange 51 | broccoli 52 | carrot 53 | hot dog 54 | pizza 55 | donut 56 | cake 57 | chair 58 | couch 59 | potted plant 60 | bed 61 | dining table 62 | toilet 63 | tv 64 | laptop 65 | mouse 66 | remote 67 | keyboard 68 | cell phone 69 | microwave 70 | oven 71 | toaster 72 | sink 73 | refrigerator 74 | book 75 | clock 76 | vase 77 | scissors 78 | teddy bear 79 | hair drier 80 | toothbrush -------------------------------------------------------------------------------- /extras/deepstream/deepstream-sample/nvdsparsebbox_retinanet.cpp: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * NVIDIA Corporation and its licensors retain all intellectual property 5 | * and proprietary rights in and to this software, related documentation 6 | * and any modifications thereto. Any use, reproduction, disclosure or 7 | * distridbution of this software and related documentation without an express 8 | * license agreement from NVIDIA Corporation is strictly prohibited. 9 | * 10 | */ 11 | 12 | #include 13 | #include 14 | #include "nvdsinfer_custom_impl.h" 15 | 16 | #define MIN(a,b) ((a) < (b) ? (a) : (b)) 17 | 18 | /* This is a sample bounding box parsing function for the sample Resnet10 19 | * detector model provided with the SDK. */ 20 | 21 | /* C-linkage to prevent name-mangling */ 22 | extern "C" 23 | bool NvDsInferParseRetinaNet (std::vector const &outputLayersInfo, 24 | NvDsInferNetworkInfo const &networkInfo, 25 | NvDsInferParseDetectionParams const &detectionParams, 26 | std::vector &objectList) 27 | { 28 | static int bboxLayerIndex = -1; 29 | static int classesLayerIndex = -1; 30 | static int scoresLayerIndex = -1; 31 | static NvDsInferDimsCHW scoresLayerDims; 32 | int numDetsToParse; 33 | 34 | /* Find the bbox layer */ 35 | if (bboxLayerIndex == -1) { 36 | for (unsigned int i = 0; i < outputLayersInfo.size(); i++) { 37 | if (strcmp(outputLayersInfo[i].layerName, "boxes") == 0) { 38 | bboxLayerIndex = i; 39 | break; 40 | } 41 | } 42 | if (bboxLayerIndex == -1) { 43 | std::cerr << "Could not find bbox layer buffer while parsing" << std::endl; 44 | return false; 45 | } 46 | } 47 | 48 | /* Find the scores layer */ 49 | if (scoresLayerIndex == -1) { 50 | for (unsigned int i = 0; i < outputLayersInfo.size(); i++) { 51 | if (strcmp(outputLayersInfo[i].layerName, "scores") == 0) { 52 | scoresLayerIndex = i; 53 | getDimsCHWFromDims(scoresLayerDims, outputLayersInfo[i].dims); 54 | break; 55 | } 56 | } 57 | if (scoresLayerIndex == -1) { 58 | std::cerr << "Could not find scores layer buffer while parsing" << std::endl; 59 | return false; 60 | } 61 | } 62 | 63 | /* Find the classes layer */ 64 | if (classesLayerIndex == -1) { 65 | for (unsigned int i = 0; i < outputLayersInfo.size(); i++) { 66 | if (strcmp(outputLayersInfo[i].layerName, "classes") == 0) { 67 | classesLayerIndex = i; 68 | break; 69 | } 70 | } 71 | if (classesLayerIndex == -1) { 72 | std::cerr << "Could not find classes layer buffer while parsing" << std::endl; 73 | return false; 74 | } 75 | } 76 | 77 | 78 | /* Calculate the number of detections to parse */ 79 | numDetsToParse = scoresLayerDims.c; 80 | 81 | float *bboxes = (float *) outputLayersInfo[bboxLayerIndex].buffer; 82 | float *classes = (float *) outputLayersInfo[classesLayerIndex].buffer; 83 | float *scores = (float *) outputLayersInfo[scoresLayerIndex].buffer; 84 | 85 | for (int indx = 0; indx < numDetsToParse; indx++) 86 | { 87 | float outputX1 = bboxes[indx * 4]; 88 | float outputY1 = bboxes[indx * 4 + 1]; 89 | float outputX2 = bboxes[indx * 4 + 2]; 90 | float outputY2 = bboxes[indx * 4 + 3]; 91 | float this_class = classes[indx]; 92 | float this_score = scores[indx]; 93 | float threshold = detectionParams.perClassThreshold[this_class]; 94 | 95 | if (this_score >= threshold) 96 | { 97 | NvDsInferParseObjectInfo object; 98 | 99 | object.classId = this_class; 100 | object.detectionConfidence = this_score; 101 | 102 | object.left = outputX1; 103 | object.top = outputY1; 104 | object.width = outputX2 - outputX1; 105 | object.height = outputY2 - outputY1; 106 | 107 | objectList.push_back(object); 108 | } 109 | } 110 | return true; 111 | } 112 | 113 | /* Check that the custom function has been defined correctly */ 114 | CHECK_CUSTOM_PARSE_FUNC_PROTOTYPE(NvDsInferParseRetinaNet); 115 | -------------------------------------------------------------------------------- /extras/test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | if [ $# -ne 2 ]; then 4 | echo "Usage: $0 images_path annotations.json" 5 | exit 1 6 | fi 7 | 8 | tmp="/tmp/odtk" 9 | 10 | tests=( 11 | "odtk train ${tmp}/model.pth --images $1 --annotations $2 --max-size 640 --override --iters 100 --backbone ResNet18FPN ResNet50FPN" 12 | "odtk train ${tmp}/model.pth --images $1 --annotations $2 --max-size 640 --override --iters 100" 13 | "odtk train ${tmp}/model.pth --fine-tune ${tmp}/model.pth --images $1 --annotations $2 --max-size 640 --override --iters 100" 14 | "odtk infer ${tmp}/model.pth --images ${tmp}/test_images --max-size 640" 15 | "odtk export ${tmp}/model.pth ${tmp}/engine.plan --size 640" 16 | "odtk infer ${tmp}/engine.plan --images ${tmp}/test_images --max-size 640" 17 | ) 18 | 19 | start=`date +%s` 20 | 21 | # Prepare small image folder for inference 22 | if [ ! -d ${tmp}/test_images ]; then 23 | mkdir -p ${tmp}/test_images 24 | cp $(find $1 | tail -n 10) ${tmp}/test_images 25 | fi 26 | 27 | # Run all tests 28 | for test in "${tests[@]}"; do 29 | echo "Running \"${test}\"" 30 | ${test} 31 | if [ $? -ne 0 ]; then 32 | echo "Test failed!" 33 | exit 1 34 | fi 35 | done 36 | 37 | end=`date +%s` 38 | 39 | echo "All test succeeded in $((end-start)) seconds!" -------------------------------------------------------------------------------- /odtk/__init__.py: -------------------------------------------------------------------------------- 1 | from . import backbones 2 | -------------------------------------------------------------------------------- /odtk/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | from .resnet import * 4 | from .mobilenet import * 5 | from .fpn import * 6 | -------------------------------------------------------------------------------- /odtk/backbones/fpn.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | from torchvision.models import resnet as vrn 4 | from torchvision.models import mobilenet as vmn 5 | 6 | from .resnet import ResNet 7 | from .mobilenet import MobileNet 8 | from .utils import register 9 | 10 | 11 | class FPN(nn.Module): 12 | 'Feature Pyramid Network - https://arxiv.org/abs/1612.03144' 13 | 14 | def __init__(self, features): 15 | super().__init__() 16 | 17 | self.stride = 128 18 | self.features = features 19 | 20 | if isinstance(features, ResNet): 21 | is_light = features.bottleneck == vrn.BasicBlock 22 | channels = [128, 256, 512] if is_light else [512, 1024, 2048] 23 | elif isinstance(features, MobileNet): 24 | channels = [32, 96, 320] 25 | 26 | self.lateral3 = nn.Conv2d(channels[0], 256, 1) 27 | self.lateral4 = nn.Conv2d(channels[1], 256, 1) 28 | self.lateral5 = nn.Conv2d(channels[2], 256, 1) 29 | self.pyramid6 = nn.Conv2d(channels[2], 256, 3, stride=2, padding=1) 30 | self.pyramid7 = nn.Conv2d(256, 256, 3, stride=2, padding=1) 31 | self.smooth3 = nn.Conv2d(256, 256, 3, padding=1) 32 | self.smooth4 = nn.Conv2d(256, 256, 3, padding=1) 33 | self.smooth5 = nn.Conv2d(256, 256, 3, padding=1) 34 | 35 | def initialize(self): 36 | def init_layer(layer): 37 | if isinstance(layer, nn.Conv2d): 38 | nn.init.xavier_uniform_(layer.weight) 39 | if layer.bias is not None: 40 | nn.init.constant_(layer.bias, val=0) 41 | self.apply(init_layer) 42 | 43 | self.features.initialize() 44 | 45 | def forward(self, x): 46 | c3, c4, c5 = self.features(x) 47 | 48 | p5 = self.lateral5(c5) 49 | p4 = self.lateral4(c4) 50 | p4 = F.interpolate(p5, scale_factor=2) + p4 51 | p3 = self.lateral3(c3) 52 | p3 = F.interpolate(p4, scale_factor=2) + p3 53 | 54 | p6 = self.pyramid6(c5) 55 | p7 = self.pyramid7(F.relu(p6)) 56 | 57 | p3 = self.smooth3(p3) 58 | p4 = self.smooth4(p4) 59 | p5 = self.smooth5(p5) 60 | 61 | return [p3, p4, p5, p6, p7] 62 | 63 | @register 64 | def ResNet18FPN(): 65 | return FPN(ResNet(layers=[2, 2, 2, 2], bottleneck=vrn.BasicBlock, outputs=[3, 4, 5], url=vrn.model_urls['resnet18'])) 66 | 67 | @register 68 | def ResNet34FPN(): 69 | return FPN(ResNet(layers=[3, 4, 6, 3], bottleneck=vrn.BasicBlock, outputs=[3, 4, 5], url=vrn.model_urls['resnet34'])) 70 | 71 | @register 72 | def ResNet50FPN(): 73 | return FPN(ResNet(layers=[3, 4, 6, 3], bottleneck=vrn.Bottleneck, outputs=[3, 4, 5], url=vrn.model_urls['resnet50'])) 74 | 75 | @register 76 | def ResNet101FPN(): 77 | return FPN(ResNet(layers=[3, 4, 23, 3], bottleneck=vrn.Bottleneck, outputs=[3, 4, 5], url=vrn.model_urls['resnet101'])) 78 | 79 | @register 80 | def ResNet152FPN(): 81 | return FPN(ResNet(layers=[3, 8, 36, 3], bottleneck=vrn.Bottleneck, outputs=[3, 4, 5], url=vrn.model_urls['resnet152'])) 82 | 83 | @register 84 | def ResNeXt50_32x4dFPN(): 85 | return FPN(ResNet(layers=[3, 4, 6, 3], bottleneck=vrn.Bottleneck, outputs=[3, 4, 5], groups=32, width_per_group=4, url=vrn.model_urls['resnext50_32x4d'])) 86 | 87 | @register 88 | def ResNeXt101_32x8dFPN(): 89 | return FPN(ResNet(layers=[3, 4, 23, 3], bottleneck=vrn.Bottleneck, outputs=[3, 4, 5], groups=32, width_per_group=8, url=vrn.model_urls['resnext101_32x8d'])) 90 | 91 | @register 92 | def MobileNetV2FPN(): 93 | return FPN(MobileNet(outputs=[6, 13, 17], url=vmn.model_urls['mobilenet_v2'])) 94 | -------------------------------------------------------------------------------- /odtk/backbones/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | class FixedBatchNorm2d(nn.Module): 6 | 'BatchNorm2d where the batch statistics and the affine parameters are fixed' 7 | 8 | def __init__(self, n): 9 | super().__init__() 10 | self.register_buffer("weight", torch.ones(n)) 11 | self.register_buffer("bias", torch.zeros(n)) 12 | self.register_buffer("running_mean", torch.zeros(n)) 13 | self.register_buffer("running_var", torch.ones(n)) 14 | 15 | def forward(self, x): 16 | return F.batch_norm(x, running_mean=self.running_mean, running_var=self.running_var, weight=self.weight, bias=self.bias) 17 | 18 | def convert_fixedbn_model(module): 19 | 'Convert batch norm layers to fixed' 20 | 21 | mod = module 22 | if isinstance(module, nn.BatchNorm2d): 23 | mod = FixedBatchNorm2d(module.num_features) 24 | mod.running_mean = module.running_mean 25 | mod.running_var = module.running_var 26 | if module.affine: 27 | mod.weight.data = module.weight.data.clone().detach() 28 | mod.bias.data = module.bias.data.clone().detach() 29 | for name, child in module.named_children(): 30 | mod.add_module(name, convert_fixedbn_model(child)) 31 | 32 | return mod 33 | -------------------------------------------------------------------------------- /odtk/backbones/mobilenet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torchvision.models import mobilenet as vmn 3 | import torch.utils.model_zoo as model_zoo 4 | 5 | class MobileNet(vmn.MobileNetV2): 6 | 'MobileNetV2: Inverted Residuals and Linear Bottlenecks - https://arxiv.org/abs/1801.04381' 7 | 8 | def __init__(self, outputs=[18], url=None): 9 | self.stride = 128 10 | self.url = url 11 | super().__init__() 12 | self.outputs = outputs 13 | self.unused_modules = ['features.18', 'classifier'] 14 | 15 | def initialize(self): 16 | if self.url: 17 | self.load_state_dict(model_zoo.load_url(self.url)) 18 | 19 | def forward(self, x): 20 | outputs = [] 21 | for indx, feat in enumerate(self.features[:-1]): 22 | x = feat(x) 23 | if indx in self.outputs: 24 | outputs.append(x) 25 | return outputs 26 | -------------------------------------------------------------------------------- /odtk/backbones/resnet.py: -------------------------------------------------------------------------------- 1 | import torchvision 2 | from torchvision.models import resnet as vrn 3 | import torch.utils.model_zoo as model_zoo 4 | 5 | from .utils import register 6 | 7 | class ResNet(vrn.ResNet): 8 | 'Deep Residual Network - https://arxiv.org/abs/1512.03385' 9 | 10 | def __init__(self, layers=[3, 4, 6, 3], bottleneck=vrn.Bottleneck, outputs=[5], groups=1, width_per_group=64, url=None): 11 | self.stride = 128 12 | self.bottleneck = bottleneck 13 | self.outputs = outputs 14 | self.url = url 15 | 16 | kwargs = {'block': bottleneck, 'layers': layers, 'groups': groups, 'width_per_group': width_per_group} 17 | super().__init__(**kwargs) 18 | self.unused_modules = ['fc'] 19 | 20 | def initialize(self): 21 | if self.url: 22 | self.load_state_dict(model_zoo.load_url(self.url)) 23 | 24 | def forward(self, x): 25 | x = self.conv1(x) 26 | x = self.bn1(x) 27 | x = self.relu(x) 28 | x = self.maxpool(x) 29 | 30 | outputs = [] 31 | for i, layer in enumerate([self.layer1, self.layer2, self.layer3, self.layer4]): 32 | level = i + 2 33 | if level > max(self.outputs): 34 | break 35 | x = layer(x) 36 | if level in self.outputs: 37 | outputs.append(x) 38 | 39 | return outputs 40 | 41 | @register 42 | def ResNet18C4(): 43 | return ResNet(layers=[2, 2, 2, 2], bottleneck=vrn.BasicBlock, outputs=[4], url=vrn.model_urls['resnet18']) 44 | 45 | @register 46 | def ResNet34C4(): 47 | return ResNet(layers=[3, 4, 6, 3], bottleneck=vrn.BasicBlock, outputs=[4], url=vrn.model_urls['resnet34']) 48 | -------------------------------------------------------------------------------- /odtk/backbones/utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torchvision 3 | 4 | def register(f): 5 | all = sys.modules[f.__module__].__dict__.setdefault('__all__', []) 6 | if f.__name__ in all: 7 | raise RuntimeError('{} already exist!'.format(f.__name__)) 8 | all.append(f.__name__) 9 | return f 10 | -------------------------------------------------------------------------------- /odtk/dali.py: -------------------------------------------------------------------------------- 1 | from contextlib import redirect_stdout 2 | from math import ceil 3 | import ctypes 4 | import torch 5 | from nvidia.dali import pipeline, ops, types 6 | from pycocotools.coco import COCO 7 | 8 | class COCOPipeline(pipeline.Pipeline): 9 | 'Dali pipeline for COCO' 10 | 11 | def __init__(self, batch_size, num_threads, path, training, annotations, world, device_id, mean, std, resize, 12 | max_size, stride, rotate_augment=False, 13 | augment_brightness=0.0, 14 | augment_contrast=0.0, augment_hue=0.0, 15 | augment_saturation=0.0): 16 | super().__init__(batch_size=batch_size, num_threads=num_threads, device_id=device_id, 17 | prefetch_queue_depth=num_threads, seed=42) 18 | self.path = path 19 | self.training = training 20 | self.stride = stride 21 | self.iter = 0 22 | 23 | self.rotate_augment = rotate_augment 24 | self.augment_brightness = augment_brightness 25 | self.augment_contrast = augment_contrast 26 | self.augment_hue = augment_hue 27 | self.augment_saturation = augment_saturation 28 | 29 | self.reader = ops.COCOReader(annotations_file=annotations, file_root=path, num_shards=world, 30 | shard_id=torch.cuda.current_device(), 31 | ltrb=True, ratio=True, shuffle_after_epoch=True, save_img_ids=True) 32 | 33 | self.decode_train = ops.ImageDecoderSlice(device="mixed", output_type=types.RGB) 34 | self.decode_infer = ops.ImageDecoder(device="mixed", output_type=types.RGB) 35 | self.bbox_crop = ops.RandomBBoxCrop(device='cpu', bbox_layout="xyXY", scaling=[0.3, 1.0], 36 | thresholds=[0.1, 0.3, 0.5, 0.7, 0.9]) 37 | 38 | self.bbox_flip = ops.BbFlip(device='cpu', ltrb=True) 39 | self.img_flip = ops.Flip(device='gpu') 40 | self.coin_flip = ops.CoinFlip(probability=0.5) 41 | self.bc = ops.BrightnessContrast(device='gpu') 42 | self.hsv = ops.Hsv(device='gpu') 43 | 44 | # Random number generation for augmentation 45 | self.brightness_dist = ops.NormalDistribution(mean=1.0, stddev=augment_brightness) 46 | self.contrast_dist = ops.NormalDistribution(mean=1.0, stddev=augment_contrast) 47 | self.hue_dist = ops.NormalDistribution(mean=0.0, stddev=augment_hue) 48 | self.saturation_dist = ops.NormalDistribution(mean=1.0, stddev=augment_saturation) 49 | 50 | if rotate_augment: 51 | raise RuntimeWarning("--augment-rotate current has no effect when using the DALI data loader.") 52 | 53 | if isinstance(resize, list): resize = max(resize) 54 | self.rand_resize = ops.Uniform(range=[resize, float(max_size)]) 55 | 56 | self.resize_train = ops.Resize(device='gpu', interp_type=types.DALIInterpType.INTERP_CUBIC, save_attrs=True) 57 | self.resize_infer = ops.Resize(device='gpu', interp_type=types.DALIInterpType.INTERP_CUBIC, 58 | resize_longer=max_size, save_attrs=True) 59 | 60 | padded_size = max_size + ((self.stride - max_size % self.stride) % self.stride) 61 | 62 | self.pad = ops.Paste(device='gpu', fill_value=0, ratio=1.1, min_canvas_size=padded_size, paste_x=0, paste_y=0) 63 | self.normalize = ops.CropMirrorNormalize(device='gpu', mean=mean, std=std, crop=(padded_size, padded_size), 64 | crop_pos_x=0, crop_pos_y=0) 65 | 66 | def define_graph(self): 67 | 68 | images, bboxes, labels, img_ids = self.reader() 69 | 70 | if self.training: 71 | crop_begin, crop_size, bboxes, labels = self.bbox_crop(bboxes, labels) 72 | images = self.decode_train(images, crop_begin, crop_size) 73 | resize = self.rand_resize() 74 | images, attrs = self.resize_train(images, resize_longer=resize) 75 | 76 | flip = self.coin_flip() 77 | bboxes = self.bbox_flip(bboxes, horizontal=flip) 78 | images = self.img_flip(images, horizontal=flip) 79 | 80 | if self.augment_brightness or self.augment_contrast: 81 | images = self.bc(images, brightness=self.brightness_dist(), contrast=self.contrast_dist()) 82 | if self.augment_hue or self.augment_saturation: 83 | images = self.hsv(images, hue=self.hue_dist(), saturation=self.saturation_dist()) 84 | 85 | else: 86 | images = self.decode_infer(images) 87 | images, attrs = self.resize_infer(images) 88 | 89 | resized_images = images 90 | images = self.normalize(self.pad(images)) 91 | 92 | return images, bboxes, labels, img_ids, attrs, resized_images 93 | 94 | 95 | class DaliDataIterator(): 96 | 'Data loader for data parallel using Dali' 97 | 98 | def __init__(self, path, resize, max_size, batch_size, stride, world, annotations, training=False, 99 | rotate_augment=False, augment_brightness=0.0, 100 | augment_contrast=0.0, augment_hue=0.0, augment_saturation=0.0): 101 | self.training = training 102 | self.resize = resize 103 | self.max_size = max_size 104 | self.stride = stride 105 | self.batch_size = batch_size // world 106 | self.mean = [255. * x for x in [0.485, 0.456, 0.406]] 107 | self.std = [255. * x for x in [0.229, 0.224, 0.225]] 108 | self.world = world 109 | self.path = path 110 | 111 | # Setup COCO 112 | with redirect_stdout(None): 113 | self.coco = COCO(annotations) 114 | self.ids = list(self.coco.imgs.keys()) 115 | if 'categories' in self.coco.dataset: 116 | self.categories_inv = {k: i for i, k in enumerate(self.coco.getCatIds())} 117 | 118 | self.pipe = COCOPipeline(batch_size=self.batch_size, num_threads=2, 119 | path=path, training=training, annotations=annotations, world=world, 120 | device_id=torch.cuda.current_device(), mean=self.mean, std=self.std, resize=resize, 121 | max_size=max_size, stride=self.stride, rotate_augment=rotate_augment, 122 | augment_brightness=augment_brightness, 123 | augment_contrast=augment_contrast, augment_hue=augment_hue, 124 | augment_saturation=augment_saturation) 125 | 126 | self.pipe.build() 127 | 128 | def __repr__(self): 129 | return '\n'.join([ 130 | ' loader: dali', 131 | ' resize: {}, max: {}'.format(self.resize, self.max_size), 132 | ]) 133 | 134 | def __len__(self): 135 | return ceil(len(self.ids) // self.world / self.batch_size) 136 | 137 | def __iter__(self): 138 | for _ in range(self.__len__()): 139 | 140 | data, ratios, ids, num_detections = [], [], [], [] 141 | dali_data, dali_boxes, dali_labels, dali_ids, dali_attrs, dali_resize_img = self.pipe.run() 142 | 143 | for l in range(len(dali_boxes)): 144 | num_detections.append(dali_boxes.at(l).shape[0]) 145 | 146 | pyt_targets = -1 * torch.ones([len(dali_boxes), max(max(num_detections), 1), 5]) 147 | 148 | for batch in range(self.batch_size): 149 | id = int(dali_ids.at(batch)[0]) 150 | 151 | # Convert dali tensor to pytorch 152 | dali_tensor = dali_data[batch] 153 | tensor_shape = dali_tensor.shape() 154 | 155 | datum = torch.zeros(dali_tensor.shape(), dtype=torch.float, device=torch.device('cuda')) 156 | c_type_pointer = ctypes.c_void_p(datum.data_ptr()) 157 | dali_tensor.copy_to_external(c_type_pointer) 158 | 159 | # Calculate image resize ratio to rescale boxes 160 | prior_size = dali_attrs.as_cpu().at(batch) 161 | resized_size = dali_resize_img[batch].shape() 162 | ratio = max(resized_size) / max(prior_size) 163 | 164 | if self.training: 165 | # Rescale boxes 166 | b_arr = dali_boxes.at(batch) 167 | num_dets = b_arr.shape[0] 168 | if num_dets!=0: 169 | pyt_bbox = torch.from_numpy(b_arr).float() 170 | 171 | pyt_bbox[:, 0] *= float(prior_size[1]) 172 | pyt_bbox[:, 1] *= float(prior_size[0]) 173 | pyt_bbox[:, 2] *= float(prior_size[1]) 174 | pyt_bbox[:, 3] *= float(prior_size[0]) 175 | # (l,t,r,b) -> (x,y,w,h) == (l,r, r-l, b-t) 176 | pyt_bbox[:, 2] -= pyt_bbox[:, 0] 177 | pyt_bbox[:, 3] -= pyt_bbox[:, 1] 178 | pyt_targets[batch, :num_dets, :4] = pyt_bbox * ratio 179 | 180 | # Arrange labels in target tensor 181 | l_arr = dali_labels.at(batch) 182 | if num_dets!=0: 183 | pyt_label = torch.from_numpy(l_arr).float() 184 | pyt_label -= 1 # Rescale labels to [0,79] instead of [1,80] 185 | pyt_targets[batch, :num_dets, 4] = pyt_label.squeeze() 186 | 187 | ids.append(id) 188 | data.append(datum.unsqueeze(0)) 189 | ratios.append(ratio) 190 | 191 | data = torch.cat(data, dim=0) 192 | 193 | if self.training: 194 | pyt_targets = pyt_targets.cuda(non_blocking=True) 195 | yield data, pyt_targets 196 | 197 | else: 198 | ids = torch.Tensor(ids).int().cuda(non_blocking=True) 199 | ratios = torch.Tensor(ratios).cuda(non_blocking=True) 200 | yield data, ids, ratios 201 | 202 | -------------------------------------------------------------------------------- /odtk/infer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import tempfile 4 | from contextlib import redirect_stdout 5 | import torch 6 | from apex import amp 7 | from apex.parallel import DistributedDataParallel as ADDP 8 | from torch.nn.parallel import DistributedDataParallel 9 | from pycocotools.cocoeval import COCOeval 10 | import numpy as np 11 | 12 | from .data import DataIterator, RotatedDataIterator 13 | from .dali import DaliDataIterator 14 | from .model import Model 15 | from .utils import Profiler, rotate_box 16 | 17 | 18 | def infer(model, path, detections_file, resize, max_size, batch_size, mixed_precision=True, is_master=True, world=0, 19 | annotations=None, with_apex=False, use_dali=True, is_validation=False, verbose=True, rotated_bbox=False): 20 | 'Run inference on images from path' 21 | 22 | DDP = DistributedDataParallel if not with_apex else ADDP 23 | backend = 'pytorch' if isinstance(model, Model) or isinstance(model, DDP) else 'tensorrt' 24 | 25 | stride = model.module.stride if isinstance(model, DDP) else model.stride 26 | 27 | # Create annotations if none was provided 28 | if not annotations: 29 | annotations = tempfile.mktemp('.json') 30 | images = [{'id': i, 'file_name': f} for i, f in enumerate(os.listdir(path))] 31 | json.dump({'images': images}, open(annotations, 'w')) 32 | 33 | # TensorRT only supports fixed input sizes, so override input size accordingly 34 | if backend == 'tensorrt': max_size = max(model.input_size) 35 | 36 | # Prepare dataset 37 | if verbose: print('Preparing dataset...') 38 | if rotated_bbox: 39 | if use_dali: raise NotImplementedError("This repo does not currently support DALI for rotated bbox.") 40 | data_iterator = RotatedDataIterator(path, resize, max_size, batch_size, stride, 41 | world, annotations, training=False) 42 | else: 43 | data_iterator = (DaliDataIterator if use_dali else DataIterator)( 44 | path, resize, max_size, batch_size, stride, 45 | world, annotations, training=False) 46 | if verbose: print(data_iterator) 47 | 48 | # Prepare model 49 | if backend == 'pytorch': 50 | # If we are doing validation during training, 51 | # no need to register model with AMP again 52 | if not is_validation: 53 | if torch.cuda.is_available(): model = model.to(memory_format=torch.channels_last).cuda() 54 | if with_apex: 55 | model = amp.initialize(model, None, 56 | opt_level='O2' if mixed_precision else 'O0', 57 | keep_batchnorm_fp32=True, 58 | verbosity=0) 59 | 60 | model.eval() 61 | 62 | if verbose: 63 | print(' backend: {}'.format(backend)) 64 | print(' device: {} {}'.format( 65 | world, 'cpu' if not torch.cuda.is_available() else 'GPU' if world == 1 else 'GPUs')) 66 | print(' batch: {}, precision: {}'.format(batch_size, 67 | 'unknown' if backend == 'tensorrt' else 'mixed' if mixed_precision else 'full')) 68 | print(' BBOX type:', 'rotated' if rotated_bbox else 'axis aligned') 69 | print('Running inference...') 70 | 71 | results = [] 72 | profiler = Profiler(['infer', 'fw']) 73 | with torch.no_grad(): 74 | for i, (data, ids, ratios) in enumerate(data_iterator): 75 | # Forward pass 76 | if backend=='pytorch': data = data.contiguous(memory_format=torch.channels_last) 77 | profiler.start('fw') 78 | scores, boxes, classes = model(data, rotated_bbox) #Need to add model size (B, 3, W, H) 79 | profiler.stop('fw') 80 | 81 | results.append([scores, boxes, classes, ids, ratios]) 82 | 83 | profiler.bump('infer') 84 | if verbose and (profiler.totals['infer'] > 60 or i == len(data_iterator) - 1): 85 | size = len(data_iterator.ids) 86 | msg = '[{:{len}}/{}]'.format(min((i + 1) * batch_size, 87 | size), size, len=len(str(size))) 88 | msg += ' {:.3f}s/{}-batch'.format(profiler.means['infer'], batch_size) 89 | msg += ' (fw: {:.3f}s)'.format(profiler.means['fw']) 90 | msg += ', {:.1f} im/s'.format(batch_size / profiler.means['infer']) 91 | print(msg, flush=True) 92 | 93 | profiler.reset() 94 | 95 | # Gather results from all devices 96 | if verbose: print('Gathering results...') 97 | results = [torch.cat(r, dim=0) for r in zip(*results)] 98 | if world > 1: 99 | for r, result in enumerate(results): 100 | all_result = [torch.ones_like(result, device=result.device) for _ in range(world)] 101 | torch.distributed.all_gather(list(all_result), result) 102 | results[r] = torch.cat(all_result, dim=0) 103 | 104 | if is_master: 105 | # Copy buffers back to host 106 | results = [r.cpu() for r in results] 107 | 108 | # Collect detections 109 | detections = [] 110 | processed_ids = set() 111 | for scores, boxes, classes, image_id, ratios in zip(*results): 112 | image_id = image_id.item() 113 | if image_id in processed_ids: 114 | continue 115 | processed_ids.add(image_id) 116 | 117 | keep = (scores > 0).nonzero(as_tuple=False) 118 | scores = scores[keep].view(-1) 119 | if rotated_bbox: 120 | boxes = boxes[keep, :].view(-1, 6) 121 | boxes[:, :4] /= ratios 122 | else: 123 | boxes = boxes[keep, :].view(-1, 4) / ratios 124 | classes = classes[keep].view(-1).int() 125 | 126 | for score, box, cat in zip(scores, boxes, classes): 127 | if rotated_bbox: 128 | x1, y1, x2, y2, sin, cos = box.data.tolist() 129 | theta = np.arctan2(sin, cos) 130 | w = x2 - x1 + 1 131 | h = y2 - y1 + 1 132 | seg = rotate_box([x1, y1, w, h, theta]) 133 | else: 134 | x1, y1, x2, y2 = box.data.tolist() 135 | cat = cat.item() 136 | if 'annotations' in data_iterator.coco.dataset: 137 | cat = data_iterator.coco.getCatIds()[cat] 138 | this_det = { 139 | 'image_id': image_id, 140 | 'score': score.item(), 141 | 'category_id': cat} 142 | if rotated_bbox: 143 | this_det['bbox'] = [x1, y1, x2 - x1 + 1, y2 - y1 + 1, theta] 144 | this_det['segmentation'] = [seg] 145 | else: 146 | this_det['bbox'] = [x1, y1, x2 - x1 + 1, y2 - y1 + 1] 147 | 148 | detections.append(this_det) 149 | 150 | if detections: 151 | # Save detections 152 | if detections_file and verbose: print('Writing {}...'.format(detections_file)) 153 | detections = {'annotations': detections} 154 | detections['images'] = data_iterator.coco.dataset['images'] 155 | if 'categories' in data_iterator.coco.dataset: 156 | detections['categories'] = data_iterator.coco.dataset['categories'] 157 | if detections_file: 158 | for d_file in detections_file: 159 | json.dump(detections, open(d_file, 'w'), indent=4) 160 | 161 | # Evaluate model on dataset 162 | if 'annotations' in data_iterator.coco.dataset: 163 | if verbose: print('Evaluating model...') 164 | with redirect_stdout(None): 165 | coco_pred = data_iterator.coco.loadRes(detections['annotations']) 166 | if rotated_bbox: 167 | coco_eval = COCOeval(data_iterator.coco, coco_pred, 'segm') 168 | else: 169 | coco_eval = COCOeval(data_iterator.coco, coco_pred, 'bbox') 170 | coco_eval.evaluate() 171 | coco_eval.accumulate() 172 | coco_eval.summarize() 173 | return coco_eval.stats # mAP and mAR 174 | else: 175 | print('No detections!') 176 | return None 177 | return 0 178 | -------------------------------------------------------------------------------- /odtk/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class FocalLoss(nn.Module): 6 | 'Focal Loss - https://arxiv.org/abs/1708.02002' 7 | 8 | def __init__(self, alpha=0.25, gamma=2): 9 | super().__init__() 10 | self.alpha = alpha 11 | self.gamma = gamma 12 | 13 | def forward(self, pred_logits, target): 14 | pred = pred_logits.sigmoid() 15 | ce = F.binary_cross_entropy_with_logits(pred_logits, target, reduction='none') 16 | alpha = target * self.alpha + (1. - target) * (1. - self.alpha) 17 | pt = torch.where(target == 1, pred, 1 - pred) 18 | return alpha * (1. - pt) ** self.gamma * ce 19 | 20 | class SmoothL1Loss(nn.Module): 21 | 'Smooth L1 Loss' 22 | 23 | def __init__(self, beta=0.11): 24 | super().__init__() 25 | self.beta = beta 26 | 27 | def forward(self, pred, target): 28 | x = (pred - target).abs() 29 | l1 = x - 0.5 * self.beta 30 | l2 = 0.5 * x ** 2 / self.beta 31 | return torch.where(x >= self.beta, l1, l2) 32 | -------------------------------------------------------------------------------- /odtk/train.py: -------------------------------------------------------------------------------- 1 | from statistics import mean 2 | from math import isfinite 3 | import torch 4 | from torch.optim import SGD, AdamW 5 | from torch.optim.lr_scheduler import LambdaLR 6 | from apex import amp, optimizers 7 | from apex.parallel import DistributedDataParallel as ADDP 8 | from torch.nn.parallel import DistributedDataParallel as DDP 9 | from torch.cuda.amp import GradScaler, autocast 10 | from .backbones.layers import convert_fixedbn_model 11 | 12 | from .data import DataIterator, RotatedDataIterator 13 | from .dali import DaliDataIterator 14 | from .utils import ignore_sigint, post_metrics, Profiler 15 | from .infer import infer 16 | 17 | 18 | def train(model, state, path, annotations, val_path, val_annotations, resize, max_size, jitter, batch_size, iterations, 19 | val_iterations, lr, warmup, milestones, gamma, rank=0, world=1, mixed_precision=True, with_apex=False, 20 | use_dali=True, verbose=True, metrics_url=None, logdir=None, rotate_augment=False, augment_brightness=0.0, 21 | augment_contrast=0.0, augment_hue=0.0, augment_saturation=0.0, regularization_l2=0.0001, rotated_bbox=False, 22 | absolute_angle=False): 23 | 'Train the model on the given dataset' 24 | 25 | # Prepare model 26 | nn_model = model 27 | stride = model.stride 28 | 29 | model = convert_fixedbn_model(model) 30 | if torch.cuda.is_available(): 31 | model = model.to(memory_format=torch.channels_last).cuda() 32 | 33 | # Setup optimizer and schedule 34 | optimizer = SGD(model.parameters(), lr=lr, weight_decay=regularization_l2, momentum=0.9) 35 | 36 | is_master = rank==0 37 | if with_apex: 38 | loss_scale = "dynamic" if use_dali else "128.0" 39 | model, optimizer = amp.initialize(model, optimizer, 40 | opt_level='O2' if mixed_precision else 'O0', 41 | keep_batchnorm_fp32=True, 42 | loss_scale=loss_scale, 43 | verbosity=is_master) 44 | 45 | if world > 1: 46 | model = DDP(model, device_ids=[rank]) if not with_apex else ADDP(model) 47 | model.train() 48 | 49 | if 'optimizer' in state: 50 | optimizer.load_state_dict(state['optimizer']) 51 | 52 | def schedule(train_iter): 53 | if warmup and train_iter <= warmup: 54 | return 0.9 * train_iter / warmup + 0.1 55 | return gamma ** len([m for m in milestones if m <= train_iter]) 56 | 57 | scheduler = LambdaLR(optimizer, schedule) 58 | if 'scheduler' in state: 59 | scheduler.load_state_dict(state['scheduler']) 60 | 61 | # Prepare dataset 62 | if verbose: print('Preparing dataset...') 63 | if rotated_bbox: 64 | if use_dali: raise NotImplementedError("This repo does not currently support DALI for rotated bbox detections.") 65 | data_iterator = RotatedDataIterator(path, jitter, max_size, batch_size, stride, 66 | world, annotations, training=True, rotate_augment=rotate_augment, 67 | augment_brightness=augment_brightness, 68 | augment_contrast=augment_contrast, augment_hue=augment_hue, 69 | augment_saturation=augment_saturation, absolute_angle=absolute_angle) 70 | else: 71 | data_iterator = (DaliDataIterator if use_dali else DataIterator)( 72 | path, jitter, max_size, batch_size, stride, 73 | world, annotations, training=True, rotate_augment=rotate_augment, augment_brightness=augment_brightness, 74 | augment_contrast=augment_contrast, augment_hue=augment_hue, augment_saturation=augment_saturation) 75 | if verbose: print(data_iterator) 76 | 77 | if verbose: 78 | print(' device: {} {}'.format( 79 | world, 'cpu' if not torch.cuda.is_available() else 'GPU' if world == 1 else 'GPUs')) 80 | print(' batch: {}, precision: {}'.format(batch_size, 'mixed' if mixed_precision else 'full')) 81 | print(' BBOX type:', 'rotated' if rotated_bbox else 'axis aligned') 82 | print('Training model for {} iterations...'.format(iterations)) 83 | 84 | # Create TensorBoard writer 85 | if is_master and logdir is not None: 86 | from torch.utils.tensorboard import SummaryWriter 87 | if verbose: 88 | print('Writing TensorBoard logs to: {}'.format(logdir)) 89 | writer = SummaryWriter(log_dir=logdir) 90 | 91 | scaler = GradScaler(enabled=mixed_precision) 92 | profiler = Profiler(['train', 'fw', 'bw']) 93 | iteration = state.get('iteration', 0) 94 | while iteration < iterations: 95 | cls_losses, box_losses = [], [] 96 | for i, (data, target) in enumerate(data_iterator): 97 | if iteration>=iterations: 98 | break 99 | 100 | # Forward pass 101 | profiler.start('fw') 102 | 103 | optimizer.zero_grad() 104 | if with_apex: 105 | cls_loss, box_loss = model([data.contiguous(memory_format=torch.channels_last), target]) 106 | else: 107 | with autocast(enabled=mixed_precision): 108 | cls_loss, box_loss = model([data.contiguous(memory_format=torch.channels_last), target]) 109 | del data 110 | profiler.stop('fw') 111 | 112 | # Backward pass 113 | profiler.start('bw') 114 | if with_apex: 115 | with amp.scale_loss(cls_loss + box_loss, optimizer) as scaled_loss: 116 | scaled_loss.backward() 117 | optimizer.step() 118 | else: 119 | scaler.scale(cls_loss + box_loss).backward() 120 | scaler.step(optimizer) 121 | scaler.update() 122 | 123 | scheduler.step() 124 | 125 | # Reduce all losses 126 | cls_loss, box_loss = cls_loss.mean().clone(), box_loss.mean().clone() 127 | if world > 1: 128 | torch.distributed.all_reduce(cls_loss) 129 | torch.distributed.all_reduce(box_loss) 130 | cls_loss /= world 131 | box_loss /= world 132 | if is_master: 133 | cls_losses.append(cls_loss) 134 | box_losses.append(box_loss) 135 | 136 | if is_master and not isfinite(cls_loss + box_loss): 137 | raise RuntimeError('Loss is diverging!\n{}'.format( 138 | 'Try lowering the learning rate.')) 139 | 140 | del cls_loss, box_loss 141 | profiler.stop('bw') 142 | 143 | iteration += 1 144 | profiler.bump('train') 145 | if is_master and (profiler.totals['train'] > 60 or iteration == iterations): 146 | focal_loss = torch.stack(list(cls_losses)).mean().item() 147 | box_loss = torch.stack(list(box_losses)).mean().item() 148 | learning_rate = optimizer.param_groups[0]['lr'] 149 | if verbose: 150 | msg = '[{:{len}}/{}]'.format(iteration, iterations, len=len(str(iterations))) 151 | msg += ' focal loss: {:.3f}'.format(focal_loss) 152 | msg += ', box loss: {:.3f}'.format(box_loss) 153 | msg += ', {:.3f}s/{}-batch'.format(profiler.means['train'], batch_size) 154 | msg += ' (fw: {:.3f}s, bw: {:.3f}s)'.format(profiler.means['fw'], profiler.means['bw']) 155 | msg += ', {:.1f} im/s'.format(batch_size / profiler.means['train']) 156 | msg += ', lr: {:.2g}'.format(learning_rate) 157 | print(msg, flush=True) 158 | 159 | if is_master and logdir is not None: 160 | writer.add_scalar('focal_loss', focal_loss, iteration) 161 | writer.add_scalar('box_loss', box_loss, iteration) 162 | writer.add_scalar('learning_rate', learning_rate, iteration) 163 | del box_loss, focal_loss 164 | 165 | if metrics_url: 166 | post_metrics(metrics_url, { 167 | 'focal loss': mean(cls_losses), 168 | 'box loss': mean(box_losses), 169 | 'im_s': batch_size / profiler.means['train'], 170 | 'lr': learning_rate 171 | }) 172 | 173 | # Save model weights 174 | state.update({ 175 | 'iteration': iteration, 176 | 'optimizer': optimizer.state_dict(), 177 | 'scheduler': scheduler.state_dict(), 178 | }) 179 | with ignore_sigint(): 180 | nn_model.save(state) 181 | 182 | profiler.reset() 183 | del cls_losses[:], box_losses[:] 184 | 185 | if val_annotations and (iteration == iterations or iteration % val_iterations == 0): 186 | stats = infer(model, val_path, None, resize, max_size, batch_size, annotations=val_annotations, 187 | mixed_precision=mixed_precision, is_master=is_master, world=world, use_dali=use_dali, 188 | with_apex=with_apex, is_validation=True, verbose=False, rotated_bbox=rotated_bbox) 189 | model.train() 190 | if is_master and logdir is not None and stats is not None: 191 | writer.add_scalar( 192 | 'Validation_Precision/mAP', stats[0], iteration) 193 | writer.add_scalar( 194 | 'Validation_Precision/mAP@0.50IoU', stats[1], iteration) 195 | writer.add_scalar( 196 | 'Validation_Precision/mAP@0.75IoU', stats[2], iteration) 197 | writer.add_scalar( 198 | 'Validation_Precision/mAP (small)', stats[3], iteration) 199 | writer.add_scalar( 200 | 'Validation_Precision/mAP (medium)', stats[4], iteration) 201 | writer.add_scalar( 202 | 'Validation_Precision/mAP (large)', stats[5], iteration) 203 | writer.add_scalar( 204 | 'Validation_Recall/mAR (max 1 Dets)', stats[6], iteration) 205 | writer.add_scalar( 206 | 'Validation_Recall/mAR (max 10 Dets)', stats[7], iteration) 207 | writer.add_scalar( 208 | 'Validation_Recall/mAR (max 100 Dets)', stats[8], iteration) 209 | writer.add_scalar( 210 | 'Validation_Recall/mAR (small)', stats[9], iteration) 211 | writer.add_scalar( 212 | 'Validation_Recall/mAR (medium)', stats[10], iteration) 213 | writer.add_scalar( 214 | 'Validation_Recall/mAR (large)', stats[11], iteration) 215 | 216 | if (iteration==iterations and not rotated_bbox) or (iteration>iterations and rotated_bbox): 217 | break 218 | 219 | if is_master and logdir is not None: 220 | writer.close() 221 | -------------------------------------------------------------------------------- /odtk/utils.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import time 3 | import json 4 | import warnings 5 | import signal 6 | from datetime import datetime 7 | from contextlib import contextmanager 8 | from PIL import Image, ImageDraw 9 | import requests 10 | import numpy as np 11 | import math 12 | import torch 13 | 14 | 15 | def order_points(pts): 16 | pts_reorder = [] 17 | 18 | for idx, pt in enumerate(pts): 19 | idx = torch.argsort(pt[:, 0]) 20 | xSorted = pt[idx, :] 21 | leftMost = xSorted[:2, :] 22 | rightMost = xSorted[2:, :] 23 | 24 | leftMost = leftMost[torch.argsort(leftMost[:, 1]), :] 25 | (tl, bl) = leftMost 26 | 27 | D = torch.cdist(tl[np.newaxis], rightMost)[0] 28 | (br, tr) = rightMost[torch.argsort(D, descending=True), :] 29 | pts_reorder.append(torch.stack([tl, tr, br, bl])) 30 | 31 | return torch.stack([p for p in pts_reorder]) 32 | 33 | def rotate_boxes(boxes, points=False): 34 | ''' 35 | Rotate target bounding boxes 36 | 37 | Input: 38 | Target boxes (xmin_ymin, width_height, theta) 39 | Output: 40 | boxes_axis (xmin_ymin, xmax_ymax, theta) 41 | boxes_rotated (xy0, xy1, xy2, xy3) 42 | ''' 43 | 44 | u = torch.stack([torch.cos(boxes[:,4]), torch.sin(boxes[:,4])], dim=1) 45 | l = torch.stack([-torch.sin(boxes[:,4]), torch.cos(boxes[:,4])], dim=1) 46 | R = torch.stack([u, l], dim=1) 47 | 48 | if points: 49 | cents = torch.stack([(boxes[:,0]+boxes[:,2])/2, (boxes[:,1]+boxes[:,3])/2],1).transpose(1,0) 50 | boxes_rotated = torch.stack([boxes[:,0],boxes[:,1], 51 | boxes[:,2], boxes[:,1], 52 | boxes[:,2], boxes[:,3], 53 | boxes[:,0], boxes[:,3], 54 | boxes[:,-2], 55 | boxes[:,-1]],1) 56 | 57 | else: 58 | cents = torch.stack([boxes[:,0]+(boxes[:,2])/2, boxes[:,1]+(boxes[:,3])/2],1).transpose(1,0) 59 | boxes_rotated = torch.stack([boxes[:,0],boxes[:,1], 60 | (boxes[:,0]+boxes[:,2]), boxes[:,1], 61 | (boxes[:,0]+boxes[:,2]), (boxes[:,1]+boxes[:,3]), 62 | boxes[:,0], (boxes[:,1]+boxes[:,3]), 63 | boxes[:,-2], 64 | boxes[:,-1]],1) 65 | 66 | xy0R = torch.matmul(R,boxes_rotated[:,:2].transpose(1,0) - cents) + cents 67 | xy1R = torch.matmul(R,boxes_rotated[:,2:4].transpose(1,0) - cents) + cents 68 | xy2R = torch.matmul(R,boxes_rotated[:,4:6].transpose(1,0) - cents) + cents 69 | xy3R = torch.matmul(R,boxes_rotated[:,6:8].transpose(1,0) - cents) + cents 70 | 71 | xy0R = torch.stack([xy0R[i,:,i] for i in range(xy0R.size(0))]) 72 | xy1R = torch.stack([xy1R[i,:,i] for i in range(xy1R.size(0))]) 73 | xy2R = torch.stack([xy2R[i,:,i] for i in range(xy2R.size(0))]) 74 | xy3R = torch.stack([xy3R[i,:,i] for i in range(xy3R.size(0))]) 75 | 76 | boxes_axis = torch.cat([boxes[:, :2], boxes[:, :2] + boxes[:, 2:4] - 1, 77 | torch.sin(boxes[:,-1, None]), torch.cos(boxes[:,-1, None])], 1) 78 | boxes_rotated = order_points(torch.stack([xy0R,xy1R,xy2R,xy3R],dim = 1)).view(-1,8) 79 | 80 | return boxes_axis, boxes_rotated 81 | 82 | 83 | def rotate_box(bbox): 84 | xmin, ymin, width, height, theta = bbox 85 | 86 | xy1 = xmin, ymin 87 | xy2 = xmin, ymin + height - 1 88 | xy3 = xmin + width - 1, ymin + height - 1 89 | xy4 = xmin + width - 1, ymin 90 | 91 | cents = np.array([xmin + (width - 1) / 2, ymin + (height - 1) / 2]) 92 | 93 | corners = np.stack([xy1, xy2, xy3, xy4]) 94 | 95 | u = np.stack([np.cos(theta), -np.sin(theta)]) 96 | l = np.stack([np.sin(theta), np.cos(theta)]) 97 | R = np.vstack([u, l]) 98 | 99 | corners = np.matmul(R, (corners - cents).transpose(1, 0)).transpose(1, 0) + cents 100 | 101 | return corners.reshape(-1).tolist() 102 | 103 | 104 | def show_detections(detections): 105 | 'Show image with drawn detections' 106 | 107 | for image, detections in detections.items(): 108 | im = Image.open(image).convert('RGBA') 109 | overlay = Image.new('RGBA', im.size, (255, 255, 255, 0)) 110 | draw = ImageDraw.Draw(overlay) 111 | detections.sort(key=lambda d: d['score']) 112 | for detection in detections: 113 | box = detection['bbox'] 114 | alpha = int(detection['score'] * 255) 115 | draw.rectangle(box, outline=(255, 255, 255, alpha)) 116 | draw.text((box[0] + 2, box[1]), '[{}]'.format(detection['class']), 117 | fill=(255, 255, 255, alpha)) 118 | draw.text((box[0] + 2, box[1] + 10), '{:.2}'.format(detection['score']), 119 | fill=(255, 255, 255, alpha)) 120 | im = Image.alpha_composite(im, overlay) 121 | im.show() 122 | 123 | 124 | def save_detections(path, detections): 125 | print('Writing detections to {}...'.format(os.path.basename(path))) 126 | with open(path, 'w') as f: 127 | json.dump(detections, f) 128 | 129 | 130 | @contextmanager 131 | def ignore_sigint(): 132 | handler = signal.getsignal(signal.SIGINT) 133 | signal.signal(signal.SIGINT, signal.SIG_IGN) 134 | try: 135 | yield 136 | finally: 137 | signal.signal(signal.SIGINT, handler) 138 | 139 | 140 | class Profiler(object): 141 | def __init__(self, names=['main']): 142 | self.names = names 143 | self.lasts = {k: 0 for k in names} 144 | self.totals = self.lasts.copy() 145 | self.counts = self.lasts.copy() 146 | self.means = self.lasts.copy() 147 | self.reset() 148 | 149 | def reset(self): 150 | last = time.time() 151 | for name in self.names: 152 | self.lasts[name] = last 153 | self.totals[name] = 0 154 | self.counts[name] = 0 155 | self.means[name] = 0 156 | 157 | def start(self, name='main'): 158 | self.lasts[name] = time.time() 159 | 160 | def stop(self, name='main'): 161 | self.totals[name] += time.time() - self.lasts[name] 162 | self.counts[name] += 1 163 | self.means[name] = self.totals[name] / self.counts[name] 164 | 165 | def bump(self, name='main'): 166 | self.stop(name) 167 | self.start(name) 168 | 169 | 170 | def post_metrics(url, metrics): 171 | try: 172 | for k, v in metrics.items(): 173 | requests.post(url, 174 | data={'time': int(datetime.now().timestamp() * 1e9), 175 | 'metric': k, 'value': v}) 176 | except Exception as e: 177 | warnings.warn('Warning: posting metrics failed: {}'.format(e)) 178 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 3 | 4 | setup( 5 | name='odtk', 6 | version='0.2.6', 7 | description='Fast and accurate single shot object detector', 8 | author = 'NVIDIA Corporation', 9 | packages=['odtk', 'odtk.backbones'], 10 | ext_modules=[CUDAExtension('odtk._C', 11 | ['csrc/extensions.cpp', 'csrc/engine.cpp', 'csrc/cuda/decode.cu', 'csrc/cuda/decode_rotate.cu', 'csrc/cuda/nms.cu', 'csrc/cuda/nms_iou.cu'], 12 | extra_compile_args={ 13 | 'cxx': ['-std=c++14', '-O2', '-Wall'], 14 | 'nvcc': [ 15 | '-std=c++14', '--expt-extended-lambda', '--use_fast_math', '-Xcompiler', '-Wall,-fno-gnu-unique', 16 | '-gencode=arch=compute_60,code=sm_60', '-gencode=arch=compute_61,code=sm_61', 17 | '-gencode=arch=compute_70,code=sm_70', '-gencode=arch=compute_72,code=sm_72', 18 | '-gencode=arch=compute_75,code=sm_75', '-gencode=arch=compute_80,code=sm_80', 19 | '-gencode=arch=compute_86,code=sm_86', '-gencode=arch=compute_86,code=compute_86' 20 | ], 21 | }, 22 | libraries=['nvinfer', 'nvinfer_plugin', 'nvonnxparser', 'opencv_core', 'opencv_imgproc', 'opencv_highgui', 'opencv_imgcodecs']) 23 | ], 24 | cmdclass={'build_ext': BuildExtension.with_options(no_python_abi_suffix=True)}, 25 | install_requires=[ 26 | 'torch>=1.0.0a0', 27 | 'torchvision', 28 | 'apex @ git+https://github.com/NVIDIA/apex', 29 | 'pycocotools @ git+https://github.com/nvidia/cocoapi.git#subdirectory=PythonAPI', 30 | 'pillow', 31 | 'requests', 32 | ], 33 | entry_points = {'console_scripts': ['odtk=odtk.main:main']} 34 | ) 35 | --------------------------------------------------------------------------------