├── .gitattributes ├── Imginput ├── bus.jpg ├── demo1.jpg ├── demo2.jpg ├── test1.jpg ├── test2.jpg └── zidane.jpg ├── Imgoutput ├── bus_m.jpg ├── bus_ms.jpg ├── demo1_m.jpg ├── demo2_m.jpg ├── test1_m.jpg ├── test2_m.jpg ├── demo1_ms.jpg ├── demo2_ms.jpg ├── test1_ms.jpg ├── test2_ms.jpg ├── zidane_m.jpg └── zidane_ms.jpg ├── models ├── yolov8m.onnx ├── yolov8m-seg.onnx └── coco.names ├── run.sh ├── CMakeLists.txt ├── include ├── utils.h ├── yolov8Predictor.h └── cmdline.h ├── README.md └── src ├── main.cpp ├── utils.cpp └── yolov8Predictor.cpp /.gitattributes: -------------------------------------------------------------------------------- 1 | *.onnx filter=lfs diff=lfs merge=lfs -text 2 | -------------------------------------------------------------------------------- /Imginput/bus.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Li-99/yolov8_onnxruntime/HEAD/Imginput/bus.jpg -------------------------------------------------------------------------------- /Imginput/demo1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Li-99/yolov8_onnxruntime/HEAD/Imginput/demo1.jpg -------------------------------------------------------------------------------- /Imginput/demo2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Li-99/yolov8_onnxruntime/HEAD/Imginput/demo2.jpg -------------------------------------------------------------------------------- /Imginput/test1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Li-99/yolov8_onnxruntime/HEAD/Imginput/test1.jpg -------------------------------------------------------------------------------- /Imginput/test2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Li-99/yolov8_onnxruntime/HEAD/Imginput/test2.jpg -------------------------------------------------------------------------------- /Imginput/zidane.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Li-99/yolov8_onnxruntime/HEAD/Imginput/zidane.jpg -------------------------------------------------------------------------------- /Imgoutput/bus_m.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Li-99/yolov8_onnxruntime/HEAD/Imgoutput/bus_m.jpg -------------------------------------------------------------------------------- /Imgoutput/bus_ms.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Li-99/yolov8_onnxruntime/HEAD/Imgoutput/bus_ms.jpg -------------------------------------------------------------------------------- /Imgoutput/demo1_m.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Li-99/yolov8_onnxruntime/HEAD/Imgoutput/demo1_m.jpg -------------------------------------------------------------------------------- /Imgoutput/demo2_m.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Li-99/yolov8_onnxruntime/HEAD/Imgoutput/demo2_m.jpg -------------------------------------------------------------------------------- /Imgoutput/test1_m.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Li-99/yolov8_onnxruntime/HEAD/Imgoutput/test1_m.jpg -------------------------------------------------------------------------------- /Imgoutput/test2_m.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Li-99/yolov8_onnxruntime/HEAD/Imgoutput/test2_m.jpg -------------------------------------------------------------------------------- /Imgoutput/demo1_ms.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Li-99/yolov8_onnxruntime/HEAD/Imgoutput/demo1_ms.jpg -------------------------------------------------------------------------------- /Imgoutput/demo2_ms.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Li-99/yolov8_onnxruntime/HEAD/Imgoutput/demo2_ms.jpg -------------------------------------------------------------------------------- /Imgoutput/test1_ms.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Li-99/yolov8_onnxruntime/HEAD/Imgoutput/test1_ms.jpg -------------------------------------------------------------------------------- /Imgoutput/test2_ms.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Li-99/yolov8_onnxruntime/HEAD/Imgoutput/test2_ms.jpg -------------------------------------------------------------------------------- /Imgoutput/zidane_m.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Li-99/yolov8_onnxruntime/HEAD/Imgoutput/zidane_m.jpg -------------------------------------------------------------------------------- /Imgoutput/zidane_ms.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Li-99/yolov8_onnxruntime/HEAD/Imgoutput/zidane_ms.jpg -------------------------------------------------------------------------------- /models/yolov8m.onnx: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:8ae23d1e9e9943f2cd051859c4352bfbe0fcf92ec4b2bbf60d35a9f09d2bb4dd 3 | size 103773495 4 | -------------------------------------------------------------------------------- /models/yolov8m-seg.onnx: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:6c3616b0bd2e1940f9bb5b737052696ac5313a85640e1b8c85898d3f80d59938 3 | size 109312393 4 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | ./build/yolov8_ort -m ./models/yolov8m.onnx -i ./Imginput -o ./Imgoutput -c ./models/coco.names -x m --gpu 2 | ./build/yolov8_ort -m ./models/yolov8m-seg.onnx -i ./Imginput -o ./Imgoutput -c ./models/coco.names -x ms --gpu 3 | 4 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.0.0) 2 | project(yolov8_ort) 3 | 4 | option(ONNXRUNTIME_DIR "Path to built ONNX Runtime directory." STRING) 5 | message(STATUS "ONNXRUNTIME_DIR: ${ONNXRUNTIME_DIR}") 6 | 7 | find_package(OpenCV REQUIRED) 8 | 9 | include_directories("include/") 10 | 11 | add_executable(yolov8_ort 12 | src/utils.cpp 13 | src/yolov8Predictor.cpp 14 | src/main.cpp) 15 | 16 | set(CMAKE_CXX_STANDARD 17) 17 | set(CMAKE_CXX_STANDARD_REQUIRED ON) 18 | 19 | target_include_directories(yolov8_ort PRIVATE "${ONNXRUNTIME_DIR}/include") 20 | 21 | target_compile_features(yolov8_ort PRIVATE cxx_std_17) 22 | target_link_libraries(yolov8_ort ${OpenCV_LIBS}) 23 | 24 | 25 | if (WIN32) 26 | target_link_libraries(yolov8_ort "${ONNXRUNTIME_DIR}/lib/onnxruntime.lib") 27 | endif(WIN32) 28 | 29 | if (UNIX) 30 | target_link_libraries(yolov8_ort "${ONNXRUNTIME_DIR}/lib/libonnxruntime.so") 31 | endif(UNIX) 32 | 33 | -------------------------------------------------------------------------------- /models/coco.names: -------------------------------------------------------------------------------- 1 | person 2 | bicycle 3 | car 4 | motorbike 5 | aeroplane 6 | bus 7 | train 8 | truck 9 | boat 10 | traffic light 11 | fire hydrant 12 | stop sign 13 | parking meter 14 | bench 15 | bird 16 | cat 17 | dog 18 | horse 19 | sheep 20 | cow 21 | elephant 22 | bear 23 | zebra 24 | giraffe 25 | backpack 26 | umbrella 27 | handbag 28 | tie 29 | suitcase 30 | frisbee 31 | skis 32 | snowboard 33 | sports ball 34 | kite 35 | baseball bat 36 | baseball glove 37 | skateboard 38 | surfboard 39 | tennis racket 40 | bottle 41 | wine glass 42 | cup 43 | fork 44 | knife 45 | spoon 46 | bowl 47 | banana 48 | apple 49 | sandwich 50 | orange 51 | broccoli 52 | carrot 53 | hot dog 54 | pizza 55 | donut 56 | cake 57 | chair 58 | sofa 59 | pottedplant 60 | bed 61 | diningtable 62 | toilet 63 | tvmonitor 64 | laptop 65 | mouse 66 | remote 67 | keyboard 68 | cell phone 69 | microwave 70 | oven 71 | toaster 72 | sink 73 | refrigerator 74 | book 75 | clock 76 | vase 77 | scissors 78 | teddy bear 79 | hair drier 80 | toothbrush -------------------------------------------------------------------------------- /include/utils.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | #include 5 | 6 | struct Yolov8Result 7 | { 8 | cv::Rect box; 9 | cv::Mat boxMask; // mask in box 10 | float conf{}; 11 | int classId{}; 12 | }; 13 | 14 | namespace utils 15 | { 16 | static std::vector colors; 17 | 18 | size_t vectorProduct(const std::vector &vector); 19 | std::wstring charToWstring(const char *str); 20 | std::vector loadNames(const std::string &path); 21 | void visualizeDetection(cv::Mat &image, std::vector &results, 22 | const std::vector &classNames); 23 | 24 | void letterbox(const cv::Mat &image, cv::Mat &outImage, 25 | const cv::Size &newShape, 26 | const cv::Scalar &color, 27 | bool auto_, 28 | bool scaleFill, 29 | bool scaleUp, 30 | int stride); 31 | 32 | void scaleCoords(cv::Rect &coords, cv::Mat &mask, 33 | const float maskThreshold, 34 | const cv::Size &imageShape, const cv::Size &imageOriginalShape); 35 | 36 | template 37 | T clip(const T &n, const T &lower, const T &upper); 38 | } 39 | -------------------------------------------------------------------------------- /include/yolov8Predictor.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | #include 5 | 6 | #include "utils.h" 7 | 8 | class YOLOPredictor 9 | { 10 | public: 11 | explicit YOLOPredictor(std::nullptr_t){}; 12 | YOLOPredictor(const std::string &modelPath, 13 | const bool &isGPU, 14 | float confThreshold, 15 | float iouThreshold, 16 | float maskThreshold); 17 | // ~YOLOPredictor(); 18 | std::vector predict(cv::Mat &image); 19 | int classNums = 80; 20 | 21 | private: 22 | Ort::Env env{nullptr}; 23 | Ort::SessionOptions sessionOptions{nullptr}; 24 | Ort::Session session{nullptr}; 25 | 26 | void preprocessing(cv::Mat &image, float *&blob, std::vector &inputTensorShape); 27 | std::vector postprocessing(const cv::Size &resizedImageShape, 28 | const cv::Size &originalImageShape, 29 | std::vector &outputTensors); 30 | 31 | static void getBestClassInfo(std::vector::iterator it, 32 | float &bestConf, 33 | int &bestClassId, 34 | const int _classNums); 35 | cv::Mat getMask(const cv::Mat &maskProposals, const cv::Mat &maskProtos); 36 | bool isDynamicInputShape{}; 37 | 38 | std::vector inputNames; 39 | std::vector input_names_ptr; 40 | 41 | std::vector outputNames; 42 | std::vector output_names_ptr; 43 | 44 | std::vector> inputShapes; 45 | std::vector> outputShapes; 46 | float confThreshold = 0.3f; 47 | float iouThreshold = 0.4f; 48 | 49 | bool hasMask = false; 50 | float maskThreshold = 0.5f; 51 | }; -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # yolov8-onnxruntime 2 | 3 | **C++ YOLOv8 ONNXRuntime** inference code for *Object Detection* or *Instance Segmentation*. 4 | 5 | **Support for custom training model deployment !!!** 6 | ## Demo 7 | 8 | YOLOv8m and yolov8m-seg: 9 | 10 |

11 | 13 | 14 |

15 | 16 |

17 | 19 | 20 |

21 | 22 | ## My Dependecies: 23 | - OpenCV 4.2 24 | - ONNXRuntime 1.15. 25 | - OS: Windows or Linux 26 | - CUDA 12.0 [Optional] 27 | - YOLOv8 export with onnx 1.14. opset 17 28 | 29 | ## Build 30 | 31 | Find the compiled package for your system on the [official website](https://github.com/microsoft/onnxruntime/releases), then unzip it and replace the extracted file path with the following *path/to/onnxruntime* 32 | 33 | ```bash 34 | mkdir build 35 | cd build 36 | cmake .. -DONNXRUNTIME_DIR=path/to/onnxruntime -DCMAKE_BUILD_TYPE=Debug 37 | make 38 | cd .. 39 | # And you can just do this 40 | # sh build.sh 41 | ``` 42 | 43 | ## Run 44 | You should convert your PyTorch model (.pt) to ONNX (.onnx). 45 | 46 | The [official tutorial](https://docs.ultralytics.com/modes/export/) may help you. 47 | 48 | Make sure you have added OpenCV libraries in your environment. 49 | 50 | Run in Linux 51 | ```bash 52 | ./build/yolov8_ort -m ./models/yolov8m.onnx -i ./Imginput -o ./Imgoutput -c ./models/coco.names -x m --gpu 53 | 54 | ./build/yolov8_ort -m ./models/yolov8m-seg.onnx -i ./Imginput -o ./Imgoutput -c ./models/coco.names -x ms --gpu 55 | 56 | # for your custom model 57 | ./build/yolov8_ort -m ./models/modelname.onnx -i ./Imginput -o ./Imgoutput -c ./models/class.names -x ms --gpu 58 | #-m Path to onnx model. 59 | #-i Image source to be predicted. 60 | #-o Path to save result. 61 | #-c Path to class names file. 62 | #-x Suffix names for save. 63 | #--gpu Whether inference on cuda device if you have. 64 | ``` 65 | For Windows 66 | ```bash 67 | ./build/yolov8_ort.exe -m ./models/modelname.onnx -i ./Imginput -o ./Imgoutput -c ./models/class.names -x ms --gpu 68 | ``` 69 | 70 | 71 | ## References 72 | 73 | - ONNXRuntime Inference examples: https://github.com/microsoft/onnxruntime-inference-examples 74 | - YOLO v8 repo: https://docs.ultralytics.com/ 75 | -------------------------------------------------------------------------------- /src/main.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include "cmdline.h" 7 | #include "utils.h" 8 | #include "yolov8Predictor.h" 9 | 10 | int main(int argc, char *argv[]) 11 | { 12 | float confThreshold = 0.4f; 13 | float iouThreshold = 0.4f; 14 | 15 | float maskThreshold = 0.5f; 16 | 17 | cmdline::parser cmd; 18 | cmd.add("model_path", 'm', "Path to onnx model.", false, "yolov8m.onnx"); 19 | cmd.add("image_path", 'i', "Image source to be predicted.", false, "./Imginput"); 20 | cmd.add("out_path", 'o', "Path to save result.", false, "./Imgoutput"); 21 | cmd.add("class_names", 'c', "Path to class names file.", false, "coco.names"); 22 | 23 | cmd.add("suffix_name", 'x', "Suffix names.", false, "yolov8m"); 24 | 25 | cmd.add("gpu", '\0', "Inference on cuda device."); 26 | 27 | cmd.parse_check(argc, argv); 28 | 29 | bool isGPU = cmd.exist("gpu"); 30 | const std::string classNamesPath = cmd.get("class_names"); 31 | const std::vector classNames = utils::loadNames(classNamesPath); 32 | const std::string imagePath = cmd.get("image_path"); 33 | const std::string savePath = cmd.get("out_path"); 34 | const std::string suffixName = cmd.get("suffix_name"); 35 | const std::string modelPath = cmd.get("model_path"); 36 | 37 | if (classNames.empty()) 38 | { 39 | std::cerr << "Error: Empty class names file." << std::endl; 40 | return -1; 41 | } 42 | if (!std::filesystem::exists(modelPath)) 43 | { 44 | std::cerr << "Error: There is no model." << std::endl; 45 | return -1; 46 | } 47 | if (!std::filesystem::is_directory(imagePath)) 48 | { 49 | std::cerr << "Error: There is no model." << std::endl; 50 | return -1; 51 | } 52 | if (!std::filesystem::is_directory(savePath)) 53 | { 54 | std::filesystem::create_directory(savePath); 55 | } 56 | std::cout << "Model from :::" << modelPath << std::endl; 57 | std::cout << "Images from :::" << imagePath << std::endl; 58 | std::cout << "Resluts will be saved :::" << savePath << std::endl; 59 | 60 | YOLOPredictor predictor{nullptr}; 61 | try 62 | { 63 | predictor = YOLOPredictor(modelPath, isGPU, 64 | confThreshold, 65 | iouThreshold, 66 | maskThreshold); 67 | std::cout << "Model was initialized." << std::endl; 68 | } 69 | catch (const std::exception &e) 70 | { 71 | std::cerr << e.what() << std::endl; 72 | return -1; 73 | } 74 | assert(classNames.size() == predictor.classNums); 75 | std::regex pattern(".+\\.(jpg|jpeg|png|gif)$"); 76 | std::cout << "Start predicting..." << std::endl; 77 | 78 | clock_t startTime, endTime; 79 | startTime = clock(); 80 | 81 | int picNums = 0; 82 | 83 | for (const auto &entry : std::filesystem::directory_iterator(imagePath)) 84 | { 85 | if (std::filesystem::is_regular_file(entry.path()) && std::regex_match(entry.path().filename().string(), pattern)) 86 | { 87 | picNums += 1; 88 | std::string Filename = entry.path().string(); 89 | std::string baseName = std::filesystem::path(Filename).filename().string(); 90 | std::cout << Filename << " predicting..." << std::endl; 91 | 92 | cv::Mat image = cv::imread(Filename); 93 | std::vector result = predictor.predict(image); 94 | utils::visualizeDetection(image, result, classNames); 95 | 96 | std::string newFilename = baseName.substr(0, baseName.find_last_of('.')) + "_" + suffixName + baseName.substr(baseName.find_last_of('.')); 97 | std::string outputFilename = savePath + "/" + newFilename; 98 | cv::imwrite(outputFilename, image); 99 | std::cout << outputFilename << " Saved !!!" << std::endl; 100 | } 101 | } 102 | endTime = clock(); 103 | std::cout << "The total run time is: " << (double)(endTime - startTime) / CLOCKS_PER_SEC << "seconds" << std::endl; 104 | std::cout << "The average run time is: " << (double)(endTime - startTime) / picNums / CLOCKS_PER_SEC << "seconds" << std::endl; 105 | 106 | std::cout << "##########DONE################" << std::endl; 107 | 108 | return 0; 109 | } 110 | -------------------------------------------------------------------------------- /src/utils.cpp: -------------------------------------------------------------------------------- 1 | #include "utils.h" 2 | 3 | size_t utils::vectorProduct(const std::vector &vector) 4 | { 5 | if (vector.empty()) 6 | return 0; 7 | 8 | size_t product = 1; 9 | for (const auto &element : vector) 10 | product *= element; 11 | 12 | return product; 13 | } 14 | 15 | std::wstring utils::charToWstring(const char *str) 16 | { 17 | typedef std::codecvt_utf8 convert_type; 18 | std::wstring_convert converter; 19 | 20 | return converter.from_bytes(str); 21 | } 22 | 23 | std::vector utils::loadNames(const std::string &path) 24 | { 25 | // load class names 26 | std::vector classNames; 27 | std::ifstream infile(path); 28 | if (infile.good()) 29 | { 30 | std::string line; 31 | while (getline(infile, line)) 32 | { 33 | if (line.back() == '\r') 34 | line.pop_back(); 35 | classNames.emplace_back(line); 36 | } 37 | infile.close(); 38 | } 39 | else 40 | { 41 | std::cerr << "ERROR: Failed to access class name path: " << path << std::endl; 42 | } 43 | // set color 44 | srand(time(0)); 45 | 46 | for (int i = 0; i < 2 * classNames.size(); i++) 47 | { 48 | int b = rand() % 256; 49 | int g = rand() % 256; 50 | int r = rand() % 256; 51 | colors.push_back(cv::Scalar(b, g, r)); 52 | } 53 | return classNames; 54 | } 55 | 56 | void utils::visualizeDetection(cv::Mat &im, std::vector &results, 57 | const std::vector &classNames) 58 | { 59 | cv::Mat image = im.clone(); 60 | for (const Yolov8Result &result : results) 61 | { 62 | 63 | int x = result.box.x; 64 | int y = result.box.y; 65 | 66 | int conf = (int)std::round(result.conf * 100); 67 | int classId = result.classId; 68 | std::string label = classNames[classId] + " 0." + std::to_string(conf); 69 | 70 | int baseline = 0; 71 | cv::Size size = cv::getTextSize(label, cv::FONT_ITALIC, 0.4, 1, &baseline); 72 | image(result.box).setTo(colors[classId + classNames.size()], result.boxMask); 73 | cv::rectangle(image, result.box, colors[classId], 2); 74 | cv::rectangle(image, 75 | cv::Point(x, y), cv::Point(x + size.width, y + 12), 76 | colors[classId], -1); 77 | cv::putText(image, label, 78 | cv::Point(x, y - 3 + 12), cv::FONT_ITALIC, 79 | 0.4, cv::Scalar(0, 0, 0), 1); 80 | } 81 | cv::addWeighted(im, 0.4, image, 0.6, 0, im); 82 | } 83 | 84 | void utils::letterbox(const cv::Mat &image, cv::Mat &outImage, 85 | const cv::Size &newShape = cv::Size(640, 640), 86 | const cv::Scalar &color = cv::Scalar(114, 114, 114), 87 | bool auto_ = true, 88 | bool scaleFill = false, 89 | bool scaleUp = true, 90 | int stride = 32) 91 | { 92 | cv::Size shape = image.size(); 93 | float r = std::min((float)newShape.height / (float)shape.height, 94 | (float)newShape.width / (float)shape.width); 95 | if (!scaleUp) 96 | r = std::min(r, 1.0f); 97 | 98 | float ratio[2]{r, r}; 99 | int newUnpad[2]{(int)std::round((float)shape.width * r), 100 | (int)std::round((float)shape.height * r)}; 101 | 102 | auto dw = (float)(newShape.width - newUnpad[0]); 103 | auto dh = (float)(newShape.height - newUnpad[1]); 104 | 105 | if (auto_) 106 | { 107 | dw = (float)((int)dw % stride); 108 | dh = (float)((int)dh % stride); 109 | } 110 | else if (scaleFill) 111 | { 112 | dw = 0.0f; 113 | dh = 0.0f; 114 | newUnpad[0] = newShape.width; 115 | newUnpad[1] = newShape.height; 116 | ratio[0] = (float)newShape.width / (float)shape.width; 117 | ratio[1] = (float)newShape.height / (float)shape.height; 118 | } 119 | 120 | dw /= 2.0f; 121 | dh /= 2.0f; 122 | 123 | if (shape.width != newUnpad[0] && shape.height != newUnpad[1]) 124 | { 125 | cv::resize(image, outImage, cv::Size(newUnpad[0], newUnpad[1])); 126 | } 127 | 128 | int top = int(std::round(dh - 0.1f)); 129 | int bottom = int(std::round(dh + 0.1f)); 130 | int left = int(std::round(dw - 0.1f)); 131 | int right = int(std::round(dw + 0.1f)); 132 | cv::copyMakeBorder(outImage, outImage, top, bottom, left, right, cv::BORDER_CONSTANT, color); 133 | } 134 | 135 | void utils::scaleCoords(cv::Rect &coords, 136 | cv::Mat &mask, 137 | const float maskThreshold, 138 | const cv::Size &imageShape, 139 | const cv::Size &imageOriginalShape) 140 | { 141 | float gain = std::min((float)imageShape.height / (float)imageOriginalShape.height, 142 | (float)imageShape.width / (float)imageOriginalShape.width); 143 | 144 | int pad[2] = {(int)(((float)imageShape.width - (float)imageOriginalShape.width * gain) / 2.0f), 145 | (int)(((float)imageShape.height - (float)imageOriginalShape.height * gain) / 2.0f)}; 146 | 147 | coords.x = (int)std::round(((float)(coords.x - pad[0]) / gain)); 148 | coords.x = std::max(0, coords.x); 149 | coords.y = (int)std::round(((float)(coords.y - pad[1]) / gain)); 150 | coords.y = std::max(0, coords.y); 151 | 152 | coords.width = (int)std::round(((float)coords.width / gain)); 153 | coords.width = std::min(coords.width, imageOriginalShape.width - coords.x); 154 | coords.height = (int)std::round(((float)coords.height / gain)); 155 | coords.height = std::min(coords.height, imageOriginalShape.height - coords.y); 156 | mask = mask(cv::Rect(pad[0], pad[1], imageShape.width - 2 * pad[0], imageShape.height - 2 * pad[1])); 157 | 158 | cv::resize(mask, mask, imageOriginalShape, cv::INTER_LINEAR); 159 | 160 | mask = mask(coords) > maskThreshold; 161 | } 162 | template 163 | T utils::clip(const T &n, const T &lower, const T &upper) 164 | { 165 | return std::max(lower, std::min(n, upper)); 166 | } 167 | -------------------------------------------------------------------------------- /src/yolov8Predictor.cpp: -------------------------------------------------------------------------------- 1 | #include "yolov8Predictor.h" 2 | 3 | YOLOPredictor::YOLOPredictor(const std::string &modelPath, 4 | const bool &isGPU, 5 | float confThreshold, 6 | float iouThreshold, 7 | float maskThreshold) 8 | { 9 | this->confThreshold = confThreshold; 10 | this->iouThreshold = iouThreshold; 11 | this->maskThreshold = maskThreshold; 12 | env = Ort::Env(OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING, "YOLOV8"); 13 | sessionOptions = Ort::SessionOptions(); 14 | 15 | std::vector availableProviders = Ort::GetAvailableProviders(); 16 | auto cudaAvailable = std::find(availableProviders.begin(), availableProviders.end(), "CUDAExecutionProvider"); 17 | OrtCUDAProviderOptions cudaOption; 18 | 19 | if (isGPU && (cudaAvailable == availableProviders.end())) 20 | { 21 | std::cout << "GPU is not supported by your ONNXRuntime build. Fallback to CPU." << std::endl; 22 | std::cout << "Inference device: CPU" << std::endl; 23 | } 24 | else if (isGPU && (cudaAvailable != availableProviders.end())) 25 | { 26 | std::cout << "Inference device: GPU" << std::endl; 27 | sessionOptions.AppendExecutionProvider_CUDA(cudaOption); 28 | } 29 | else 30 | { 31 | std::cout << "Inference device: CPU" << std::endl; 32 | } 33 | 34 | #ifdef _WIN32 35 | std::wstring w_modelPath = utils::charToWstring(modelPath.c_str()); 36 | session = Ort::Session(env, w_modelPath.c_str(), sessionOptions); 37 | #else 38 | session = Ort::Session(env, modelPath.c_str(), sessionOptions); 39 | #endif 40 | const size_t num_input_nodes = session.GetInputCount(); //==1 41 | const size_t num_output_nodes = session.GetOutputCount(); //==1,2 42 | if (num_output_nodes > 1) 43 | { 44 | this->hasMask = true; 45 | std::cout << "Instance Segmentation" << std::endl; 46 | } 47 | else 48 | std::cout << "Object Detection" << std::endl; 49 | 50 | Ort::AllocatorWithDefaultOptions allocator; 51 | for (int i = 0; i < num_input_nodes; i++) 52 | { 53 | auto input_name = session.GetInputNameAllocated(i, allocator); 54 | this->inputNames.push_back(input_name.get()); 55 | input_names_ptr.push_back(std::move(input_name)); 56 | 57 | Ort::TypeInfo inputTypeInfo = session.GetInputTypeInfo(i); 58 | std::vector inputTensorShape = inputTypeInfo.GetTensorTypeAndShapeInfo().GetShape(); 59 | this->inputShapes.push_back(inputTensorShape); 60 | this->isDynamicInputShape = false; 61 | // checking if width and height are dynamic 62 | if (inputTensorShape[2] == -1 && inputTensorShape[3] == -1) 63 | { 64 | std::cout << "Dynamic input shape" << std::endl; 65 | this->isDynamicInputShape = true; 66 | } 67 | } 68 | for (int i = 0; i < num_output_nodes; i++) 69 | { 70 | auto output_name = session.GetOutputNameAllocated(i, allocator); 71 | this->outputNames.push_back(output_name.get()); 72 | output_names_ptr.push_back(std::move(output_name)); 73 | 74 | Ort::TypeInfo outputTypeInfo = session.GetOutputTypeInfo(i); 75 | std::vector outputTensorShape = outputTypeInfo.GetTensorTypeAndShapeInfo().GetShape(); 76 | this->outputShapes.push_back(outputTensorShape); 77 | if (i == 0) 78 | { 79 | if (!this->hasMask) 80 | classNums = outputTensorShape[1] - 4; 81 | else 82 | classNums = outputTensorShape[1] - 4 - 32; 83 | } 84 | } 85 | // for (const char *x : this->inputNames) 86 | // { 87 | // std::cout << x << std::endl; 88 | // } 89 | // for (const char *x : this->outputNames) 90 | // { 91 | // std::cout << x << std::endl; 92 | // } 93 | // std::cout << classNums << std::endl; 94 | } 95 | 96 | void YOLOPredictor::getBestClassInfo(std::vector::iterator it, 97 | float &bestConf, 98 | int &bestClassId, 99 | const int _classNums) 100 | { 101 | // first 4 element are box 102 | bestClassId = 4; 103 | bestConf = 0; 104 | 105 | for (int i = 4; i < _classNums + 4; i++) 106 | { 107 | if (it[i] > bestConf) 108 | { 109 | bestConf = it[i]; 110 | bestClassId = i - 4; 111 | } 112 | } 113 | } 114 | cv::Mat YOLOPredictor::getMask(const cv::Mat &maskProposals, 115 | const cv::Mat &maskProtos) 116 | { 117 | cv::Mat protos = maskProtos.reshape(0, {(int)this->outputShapes[1][1], (int)this->outputShapes[1][2] * (int)this->outputShapes[1][3]}); 118 | 119 | cv::Mat matmul_res = (maskProposals * protos).t(); 120 | cv::Mat masks = matmul_res.reshape(1, {(int)this->outputShapes[1][2], (int)this->outputShapes[1][3]}); 121 | cv::Mat dest; 122 | 123 | // sigmoid 124 | cv::exp(-masks, dest); 125 | dest = 1.0 / (1.0 + dest); 126 | cv::resize(dest, dest, cv::Size((int)this->inputShapes[0][2], (int)this->inputShapes[0][3]), cv::INTER_LINEAR); 127 | return dest; 128 | } 129 | 130 | void YOLOPredictor::preprocessing(cv::Mat &image, float *&blob, std::vector &inputTensorShape) 131 | { 132 | cv::Mat resizedImage, floatImage; 133 | cv::cvtColor(image, resizedImage, cv::COLOR_BGR2RGB); 134 | utils::letterbox(resizedImage, resizedImage, cv::Size((int)this->inputShapes[0][2], (int)this->inputShapes[0][3]), 135 | cv::Scalar(114, 114, 114), this->isDynamicInputShape, 136 | false, true, 32); 137 | 138 | inputTensorShape[2] = resizedImage.rows; 139 | inputTensorShape[3] = resizedImage.cols; 140 | 141 | resizedImage.convertTo(floatImage, CV_32FC3, 1 / 255.0); 142 | blob = new float[floatImage.cols * floatImage.rows * floatImage.channels()]; 143 | cv::Size floatImageSize{floatImage.cols, floatImage.rows}; 144 | 145 | // hwc -> chw 146 | std::vector chw(floatImage.channels()); 147 | for (int i = 0; i < floatImage.channels(); ++i) 148 | { 149 | chw[i] = cv::Mat(floatImageSize, CV_32FC1, blob + i * floatImageSize.width * floatImageSize.height); 150 | } 151 | cv::split(floatImage, chw); 152 | } 153 | 154 | std::vector YOLOPredictor::postprocessing(const cv::Size &resizedImageShape, 155 | const cv::Size &originalImageShape, 156 | std::vector &outputTensors) 157 | { 158 | 159 | // for box 160 | std::vector boxes; 161 | std::vector confs; 162 | std::vector classIds; 163 | 164 | float *boxOutput = outputTensors[0].GetTensorMutableData(); 165 | //[1,4+n,8400]=>[1,8400,4+n] or [1,4+n+32,8400]=>[1,8400,4+n+32] 166 | cv::Mat output0 = cv::Mat(cv::Size((int)this->outputShapes[0][2], (int)this->outputShapes[0][1]), CV_32F, boxOutput).t(); 167 | float *output0ptr = (float *)output0.data; 168 | int rows = (int)this->outputShapes[0][2]; 169 | int cols = (int)this->outputShapes[0][1]; 170 | // std::cout << rows << cols << std::endl; 171 | // if hasMask 172 | std::vector> picked_proposals; 173 | cv::Mat mask_protos; 174 | 175 | for (int i = 0; i < rows; i++) 176 | { 177 | std::vector it(output0ptr + i * cols, output0ptr + (i + 1) * cols); 178 | float confidence; 179 | int classId; 180 | this->getBestClassInfo(it.begin(), confidence, classId, classNums); 181 | 182 | if (confidence > this->confThreshold) 183 | { 184 | if (this->hasMask) 185 | { 186 | std::vector temp(it.begin() + 4 + classNums, it.end()); 187 | picked_proposals.push_back(temp); 188 | } 189 | int centerX = (int)(it[0]); 190 | int centerY = (int)(it[1]); 191 | int width = (int)(it[2]); 192 | int height = (int)(it[3]); 193 | int left = centerX - width / 2; 194 | int top = centerY - height / 2; 195 | boxes.emplace_back(left, top, width, height); 196 | confs.emplace_back(confidence); 197 | classIds.emplace_back(classId); 198 | } 199 | } 200 | 201 | std::vector indices; 202 | cv::dnn::NMSBoxes(boxes, confs, this->confThreshold, this->iouThreshold, indices); 203 | 204 | if (this->hasMask) 205 | { 206 | float *maskOutput = outputTensors[1].GetTensorMutableData(); 207 | std::vector mask_protos_shape = {1, (int)this->outputShapes[1][1], (int)this->outputShapes[1][2], (int)this->outputShapes[1][3]}; 208 | mask_protos = cv::Mat(mask_protos_shape, CV_32F, maskOutput); 209 | } 210 | 211 | std::vector results; 212 | for (int idx : indices) 213 | { 214 | Yolov8Result res; 215 | res.box = cv::Rect(boxes[idx]); 216 | if (this->hasMask) 217 | res.boxMask = this->getMask(cv::Mat(picked_proposals[idx]).t(), mask_protos); 218 | else 219 | res.boxMask = cv::Mat::zeros((int)this->inputShapes[0][2], (int)this->inputShapes[0][3], CV_8U); 220 | 221 | utils::scaleCoords(res.box, res.boxMask, this->maskThreshold, resizedImageShape, originalImageShape); 222 | res.conf = confs[idx]; 223 | res.classId = classIds[idx]; 224 | results.emplace_back(res); 225 | } 226 | 227 | return results; 228 | } 229 | 230 | std::vector YOLOPredictor::predict(cv::Mat &image) 231 | { 232 | float *blob = nullptr; 233 | std::vector inputTensorShape{1, 3, -1, -1}; 234 | this->preprocessing(image, blob, inputTensorShape); 235 | 236 | size_t inputTensorSize = utils::vectorProduct(inputTensorShape); 237 | 238 | std::vector inputTensorValues(blob, blob + inputTensorSize); 239 | 240 | std::vector inputTensors; 241 | 242 | Ort::MemoryInfo memoryInfo = Ort::MemoryInfo::CreateCpu( 243 | OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault); 244 | 245 | inputTensors.push_back(Ort::Value::CreateTensor( 246 | memoryInfo, inputTensorValues.data(), inputTensorSize, 247 | inputTensorShape.data(), inputTensorShape.size())); 248 | 249 | std::vector outputTensors = this->session.Run(Ort::RunOptions{nullptr}, 250 | this->inputNames.data(), 251 | inputTensors.data(), 252 | 1, 253 | this->outputNames.data(), 254 | this->outputNames.size()); 255 | 256 | cv::Size resizedShape = cv::Size((int)inputTensorShape[3], (int)inputTensorShape[2]); 257 | std::vector result = this->postprocessing(resizedShape, 258 | image.size(), 259 | outputTensors); 260 | 261 | delete[] blob; 262 | 263 | return result; 264 | } 265 | -------------------------------------------------------------------------------- /include/cmdline.h: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright (c) 2009, Hideyuki Tanaka 3 | All rights reserved. 4 | 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions are met: 7 | * Redistributions of source code must retain the above copyright 8 | notice, this list of conditions and the following disclaimer. 9 | * Redistributions in binary form must reproduce the above copyright 10 | notice, this list of conditions and the following disclaimer in the 11 | documentation and/or other materials provided with the distribution. 12 | * Neither the name of the nor the 13 | names of its contributors may be used to endorse or promote products 14 | derived from this software without specific prior written permission. 15 | 16 | THIS SOFTWARE IS PROVIDED BY ''AS IS'' AND ANY 17 | EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 18 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 19 | DISCLAIMED. IN NO EVENT SHALL BE LIABLE FOR ANY 20 | DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 21 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 22 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 23 | ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 24 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 25 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | */ 27 | 28 | #pragma once 29 | 30 | //#define USE_DEMANGLING 31 | 32 | #include 33 | #include 34 | #include 35 | #include 36 | #include 37 | #include 38 | #include 39 | #include 40 | #include 41 | #ifdef USE_DEMANGLING 42 | #include 43 | #endif 44 | #include 45 | 46 | namespace cmdline{ 47 | 48 | namespace detail{ 49 | 50 | template 51 | class lexical_cast_t{ 52 | public: 53 | static Target cast(const Source &arg){ 54 | Target ret; 55 | std::stringstream ss; 56 | if (!(ss<>ret && ss.eof())) 57 | throw std::bad_cast(); 58 | 59 | return ret; 60 | } 61 | }; 62 | 63 | template 64 | class lexical_cast_t{ 65 | public: 66 | static Target cast(const Source &arg){ 67 | return arg; 68 | } 69 | }; 70 | 71 | template 72 | class lexical_cast_t{ 73 | public: 74 | static std::string cast(const Source &arg){ 75 | std::ostringstream ss; 76 | ss< 82 | class lexical_cast_t{ 83 | public: 84 | static Target cast(const std::string &arg){ 85 | Target ret; 86 | std::istringstream ss(arg); 87 | if (!(ss>>ret && ss.eof())) 88 | throw std::bad_cast(); 89 | return ret; 90 | } 91 | }; 92 | 93 | template 94 | struct is_same { 95 | static const bool value = false; 96 | }; 97 | 98 | template 99 | struct is_same{ 100 | static const bool value = true; 101 | }; 102 | 103 | template 104 | Target lexical_cast(const Source &arg) 105 | { 106 | return lexical_cast_t::value>::cast(arg); 107 | } 108 | 109 | static inline std::string demangle(const std::string &name) 110 | { 111 | #ifdef USE_DEMANGLING 112 | int status=0; 113 | const char *p=::__cxa_demangle(name.c_str(), 0, 0, &status); 114 | std::string ret(p); 115 | free(p); 116 | return ret; 117 | #else 118 | return name; 119 | #endif 120 | } 121 | 122 | template 123 | std::string readable_typename() 124 | { 125 | return demangle(typeid(T).name()); 126 | } 127 | 128 | template 129 | std::string default_value(T def) 130 | { 131 | return detail::lexical_cast(def); 132 | } 133 | 134 | template <> 135 | inline std::string readable_typename() 136 | { 137 | return "string"; 138 | } 139 | 140 | } // detail 141 | 142 | //----- 143 | 144 | class cmdline_error : public std::exception { 145 | public: 146 | cmdline_error(const std::string &msg): msg(msg){} 147 | ~cmdline_error() throw() {} 148 | const char *what() const throw() { return msg.c_str(); } 149 | private: 150 | std::string msg; 151 | }; 152 | 153 | template 154 | struct default_reader{ 155 | T operator()(const std::string &str){ 156 | return detail::lexical_cast(str); 157 | } 158 | }; 159 | 160 | template 161 | struct range_reader{ 162 | range_reader(const T &low, const T &high): low(low), high(high) {} 163 | T operator()(const std::string &s) const { 164 | T ret=default_reader()(s); 165 | if (!(ret>=low && ret<=high)) throw cmdline::cmdline_error("range_error"); 166 | return ret; 167 | } 168 | private: 169 | T low, high; 170 | }; 171 | 172 | template 173 | range_reader range(const T &low, const T &high) 174 | { 175 | return range_reader(low, high); 176 | } 177 | 178 | template 179 | struct oneof_reader{ 180 | T operator()(const std::string &s){ 181 | T ret=default_reader()(s); 182 | if (std::find(alt.begin(), alt.end(), ret)==alt.end()) 183 | throw cmdline_error(""); 184 | return ret; 185 | } 186 | void add(const T &v){ alt.push_back(v); } 187 | private: 188 | std::vector alt; 189 | }; 190 | 191 | template 192 | oneof_reader oneof(T a1) 193 | { 194 | oneof_reader ret; 195 | ret.add(a1); 196 | return ret; 197 | } 198 | 199 | template 200 | oneof_reader oneof(T a1, T a2) 201 | { 202 | oneof_reader ret; 203 | ret.add(a1); 204 | ret.add(a2); 205 | return ret; 206 | } 207 | 208 | template 209 | oneof_reader oneof(T a1, T a2, T a3) 210 | { 211 | oneof_reader ret; 212 | ret.add(a1); 213 | ret.add(a2); 214 | ret.add(a3); 215 | return ret; 216 | } 217 | 218 | template 219 | oneof_reader oneof(T a1, T a2, T a3, T a4) 220 | { 221 | oneof_reader ret; 222 | ret.add(a1); 223 | ret.add(a2); 224 | ret.add(a3); 225 | ret.add(a4); 226 | return ret; 227 | } 228 | 229 | template 230 | oneof_reader oneof(T a1, T a2, T a3, T a4, T a5) 231 | { 232 | oneof_reader ret; 233 | ret.add(a1); 234 | ret.add(a2); 235 | ret.add(a3); 236 | ret.add(a4); 237 | ret.add(a5); 238 | return ret; 239 | } 240 | 241 | template 242 | oneof_reader oneof(T a1, T a2, T a3, T a4, T a5, T a6) 243 | { 244 | oneof_reader ret; 245 | ret.add(a1); 246 | ret.add(a2); 247 | ret.add(a3); 248 | ret.add(a4); 249 | ret.add(a5); 250 | ret.add(a6); 251 | return ret; 252 | } 253 | 254 | template 255 | oneof_reader oneof(T a1, T a2, T a3, T a4, T a5, T a6, T a7) 256 | { 257 | oneof_reader ret; 258 | ret.add(a1); 259 | ret.add(a2); 260 | ret.add(a3); 261 | ret.add(a4); 262 | ret.add(a5); 263 | ret.add(a6); 264 | ret.add(a7); 265 | return ret; 266 | } 267 | 268 | template 269 | oneof_reader oneof(T a1, T a2, T a3, T a4, T a5, T a6, T a7, T a8) 270 | { 271 | oneof_reader ret; 272 | ret.add(a1); 273 | ret.add(a2); 274 | ret.add(a3); 275 | ret.add(a4); 276 | ret.add(a5); 277 | ret.add(a6); 278 | ret.add(a7); 279 | ret.add(a8); 280 | return ret; 281 | } 282 | 283 | template 284 | oneof_reader oneof(T a1, T a2, T a3, T a4, T a5, T a6, T a7, T a8, T a9) 285 | { 286 | oneof_reader ret; 287 | ret.add(a1); 288 | ret.add(a2); 289 | ret.add(a3); 290 | ret.add(a4); 291 | ret.add(a5); 292 | ret.add(a6); 293 | ret.add(a7); 294 | ret.add(a8); 295 | ret.add(a9); 296 | return ret; 297 | } 298 | 299 | template 300 | oneof_reader oneof(T a1, T a2, T a3, T a4, T a5, T a6, T a7, T a8, T a9, T a10) 301 | { 302 | oneof_reader ret; 303 | ret.add(a1); 304 | ret.add(a2); 305 | ret.add(a3); 306 | ret.add(a4); 307 | ret.add(a5); 308 | ret.add(a6); 309 | ret.add(a7); 310 | ret.add(a8); 311 | ret.add(a9); 312 | ret.add(a10); 313 | return ret; 314 | } 315 | 316 | //----- 317 | 318 | class parser{ 319 | public: 320 | parser(){ 321 | } 322 | ~parser(){ 323 | for (std::map::iterator p=options.begin(); 324 | p!=options.end(); p++) 325 | delete p->second; 326 | } 327 | 328 | void add(const std::string &name, 329 | char short_name=0, 330 | const std::string &desc=""){ 331 | if (options.count(name)) throw cmdline_error("multiple definition: "+name); 332 | options[name]=new option_without_value(name, short_name, desc); 333 | ordered.push_back(options[name]); 334 | } 335 | 336 | template 337 | void add(const std::string &name, 338 | char short_name=0, 339 | const std::string &desc="", 340 | bool need=true, 341 | const T def=T()){ 342 | add(name, short_name, desc, need, def, default_reader()); 343 | } 344 | 345 | template 346 | void add(const std::string &name, 347 | char short_name=0, 348 | const std::string &desc="", 349 | bool need=true, 350 | const T def=T(), 351 | F reader=F()){ 352 | if (options.count(name)) throw cmdline_error("multiple definition: "+name); 353 | options[name]=new option_with_value_with_reader(name, short_name, need, def, desc, reader); 354 | ordered.push_back(options[name]); 355 | } 356 | 357 | void footer(const std::string &f){ 358 | ftr=f; 359 | } 360 | 361 | void set_program_name(const std::string &name){ 362 | prog_name=name; 363 | } 364 | 365 | bool exist(const std::string &name) const { 366 | if (options.count(name)==0) throw cmdline_error("there is no flag: --"+name); 367 | return options.find(name)->second->has_set(); 368 | } 369 | 370 | template 371 | const T &get(const std::string &name) const { 372 | if (options.count(name)==0) throw cmdline_error("there is no flag: --"+name); 373 | const option_with_value *p=dynamic_cast*>(options.find(name)->second); 374 | if (p==NULL) throw cmdline_error("type mismatch flag '"+name+"'"); 375 | return p->get(); 376 | } 377 | 378 | const std::vector &rest() const { 379 | return others; 380 | } 381 | 382 | bool parse(const std::string &arg){ 383 | std::vector args; 384 | 385 | std::string buf; 386 | bool in_quote=false; 387 | for (std::string::size_type i=0; i=arg.length()){ 402 | errors.push_back("unexpected occurrence of '\\' at end of string"); 403 | return false; 404 | } 405 | } 406 | 407 | buf+=arg[i]; 408 | } 409 | 410 | if (in_quote){ 411 | errors.push_back("quote is not closed"); 412 | return false; 413 | } 414 | 415 | if (buf.length()>0) 416 | args.push_back(buf); 417 | 418 | //for (size_t i=0; i &args){ 425 | int argc=static_cast(args.size()); 426 | std::vector argv(argc); 427 | 428 | for (int i=0; i lookup; 446 | for (std::map::iterator p=options.begin(); 447 | p!=options.end(); p++){ 448 | if (p->first.length()==0) continue; 449 | char initial=p->second->short_name(); 450 | if (initial){ 451 | if (lookup.count(initial)>0){ 452 | lookup[initial]=""; 453 | errors.push_back(std::string("short option '")+initial+"' is ambiguous"); 454 | return false; 455 | } 456 | else lookup[initial]=p->first; 457 | } 458 | } 459 | 460 | for (int i=1; i &args){ 542 | if (!options.count("help")) 543 | add("help", '?', "print this message"); 544 | check(args.size(), parse(args)); 545 | } 546 | 547 | void parse_check(int argc, char *argv[]){ 548 | if (!options.count("help")) 549 | add("help", '?', "print this message"); 550 | check(argc, parse(argc, argv)); 551 | } 552 | 553 | std::string error() const{ 554 | return errors.size()>0?errors[0]:""; 555 | } 556 | 557 | std::string error_full() const{ 558 | std::ostringstream oss; 559 | for (size_t i=0; imust()) 569 | oss<short_description()<<" "; 570 | } 571 | 572 | oss<<"[options] ... "<name().length()); 578 | } 579 | for (size_t i=0; ishort_name()){ 581 | oss<<" -"<short_name()<<", "; 582 | } 583 | else{ 584 | oss<<" "; 585 | } 586 | 587 | oss<<"--"<name(); 588 | for (size_t j=ordered[i]->name().length(); jdescription()<set()){ 615 | errors.push_back("option needs value: --"+name); 616 | return; 617 | } 618 | } 619 | 620 | void set_option(const std::string &name, const std::string &value){ 621 | if (options.count(name)==0){ 622 | errors.push_back("undefined option: --"+name); 623 | return; 624 | } 625 | if (!options[name]->set(value)){ 626 | errors.push_back("option value is invalid: --"+name+"="+value); 627 | return; 628 | } 629 | } 630 | 631 | class option_base{ 632 | public: 633 | virtual ~option_base(){} 634 | 635 | virtual bool has_value() const=0; 636 | virtual bool set()=0; 637 | virtual bool set(const std::string &value)=0; 638 | virtual bool has_set() const=0; 639 | virtual bool valid() const=0; 640 | virtual bool must() const=0; 641 | 642 | virtual const std::string &name() const=0; 643 | virtual char short_name() const=0; 644 | virtual const std::string &description() const=0; 645 | virtual std::string short_description() const=0; 646 | }; 647 | 648 | class option_without_value : public option_base { 649 | public: 650 | option_without_value(const std::string &name, 651 | char short_name, 652 | const std::string &desc) 653 | :nam(name), snam(short_name), desc(desc), has(false){ 654 | } 655 | ~option_without_value(){} 656 | 657 | bool has_value() const { return false; } 658 | 659 | bool set(){ 660 | has=true; 661 | return true; 662 | } 663 | 664 | bool set(const std::string &){ 665 | return false; 666 | } 667 | 668 | bool has_set() const { 669 | return has; 670 | } 671 | 672 | bool valid() const{ 673 | return true; 674 | } 675 | 676 | bool must() const{ 677 | return false; 678 | } 679 | 680 | const std::string &name() const{ 681 | return nam; 682 | } 683 | 684 | char short_name() const{ 685 | return snam; 686 | } 687 | 688 | const std::string &description() const { 689 | return desc; 690 | } 691 | 692 | std::string short_description() const{ 693 | return "--"+nam; 694 | } 695 | 696 | private: 697 | std::string nam; 698 | char snam; 699 | std::string desc; 700 | bool has; 701 | }; 702 | 703 | template 704 | class option_with_value : public option_base { 705 | public: 706 | option_with_value(const std::string &name, 707 | char short_name, 708 | bool need, 709 | const T &def, 710 | const std::string &desc) 711 | : nam(name), snam(short_name), need(need), has(false) 712 | , def(def), actual(def) { 713 | this->desc=full_description(desc); 714 | } 715 | ~option_with_value(){} 716 | 717 | const T &get() const { 718 | return actual; 719 | } 720 | 721 | bool has_value() const { return true; } 722 | 723 | bool set(){ 724 | return false; 725 | } 726 | 727 | bool set(const std::string &value){ 728 | try{ 729 | actual=read(value); 730 | has=true; 731 | } 732 | catch(const std::exception &e){ 733 | return false; 734 | } 735 | return true; 736 | } 737 | 738 | bool has_set() const{ 739 | return has; 740 | } 741 | 742 | bool valid() const{ 743 | if (need && !has) return false; 744 | return true; 745 | } 746 | 747 | bool must() const{ 748 | return need; 749 | } 750 | 751 | const std::string &name() const{ 752 | return nam; 753 | } 754 | 755 | char short_name() const{ 756 | return snam; 757 | } 758 | 759 | const std::string &description() const { 760 | return desc; 761 | } 762 | 763 | std::string short_description() const{ 764 | return "--"+nam+"="+detail::readable_typename(); 765 | } 766 | 767 | protected: 768 | std::string full_description(const std::string &desc){ 769 | return 770 | desc+" ("+detail::readable_typename()+ 771 | (need?"":" [="+detail::default_value(def)+"]") 772 | +")"; 773 | } 774 | 775 | virtual T read(const std::string &s)=0; 776 | 777 | std::string nam; 778 | char snam; 779 | bool need; 780 | std::string desc; 781 | 782 | bool has; 783 | T def; 784 | T actual; 785 | }; 786 | 787 | template 788 | class option_with_value_with_reader : public option_with_value { 789 | public: 790 | option_with_value_with_reader(const std::string &name, 791 | char short_name, 792 | bool need, 793 | const T def, 794 | const std::string &desc, 795 | F reader) 796 | : option_with_value(name, short_name, need, def, desc), reader(reader){ 797 | } 798 | 799 | private: 800 | T read(const std::string &s){ 801 | return reader(s); 802 | } 803 | 804 | F reader; 805 | }; 806 | 807 | std::map options; 808 | std::vector ordered; 809 | std::string ftr; 810 | 811 | std::string prog_name; 812 | std::vector others; 813 | 814 | std::vector errors; 815 | }; 816 | 817 | } // cmdline 818 | --------------------------------------------------------------------------------