├── README.md ├── before_build.sh ├── coco.names ├── docker ├── .dockerignore └── Dockerfile ├── input_image ├── bus.jpg └── zidane.jpg ├── model ├── yolov5s_torchscript_B1-C3-H640-W640_torch1-6-0_cuda10-2_cpu.pt └── yolov5s_torchscript_B1-C3-H640-W640_torch1-6-0_cuda10-2_gpu.pt ├── result_image └── result_image_YOLOv5 │ ├── YOLOv5_segfault_while_loding_TorchScript.png │ ├── result_bus.jpg │ └── result_zidane.jpg └── src ├── CMakeLists.txt ├── cxxopts └── cxxopts.hpp ├── main.cpp └── object_detector ├── CMakeLists.txt ├── include └── object_detector.h └── src └── object_detector.cpp /README.md: -------------------------------------------------------------------------------- 1 | # YOLOv5 with PyTorch c++ 2 | 3 | A c++ implementation of [Ultralytics LLC's YOLOv5](https://github.com/ultralytics/yolov5) with [PyTorch c++ API](https://pytorch.org/cppdocs/) (LibTorch) inspired by [yasenh/libtorch-yolov5](https://github.com/yasenh/libtorch-yolov5). 4 | 5 | ![alt](./result_image/result_image_YOLOv5/result_zidane.jpg) 6 | 7 | ![alt](./result_image/result_image_YOLOv5/result_bus.jpg) 8 | 9 | 10 | 11 | # Pre-requirements 12 | 13 | Docker 19.03 or later with NVIDIA Container Toolkit is recommended to run this code without having troubles related to dependencies' version. 14 | 15 | * Docker 19.03+ 16 | * NVIDIA driver 17 | * NVIDIA Container Toolkit 18 | 19 | 20 | 21 | # How to build 22 | 23 | In the host machine, pull a docker image from [my DockerHub repository](https://hub.docker.com/repository/docker/hotsuyuki/ubuntu18_04-cuda10_2-cudnn7_6_5-tensorrt7_0_0-opencv4_4_0) and launch a docker container. 24 | 25 | ``` 26 | $ git clone https://github.com/hotsuyuki/YOLOv5_PyTorch_cpp.git 27 | $ cd YOLOv5_PyTorch_cpp/ 28 | $ docker container run --gpus all --rm -it -v $PWD:/workspace/YOLOv5_PyTorch_cpp hotsuyuki/ubuntu18_04-cuda10_2-cudnn7_6_5-tensorrt7_0_0-opencv4_4_0 29 | ``` 30 | 31 | Then in the docker container, download LibTorch v1.6.0 and unzip it by running `before_build.sh`. 32 | 33 | ``` 34 | # cd /workspace/YOLOv5_PyTorch_cpp/ 35 | # sh ./before_build.sh 36 | ``` 37 | 38 | Finally, build the source code. 39 | 40 | ``` 41 | # cd /workspace/YOLOv5_PyTorch_cpp/src/ 42 | # mkdir build 43 | # cd build/ 44 | # cmake .. 45 | # cmake --build . 46 | ``` 47 | 48 | This command would produce an executable file `main` in the `build` directory. 49 | 50 | 51 | 52 | # How to run 53 | 54 | The executable file `main` requires at least two arguments: 55 | 56 | * `input-dir` (directory path to the input images) 57 | * `model-file` (file path to the model weights) 58 | 59 | ## (a) Inference on CPU: 60 | 61 | ``` 62 | # ./main --input-dir ../../input_image/ --model-file ../../model/yolov5s_torchscript_B1-C3-H640-W640_torch1-6-0_cuda10-2_cpu.pt 63 | ``` 64 | or 65 | ``` 66 | # ./main ../../input_image/ ../../model/yolov5s_torchscript_B1-C3-H640-W640_torch1-6-0_cuda10-2_cpu.pt 67 | ``` 68 | 69 | ## (b) Inference on GPU: 70 | 71 | ``` 72 | # ./main --input-dir ../../input_image/ --model-file ../../model/yolov5s_torchscript_B1-C3-H640-W640_torch1-6-0_cuda10-2_gpu.pt 73 | ``` 74 | or 75 | ``` 76 | # ./main ../../input_image/ ../../model/yolov5s_torchscript_B1-C3-H640-W640_torch1-6-0_cuda10-2_gpu.pt 77 | ``` 78 | 79 | *NOTE: The model file name ends with "_cpu.pt" when using CPU, whereas it ends with "_gpu.pt" when using GPU.* 80 | 81 | This repository provides two TorchScript model files: 82 | 83 | * yolov5s_torchscript_B1-C3-H640-W640_torch1-6-0_cuda10-2_cpu.pt 84 | * yolov5s_torchscript_B1-C3-H640-W640_torch1-6-0_cuda10-2_gpu.pt 85 | 86 | (Both are exported from [ultralytics/yolov5/models/export.py](https://github.com/ultralytics/yolov5/blob/master/models/export.py)) 87 | 88 | The full arguments and options are shown below: 89 | 90 | ``` 91 | usage: 92 | main [OPTION...] input-dir model-file 93 | 94 | positional arguments: 95 | input-dir String: Path to input images directory 96 | model-file String: Path to TorchScript model file 97 | 98 | options: 99 | --conf-thres arg Float: Object confidence threshold (default: 0.25) 100 | --iou-thres arg Float: IoU threshold for NMS (default: 0.45) 101 | -h, --help Print usage 102 | ``` 103 | 104 | 105 | 106 | # Trouble shooting 107 | 108 | ### Problem: 109 | 110 | Sometimes, the YOLOv5 program becames no responce (or stops with segmentation fault) while loding a TorchScript model using `torch::jit::load()`. 111 | 112 | ![alt](./result_image/result_image_YOLOv5/YOLOv5_segfault_while_loding_TorchScript.png) 113 | 114 | ### Solution: 115 | 116 | Usually, the TorchScript model can be loaded within a few seconds. 117 | If the program keeps loading the model more than a minute, it is recommended to stop the program by `Ctrl + c`, and rerun it. 118 | (This problem does not happen in other TorchScript file, so the root cause might be in the exported TorchScript file.) 119 | 120 | 121 | 122 | # References 123 | 124 | * [ultralytics/yolov5](https://github.com/ultralytics/yolov5) 125 | * [yasenh/libtorch-yolov5](https://github.com/yasenh/libtorch-yolov5) 126 | * [TadaoYamaoka/cxxopts/include/cxxopts.hpp](https://github.com/TadaoYamaoka/cxxopts/blob/master/include/cxxopts.hpp) -------------------------------------------------------------------------------- /before_build.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | set -v 4 | 5 | 6 | ### LibTorch v1.6.0 7 | wget https://download.pytorch.org/libtorch/cu102/libtorch-cxx11-abi-shared-with-deps-1.6.0.zip 8 | unzip ./libtorch-cxx11-abi-shared-with-deps-1.6.0.zip -d ./src/ 9 | rm ./libtorch-cxx11-abi-shared-with-deps-1.6.0.zip 10 | mv ./src/libtorch/ ./src/libtorch_v1-6-0/ 11 | -------------------------------------------------------------------------------- /coco.names: -------------------------------------------------------------------------------- 1 | person 2 | bicycle 3 | car 4 | motorcycle 5 | airplane 6 | bus 7 | train 8 | truck 9 | boat 10 | traffic light 11 | fire hydrant 12 | stop sign 13 | parking meter 14 | bench 15 | bird 16 | cat 17 | dog 18 | horse 19 | sheep 20 | cow 21 | elephant 22 | bear 23 | zebra 24 | giraffe 25 | backpack 26 | umbrella 27 | handbag 28 | tie 29 | suitcase 30 | frisbee 31 | skis 32 | snowboard 33 | sports ball 34 | kite 35 | baseball bat 36 | baseball glove 37 | skateboard 38 | surfboard 39 | tennis racket 40 | bottle 41 | wine glass 42 | cup 43 | fork 44 | knife 45 | spoon 46 | bowl 47 | banana 48 | apple 49 | sandwich 50 | orange 51 | broccoli 52 | carrot 53 | hot dog 54 | pizza 55 | donut 56 | cake 57 | chair 58 | couch 59 | potted plant 60 | bed 61 | dining table 62 | toilet 63 | tv 64 | laptop 65 | mouse 66 | remote 67 | keyboard 68 | cell phone 69 | microwave 70 | oven 71 | toaster 72 | sink 73 | refrigerator 74 | book 75 | clock 76 | vase 77 | scissors 78 | teddy bear 79 | hair drier 80 | toothbrush 81 | -------------------------------------------------------------------------------- /docker/.dockerignore: -------------------------------------------------------------------------------- 1 | **/.DS_Store 2 | **/Thumbs.db 3 | **/__pycache__ -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | # Ubuntu 18.04, CUDA 10.2, cuDNN 7.6.5, TensorRT 7.0.0 2 | FROM nvcr.io/nvidia/tensorrt:20.01-py3 3 | 4 | ENV DEBIAN_FRONTEND noninteractive 5 | 6 | ARG OPENCV_VERSION="4.4.0" 7 | #ARG GPU_ARCH="5.0 5.2 6.1 7.0 7.5 8.6" 8 | 9 | WORKDIR /workspace 10 | 11 | # Build tools 12 | RUN apt update && \ 13 | apt install -y \ 14 | sudo \ 15 | tzdata \ 16 | git \ 17 | cmake \ 18 | wget \ 19 | unzip \ 20 | build-essential \ 21 | pkg-config 22 | 23 | # Media I/O 24 | RUN sudo apt install -y \ 25 | zlib1g-dev \ 26 | libjpeg-dev \ 27 | libwebp-dev \ 28 | libpng-dev \ 29 | libtiff5-dev \ 30 | libopenexr-dev \ 31 | libgdal-dev \ 32 | libgtk2.0-dev 33 | 34 | # Video I/O 35 | RUN sudo apt install -y \ 36 | libdc1394-22-dev \ 37 | libavcodec-dev \ 38 | libavformat-dev \ 39 | libswscale-dev \ 40 | libtheora-dev \ 41 | libvorbis-dev \ 42 | libxvidcore-dev \ 43 | libx264-dev \ 44 | yasm \ 45 | libopencore-amrnb-dev \ 46 | libopencore-amrwb-dev \ 47 | libv4l-dev \ 48 | libxine2-dev \ 49 | libgstreamer1.0-dev \ 50 | libgstreamer-plugins-base1.0-dev \ 51 | libopencv-highgui-dev \ 52 | ffmpeg 53 | 54 | # Parallelism 55 | RUN sudo apt install -y \ 56 | libtbb-dev 57 | 58 | # # Linear algebra 59 | # RUN sudo apt install -y libeigen3-dev 60 | 61 | # Python 62 | RUN sudo apt install -y \ 63 | python3.8 \ 64 | python3.8-venv 65 | 66 | # Build OpenCV 67 | RUN wget https://github.com/opencv/opencv/archive/${OPENCV_VERSION}.zip && \ 68 | unzip ${OPENCV_VERSION}.zip && rm ${OPENCV_VERSION}.zip && \ 69 | mv opencv-${OPENCV_VERSION} OpenCV && \ 70 | cd OpenCV && \ 71 | mkdir build && \ 72 | cd build && \ 73 | cmake \ 74 | -D WITH_TBB=ON \ 75 | -D CMAKE_BUILD_TYPE=RELEASE \ 76 | -D WITH_FFMPEG=ON \ 77 | -D WITH_V4L=ON \ 78 | #-D CUDA_ARCH_BIN=${GPU_ARCH} \ 79 | #-D CUDA_ARCH_PTX=${GPU_ARCH} \ 80 | #-D WITH_EIGEN=ON \ 81 | #-D EIGEN_INCLUDE_PATH=/usr/include/eigen3 \ 82 | .. && \ 83 | make all -j$(nproc) && \ 84 | make install -------------------------------------------------------------------------------- /input_image/bus.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hotsuyuki/YOLOv5_PyTorch_cpp/15f035f53ce8bcffb8e40c3b67bad7ee9497ae9a/input_image/bus.jpg -------------------------------------------------------------------------------- /input_image/zidane.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hotsuyuki/YOLOv5_PyTorch_cpp/15f035f53ce8bcffb8e40c3b67bad7ee9497ae9a/input_image/zidane.jpg -------------------------------------------------------------------------------- /model/yolov5s_torchscript_B1-C3-H640-W640_torch1-6-0_cuda10-2_cpu.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hotsuyuki/YOLOv5_PyTorch_cpp/15f035f53ce8bcffb8e40c3b67bad7ee9497ae9a/model/yolov5s_torchscript_B1-C3-H640-W640_torch1-6-0_cuda10-2_cpu.pt -------------------------------------------------------------------------------- /model/yolov5s_torchscript_B1-C3-H640-W640_torch1-6-0_cuda10-2_gpu.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hotsuyuki/YOLOv5_PyTorch_cpp/15f035f53ce8bcffb8e40c3b67bad7ee9497ae9a/model/yolov5s_torchscript_B1-C3-H640-W640_torch1-6-0_cuda10-2_gpu.pt -------------------------------------------------------------------------------- /result_image/result_image_YOLOv5/YOLOv5_segfault_while_loding_TorchScript.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hotsuyuki/YOLOv5_PyTorch_cpp/15f035f53ce8bcffb8e40c3b67bad7ee9497ae9a/result_image/result_image_YOLOv5/YOLOv5_segfault_while_loding_TorchScript.png -------------------------------------------------------------------------------- /result_image/result_image_YOLOv5/result_bus.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hotsuyuki/YOLOv5_PyTorch_cpp/15f035f53ce8bcffb8e40c3b67bad7ee9497ae9a/result_image/result_image_YOLOv5/result_bus.jpg -------------------------------------------------------------------------------- /result_image/result_image_YOLOv5/result_zidane.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hotsuyuki/YOLOv5_PyTorch_cpp/15f035f53ce8bcffb8e40c3b67bad7ee9497ae9a/result_image/result_image_YOLOv5/result_zidane.jpg -------------------------------------------------------------------------------- /src/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.12) 2 | 3 | project(YOLOv5_PyTorch_cpp 4 | VERSION 1.0.0 5 | DESCRIPTION "Ultralytics LLC's YOLOv5 with PyTorch c++ API" 6 | LANGUAGES CXX 7 | ) 8 | 9 | set(Torch_DIR ${CMAKE_SOURCE_DIR}/libtorch_v1-6-0/share/cmake/Torch/) 10 | find_package(Torch PATHS ${Torch_DIR} REQUIRED) 11 | 12 | add_subdirectory(${CMAKE_SOURCE_DIR}/object_detector/) 13 | 14 | add_executable(main ${CMAKE_SOURCE_DIR}/main.cpp) 15 | target_compile_features(main PRIVATE cxx_std_17) 16 | 17 | target_include_directories(main PRIVATE 18 | ${CMAKE_SOURCE_DIR}/cxxopts/ 19 | ) 20 | 21 | target_link_libraries(main 22 | object_detector 23 | ) -------------------------------------------------------------------------------- /src/cxxopts/cxxopts.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | 3 | Copyright (c) 2014, 2015, 2016, 2017 Jarryd Beck 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in 13 | all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 | THE SOFTWARE. 22 | 23 | */ 24 | 25 | #ifndef CXXOPTS_HPP_INCLUDED 26 | #define CXXOPTS_HPP_INCLUDED 27 | 28 | #include 29 | #include 30 | #include 31 | #include 32 | #include 33 | #include 34 | #include 35 | #include 36 | #include 37 | #include 38 | #include 39 | #include 40 | 41 | #ifdef __cpp_lib_optional 42 | #include 43 | #define CXXOPTS_HAS_OPTIONAL 44 | #endif 45 | 46 | #define CXXOPTS__VERSION_MAJOR 2 47 | #define CXXOPTS__VERSION_MINOR 2 48 | #define CXXOPTS__VERSION_PATCH 0 49 | 50 | namespace cxxopts 51 | { 52 | static constexpr struct { 53 | uint8_t major, minor, patch; 54 | } version = { 55 | CXXOPTS__VERSION_MAJOR, 56 | CXXOPTS__VERSION_MINOR, 57 | CXXOPTS__VERSION_PATCH 58 | }; 59 | } 60 | 61 | //when we ask cxxopts to use Unicode, help strings are processed using ICU, 62 | //which results in the correct lengths being computed for strings when they 63 | //are formatted for the help output 64 | //it is necessary to make sure that can be found by the 65 | //compiler, and that icu-uc is linked in to the binary. 66 | 67 | #ifdef CXXOPTS_USE_UNICODE 68 | #include 69 | 70 | namespace cxxopts 71 | { 72 | typedef icu::UnicodeString String; 73 | 74 | inline 75 | String 76 | toLocalString(std::string s) 77 | { 78 | return icu::UnicodeString::fromUTF8(std::move(s)); 79 | } 80 | 81 | class UnicodeStringIterator : public 82 | std::iterator 83 | { 84 | public: 85 | 86 | UnicodeStringIterator(const icu::UnicodeString* string, int32_t pos) 87 | : s(string) 88 | , i(pos) 89 | { 90 | } 91 | 92 | value_type 93 | operator*() const 94 | { 95 | return s->char32At(i); 96 | } 97 | 98 | bool 99 | operator==(const UnicodeStringIterator& rhs) const 100 | { 101 | return s == rhs.s && i == rhs.i; 102 | } 103 | 104 | bool 105 | operator!=(const UnicodeStringIterator& rhs) const 106 | { 107 | return !(*this == rhs); 108 | } 109 | 110 | UnicodeStringIterator& 111 | operator++() 112 | { 113 | ++i; 114 | return *this; 115 | } 116 | 117 | UnicodeStringIterator 118 | operator+(int32_t v) 119 | { 120 | return UnicodeStringIterator(s, i + v); 121 | } 122 | 123 | private: 124 | const icu::UnicodeString* s; 125 | int32_t i; 126 | }; 127 | 128 | inline 129 | String& 130 | stringAppend(String&s, String a) 131 | { 132 | return s.append(std::move(a)); 133 | } 134 | 135 | inline 136 | String& 137 | stringAppend(String& s, int n, UChar32 c) 138 | { 139 | for (int i = 0; i != n; ++i) 140 | { 141 | s.append(c); 142 | } 143 | 144 | return s; 145 | } 146 | 147 | template 148 | String& 149 | stringAppend(String& s, Iterator begin, Iterator end) 150 | { 151 | while (begin != end) 152 | { 153 | s.append(*begin); 154 | ++begin; 155 | } 156 | 157 | return s; 158 | } 159 | 160 | inline 161 | size_t 162 | stringLength(const String& s) 163 | { 164 | return s.length(); 165 | } 166 | 167 | inline 168 | std::string 169 | toUTF8String(const String& s) 170 | { 171 | std::string result; 172 | s.toUTF8String(result); 173 | 174 | return result; 175 | } 176 | 177 | inline 178 | bool 179 | empty(const String& s) 180 | { 181 | return s.isEmpty(); 182 | } 183 | } 184 | 185 | namespace std 186 | { 187 | inline 188 | cxxopts::UnicodeStringIterator 189 | begin(const icu::UnicodeString& s) 190 | { 191 | return cxxopts::UnicodeStringIterator(&s, 0); 192 | } 193 | 194 | inline 195 | cxxopts::UnicodeStringIterator 196 | end(const icu::UnicodeString& s) 197 | { 198 | return cxxopts::UnicodeStringIterator(&s, s.length()); 199 | } 200 | } 201 | 202 | //ifdef CXXOPTS_USE_UNICODE 203 | #else 204 | 205 | namespace cxxopts 206 | { 207 | typedef std::string String; 208 | 209 | template 210 | T 211 | toLocalString(T&& t) 212 | { 213 | return std::forward(t); 214 | } 215 | 216 | inline 217 | size_t 218 | stringLength(const String& s) 219 | { 220 | return s.length(); 221 | } 222 | 223 | inline 224 | String& 225 | stringAppend(String&s, String a) 226 | { 227 | return s.append(std::move(a)); 228 | } 229 | 230 | inline 231 | String& 232 | stringAppend(String& s, size_t n, char c) 233 | { 234 | return s.append(n, c); 235 | } 236 | 237 | template 238 | String& 239 | stringAppend(String& s, Iterator begin, Iterator end) 240 | { 241 | return s.append(begin, end); 242 | } 243 | 244 | template 245 | std::string 246 | toUTF8String(T&& t) 247 | { 248 | return std::forward(t); 249 | } 250 | 251 | inline 252 | bool 253 | empty(const std::string& s) 254 | { 255 | return s.empty(); 256 | } 257 | } 258 | 259 | //ifdef CXXOPTS_USE_UNICODE 260 | #endif 261 | 262 | namespace cxxopts 263 | { 264 | namespace 265 | { 266 | #ifdef _WIN32 267 | const std::string LQUOTE("\'"); 268 | const std::string RQUOTE("\'"); 269 | #else 270 | const std::string LQUOTE("‘"); 271 | const std::string RQUOTE("’"); 272 | #endif 273 | } 274 | 275 | class Value : public std::enable_shared_from_this 276 | { 277 | public: 278 | 279 | virtual ~Value() = default; 280 | 281 | virtual 282 | std::shared_ptr 283 | clone() const = 0; 284 | 285 | virtual void 286 | parse(const std::string& text) const = 0; 287 | 288 | virtual void 289 | parse() const = 0; 290 | 291 | virtual bool 292 | has_default() const = 0; 293 | 294 | virtual bool 295 | is_container() const = 0; 296 | 297 | virtual bool 298 | has_implicit() const = 0; 299 | 300 | virtual std::string 301 | get_default_value() const = 0; 302 | 303 | virtual std::string 304 | get_implicit_value() const = 0; 305 | 306 | virtual std::shared_ptr 307 | default_value(const std::string& value) = 0; 308 | 309 | virtual std::shared_ptr 310 | implicit_value(const std::string& value) = 0; 311 | 312 | virtual bool 313 | is_boolean() const = 0; 314 | }; 315 | 316 | class OptionException : public std::exception 317 | { 318 | public: 319 | OptionException(const std::string& message) 320 | : m_message(message) 321 | { 322 | } 323 | 324 | virtual const char* 325 | what() const noexcept 326 | { 327 | return m_message.c_str(); 328 | } 329 | 330 | private: 331 | std::string m_message; 332 | }; 333 | 334 | class OptionSpecException : public OptionException 335 | { 336 | public: 337 | 338 | OptionSpecException(const std::string& message) 339 | : OptionException(message) 340 | { 341 | } 342 | }; 343 | 344 | class OptionParseException : public OptionException 345 | { 346 | public: 347 | OptionParseException(const std::string& message) 348 | : OptionException(message) 349 | { 350 | } 351 | }; 352 | 353 | class option_exists_error : public OptionSpecException 354 | { 355 | public: 356 | option_exists_error(const std::string& option) 357 | : OptionSpecException(u8"Option " + LQUOTE + option + RQUOTE + u8" already exists") 358 | { 359 | } 360 | }; 361 | 362 | class invalid_option_format_error : public OptionSpecException 363 | { 364 | public: 365 | invalid_option_format_error(const std::string& format) 366 | : OptionSpecException(u8"Invalid option format " + LQUOTE + format + RQUOTE) 367 | { 368 | } 369 | }; 370 | 371 | class option_syntax_exception : public OptionParseException { 372 | public: 373 | option_syntax_exception(const std::string& text) 374 | : OptionParseException(u8"Argument " + LQUOTE + text + RQUOTE + 375 | u8" starts with a - but has incorrect syntax") 376 | { 377 | } 378 | }; 379 | 380 | class option_not_exists_exception : public OptionParseException 381 | { 382 | public: 383 | option_not_exists_exception(const std::string& option) 384 | : OptionParseException(u8"Option " + LQUOTE + option + RQUOTE + u8" does not exist") 385 | { 386 | } 387 | }; 388 | 389 | class missing_argument_exception : public OptionParseException 390 | { 391 | public: 392 | missing_argument_exception(const std::string& option) 393 | : OptionParseException( 394 | u8"Option " + LQUOTE + option + RQUOTE + u8" is missing an argument" 395 | ) 396 | { 397 | } 398 | }; 399 | 400 | class option_requires_argument_exception : public OptionParseException 401 | { 402 | public: 403 | option_requires_argument_exception(const std::string& option) 404 | : OptionParseException( 405 | u8"Option " + LQUOTE + option + RQUOTE + u8" requires an argument" 406 | ) 407 | { 408 | } 409 | }; 410 | 411 | class option_not_has_argument_exception : public OptionParseException 412 | { 413 | public: 414 | option_not_has_argument_exception 415 | ( 416 | const std::string& option, 417 | const std::string& arg 418 | ) 419 | : OptionParseException( 420 | u8"Option " + LQUOTE + option + RQUOTE + 421 | u8" does not take an argument, but argument " + 422 | LQUOTE + arg + RQUOTE + " given" 423 | ) 424 | { 425 | } 426 | }; 427 | 428 | class option_not_present_exception : public OptionParseException 429 | { 430 | public: 431 | option_not_present_exception(const std::string& option) 432 | : OptionParseException(u8"Option " + LQUOTE + option + RQUOTE + u8" not present") 433 | { 434 | } 435 | }; 436 | 437 | class argument_incorrect_type : public OptionParseException 438 | { 439 | public: 440 | argument_incorrect_type 441 | ( 442 | const std::string& arg 443 | ) 444 | : OptionParseException( 445 | u8"Argument " + LQUOTE + arg + RQUOTE + u8" failed to parse" 446 | ) 447 | { 448 | } 449 | }; 450 | 451 | class option_required_exception : public OptionParseException 452 | { 453 | public: 454 | option_required_exception(const std::string& option) 455 | : OptionParseException( 456 | u8"Option " + LQUOTE + option + RQUOTE + u8" is required but not present" 457 | ) 458 | { 459 | } 460 | }; 461 | 462 | namespace values 463 | { 464 | namespace 465 | { 466 | std::basic_regex integer_pattern 467 | ("(-)?(0x)?([0-9a-zA-Z]+)|((0x)?0)"); 468 | std::basic_regex truthy_pattern 469 | ("(t|T)(rue)?"); 470 | std::basic_regex falsy_pattern 471 | ("((f|F)(alse)?)?"); 472 | } 473 | 474 | namespace detail 475 | { 476 | template 477 | struct SignedCheck; 478 | 479 | template 480 | struct SignedCheck 481 | { 482 | template 483 | void 484 | operator()(bool negative, U u, const std::string& text) 485 | { 486 | if (negative) 487 | { 488 | if (u > static_cast(-(std::numeric_limits::min)())) 489 | { 490 | throw argument_incorrect_type(text); 491 | } 492 | } 493 | else 494 | { 495 | if (u > static_cast((std::numeric_limits::max)())) 496 | { 497 | throw argument_incorrect_type(text); 498 | } 499 | } 500 | } 501 | }; 502 | 503 | template 504 | struct SignedCheck 505 | { 506 | template 507 | void 508 | operator()(bool, U, const std::string&) {} 509 | }; 510 | 511 | template 512 | void 513 | check_signed_range(bool negative, U value, const std::string& text) 514 | { 515 | SignedCheck::is_signed>()(negative, value, text); 516 | } 517 | } 518 | 519 | template 520 | R 521 | checked_negate(T&& t, const std::string&, std::true_type) 522 | { 523 | // if we got to here, then `t` is a positive number that fits into 524 | // `R`. So to avoid MSVC C4146, we first cast it to `R`. 525 | // See https://github.com/jarro2783/cxxopts/issues/62 for more details. 526 | return -static_cast(t); 527 | } 528 | 529 | template 530 | T 531 | checked_negate(T&&, const std::string& text, std::false_type) 532 | { 533 | throw argument_incorrect_type(text); 534 | } 535 | 536 | template 537 | void 538 | integer_parser(const std::string& text, T& value) 539 | { 540 | std::smatch match; 541 | std::regex_match(text, match, integer_pattern); 542 | 543 | if (match.length() == 0) 544 | { 545 | throw argument_incorrect_type(text); 546 | } 547 | 548 | if (match.length(4) > 0) 549 | { 550 | value = 0; 551 | return; 552 | } 553 | 554 | using US = typename std::make_unsigned::type; 555 | 556 | constexpr auto umax = (std::numeric_limits::max)(); 557 | constexpr bool is_signed = std::numeric_limits::is_signed; 558 | const bool negative = match.length(1) > 0; 559 | const uint8_t base = match.length(2) > 0 ? 16 : 10; 560 | 561 | auto value_match = match[3]; 562 | 563 | US result = 0; 564 | 565 | for (auto iter = value_match.first; iter != value_match.second; ++iter) 566 | { 567 | US digit = 0; 568 | 569 | if (*iter >= '0' && *iter <= '9') 570 | { 571 | digit = *iter - '0'; 572 | } 573 | else if (base == 16 && *iter >= 'a' && *iter <= 'f') 574 | { 575 | digit = *iter - 'a' + 10; 576 | } 577 | else if (base == 16 && *iter >= 'A' && *iter <= 'F') 578 | { 579 | digit = *iter - 'A' + 10; 580 | } 581 | else 582 | { 583 | throw argument_incorrect_type(text); 584 | } 585 | 586 | if (umax - digit < result * base) 587 | { 588 | throw argument_incorrect_type(text); 589 | } 590 | 591 | result = result * base + digit; 592 | } 593 | 594 | detail::check_signed_range(negative, result, text); 595 | 596 | if (negative) 597 | { 598 | value = checked_negate(result, 599 | text, 600 | std::integral_constant()); 601 | } 602 | else 603 | { 604 | value = result; 605 | } 606 | } 607 | 608 | template 609 | void stringstream_parser(const std::string& text, T& value) 610 | { 611 | std::stringstream in(text); 612 | in >> value; 613 | if (!in) { 614 | throw argument_incorrect_type(text); 615 | } 616 | } 617 | 618 | inline 619 | void 620 | parse_value(const std::string& text, uint8_t& value) 621 | { 622 | integer_parser(text, value); 623 | } 624 | 625 | inline 626 | void 627 | parse_value(const std::string& text, int8_t& value) 628 | { 629 | integer_parser(text, value); 630 | } 631 | 632 | inline 633 | void 634 | parse_value(const std::string& text, uint16_t& value) 635 | { 636 | integer_parser(text, value); 637 | } 638 | 639 | inline 640 | void 641 | parse_value(const std::string& text, int16_t& value) 642 | { 643 | integer_parser(text, value); 644 | } 645 | 646 | inline 647 | void 648 | parse_value(const std::string& text, uint32_t& value) 649 | { 650 | integer_parser(text, value); 651 | } 652 | 653 | inline 654 | void 655 | parse_value(const std::string& text, int32_t& value) 656 | { 657 | integer_parser(text, value); 658 | } 659 | 660 | inline 661 | void 662 | parse_value(const std::string& text, uint64_t& value) 663 | { 664 | integer_parser(text, value); 665 | } 666 | 667 | inline 668 | void 669 | parse_value(const std::string& text, int64_t& value) 670 | { 671 | integer_parser(text, value); 672 | } 673 | 674 | inline 675 | void 676 | parse_value(const std::string& text, bool& value) 677 | { 678 | std::smatch result; 679 | std::regex_match(text, result, truthy_pattern); 680 | 681 | if (!result.empty()) 682 | { 683 | value = true; 684 | return; 685 | } 686 | 687 | std::regex_match(text, result, falsy_pattern); 688 | if (!result.empty()) 689 | { 690 | value = false; 691 | return; 692 | } 693 | 694 | throw argument_incorrect_type(text); 695 | } 696 | 697 | inline 698 | void 699 | parse_value(const std::string& text, std::string& value) 700 | { 701 | value = text; 702 | } 703 | 704 | // The fallback parser. It uses the stringstream parser to parse all types 705 | // that have not been overloaded explicitly. It has to be placed in the 706 | // source code before all other more specialized templates. 707 | template 708 | void 709 | parse_value(const std::string& text, T& value) { 710 | stringstream_parser(text, value); 711 | } 712 | 713 | template 714 | void 715 | parse_value(const std::string& text, std::vector& value) 716 | { 717 | T v; 718 | parse_value(text, v); 719 | value.push_back(v); 720 | } 721 | 722 | #ifdef CXXOPTS_HAS_OPTIONAL 723 | template 724 | void 725 | parse_value(const std::string& text, std::optional& value) 726 | { 727 | T result; 728 | parse_value(text, result); 729 | value = std::move(result); 730 | } 731 | #endif 732 | 733 | template 734 | struct type_is_container 735 | { 736 | static constexpr bool value = false; 737 | }; 738 | 739 | template 740 | struct type_is_container> 741 | { 742 | static constexpr bool value = true; 743 | }; 744 | 745 | template 746 | class abstract_value : public Value 747 | { 748 | using Self = abstract_value; 749 | 750 | public: 751 | abstract_value() 752 | : m_result(std::make_shared()) 753 | , m_store(m_result.get()) 754 | { 755 | } 756 | 757 | abstract_value(T* t) 758 | : m_store(t) 759 | { 760 | } 761 | 762 | virtual ~abstract_value() = default; 763 | 764 | abstract_value(const abstract_value& rhs) 765 | { 766 | if (rhs.m_result) 767 | { 768 | m_result = std::make_shared(); 769 | m_store = m_result.get(); 770 | } 771 | else 772 | { 773 | m_store = rhs.m_store; 774 | } 775 | 776 | m_default = rhs.m_default; 777 | m_implicit = rhs.m_implicit; 778 | m_default_value = rhs.m_default_value; 779 | m_implicit_value = rhs.m_implicit_value; 780 | } 781 | 782 | void 783 | parse(const std::string& text) const 784 | { 785 | parse_value(text, *m_store); 786 | } 787 | 788 | bool 789 | is_container() const 790 | { 791 | return type_is_container::value; 792 | } 793 | 794 | void 795 | parse() const 796 | { 797 | parse_value(m_default_value, *m_store); 798 | } 799 | 800 | bool 801 | has_default() const 802 | { 803 | return m_default; 804 | } 805 | 806 | bool 807 | has_implicit() const 808 | { 809 | return m_implicit; 810 | } 811 | 812 | std::shared_ptr 813 | default_value(const std::string& value) 814 | { 815 | m_default = true; 816 | m_default_value = value; 817 | return shared_from_this(); 818 | } 819 | 820 | std::shared_ptr 821 | implicit_value(const std::string& value) 822 | { 823 | m_implicit = true; 824 | m_implicit_value = value; 825 | return shared_from_this(); 826 | } 827 | 828 | std::string 829 | get_default_value() const 830 | { 831 | return m_default_value; 832 | } 833 | 834 | std::string 835 | get_implicit_value() const 836 | { 837 | return m_implicit_value; 838 | } 839 | 840 | bool 841 | is_boolean() const 842 | { 843 | return std::is_same::value; 844 | } 845 | 846 | const T& 847 | get() const 848 | { 849 | if (m_store == nullptr) 850 | { 851 | return *m_result; 852 | } 853 | else 854 | { 855 | return *m_store; 856 | } 857 | } 858 | 859 | protected: 860 | std::shared_ptr m_result; 861 | T* m_store; 862 | 863 | bool m_default = false; 864 | bool m_implicit = false; 865 | 866 | std::string m_default_value; 867 | std::string m_implicit_value; 868 | }; 869 | 870 | template 871 | class standard_value : public abstract_value 872 | { 873 | public: 874 | using abstract_value::abstract_value; 875 | 876 | std::shared_ptr 877 | clone() const 878 | { 879 | return std::make_shared>(*this); 880 | } 881 | }; 882 | 883 | template <> 884 | class standard_value : public abstract_value 885 | { 886 | public: 887 | ~standard_value() = default; 888 | 889 | standard_value() 890 | { 891 | set_default_and_implicit(); 892 | } 893 | 894 | standard_value(bool* b) 895 | : abstract_value(b) 896 | { 897 | set_default_and_implicit(); 898 | } 899 | 900 | std::shared_ptr 901 | clone() const 902 | { 903 | return std::make_shared>(*this); 904 | } 905 | 906 | private: 907 | 908 | void 909 | set_default_and_implicit() 910 | { 911 | m_default = true; 912 | m_default_value = "false"; 913 | m_implicit = true; 914 | m_implicit_value = "true"; 915 | } 916 | }; 917 | } 918 | 919 | template 920 | std::shared_ptr 921 | value() 922 | { 923 | return std::make_shared>(); 924 | } 925 | 926 | template 927 | std::shared_ptr 928 | value(T& t) 929 | { 930 | return std::make_shared>(&t); 931 | } 932 | 933 | class OptionAdder; 934 | 935 | class OptionDetails 936 | { 937 | public: 938 | OptionDetails 939 | ( 940 | const std::string& short_, 941 | const std::string& long_, 942 | const String& desc, 943 | std::shared_ptr val 944 | ) 945 | : m_short(short_) 946 | , m_long(long_) 947 | , m_desc(desc) 948 | , m_value(val) 949 | , m_count(0) 950 | { 951 | } 952 | 953 | OptionDetails(const OptionDetails& rhs) 954 | : m_desc(rhs.m_desc) 955 | , m_count(rhs.m_count) 956 | { 957 | m_value = rhs.m_value->clone(); 958 | } 959 | 960 | OptionDetails(OptionDetails&& rhs) = default; 961 | 962 | const String& 963 | description() const 964 | { 965 | return m_desc; 966 | } 967 | 968 | const Value& value() const { 969 | return *m_value; 970 | } 971 | 972 | std::shared_ptr 973 | make_storage() const 974 | { 975 | return m_value->clone(); 976 | } 977 | 978 | const std::string& 979 | short_name() const 980 | { 981 | return m_short; 982 | } 983 | 984 | const std::string& 985 | long_name() const 986 | { 987 | return m_long; 988 | } 989 | 990 | private: 991 | std::string m_short; 992 | std::string m_long; 993 | String m_desc; 994 | std::shared_ptr m_value; 995 | int m_count; 996 | }; 997 | 998 | struct HelpOptionDetails 999 | { 1000 | std::string s; 1001 | std::string l; 1002 | String desc; 1003 | bool has_default; 1004 | std::string default_value; 1005 | bool has_implicit; 1006 | std::string implicit_value; 1007 | std::string arg_help; 1008 | bool is_container; 1009 | bool is_boolean; 1010 | }; 1011 | 1012 | struct HelpGroupDetails 1013 | { 1014 | std::string name; 1015 | std::string description; 1016 | std::vector options; 1017 | }; 1018 | 1019 | class OptionValue 1020 | { 1021 | public: 1022 | void 1023 | parse 1024 | ( 1025 | std::shared_ptr details, 1026 | const std::string& text 1027 | ) 1028 | { 1029 | ensure_value(details); 1030 | ++m_count; 1031 | m_value->parse(text); 1032 | } 1033 | 1034 | void 1035 | parse_default(std::shared_ptr details) 1036 | { 1037 | ensure_value(details); 1038 | m_value->parse(); 1039 | } 1040 | 1041 | size_t 1042 | count() const 1043 | { 1044 | return m_count; 1045 | } 1046 | 1047 | template 1048 | const T& 1049 | as() const 1050 | { 1051 | if (m_value == nullptr) { 1052 | throw std::domain_error("No value"); 1053 | } 1054 | 1055 | #ifdef CXXOPTS_NO_RTTI 1056 | return static_cast&>(*m_value).get(); 1057 | #else 1058 | return dynamic_cast&>(*m_value).get(); 1059 | #endif 1060 | } 1061 | 1062 | private: 1063 | void 1064 | ensure_value(std::shared_ptr details) 1065 | { 1066 | if (m_value == nullptr) 1067 | { 1068 | m_value = details->make_storage(); 1069 | } 1070 | } 1071 | 1072 | std::shared_ptr m_value; 1073 | size_t m_count = 0; 1074 | }; 1075 | 1076 | class KeyValue 1077 | { 1078 | public: 1079 | KeyValue(std::string key_, std::string value_) 1080 | : m_key(std::move(key_)) 1081 | , m_value(std::move(value_)) 1082 | { 1083 | } 1084 | 1085 | const 1086 | std::string& 1087 | key() const 1088 | { 1089 | return m_key; 1090 | } 1091 | 1092 | const std::string 1093 | value() const 1094 | { 1095 | return m_value; 1096 | } 1097 | 1098 | template 1099 | T 1100 | as() const 1101 | { 1102 | T result; 1103 | values::parse_value(m_value, result); 1104 | return result; 1105 | } 1106 | 1107 | private: 1108 | std::string m_key; 1109 | std::string m_value; 1110 | }; 1111 | 1112 | class ParseResult 1113 | { 1114 | public: 1115 | 1116 | ParseResult( 1117 | const std::shared_ptr< 1118 | std::unordered_map> 1119 | >, 1120 | std::vector, 1121 | bool allow_unrecognised, 1122 | int&, char**&); 1123 | 1124 | size_t 1125 | count(const std::string& o) const 1126 | { 1127 | auto iter = m_options->find(o); 1128 | if (iter == m_options->end()) 1129 | { 1130 | return 0; 1131 | } 1132 | 1133 | auto riter = m_results.find(iter->second); 1134 | 1135 | return riter->second.count(); 1136 | } 1137 | 1138 | const OptionValue& 1139 | operator[](const std::string& option) const 1140 | { 1141 | auto iter = m_options->find(option); 1142 | 1143 | if (iter == m_options->end()) 1144 | { 1145 | throw option_not_present_exception(option); 1146 | } 1147 | 1148 | auto riter = m_results.find(iter->second); 1149 | 1150 | return riter->second; 1151 | } 1152 | 1153 | const std::vector& 1154 | arguments() const 1155 | { 1156 | return m_sequential; 1157 | } 1158 | 1159 | private: 1160 | 1161 | void 1162 | parse(int& argc, char**& argv); 1163 | 1164 | void 1165 | add_to_option(const std::string& option, const std::string& arg); 1166 | 1167 | bool 1168 | consume_positional(std::string a); 1169 | 1170 | void 1171 | parse_option 1172 | ( 1173 | std::shared_ptr value, 1174 | const std::string& name, 1175 | const std::string& arg = "" 1176 | ); 1177 | 1178 | void 1179 | parse_default(std::shared_ptr details); 1180 | 1181 | void 1182 | checked_parse_arg 1183 | ( 1184 | int argc, 1185 | char* argv[], 1186 | int& current, 1187 | std::shared_ptr value, 1188 | const std::string& name 1189 | ); 1190 | 1191 | const std::shared_ptr< 1192 | std::unordered_map> 1193 | > m_options; 1194 | std::vector m_positional; 1195 | std::vector::iterator m_next_positional; 1196 | std::unordered_set m_positional_set; 1197 | std::unordered_map, OptionValue> m_results; 1198 | 1199 | bool m_allow_unrecognised; 1200 | 1201 | std::vector m_sequential; 1202 | }; 1203 | 1204 | class Options 1205 | { 1206 | typedef std::unordered_map> 1207 | OptionMap; 1208 | public: 1209 | 1210 | Options(std::string program, std::string help_string = "") 1211 | : m_program(std::move(program)) 1212 | , m_help_string(toLocalString(std::move(help_string))) 1213 | , m_custom_help("[OPTION...]") 1214 | , m_allow_unrecognised(false) 1215 | , m_options(std::make_shared()) 1216 | , m_next_positional(m_positional.end()) 1217 | { 1218 | } 1219 | 1220 | Options& 1221 | custom_help(std::string help_text) 1222 | { 1223 | m_custom_help = std::move(help_text); 1224 | return *this; 1225 | } 1226 | 1227 | Options& 1228 | allow_unrecognised_options() 1229 | { 1230 | m_allow_unrecognised = true; 1231 | return *this; 1232 | } 1233 | 1234 | ParseResult 1235 | parse(int& argc, char**& argv); 1236 | 1237 | OptionAdder 1238 | add_options(std::string group = ""); 1239 | 1240 | void 1241 | add_option 1242 | ( 1243 | const std::string& group, 1244 | const std::string& s, 1245 | const std::string& l, 1246 | std::string desc, 1247 | std::shared_ptr value, 1248 | std::string arg_help 1249 | ); 1250 | 1251 | //parse positional arguments into the given option 1252 | void 1253 | parse_positional(std::string option); 1254 | 1255 | void 1256 | parse_positional(std::vector options); 1257 | 1258 | void 1259 | parse_positional(std::initializer_list options); 1260 | 1261 | template 1262 | void 1263 | parse_positional(Iterator begin, Iterator end) { 1264 | parse_positional(std::vector{begin, end}); 1265 | } 1266 | 1267 | std::string 1268 | usage() const; 1269 | 1270 | std::string 1271 | help(const std::vector& groups = {}) const; 1272 | 1273 | const std::vector 1274 | groups() const; 1275 | 1276 | const HelpGroupDetails& 1277 | group_help(const std::string& group) const; 1278 | 1279 | private: 1280 | 1281 | void 1282 | add_one_option 1283 | ( 1284 | const std::string& option, 1285 | std::shared_ptr details 1286 | ); 1287 | 1288 | String 1289 | help_one_group(const std::string& group) const; 1290 | 1291 | void 1292 | generate_group_help 1293 | ( 1294 | String& result, 1295 | const std::vector& groups 1296 | ) const; 1297 | 1298 | void 1299 | generate_all_groups_help(String& result) const; 1300 | 1301 | std::string m_program; 1302 | String m_help_string; 1303 | std::string m_custom_help; 1304 | bool m_allow_unrecognised; 1305 | 1306 | std::shared_ptr m_options; 1307 | std::vector m_positional; 1308 | std::vector::iterator m_next_positional; 1309 | std::unordered_set m_positional_set; 1310 | 1311 | //mapping from groups to help options 1312 | std::map m_help; 1313 | }; 1314 | 1315 | class OptionAdder 1316 | { 1317 | public: 1318 | 1319 | OptionAdder(Options& options, std::string group) 1320 | : m_options(options), m_group(std::move(group)) 1321 | { 1322 | } 1323 | 1324 | OptionAdder& 1325 | operator() 1326 | ( 1327 | const std::string& opts, 1328 | const std::string& desc, 1329 | std::shared_ptr value 1330 | = ::cxxopts::value(), 1331 | std::string arg_help = "" 1332 | ); 1333 | 1334 | private: 1335 | Options& m_options; 1336 | std::string m_group; 1337 | }; 1338 | 1339 | namespace 1340 | { 1341 | constexpr int OPTION_LONGEST = 30; 1342 | constexpr int OPTION_DESC_GAP = 2; 1343 | 1344 | std::basic_regex option_matcher 1345 | ("--([[:alnum:]][-_[:alnum:]]+)(=(.*))?|-([[:alnum:]]+)"); 1346 | 1347 | std::basic_regex option_specifier 1348 | ("(([[:alnum:]]),)?[ ]*([[:alnum:]][-_[:alnum:]]*)?"); 1349 | 1350 | String 1351 | format_option 1352 | ( 1353 | const HelpOptionDetails& o 1354 | ) 1355 | { 1356 | auto& s = o.s; 1357 | auto& l = o.l; 1358 | 1359 | String result = " "; 1360 | 1361 | if (s.size() > 0) 1362 | { 1363 | result += "-" + toLocalString(s) + ","; 1364 | } 1365 | else 1366 | { 1367 | result += " "; 1368 | } 1369 | 1370 | if (l.size() > 0) 1371 | { 1372 | result += " --" + toLocalString(l); 1373 | } 1374 | 1375 | auto arg = o.arg_help.size() > 0 ? toLocalString(o.arg_help) : "arg"; 1376 | 1377 | if (!o.is_boolean) 1378 | { 1379 | if (o.has_implicit) 1380 | { 1381 | result += " [=" + arg + "(=" + toLocalString(o.implicit_value) + ")]"; 1382 | } 1383 | else 1384 | { 1385 | result += " " + arg; 1386 | } 1387 | } 1388 | 1389 | return result; 1390 | } 1391 | 1392 | String 1393 | format_description 1394 | ( 1395 | const HelpOptionDetails& o, 1396 | size_t start, 1397 | size_t width 1398 | ) 1399 | { 1400 | auto desc = o.desc; 1401 | 1402 | if (o.has_default && (!o.is_boolean || o.default_value != "false")) 1403 | { 1404 | desc += toLocalString(" (default: " + o.default_value + ")"); 1405 | } 1406 | 1407 | String result; 1408 | 1409 | auto current = std::begin(desc); 1410 | auto startLine = current; 1411 | auto lastSpace = current; 1412 | 1413 | auto size = size_t{}; 1414 | 1415 | while (current != std::end(desc)) 1416 | { 1417 | if (*current == ' ') 1418 | { 1419 | lastSpace = current; 1420 | } 1421 | 1422 | if (*current == '\n') 1423 | { 1424 | startLine = current + 1; 1425 | lastSpace = startLine; 1426 | } 1427 | else if (size > width) 1428 | { 1429 | if (lastSpace == startLine) 1430 | { 1431 | stringAppend(result, startLine, current + 1); 1432 | stringAppend(result, "\n"); 1433 | stringAppend(result, start, ' '); 1434 | startLine = current + 1; 1435 | lastSpace = startLine; 1436 | } 1437 | else 1438 | { 1439 | stringAppend(result, startLine, lastSpace); 1440 | stringAppend(result, "\n"); 1441 | stringAppend(result, start, ' '); 1442 | startLine = lastSpace + 1; 1443 | } 1444 | size = 0; 1445 | } 1446 | else 1447 | { 1448 | ++size; 1449 | } 1450 | 1451 | ++current; 1452 | } 1453 | 1454 | //append whatever is left 1455 | stringAppend(result, startLine, current); 1456 | 1457 | return result; 1458 | } 1459 | } 1460 | 1461 | inline 1462 | ParseResult::ParseResult 1463 | ( 1464 | const std::shared_ptr< 1465 | std::unordered_map> 1466 | > options, 1467 | std::vector positional, 1468 | bool allow_unrecognised, 1469 | int& argc, char**& argv 1470 | ) 1471 | : m_options(options) 1472 | , m_positional(std::move(positional)) 1473 | , m_next_positional(m_positional.begin()) 1474 | , m_allow_unrecognised(allow_unrecognised) 1475 | { 1476 | parse(argc, argv); 1477 | } 1478 | 1479 | inline 1480 | OptionAdder 1481 | Options::add_options(std::string group) 1482 | { 1483 | return OptionAdder(*this, std::move(group)); 1484 | } 1485 | 1486 | inline 1487 | OptionAdder& 1488 | OptionAdder::operator() 1489 | ( 1490 | const std::string& opts, 1491 | const std::string& desc, 1492 | std::shared_ptr value, 1493 | std::string arg_help 1494 | ) 1495 | { 1496 | std::match_results result; 1497 | std::regex_match(opts.c_str(), result, option_specifier); 1498 | 1499 | if (result.empty()) 1500 | { 1501 | throw invalid_option_format_error(opts); 1502 | } 1503 | 1504 | const auto& short_match = result[2]; 1505 | const auto& long_match = result[3]; 1506 | 1507 | if (!short_match.length() && !long_match.length()) 1508 | { 1509 | throw invalid_option_format_error(opts); 1510 | } else if (long_match.length() == 1 && short_match.length()) 1511 | { 1512 | throw invalid_option_format_error(opts); 1513 | } 1514 | 1515 | auto option_names = [] 1516 | ( 1517 | const std::sub_match& short_, 1518 | const std::sub_match& long_ 1519 | ) 1520 | { 1521 | if (long_.length() == 1) 1522 | { 1523 | return std::make_tuple(long_.str(), short_.str()); 1524 | } 1525 | else 1526 | { 1527 | return std::make_tuple(short_.str(), long_.str()); 1528 | } 1529 | }(short_match, long_match); 1530 | 1531 | m_options.add_option 1532 | ( 1533 | m_group, 1534 | std::get<0>(option_names), 1535 | std::get<1>(option_names), 1536 | desc, 1537 | value, 1538 | std::move(arg_help) 1539 | ); 1540 | 1541 | return *this; 1542 | } 1543 | 1544 | inline 1545 | void 1546 | ParseResult::parse_default(std::shared_ptr details) 1547 | { 1548 | m_results[details].parse_default(details); 1549 | } 1550 | 1551 | inline 1552 | void 1553 | ParseResult::parse_option 1554 | ( 1555 | std::shared_ptr value, 1556 | const std::string& /*name*/, 1557 | const std::string& arg 1558 | ) 1559 | { 1560 | auto& result = m_results[value]; 1561 | result.parse(value, arg); 1562 | 1563 | m_sequential.emplace_back(value->long_name(), arg); 1564 | } 1565 | 1566 | inline 1567 | void 1568 | ParseResult::checked_parse_arg 1569 | ( 1570 | int argc, 1571 | char* argv[], 1572 | int& current, 1573 | std::shared_ptr value, 1574 | const std::string& name 1575 | ) 1576 | { 1577 | if (current + 1 >= argc) 1578 | { 1579 | if (value->value().has_implicit()) 1580 | { 1581 | parse_option(value, name, value->value().get_implicit_value()); 1582 | } 1583 | else 1584 | { 1585 | throw missing_argument_exception(name); 1586 | } 1587 | } 1588 | else 1589 | { 1590 | if (value->value().has_implicit()) 1591 | { 1592 | parse_option(value, name, value->value().get_implicit_value()); 1593 | } 1594 | else 1595 | { 1596 | parse_option(value, name, argv[current + 1]); 1597 | ++current; 1598 | } 1599 | } 1600 | } 1601 | 1602 | inline 1603 | void 1604 | ParseResult::add_to_option(const std::string& option, const std::string& arg) 1605 | { 1606 | auto iter = m_options->find(option); 1607 | 1608 | if (iter == m_options->end()) 1609 | { 1610 | throw option_not_exists_exception(option); 1611 | } 1612 | 1613 | parse_option(iter->second, option, arg); 1614 | } 1615 | 1616 | inline 1617 | bool 1618 | ParseResult::consume_positional(std::string a) 1619 | { 1620 | while (m_next_positional != m_positional.end()) 1621 | { 1622 | auto iter = m_options->find(*m_next_positional); 1623 | if (iter != m_options->end()) 1624 | { 1625 | auto& result = m_results[iter->second]; 1626 | if (!iter->second->value().is_container()) 1627 | { 1628 | if (result.count() == 0) 1629 | { 1630 | add_to_option(*m_next_positional, a); 1631 | ++m_next_positional; 1632 | return true; 1633 | } 1634 | else 1635 | { 1636 | ++m_next_positional; 1637 | continue; 1638 | } 1639 | } 1640 | else 1641 | { 1642 | add_to_option(*m_next_positional, a); 1643 | return true; 1644 | } 1645 | } 1646 | ++m_next_positional; 1647 | } 1648 | 1649 | return false; 1650 | } 1651 | 1652 | inline 1653 | void 1654 | Options::parse_positional(std::string option) 1655 | { 1656 | parse_positional(std::vector{std::move(option)}); 1657 | } 1658 | 1659 | inline 1660 | void 1661 | Options::parse_positional(std::vector options) 1662 | { 1663 | m_positional = std::move(options); 1664 | m_next_positional = m_positional.begin(); 1665 | 1666 | m_positional_set.insert(m_positional.begin(), m_positional.end()); 1667 | } 1668 | 1669 | inline 1670 | void 1671 | Options::parse_positional(std::initializer_list options) 1672 | { 1673 | parse_positional(std::vector(std::move(options))); 1674 | } 1675 | 1676 | inline 1677 | ParseResult 1678 | Options::parse(int& argc, char**& argv) 1679 | { 1680 | if (m_options->find("help") == m_options->end()) 1681 | { 1682 | add_options()("h,help", ""); 1683 | } 1684 | ParseResult result(m_options, m_positional, m_allow_unrecognised, argc, argv); 1685 | if (result["help"].count() == 0) 1686 | { 1687 | for (const auto& arg : m_positional) 1688 | { 1689 | if (result[arg].count() == 0) 1690 | throw cxxopts::option_required_exception(arg); 1691 | } 1692 | } 1693 | return result; 1694 | } 1695 | 1696 | inline 1697 | void 1698 | ParseResult::parse(int& argc, char**& argv) 1699 | { 1700 | int current = 1; 1701 | 1702 | int nextKeep = 1; 1703 | 1704 | bool consume_remaining = false; 1705 | 1706 | while (current != argc) 1707 | { 1708 | if (strcmp(argv[current], "--") == 0) 1709 | { 1710 | consume_remaining = true; 1711 | ++current; 1712 | break; 1713 | } 1714 | 1715 | std::match_results result; 1716 | std::regex_match(argv[current], result, option_matcher); 1717 | 1718 | if (result.empty()) 1719 | { 1720 | //not a flag 1721 | 1722 | // but if it starts with a `-`, then it's an error 1723 | if (argv[current][0] == '-' && argv[current][1] != '\0') { 1724 | throw option_syntax_exception(argv[current]); 1725 | } 1726 | 1727 | //if true is returned here then it was consumed, otherwise it is 1728 | //ignored 1729 | if (consume_positional(argv[current])) 1730 | { 1731 | } 1732 | else 1733 | { 1734 | argv[nextKeep] = argv[current]; 1735 | ++nextKeep; 1736 | } 1737 | //if we return from here then it was parsed successfully, so continue 1738 | } 1739 | else 1740 | { 1741 | //short or long option? 1742 | if (result[4].length() != 0) 1743 | { 1744 | const std::string& s = result[4]; 1745 | 1746 | for (std::size_t i = 0; i != s.size(); ++i) 1747 | { 1748 | std::string name(1, s[i]); 1749 | auto iter = m_options->find(name); 1750 | 1751 | if (iter == m_options->end()) 1752 | { 1753 | if (m_allow_unrecognised) 1754 | { 1755 | continue; 1756 | } 1757 | else 1758 | { 1759 | //error 1760 | throw option_not_exists_exception(name); 1761 | } 1762 | } 1763 | 1764 | auto value = iter->second; 1765 | 1766 | if (i + 1 == s.size()) 1767 | { 1768 | //it must be the last argument 1769 | checked_parse_arg(argc, argv, current, value, name); 1770 | } 1771 | else if (value->value().has_implicit()) 1772 | { 1773 | parse_option(value, name, value->value().get_implicit_value()); 1774 | } 1775 | else 1776 | { 1777 | //error 1778 | throw option_requires_argument_exception(name); 1779 | } 1780 | } 1781 | } 1782 | else if (result[1].length() != 0) 1783 | { 1784 | const std::string& name = result[1]; 1785 | 1786 | auto iter = m_options->find(name); 1787 | 1788 | if (iter == m_options->end()) 1789 | { 1790 | if (m_allow_unrecognised) 1791 | { 1792 | // keep unrecognised options in argument list, skip to next argument 1793 | argv[nextKeep] = argv[current]; 1794 | ++nextKeep; 1795 | ++current; 1796 | continue; 1797 | } 1798 | else 1799 | { 1800 | //error 1801 | throw option_not_exists_exception(name); 1802 | } 1803 | } 1804 | 1805 | auto opt = iter->second; 1806 | 1807 | //equals provided for long option? 1808 | if (result[2].length() != 0) 1809 | { 1810 | //parse the option given 1811 | 1812 | parse_option(opt, name, result[3]); 1813 | } 1814 | else 1815 | { 1816 | //parse the next argument 1817 | checked_parse_arg(argc, argv, current, opt, name); 1818 | } 1819 | } 1820 | 1821 | } 1822 | 1823 | ++current; 1824 | } 1825 | 1826 | for (auto& opt : *m_options) 1827 | { 1828 | auto& detail = opt.second; 1829 | auto& value = detail->value(); 1830 | 1831 | auto& store = m_results[detail]; 1832 | 1833 | if(!store.count() && value.has_default()){ 1834 | parse_default(detail); 1835 | } 1836 | } 1837 | 1838 | if (consume_remaining) 1839 | { 1840 | while (current < argc) 1841 | { 1842 | if (!consume_positional(argv[current])) { 1843 | break; 1844 | } 1845 | ++current; 1846 | } 1847 | 1848 | //adjust argv for any that couldn't be swallowed 1849 | while (current != argc) { 1850 | argv[nextKeep] = argv[current]; 1851 | ++nextKeep; 1852 | ++current; 1853 | } 1854 | } 1855 | 1856 | argc = nextKeep; 1857 | 1858 | } 1859 | 1860 | inline 1861 | void 1862 | Options::add_option 1863 | ( 1864 | const std::string& group, 1865 | const std::string& s, 1866 | const std::string& l, 1867 | std::string desc, 1868 | std::shared_ptr value, 1869 | std::string arg_help 1870 | ) 1871 | { 1872 | auto stringDesc = toLocalString(std::move(desc)); 1873 | auto option = std::make_shared(s, l, stringDesc, value); 1874 | 1875 | if (s.size() > 0) 1876 | { 1877 | add_one_option(s, option); 1878 | } 1879 | 1880 | if (l.size() > 0) 1881 | { 1882 | add_one_option(l, option); 1883 | } 1884 | 1885 | //add the help details 1886 | auto& options = m_help[group]; 1887 | 1888 | options.options.emplace_back(HelpOptionDetails{s, l, stringDesc, 1889 | value->has_default(), value->get_default_value(), 1890 | value->has_implicit(), value->get_implicit_value(), 1891 | std::move(arg_help), 1892 | value->is_container(), 1893 | value->is_boolean()}); 1894 | } 1895 | 1896 | inline 1897 | void 1898 | Options::add_one_option 1899 | ( 1900 | const std::string& option, 1901 | std::shared_ptr details 1902 | ) 1903 | { 1904 | auto in = m_options->emplace(option, details); 1905 | 1906 | if (!in.second) 1907 | { 1908 | throw option_exists_error(option); 1909 | } 1910 | } 1911 | 1912 | inline 1913 | String 1914 | Options::help_one_group(const std::string& g) const 1915 | { 1916 | typedef std::vector> OptionHelp; 1917 | 1918 | auto group = m_help.find(g); 1919 | if (group == m_help.end()) 1920 | { 1921 | return ""; 1922 | } 1923 | 1924 | OptionHelp format; 1925 | 1926 | size_t longest = 0; 1927 | 1928 | String result; 1929 | 1930 | if (!g.empty()) 1931 | { 1932 | result += toLocalString("\n" + g + " options:\n"); 1933 | } 1934 | else 1935 | { 1936 | result += toLocalString("\noptions:\n"); 1937 | } 1938 | 1939 | for (const auto& o : group->second.options) 1940 | { 1941 | if (m_positional_set.find(o.l) != m_positional_set.end()) 1942 | { 1943 | continue; 1944 | } 1945 | 1946 | auto s = format_option(o); 1947 | longest = (std::max)(longest, stringLength(s)); 1948 | format.push_back(std::make_pair(s, String())); 1949 | } 1950 | 1951 | longest = (std::min)(longest, static_cast(OPTION_LONGEST)); 1952 | 1953 | //widest allowed description 1954 | auto allowed = size_t{76} - longest - OPTION_DESC_GAP; 1955 | 1956 | auto fiter = format.begin(); 1957 | for (const auto& o : group->second.options) 1958 | { 1959 | if (m_positional_set.find(o.l) != m_positional_set.end()) 1960 | { 1961 | continue; 1962 | } 1963 | 1964 | auto d = format_description(o, longest + OPTION_DESC_GAP, allowed); 1965 | 1966 | result += fiter->first; 1967 | if (stringLength(fiter->first) > longest) 1968 | { 1969 | result += '\n'; 1970 | result += toLocalString(std::string(longest + OPTION_DESC_GAP, ' ')); 1971 | } 1972 | else 1973 | { 1974 | result += toLocalString(std::string(longest + OPTION_DESC_GAP - 1975 | stringLength(fiter->first), 1976 | ' ')); 1977 | } 1978 | result += d; 1979 | result += '\n'; 1980 | 1981 | ++fiter; 1982 | } 1983 | 1984 | return result; 1985 | } 1986 | 1987 | inline 1988 | void 1989 | Options::generate_group_help 1990 | ( 1991 | String& result, 1992 | const std::vector& print_groups 1993 | ) const 1994 | { 1995 | for (size_t i = 0; i != print_groups.size(); ++i) 1996 | { 1997 | const String& group_help_text = help_one_group(print_groups[i]); 1998 | if (empty(group_help_text)) 1999 | { 2000 | continue; 2001 | } 2002 | result += group_help_text; 2003 | } 2004 | } 2005 | 2006 | inline 2007 | void 2008 | Options::generate_all_groups_help(String& result) const 2009 | { 2010 | std::vector all_groups; 2011 | all_groups.reserve(m_help.size()); 2012 | 2013 | for (auto& group : m_help) 2014 | { 2015 | all_groups.push_back(group.first); 2016 | } 2017 | 2018 | generate_group_help(result, all_groups); 2019 | } 2020 | 2021 | inline 2022 | std::string 2023 | Options::usage() const 2024 | { 2025 | String result; 2026 | if (m_help_string != "") 2027 | { 2028 | result += m_help_string + "\n"; 2029 | } 2030 | result += "usage:\n " + 2031 | toLocalString(m_program) + " " + toLocalString(m_custom_help); 2032 | 2033 | for (const auto& arg : m_positional) 2034 | { 2035 | result += " " + arg; 2036 | } 2037 | return result; 2038 | } 2039 | 2040 | inline 2041 | std::string 2042 | Options::help(const std::vector& help_groups) const 2043 | { 2044 | String result = usage(); 2045 | 2046 | result += "\n"; 2047 | 2048 | if (m_positional.size() > 0) 2049 | { 2050 | typedef std::vector> OptionHelp; 2051 | auto group = m_help.find(""); 2052 | OptionHelp format; 2053 | 2054 | size_t longest = 0; 2055 | 2056 | result += toLocalString("\npositional arguments:\n"); 2057 | 2058 | for (const auto& o : group->second.options) 2059 | { 2060 | if (m_positional_set.find(o.l) == m_positional_set.end()) 2061 | { 2062 | continue; 2063 | } 2064 | 2065 | auto s = " " + toLocalString(o.l); 2066 | longest = (std::max)(longest, stringLength(s)); 2067 | format.push_back(std::make_pair(s, String())); 2068 | } 2069 | 2070 | longest = (std::min)(longest, static_cast(OPTION_LONGEST)); 2071 | 2072 | //widest allowed description 2073 | auto allowed = size_t{ 76 } -longest - OPTION_DESC_GAP; 2074 | 2075 | auto fiter = format.begin(); 2076 | for (const auto& o : group->second.options) 2077 | { 2078 | if (m_positional_set.find(o.l) == m_positional_set.end()) 2079 | { 2080 | continue; 2081 | } 2082 | 2083 | auto d = format_description(o, longest + OPTION_DESC_GAP, allowed); 2084 | 2085 | result += fiter->first; 2086 | if (stringLength(fiter->first) > longest) 2087 | { 2088 | result += '\n'; 2089 | result += toLocalString(std::string(longest + OPTION_DESC_GAP, ' ')); 2090 | } 2091 | else 2092 | { 2093 | result += toLocalString(std::string(longest + OPTION_DESC_GAP - 2094 | stringLength(fiter->first), 2095 | ' ')); 2096 | } 2097 | result += d; 2098 | result += '\n'; 2099 | 2100 | ++fiter; 2101 | } 2102 | } 2103 | 2104 | if (help_groups.size() == 0) 2105 | { 2106 | generate_all_groups_help(result); 2107 | } 2108 | else 2109 | { 2110 | generate_group_help(result, help_groups); 2111 | } 2112 | 2113 | result.pop_back(); 2114 | return toUTF8String(result); 2115 | } 2116 | 2117 | inline 2118 | const std::vector 2119 | Options::groups() const 2120 | { 2121 | std::vector g; 2122 | 2123 | std::transform( 2124 | m_help.begin(), 2125 | m_help.end(), 2126 | std::back_inserter(g), 2127 | [] (const std::map::value_type& pair) 2128 | { 2129 | return pair.first; 2130 | } 2131 | ); 2132 | 2133 | return g; 2134 | } 2135 | 2136 | inline 2137 | const HelpGroupDetails& 2138 | Options::group_help(const std::string& group) const 2139 | { 2140 | return m_help.at(group); 2141 | } 2142 | 2143 | } 2144 | 2145 | #endif //CXXOPTS_HPP_INCLUDED 2146 | -------------------------------------------------------------------------------- /src/main.cpp: -------------------------------------------------------------------------------- 1 | // This code follows Google C++ Style Guide 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | #include "cxxopts.hpp" 8 | #include "object_detector.h" 9 | 10 | 11 | int main(int argc, char* argv[]) { 12 | // Decomposes argv[0] into directory path and file name 13 | std::string argv0(argv[0]); 14 | std::size_t last_slash_pos = argv0.find_last_of('/'); 15 | std::string executable_directory = argv0.substr(0, last_slash_pos + 1); 16 | std::string executable_filename = argv0.substr(last_slash_pos + 1); 17 | 18 | std::string help_string = "Detects cars and pedstrians from an image by YOLOv5"; 19 | cxxopts::Options option_parser(executable_filename, help_string); 20 | 21 | std::string input_directory; 22 | std::string model_filename; 23 | std::string device_option; 24 | float confidence_threshold; 25 | float iou_threshold; 26 | 27 | // Parses program options 28 | // https://tadaoyamaoka.hatenablog.com/entry/2019/01/30/235251 29 | try { 30 | option_parser.add_options() 31 | ("input-dir", "String: Path to input images directory", cxxopts::value()) 32 | ("model-file", "String: Path to TorchScript model file", cxxopts::value()) 33 | ("conf-thres", "Float: Object confidence threshold", cxxopts::value()->default_value("0.25")) 34 | ("iou-thres", "Float: IoU threshold for NMS", cxxopts::value()->default_value("0.45")) 35 | ("h,help", "Print usage"); 36 | 37 | option_parser.parse_positional({"input-dir", "model-file"}); 38 | auto options = option_parser.parse(argc, argv); 39 | 40 | if (options.count("help")) { 41 | std::cout << option_parser.help({}) << std::endl; 42 | return EXIT_SUCCESS; 43 | } 44 | 45 | input_directory = options["input-dir"].as(); 46 | model_filename = options["model-file"].as(); 47 | confidence_threshold = options["conf-thres"].as(); 48 | iou_threshold = options["iou-thres"].as(); 49 | 50 | std::cout << "\n"; 51 | std::cout << "input_directory = " << input_directory << "\n"; 52 | std::cout << "model_filename = " << model_filename << "\n"; 53 | std::cout << "confidence_threshold = " << confidence_threshold << "\n"; 54 | std::cout << "iou_threshold = " << iou_threshold << "\n"; 55 | std::cout << "\n"; 56 | } 57 | catch (cxxopts::OptionException &e) { 58 | std::cout << option_parser.usage() << "\n"; 59 | std::cerr << e.what() << "\n"; 60 | std::exit(EXIT_FAILURE); 61 | } 62 | 63 | yolov5::ObjectDetector object_detector(model_filename); 64 | 65 | std::string class_name_filename = executable_directory + "../../coco.names"; 66 | if (!object_detector.LoadClassNames(class_name_filename)) { 67 | return EXIT_FAILURE; 68 | } 69 | 70 | if (!object_detector.LoadInputImagePaths(input_directory)) { 71 | return EXIT_FAILURE; 72 | } 73 | 74 | object_detector.Inference(confidence_threshold, iou_threshold); 75 | 76 | return EXIT_SUCCESS; 77 | } 78 | -------------------------------------------------------------------------------- /src/object_detector/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.12) 2 | 3 | project(yolov5-object-detector-lib 4 | VERSION 1.0.0 5 | DESCRIPTION "YOLOv5 ObjectDetector library" 6 | LANGUAGES CXX 7 | ) 8 | 9 | find_package(OpenCV REQUIRED) 10 | if (OpenCV_FOUND) 11 | message(STATUS "OpenCV library status:") 12 | message(STATUS " version: ${OpenCV_VERSION}") 13 | message(STATUS " include path: ${OpenCV_INCLUDE_DIRS}") 14 | else () 15 | message(FATAL_ERROR "Could not find OpenCV") 16 | endif () 17 | 18 | set(Torch_DIR ${CMAKE_SOURCE_DIR}/libtorch_v1-6-0/share/cmake/Torch/) 19 | find_package(Torch PATHS ${Torch_DIR} REQUIRED) 20 | if (TORCH_FOUND) 21 | message(STATUS "Torch library status:") 22 | message(STATUS " version: ${Torch_VERSION}") 23 | message(STATUS " include path: ${TORCH_INCLUDE_DIRS}") 24 | else () 25 | message(FATAL_ERROR "Could not find Torch") 26 | endif () 27 | 28 | add_library(object_detector STATIC 29 | ${PROJECT_SOURCE_DIR}/src/object_detector.cpp 30 | ) 31 | target_compile_features(object_detector PRIVATE cxx_std_17) 32 | 33 | target_include_directories(object_detector PUBLIC 34 | ${PROJECT_SOURCE_DIR}/include/ 35 | ) 36 | 37 | target_link_libraries(object_detector 38 | ${OpenCV_LIBS} 39 | ${TORCH_LIBRARIES} 40 | ) 41 | -------------------------------------------------------------------------------- /src/object_detector/include/object_detector.h: -------------------------------------------------------------------------------- 1 | // This code follows Google C++ Style Guide 2 | 3 | #ifndef COMPUTERVISION20200907T072717Z001_OBJECTDETECTOR_OBJECTDETECTOR_H_ 4 | #define COMPUTERVISION20200907T072717Z001_OBJECTDETECTOR_OBJECTDETECTOR_H_ 5 | 6 | 7 | #include 8 | #include 9 | #include 10 | 11 | #include 12 | #include 13 | #include 14 | 15 | 16 | namespace yolov5 { 17 | 18 | #define CLAMP(lower, x, upper) std::max(lower, std::min(x, upper)); 19 | #define DEBUG_PRINT(var) std::cout << #var << " = " << var << "\n"; 20 | 21 | struct ObjectInfo { 22 | cv::Rect bbox_rect; 23 | float class_score; 24 | int class_id; 25 | }; 26 | 27 | struct LetterboxInfo { 28 | int original_height; 29 | int original_width; 30 | float scale; 31 | int padding_height; 32 | int padding_width; 33 | }; 34 | 35 | class ObjectDetector { 36 | public: 37 | ObjectDetector(const std::string& model_filename) 38 | : input_height_(640), 39 | input_width_(640), 40 | nms_max_bbox_size_(4096) { 41 | std::string height_prefix = "-H"; 42 | std::size_t height_pos = model_filename.find(height_prefix); 43 | std::string height_string = model_filename.substr(height_pos + height_prefix.length(), 4); 44 | height_string.erase(height_string.find_last_not_of("-_") + 1); 45 | 46 | int input_height = std::stoi(height_string); 47 | if (input_height != input_height_) { 48 | std::cerr << "[ObjectDetector()] Error: (input_height)=" << input_height 49 | << " doesn't match to (input_image_size_)=" << input_height_ << "\n"; 50 | std::exit(EXIT_FAILURE); 51 | } 52 | 53 | std::string width_prefix = "-W"; 54 | std::size_t width_pos = model_filename.find(width_prefix); 55 | std::string width_string = model_filename.substr(width_pos + width_prefix.length(), 4); 56 | width_string.erase(width_string.find_last_not_of("-_") + 1); 57 | 58 | int input_width = std::stoi(width_string); 59 | if (input_width != input_width_) { 60 | std::cerr << "[ObjectDetector()] Error: (input_width)=" << input_width 61 | << " doesn't match to (input_image_size_)=" << input_width_ << "\n"; 62 | std::exit(EXIT_FAILURE); 63 | } 64 | 65 | std::cout << "Input height = " << input_height_ << "\n"; 66 | std::cout << "Input width = " << input_width_ << "\n\n"; 67 | 68 | // Deserializes the ScriptModule from a file using torch::jit::load() 69 | // https://pytorch.org/tutorials/advanced/cpp_export.html#a-minimal-c-application 70 | try { 71 | std::cout << "[ObjectDetector()] torch::jit::load( " << model_filename << " ); ... \n"; 72 | model_ = torch::jit::load(model_filename); 73 | std::cout << "[ObjectDetector()] " << model_filename << " has been loaded \n\n"; 74 | } 75 | catch (const c10::Error& e) { 76 | std::cerr << e.what() << "\n"; 77 | std::exit(EXIT_FAILURE); 78 | } 79 | catch (...) { 80 | std::cerr << "[ObjectDetector()] Exception: Could not load " << model_filename << "\n"; 81 | std::exit(EXIT_FAILURE); 82 | } 83 | 84 | bool is_found_gpu_string = (model_filename.find("_gpu") != std::string::npos); 85 | is_gpu_ = (is_found_gpu_string && torch::cuda::is_available()); 86 | 87 | if (is_gpu_) { 88 | std::cout << "Inference on GPU with CUDA \n\n"; 89 | model_.to(torch::kCUDA); 90 | model_.to(torch::kHalf); 91 | } else { 92 | std::cout << "Inference on CPU \n\n"; 93 | model_.to(torch::kCPU); 94 | } 95 | 96 | model_.eval(); 97 | } 98 | 99 | ~ObjectDetector() {} 100 | 101 | bool LoadClassNames(const std::string& class_name_filename); 102 | 103 | bool LoadInputImagePaths(const std::string& input_directory); 104 | 105 | void Inference(float confidence_threshold, float iou_threshold); 106 | 107 | 108 | private: 109 | void Detect(const cv::Mat& input_image, 110 | float confidence_threshold, float iou_threshold, 111 | std::vector& results); 112 | 113 | LetterboxInfo PreProcess(const cv::Mat& input_image, 114 | std::vector& inputs); 115 | 116 | LetterboxInfo Letterboxing(const cv::Mat& input_image, cv::Mat& letterbox_image); 117 | 118 | void PostProcess(const at::Tensor& output_tensor, 119 | const LetterboxInfo& letterbox_info, 120 | float confidence_threshold, float iou_threshold, 121 | std::vector& results); 122 | 123 | void XcenterYcenterWidthHeight2TopLeftBottomRight(const at::Tensor& xywh_bbox_tensor, 124 | at::Tensor& tlbr_bbox_tensor); 125 | 126 | void RestoreBoundingboxSize(const std::vector& bbox_infos, 127 | const LetterboxInfo& letterbox_info, 128 | std::vector& restored_bbox_infos); 129 | 130 | void SaveResultImage(const cv::Mat& input_image, 131 | const std::vector& results, 132 | const std::string& input_image_path); 133 | 134 | int input_height_; 135 | int input_width_; 136 | int nms_max_bbox_size_; 137 | torch::jit::script::Module model_; 138 | bool is_gpu_; 139 | std::vector class_names_; 140 | std::vector input_image_paths_; 141 | }; 142 | 143 | } // namespace yolov5 144 | 145 | 146 | #endif // COMPUTERVISION20200907T072717Z001_OBJECTDETECTOR_OBJECTDETECTOR_H_ -------------------------------------------------------------------------------- /src/object_detector/src/object_detector.cpp: -------------------------------------------------------------------------------- 1 | // This code follows Google C++ Style Guide 2 | 3 | #include "object_detector.h" 4 | 5 | #include 6 | 7 | #include 8 | #include 9 | #include 10 | #include 11 | 12 | #include 13 | #include 14 | #include 15 | 16 | 17 | namespace yolov5 { 18 | 19 | bool ObjectDetector::LoadClassNames(const std::string& class_name_filename) { 20 | std::ifstream class_name_ifs(class_name_filename); 21 | if (class_name_ifs.is_open()) { 22 | std::string class_name; 23 | while (std::getline(class_name_ifs, class_name)) { 24 | class_names_.emplace_back(class_name); 25 | } 26 | class_name_ifs.close(); 27 | } else { 28 | std::cerr << "[ObjectDetector::LoadClassNames()] Error: Could not open " 29 | << class_name_filename << "\n"; 30 | return false; 31 | } 32 | 33 | if (class_names_.size() == 0) { 34 | std::cerr << "[ObjectDetector::LoadClassNames()] Error: labe names are empty \n"; 35 | return false; 36 | } 37 | 38 | return true; 39 | } 40 | 41 | 42 | bool ObjectDetector::LoadInputImagePaths(const std::string& input_directory) { 43 | DIR* dir; 44 | struct dirent* entry; 45 | if ((dir = opendir(input_directory.c_str())) != NULL) { 46 | while ((entry = readdir(dir)) != NULL) { 47 | if (entry->d_name[0] != '.') { 48 | std::string input_image_filename(entry->d_name); 49 | std::string input_image_path = input_directory + input_image_filename; 50 | input_image_paths_.emplace_back(input_image_path); 51 | } 52 | } 53 | closedir(dir); 54 | } else { 55 | std::cerr << "[ObjectDetector::LoadInputImages()] Error: Could not open " 56 | << input_directory << "\n"; 57 | return false; 58 | } 59 | 60 | if (input_image_paths_.size() == 0) { 61 | std::cerr << "[ObjectDetector::LoadInputImages()] Error: input image filenames are empty \n"; 62 | return false; 63 | } 64 | 65 | return true; 66 | } 67 | 68 | 69 | void ObjectDetector::Inference(float confidence_threshold, float iou_threshold) { 70 | std::cout << "=== Empty inferences to warm up === \n\n"; 71 | for (std::size_t i = 0; i < 3; ++i) { 72 | cv::Mat tmp_image = cv::Mat::zeros(input_height_, input_width_, CV_32FC3); 73 | std::vector tmp_results; 74 | Detect(tmp_image, 1.0, 1.0, tmp_results); 75 | } 76 | std::cout << "=== Warming up is done === \n\n\n"; 77 | 78 | for (const auto& input_image_path : input_image_paths_) { 79 | std::cout << "input_image_path = " << input_image_path << "\n"; 80 | 81 | cv::Mat input_image = cv::imread(input_image_path); 82 | if (input_image.empty()) { 83 | std::cerr << "[ObjectDetector::Run()] Error: Cloud not open " 84 | << input_image_path << "\n"; 85 | continue; 86 | } 87 | 88 | std::vector results; 89 | Detect(input_image, confidence_threshold, iou_threshold, results); 90 | 91 | SaveResultImage(input_image, results, input_image_path); 92 | } 93 | 94 | return; 95 | } 96 | 97 | 98 | void ObjectDetector::Detect(const cv::Mat& input_image, 99 | float confidence_threshold, float iou_threshold, 100 | std::vector& results) { 101 | torch::NoGradGuard no_grad_guard; 102 | 103 | auto start_preprocess = std::chrono::high_resolution_clock::now(); 104 | std::vector inputs; 105 | LetterboxInfo letterbox_info = PreProcess(input_image, inputs); 106 | auto end_preprocess = std::chrono::high_resolution_clock::now(); 107 | 108 | auto duration_preprocess = std::chrono::duration_cast(end_preprocess - start_preprocess); 109 | std::cout << "Pre-processing: " << duration_preprocess.count() << " [ms] \n"; 110 | 111 | // output_tensor ... {Batch=1, Num of max bbox=25200, 85} 112 | // 25200 ... {(640[px]/32[stride])^2 + (640[px]/16[stride])^2 + (640[px]/8[stride])^2} x 3[layer] 113 | // 85 ... 0: center x, 1: center y, 2: width, 3: height, 4: obj conf, 5~84: class conf 114 | auto start_inference = std::chrono::high_resolution_clock::now(); 115 | at::Tensor output_tensor = model_.forward(inputs).toTuple()->elements()[0].toTensor(); 116 | auto end_inference = std::chrono::high_resolution_clock::now(); 117 | 118 | auto duration_inference = std::chrono::duration_cast(end_inference - start_inference); 119 | std::cout << "Inference: " << duration_inference.count() << " [ms] \n"; 120 | 121 | // results ... {Num of obj, 6} 122 | // 6 ... 0: top-left x, 1: top-left y, 2: bottom-right x, 3: bottom-right y, 4: class score, 5: class id 123 | auto start_postprocess = std::chrono::high_resolution_clock::now(); 124 | PostProcess(output_tensor, letterbox_info, confidence_threshold, iou_threshold, results); 125 | auto end_postprocess = std::chrono::high_resolution_clock::now(); 126 | 127 | auto duration_postprocess = std::chrono::duration_cast(end_postprocess - start_postprocess); 128 | std::cout << "Post-processing: " << duration_postprocess.count() << " [ms] \n\n"; 129 | 130 | return; 131 | } 132 | 133 | 134 | LetterboxInfo ObjectDetector::PreProcess(const cv::Mat& input_image, 135 | std::vector& inputs) { 136 | cv::Mat letterbox_image; 137 | LetterboxInfo letterbox_info = Letterboxing(input_image, letterbox_image); 138 | 139 | // 0 ~ 255 ---> 0.0 ~ 1.0 140 | cv::cvtColor(letterbox_image, letterbox_image, cv::COLOR_BGR2RGB); 141 | letterbox_image.convertTo(letterbox_image, CV_32FC3, 1.0 / 255.0); 142 | 143 | // input_tensor ... {Batch=1, Height, Width, Channel=3} 144 | // ---> 145 | // input_tensor ... {Batch=1, Channel=3, Height, Width} 146 | at::Tensor input_tensor = torch::from_blob(letterbox_image.data, 147 | {1, input_height_, input_width_, 3}); 148 | input_tensor = input_tensor.permute({0, 3, 1, 2}).contiguous(); 149 | 150 | if (is_gpu_) { 151 | input_tensor = input_tensor.to(torch::kCUDA); 152 | input_tensor = input_tensor.to(torch::kHalf); 153 | } else { 154 | input_tensor = input_tensor.to(torch::kCPU); 155 | } 156 | 157 | inputs.clear(); 158 | inputs.emplace_back(input_tensor); 159 | 160 | return letterbox_info; 161 | } 162 | 163 | 164 | LetterboxInfo ObjectDetector::Letterboxing(const cv::Mat& input_image, cv::Mat& letterbox_image) { 165 | float scale = std::min(input_height_ / static_cast(input_image.size().height), 166 | input_width_ / static_cast(input_image.size().width)); 167 | cv::resize(input_image, letterbox_image, cv::Size(), scale, scale); 168 | 169 | int top_margin = floor((input_height_ - letterbox_image.size().height) / 2.0); 170 | int bottom_margin = ceil((input_height_ - letterbox_image.size().height) / 2.0); 171 | int left_margin = floor((input_width_ - letterbox_image.size().width) / 2.0); 172 | int right_margin = ceil((input_width_ - letterbox_image.size().width) / 2.0); 173 | cv::copyMakeBorder(letterbox_image, letterbox_image, 174 | top_margin, bottom_margin, left_margin, right_margin, 175 | cv::BORDER_CONSTANT, cv::Scalar(114, 114, 114)); 176 | 177 | LetterboxInfo letterbox_info; 178 | letterbox_info.original_height = input_image.size().height; 179 | letterbox_info.original_width = input_image.size().width; 180 | letterbox_info.scale = scale; 181 | letterbox_info.padding_height = top_margin; 182 | letterbox_info.padding_width = left_margin; 183 | 184 | return letterbox_info; 185 | } 186 | 187 | 188 | void ObjectDetector::PostProcess(const at::Tensor& output_tensor, const LetterboxInfo& letterbox_info, 189 | float confidence_threshold, float iou_threshold, 190 | std::vector& results) { 191 | int batch_size = output_tensor.size(0); 192 | if (batch_size != 1) { 193 | std::cerr << "[ObjectDetector::PostProcess()] Error: Batch size of output tensor is not 1 \n"; 194 | return; 195 | } 196 | 197 | // 85 ... 0: center x, 1: center y, 2: width, 3: height, 4: obj conf, 5~84: class conf 198 | int num_bbox_confidence_class_idx = output_tensor.size(2); 199 | 200 | // 5 = 85 - 80 ... 0: center x, 1: center y, 2: width, 3: height, 4: obj conf 201 | int num_bbox_confidence_idx = num_bbox_confidence_class_idx - class_names_.size(); 202 | 203 | // 4 = 5 - 1 ... 0: center x, 1: center y, 2: width, 3: height 204 | int num_bbox_idx = num_bbox_confidence_idx - 1; 205 | 206 | 207 | 208 | /***************************************************************************** 209 | * Thresholding the detected objects by class confidence 210 | ****************************************************************************/ 211 | 212 | int bbox_confidence_class_dim = -1; // always in the last dimension 213 | int object_confidence_idx = 4; 214 | 215 | // output_tensor ... {Batch=1, Num of max bbox=25200, 85} 216 | // ---> 217 | // candidate_object_mask ... {Batch=1, Num of max bbox=25200, 1} 218 | at::Tensor candidate_object_mask = output_tensor.select(bbox_confidence_class_dim, 219 | object_confidence_idx); 220 | candidate_object_mask = candidate_object_mask.gt(confidence_threshold); 221 | candidate_object_mask = candidate_object_mask.unsqueeze(bbox_confidence_class_dim); 222 | 223 | // output_tensor[0] ... {Num of max bbox=25200, 85} 224 | // candidate_object_mask[0] ... {Num of max bbox=25200, 1} 225 | // ---> 226 | // candidate_object_tensor ... {Num of candidate bbox*85} 227 | at::Tensor candidate_object_tensor = torch::masked_select(output_tensor[0], 228 | candidate_object_mask[0]); 229 | 230 | // candidate_object_tensor ... {Num of candidate bbox*85} 231 | // ---> 232 | // candidate_object_tensor ... {Num of candidate bbox, 85} 233 | candidate_object_tensor = candidate_object_tensor.view({-1, num_bbox_confidence_class_idx}); 234 | 235 | // If there is no any candidate objects at all, return 236 | if (candidate_object_tensor.size(0) == 0) { 237 | return; 238 | } 239 | 240 | // candidate_object_tensor ... {Num of candidate bbox, 85} 241 | // ---> 242 | // xywh_bbox_tensor ... {Num of candidate bbox, 4} => similar to [:, 0:4] in Python 243 | at::Tensor xywh_bbox_tensor = candidate_object_tensor.slice(bbox_confidence_class_dim, 244 | 0, num_bbox_idx); 245 | 246 | // xywh_bbox_tensor ... {Num of candidate bbox, 4} 247 | // 4 ... 0: x center, 1: y center, 2: width, 3: height 248 | // ---> 249 | // bbox_tensor ... {Num of candidate bbox, 4} 250 | // 4 ... 0: top-left x, 1: top-left y, 2: bottom-right x, 3: bottom-right y 251 | at::Tensor bbox_tensor; 252 | XcenterYcenterWidthHeight2TopLeftBottomRight(xywh_bbox_tensor, bbox_tensor); 253 | 254 | // candidate_object_tensor ... {Num of candidate bbox, 85} 255 | // ---> 256 | // object_confidence_tensor ... {Num of candidate bbox, 1} => similar to [:, 4:5] in Python 257 | at::Tensor object_confidence_tensor = candidate_object_tensor.slice(bbox_confidence_class_dim, 258 | num_bbox_idx, num_bbox_confidence_idx); 259 | 260 | // candidate_object_tensor ... {Num of candidate bbox, 85} 261 | // ---> 262 | // class_confidence_tensor ... {Num of candidate bbox, 80} => similar to [:, 5:] in Python 263 | at::Tensor class_confidence_tensor = candidate_object_tensor.slice(bbox_confidence_class_dim, 264 | num_bbox_confidence_idx); 265 | 266 | // class_score_tensor ... {Num of candidate bbox, 80} 267 | at::Tensor class_score_tensor = class_confidence_tensor * object_confidence_tensor; 268 | 269 | // max_class_score_tuple ... (value: {Num of candidate bbox}, index: {Num of candidate bbox}) 270 | std::tuple max_class_score_tuple = torch::max(class_score_tensor, 271 | bbox_confidence_class_dim); 272 | 273 | // max_class_score ... {Num of candidate bbox} 274 | // ---> 275 | // max_class_score ... {Num of candidate bbox, 1} 276 | at::Tensor max_class_score = std::get<0>(max_class_score_tuple).to(torch::kFloat); 277 | max_class_score = max_class_score.unsqueeze(bbox_confidence_class_dim); 278 | 279 | // max_class_id ... {Num of candidate bbox} 280 | // ---> 281 | // max_class_id ... {Num of candidate bbox, 1} 282 | at::Tensor max_class_id = std::get<1>(max_class_score_tuple).to(torch::kFloat); 283 | max_class_id = max_class_id.unsqueeze(bbox_confidence_class_dim); 284 | 285 | // result_tensor ... {Num of candidate bbox, 6} 286 | // 6 ... 0: top-left x, 1: top-left y, 2: bottom-right x, 3: bottom-right y, 4: class score, 5: class id 287 | at::Tensor result_tensor = torch::cat({bbox_tensor, max_class_score, max_class_id}, 288 | bbox_confidence_class_dim); 289 | 290 | 291 | 292 | /***************************************************************************** 293 | * Non Maximum Suppression 294 | ****************************************************************************/ 295 | 296 | // class_id_tensor ... {Num of candidate bbox, 1} => similar to [:, -1:] in Python 297 | at::Tensor class_id_tensor = result_tensor.slice(bbox_confidence_class_dim, -1); 298 | 299 | // class_offset_bbox_tensor ... {Num of candidate bbox, 4} 300 | // 4 ... 0: top-left x, 1: top-left y, 2: bottom-right x, 3: bottom-right y (but offset by +4096 * class id) 301 | at::Tensor class_offset_bbox_tensor = result_tensor.slice(bbox_confidence_class_dim, 0, num_bbox_idx) 302 | + nms_max_bbox_size_ * class_id_tensor; 303 | 304 | // Copies tensor to CPU to access tensor elements efficiently with TensorAccessor 305 | // https://pytorch.org/cppdocs/notes/tensor_basics.html#efficient-access-to-tensor-elements 306 | at::Tensor class_offset_bbox_tensor_cpu = class_offset_bbox_tensor.cpu(); 307 | at::Tensor result_tensor_cpu = result_tensor.cpu(); 308 | auto class_offset_bbox_tensor_accessor = class_offset_bbox_tensor_cpu.accessor(); 309 | auto result_tensor_accessor = result_tensor_cpu.accessor(); 310 | 311 | std::vector offset_bboxes; 312 | std::vector class_scores; 313 | offset_bboxes.reserve(result_tensor_accessor.size(0)); 314 | class_scores.reserve(result_tensor_accessor.size(0)); 315 | 316 | for (std::size_t i = 0; i < result_tensor_accessor.size(0); ++i) { 317 | float class_offset_top_left_x = class_offset_bbox_tensor_accessor[i][0]; 318 | float class_offset_top_left_y = class_offset_bbox_tensor_accessor[i][1]; 319 | float class_offset_bottom_right_x = class_offset_bbox_tensor_accessor[i][2]; 320 | float class_offset_bottom_right_y = class_offset_bbox_tensor_accessor[i][3]; 321 | 322 | offset_bboxes.emplace_back(cv::Rect(cv::Point(class_offset_top_left_x, class_offset_top_left_y), 323 | cv::Point(class_offset_bottom_right_x, class_offset_bottom_right_y))); 324 | 325 | class_scores.emplace_back(result_tensor_accessor[i][4]); 326 | } 327 | 328 | std::vector nms_indecies; 329 | cv::dnn::NMSBoxes(offset_bboxes, class_scores, confidence_threshold, iou_threshold, nms_indecies); 330 | 331 | 332 | 333 | /***************************************************************************** 334 | * Create result data 335 | ****************************************************************************/ 336 | 337 | std::vector object_infos; 338 | for (const auto& nms_idx : nms_indecies) { 339 | float top_left_x = result_tensor_accessor[nms_idx][0]; 340 | float top_left_y = result_tensor_accessor[nms_idx][1]; 341 | float bottom_right_x = result_tensor_accessor[nms_idx][2]; 342 | float bottom_right_y = result_tensor_accessor[nms_idx][3]; 343 | 344 | ObjectInfo object_info; 345 | object_info.bbox_rect = cv::Rect(cv::Point(top_left_x, top_left_y), 346 | cv::Point(bottom_right_x, bottom_right_y)); 347 | object_info.class_score = result_tensor_accessor[nms_idx][4]; 348 | object_info.class_id = result_tensor_accessor[nms_idx][5]; 349 | 350 | object_infos.emplace_back(object_info); 351 | } 352 | 353 | RestoreBoundingboxSize(object_infos, letterbox_info, results); 354 | 355 | return; 356 | } 357 | 358 | 359 | // xywh_bbox_tensor ... {Num of bbox, 4} 360 | // 4 ... 0: x center, 1: y center, 2: width, 3: height 361 | // ---> 362 | // tlbr_bbox_tensor ... {Num of bbox, 4} 363 | // 4 ... 0: top-left x, 1: top-left y, 2: bottom-right x, 3: bottom-right y 364 | void ObjectDetector::XcenterYcenterWidthHeight2TopLeftBottomRight(const at::Tensor& xywh_bbox_tensor, 365 | at::Tensor& tlbr_bbox_tensor) { 366 | tlbr_bbox_tensor = torch::zeros_like(xywh_bbox_tensor); 367 | 368 | int bbox_dim = -1; // the last dimension 369 | 370 | int x_center_idx = 0; 371 | int y_center_idx = 1; 372 | int width_idx = 2; 373 | int height_idx = 3; 374 | 375 | tlbr_bbox_tensor.select(bbox_dim, 0) = xywh_bbox_tensor.select(bbox_dim, x_center_idx) 376 | - xywh_bbox_tensor.select(bbox_dim, width_idx).div(2.0); 377 | tlbr_bbox_tensor.select(bbox_dim, 1) = xywh_bbox_tensor.select(bbox_dim, y_center_idx) 378 | - xywh_bbox_tensor.select(bbox_dim, height_idx).div(2.0); 379 | tlbr_bbox_tensor.select(bbox_dim, 2) = xywh_bbox_tensor.select(bbox_dim, x_center_idx) 380 | + xywh_bbox_tensor.select(bbox_dim, width_idx).div(2.0); 381 | tlbr_bbox_tensor.select(bbox_dim, 3) = xywh_bbox_tensor.select(bbox_dim, y_center_idx) 382 | + xywh_bbox_tensor.select(bbox_dim, height_idx).div(2.0); 383 | 384 | return; 385 | } 386 | 387 | 388 | void ObjectDetector::RestoreBoundingboxSize(const std::vector& object_infos, 389 | const LetterboxInfo& letterbox_info, 390 | std::vector& restored_object_infos) { 391 | restored_object_infos.clear(); 392 | restored_object_infos.reserve(object_infos.size()); 393 | 394 | for (const auto& object_info : object_infos) { 395 | float top_left_x = (object_info.bbox_rect.tl().x - letterbox_info.padding_width) / letterbox_info.scale; 396 | float top_left_y = (object_info.bbox_rect.tl().y - letterbox_info.padding_height) / letterbox_info.scale; 397 | float bottom_right_x = (object_info.bbox_rect.br().x - letterbox_info.padding_width) / letterbox_info.scale; 398 | float bottom_right_y = (object_info.bbox_rect.br().y - letterbox_info.padding_height) / letterbox_info.scale; 399 | 400 | top_left_x = CLAMP(0.0f, top_left_x, static_cast(letterbox_info.original_width)); 401 | top_left_y = CLAMP(0.0f, top_left_y, static_cast(letterbox_info.original_height)); 402 | bottom_right_x = CLAMP(0.0f, bottom_right_x, static_cast(letterbox_info.original_width)); 403 | bottom_right_y = CLAMP(0.0f, bottom_right_y, static_cast(letterbox_info.original_height)); 404 | 405 | ObjectInfo restored_object_info; 406 | restored_object_info.bbox_rect = cv::Rect(cv::Point(top_left_x, top_left_y), 407 | cv::Point(bottom_right_x, bottom_right_y)); 408 | restored_object_info.class_score = object_info.class_score; 409 | restored_object_info.class_id = object_info.class_id; 410 | 411 | restored_object_infos.emplace_back(restored_object_info); 412 | } 413 | 414 | return; 415 | } 416 | 417 | 418 | void ObjectDetector::SaveResultImage(const cv::Mat& input_image, 419 | const std::vector& results, 420 | const std::string& input_image_path) { 421 | cv::Mat result_image(input_image); 422 | 423 | for (const auto& object_info : results) { 424 | // Draws object bounding box 425 | cv::rectangle(result_image, object_info.bbox_rect, cv::Scalar(0,0,255), 1); 426 | 427 | // Class info text 428 | std::string class_name = class_names_[object_info.class_id]; 429 | std::stringstream class_score; 430 | class_score << std::fixed << std::setprecision(2) << object_info.class_score; 431 | std::string class_info = class_name + " " + class_score.str(); 432 | 433 | // Size of class info text 434 | auto font_face = cv::FONT_HERSHEY_SIMPLEX; 435 | float font_scale = 1.0; 436 | int thickness = 1; 437 | int baseline = 0; 438 | cv::Size class_info_size = cv::getTextSize(class_info, font_face, font_scale, thickness, &baseline); 439 | 440 | // Draws rectangle of class info text 441 | int height_offset = 5; // [px] 442 | cv::Point class_info_top_left = cv::Point(object_info.bbox_rect.tl().x, 443 | object_info.bbox_rect.tl().y - class_info_size.height - height_offset); 444 | cv::Point class_info_bottom_right = cv::Point(object_info.bbox_rect.tl().x + class_info_size.width, 445 | object_info.bbox_rect.tl().y); 446 | cv::rectangle(result_image, class_info_top_left, class_info_bottom_right, cv::Scalar(0,0,255), -1); 447 | 448 | // Draws class info text 449 | cv::Point class_info_text_position = cv::Point(object_info.bbox_rect.tl().x, 450 | object_info.bbox_rect.tl().y - height_offset); 451 | cv::putText(result_image, class_info, class_info_text_position, 452 | font_face, font_scale, cv::Scalar(0,0,0), thickness); 453 | } 454 | 455 | std::size_t last_slash_pos = input_image_path.find_last_of('/'); 456 | std::string input_image_directory = input_image_path.substr(0, last_slash_pos + 1); 457 | std::string input_image_filename = input_image_path.substr(last_slash_pos + 1); 458 | 459 | std::size_t last_hyphen_pos = input_image_filename.find_last_of('_'); 460 | std::string result_image_directory = input_image_directory + "../result_image/"; 461 | std::string result_image_filename = "result_" + input_image_filename.substr(last_hyphen_pos + 1); 462 | 463 | cv::imwrite(result_image_directory + result_image_filename, result_image); 464 | 465 | return; 466 | } 467 | 468 | } // namespace yolov5 --------------------------------------------------------------------------------