├── README.md ├── Ubuntu ├── README.md ├── new_version │ └── erfnet-onnx-vein │ │ ├── build-erfnet-onnx-vein-JetsonTX2-Debug │ │ ├── Makefile │ │ ├── erfnet-onnx-vein │ │ ├── main.o │ │ └── onnx_test │ │ ├── erfnet-onnx-vein │ │ ├── erfnet-onnx-vein.pro │ │ ├── erfnet-onnx-vein.pro.user │ │ └── main.cpp │ │ └── libonnxruntime.so ├── onnx_erfnet │ ├── main.cpp │ ├── onnx_test_seg.pro │ └── onnx_test_seg.pro.user └── onnx_mobilenet │ ├── main.cpp │ ├── onnx_test.pro │ └── onnx_test.pro.user ├── Windows ├── README.md ├── new_version │ └── onnx_ssd │ │ ├── ort_test.sln │ │ └── ort_test │ │ ├── ReadMe.txt │ │ ├── onnxruntime.lib │ │ ├── ort_test.cpp │ │ ├── ort_test.vcxproj │ │ ├── ort_test.vcxproj.filters │ │ ├── ort_test.vcxproj.user │ │ ├── stdafx.cpp │ │ ├── stdafx.h │ │ └── targetver.h ├── onnx_crnn │ └── ort_crnn.cpp ├── onnx_erfnet │ ├── ort_test.sln │ └── ort_test │ │ ├── ReadMe.txt │ │ ├── onnxruntime.lib │ │ ├── ort_test.cpp │ │ ├── ort_test.vcxproj │ │ ├── ort_test.vcxproj.filters │ │ ├── ort_test.vcxproj.user │ │ ├── stdafx.cpp │ │ ├── stdafx.h │ │ └── targetver.h ├── onnx_mobilenet │ ├── ort_test.sln │ └── ort_test │ │ ├── ReadMe.txt │ │ ├── onnxruntime.lib │ │ ├── ort_test.cpp │ │ ├── ort_test.vcxproj │ │ ├── ort_test.vcxproj.filters │ │ ├── ort_test.vcxproj.user │ │ ├── stdafx.cpp │ │ ├── stdafx.h │ │ └── targetver.h ├── onnx_style_transfer │ ├── ort_test.sln │ └── ort_test │ │ ├── ReadMe.txt │ │ ├── onnxruntime.lib │ │ ├── ort_test.cpp │ │ ├── ort_test.vcxproj │ │ ├── ort_test.vcxproj.filters │ │ ├── ort_test.vcxproj.user │ │ ├── stdafx.cpp │ │ ├── stdafx.h │ │ └── targetver.h ├── onnx_super_resolustion │ ├── ort_test.sln │ └── ort_test │ │ ├── ReadMe.txt │ │ ├── onnxruntime.lib │ │ ├── ort_test.cpp │ │ ├── ort_test.vcxproj │ │ ├── ort_test.vcxproj.filters │ │ ├── ort_test.vcxproj.user │ │ ├── stdafx.cpp │ │ ├── stdafx.h │ │ └── targetver.h └── onnx_tiny_yolov2 │ ├── ort_test.sln │ └── ort_test │ ├── ReadMe.txt │ ├── onnxruntime.lib │ ├── ort_test.cpp │ ├── ort_test.vcxproj │ ├── ort_test.vcxproj.filters │ ├── ort_test.vcxproj.user │ ├── stdafx.cpp │ ├── stdafx.h │ └── targetver.h ├── include └── onnxruntime │ └── core │ ├── common │ ├── code_location.h │ ├── common.h │ ├── const_pointer_container.h │ ├── exceptions.h │ ├── logging │ │ ├── capture.h │ │ ├── isink.h │ │ ├── logging.h │ │ ├── macros.h │ │ └── severity.h │ ├── ml_status.h │ └── status.h │ ├── framework │ ├── alloc_kind.h │ ├── allocator.h │ ├── customregistry.h │ ├── data_types.h │ ├── execution_provider.h │ ├── fence.h │ ├── framework_common.h │ ├── func_api.h │ ├── kernel_def_builder.h │ ├── kernel_registry.h │ ├── ml_value.h │ ├── op_kernel.h │ ├── op_kernel_info.h │ ├── op_node_proto_helper.h │ ├── run_options.h │ ├── sparse_tensor.h │ ├── tensor.h │ └── tensor_shape.h │ ├── graph │ ├── basic_types.h │ ├── constants.h │ ├── function.h │ ├── graph.h │ ├── graph_nodes.h │ ├── graph_viewer.h │ ├── indexed_sub_graph.h │ ├── node_arg.h │ ├── onnx_protobuf.h │ └── schema_registry.h │ ├── optimizer │ ├── graph_transformer.h │ ├── graph_transformer_level.h │ ├── graph_transformer_utils.h │ ├── rewrite_rule.h │ └── rule_based_graph_transformer.h │ ├── platform │ ├── ort_mutex.h │ └── threadpool.h │ ├── providers │ ├── cpu │ │ └── cpu_provider_factory.h │ ├── cuda │ │ └── cuda_provider_factory.h │ ├── mkldnn │ │ └── mkldnn_provider_factory.h │ ├── ngraph │ │ └── ngraph_provider_factory.h │ ├── nnapi │ │ └── nnapi_provider_factory.h │ ├── nuphar │ │ └── nuphar_provider_factory.h │ ├── openvino │ │ └── openvino_provider_factory.h │ ├── providers.h │ └── tensorrt │ │ └── tensorrt_provider_factory.h │ └── session │ ├── automl_data_containers.h │ ├── environment.h │ ├── onnxruntime_c_api.h │ ├── onnxruntime_cxx_api.h │ └── onnxruntime_cxx_inline.h ├── test_imgs ├── classification │ ├── cls_001.jpg │ └── cls_002.jpg ├── detection │ ├── 000001.jpg │ ├── 000002.jpg │ ├── 000003.jpg │ ├── 000004.jpg │ ├── 000005.jpg │ ├── 000006.jpg │ ├── 000007.jpg │ ├── 000008.jpg │ ├── 000009.jpg │ ├── 000010.jpg │ ├── 000011.jpg │ ├── 000012.jpg │ ├── 000013.jpg │ ├── 000014.jpg │ ├── 000015.jpg │ ├── 000016.jpg │ ├── 000017.jpg │ └── 000018.jpg ├── segmentation │ ├── 00000.png │ ├── 00001.png │ ├── 00002.png │ ├── 00003.png │ ├── 00004.png │ ├── 00005.png │ ├── 00006.png │ ├── 00007.png │ ├── 00008.png │ ├── 00009.png │ └── 00010.png ├── style_transfer │ └── church.jpg └── super_resolution │ ├── LowResolution.png │ ├── RizeResolution.png │ ├── SuperResolution.png │ └── rawimg.jpg └── test_models ├── candy.onnx ├── erfnet.onnx ├── mobilenetv2-1.0.onnx ├── mosaic.onnx ├── pointilism.onnx ├── rain_princess.onnx ├── super_resolution.onnx ├── tiny_yolov2.onnx └── udnie.onnx /README.md: -------------------------------------------------------------------------------- 1 | # onnxruntime projects 2 | ## Introduction 3 | This repository include codes for some onnxruntime projects,such as classification, segmentation, detection, style transfer and super resolution. 4 | ## Onnxruntime 5 | ONNX Runtime is a performance-focused complete scoring engine for Open Neural Network Exchange (ONNX) models, with an open extensible architecture to continually address the latest developments in AI and Deep Learning. 6 | In my repository,onnxruntime.dll have been compiled. You can download it and see specific information about onnxruntime in https://github.com/microsoft/onnxruntime. 7 | 8 | ## Projects 9 | The programming language is C++ and The platform is Visual Studio. I have finished some projects based on onnxruntime official samples. The link have been mentioned afore. Also, you can download some onnx models in https://github.com/onnx/models. If necessary,you can see the structure onnx models in https://lutzroeder.github.io/netron/. 10 | 11 | ##### Windows 12 | 13 | | Network | Classes | Input resolution | Batch size | Iterations | CPU Running time | GPU Running time | TRT Running time* | 14 | | :---------------------------------: | :-----: | :--------------: | :--------: | :--------: | :--------------: | :--------------: | :---------------: | 15 | | MobileNet | 1000 | 224x224 | 1 | 1000 | 19.56s | 4.15s | 1.05s | 16 | | ERFNet | 4 | 640x480 | 1 | 1000 | >100s | 12.93s | 5.6s | 17 | | Tiny_YOLOv2 | 20 | 416x416 | 1 | 1000 | 40.64s | 2.97s | 1.92s | 18 | | Super Resolution with sub-pixel CNN | - | 224x224 | 1 | 1000 | 34.14s | 1.79s | 1.14s | 19 | | Fast Neural Style Transfer | - | 224x224 | 1 | 1000 | 87.99s | 4.64s | - | 20 | 21 | ##### Ubuntu 22 | 23 | | Network | Classes | Input resolution | Batch size | Iterations | CPU Running time | GPU Running time | TRT Running time* | 24 | | :-------: | :-----: | :--------------: | :--------: | :--------: | :--------------: | :--------------: | :---------------: | 25 | | MobileNet | 1000 | 224x224 | 1 | 1000 | 20.09s | 4.24s | 0.79s | 26 | | ERFNet | 4 | 640x480 | 1 | 1000 | >100s | 13.56s | 4.90s | 27 | 28 | *The TensorRT engine is compiled with FP16 settings. Just add "trt_builder->setFp16Mode(true);" to 339 line of tensorrt_execution_provider.cc, if you build libonnxruntime yourself. 29 | 30 | **This experiment is implemented on NVIDIA 2080Ti. 31 | 32 | ### Classification 33 | --- 34 | The onnx model is moblienet. You can download it in the link mentioned afore. 35 | ### Segmentation 36 | --- 37 | The onnx model is our trained erfnet. We use specific datasets to train erfnet. 38 | ### Detection 39 | --- 40 | The onnx model is Tiny YOLOv2.You can download it in the link mentioned afore. 41 | ### Style transfer 42 | --- 43 | The onnx model is Fast Neural Style Transfer. You can download it in the link mentioned afore. 44 | 45 | ### Super resolution 46 | --- 47 | The onnx model is Super Resolution with sub-pixel CNN. You can download it in the link mentioned afore. -------------------------------------------------------------------------------- /Ubuntu/README.md: -------------------------------------------------------------------------------- 1 | ## Start 2 | 3 | 1, Download the third-party dependencies from [BaiduYun](https://pan.baidu.com/s/1cIOIHaF058pFwhve-QO2gw), and extract them to the workplace. 4 | 5 | 2, Open the pro file with Qtcreator, modify the configure file and run the project. 6 | 7 | -------------------------------------------------------------------------------- /Ubuntu/new_version/erfnet-onnx-vein/build-erfnet-onnx-vein-JetsonTX2-Debug/erfnet-onnx-vein: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tenglike1997/onnxruntime-projects/7e99a75f1ce36329c4dbf2863023e7e05c73a991/Ubuntu/new_version/erfnet-onnx-vein/build-erfnet-onnx-vein-JetsonTX2-Debug/erfnet-onnx-vein -------------------------------------------------------------------------------- /Ubuntu/new_version/erfnet-onnx-vein/build-erfnet-onnx-vein-JetsonTX2-Debug/main.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tenglike1997/onnxruntime-projects/7e99a75f1ce36329c4dbf2863023e7e05c73a991/Ubuntu/new_version/erfnet-onnx-vein/build-erfnet-onnx-vein-JetsonTX2-Debug/main.o -------------------------------------------------------------------------------- /Ubuntu/new_version/erfnet-onnx-vein/build-erfnet-onnx-vein-JetsonTX2-Debug/onnx_test: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tenglike1997/onnxruntime-projects/7e99a75f1ce36329c4dbf2863023e7e05c73a991/Ubuntu/new_version/erfnet-onnx-vein/build-erfnet-onnx-vein-JetsonTX2-Debug/onnx_test -------------------------------------------------------------------------------- /Ubuntu/new_version/erfnet-onnx-vein/erfnet-onnx-vein/erfnet-onnx-vein.pro: -------------------------------------------------------------------------------- 1 | QT -= gui 2 | 3 | CONFIG += c++11 console 4 | CONFIG -= app_bundle 5 | 6 | # The following define makes your compiler emit warnings if you use 7 | # any feature of Qt which as been marked deprecated (the exact warnings 8 | # depend on your compiler). Please consult the documentation of the 9 | # deprecated API in order to know how to port your code away from it. 10 | DEFINES += QT_DEPRECATED_WARNINGS 11 | TEMPLATE = app 12 | # You can also make your code fail to compile if you use deprecated APIs. 13 | # In order to do so, uncomment the following line. 14 | # You can also select to disable deprecated APIs only up to a certain version of Qt. 15 | #DEFINES += QT_DISABLE_DEPRECATED_BEFORE=0x060000 # disables all the APIs deprecated before Qt 6.0.0 16 | 17 | SOURCES += main.cpp 18 | 19 | INCLUDEPATH +=/usr/include/opencv\ 20 | /usr/include/opencv2 \ 21 | #/home/tzj/onnxruntime/include/onnxruntime \ 22 | /home/tzj/onnxruntime/include/onnxruntime/core/session \ 23 | /home/tzj/onnxruntime/include/onnxruntime/core/providers/cuda 24 | 25 | 26 | LIBS += /usr/lib/libopencv_highgui.so \ 27 | /usr/lib/libopencv_core.so \ 28 | /usr/lib/libopencv_imgproc.so \ 29 | /usr/lib/libopencv_videoio.so \ 30 | /home/tzj/onnxruntime/build/Linux/RelWithDebInfo/libonnxruntime.so.1.2.0 \ 31 | /home/tzj/onnxruntime/build/Linux/RelWithDebInfo/libonnxruntime.so \ 32 | 33 | 34 | 35 | 36 | QMAKE_CXXFLAGS += -std=c++11 -g 37 | LIBS += -L/usr/local/lib -lopencv_core -lopencv_imgcodecs -lopencv_highgui 38 | 39 | 40 | 41 | 42 | 43 | -------------------------------------------------------------------------------- /Ubuntu/new_version/erfnet-onnx-vein/erfnet-onnx-vein/main.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | #include 11 | #include 12 | #include 13 | //#include 14 | 15 | using namespace cv; 16 | using namespace std; 17 | 18 | 19 | //OrtApi *g; 20 | Ort::Env env{ ORT_LOGGING_LEVEL_WARNING, "test" }; 21 | 22 | static constexpr const int width_ = 256; 23 | static constexpr const int height_ = 256; 24 | static constexpr const int channel = 3; 25 | 26 | std::array input_image_{}; 27 | std::array results_{}; 28 | std::array results_extra{}; 29 | int result_[height_ * width_]{ 0}; 30 | 31 | Ort::Value input_tensor_{ nullptr }; 32 | std::array input_shape_{ 1,channel, height_, width_ }; 33 | 34 | Ort::Value output_tensor_{ nullptr }; 35 | std::array output_shape_{ 1,2,height_, width_ }; 36 | 37 | 38 | int main() 39 | { 40 | auto allocator_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); 41 | input_tensor_ = Ort::Value::CreateTensor(allocator_info, input_image_.data(), input_image_.size(), input_shape_.data(), input_shape_.size()); 42 | output_tensor_ = Ort::Value::CreateTensor(allocator_info, results_.data(), results_.size(), output_shape_.data(), output_shape_.size()); 43 | 44 | const char* input_names[] = { "input" }; 45 | const char* output_names[] = { "output" }; 46 | 47 | Ort::SessionOptions session_option; 48 | //ORT_THROW_ON_ERROR(OrtSessionOptionsAppendExecutionProvider_Mkldnn(session_option, 1)); 49 | //ORT_THROW_ON_ERROR(OrtSessionOptionsAppendExecutionProvider_Tensorrt(session_option, 0)); 50 | 51 | Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CUDA(session_option, 0)); 52 | const ORTCHAR_T* model_path = "/home/tzj/vein-model-pic/ERFNet-vein.onnx"; 53 | Ort::Session session_(env,model_path,session_option); 54 | 55 | const int row = height_; 56 | const int col = width_; 57 | 58 | Mat img = imread("/home/tzj/vein-model-pic/1.bmp"); 59 | Mat dst(row, col, CV_8UC3); 60 | Mat dst2; 61 | resize(img, dst, Size(col, row)); 62 | cvtColor(dst, dst, CV_BGR2RGB); 63 | //resize(dst2, dst, Size(col, row)); 64 | 65 | float* output = input_image_.data(); 66 | fill(input_image_.begin(), input_image_.end(), 0.f); 67 | Scalar rgb_mean = mean(dst); 68 | for (int c = 0; c < 3; c++) { 69 | for (int i = 0; i < row; i++) { 70 | for (int j = 0; j < col; j++) { 71 | output[c*row*col + i*col + j] = (dst.ptr(i)[j*3+c])/255.0 ; 72 | } 73 | } 74 | } 75 | 76 | double timeStart = (double)getTickCount(); 77 | for (int i = 0; i < 1; i++) { 78 | session_.Run(nullptr, input_names, &input_tensor_, 1, output_names, &output_tensor_,1); 79 | } 80 | double nTime = ((double)getTickCount() - timeStart) / getTickFrequency(); 81 | cout << "running time :" << nTime << "sec\n" << endl; 82 | 83 | 84 | for (int i = 0; i < height_*width_; i++) { 85 | results_extra[0] = results_[i]; 86 | results_extra[1] = results_[i + height_ * width_]; 87 | //result_[i] = std::distance(results_extra.begin(), std::max_element(results_extra.begin(), results_extra.end())); 88 | //result_[i] = results_extra[0] < results_extra[1]; 89 | result_[i] = results_extra[1] > results_extra[0]; 90 | } 91 | int* result = result_; 92 | 93 | Mat outputimage(height_, width_, CV_8UC1, Scalar(0)); 94 | for (int i = 0; i < height_; i++) { 95 | for (int j = 0; j < width_; j++) { 96 | if (result[i * width_ + j] == 0) { 97 | outputimage.ptr(i)[j] = 0; 98 | 99 | } 100 | if (result[i * width_ + j] == 1) { 101 | outputimage.ptr(i)[j] = 255; 102 | } 103 | } 104 | } 105 | 106 | //Mat outputimage(height_, width_, CV_8UC1, Scalar(0)); 107 | //for (int i = 0; i < height_; i++) { 108 | // for (int j = 0; j < width_; j++) { 109 | // outputimage.ptr(i)[j] = results_[width_*height_+i*width_ + j] * 255.0; 110 | // } 111 | //} 112 | 113 | resize(outputimage, outputimage, Size(row, col)); 114 | imshow("test",outputimage); 115 | waitKey(0); 116 | system("pause"); 117 | return 0; 118 | } 119 | 120 | -------------------------------------------------------------------------------- /Ubuntu/new_version/erfnet-onnx-vein/libonnxruntime.so: -------------------------------------------------------------------------------- 1 | libonnxruntime.so.1.3.0 -------------------------------------------------------------------------------- /Ubuntu/onnx_erfnet/main.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | 15 | using namespace cv; 16 | using namespace std; 17 | 18 | Ort::Env env{ ORT_LOGGING_LEVEL_WARNING, "test" }; 19 | 20 | static constexpr const int width_ = 640; 21 | static constexpr const int height_ = 480; 22 | static constexpr const int channel = 3; 23 | 24 | std::array input_image_{}; 25 | std::array results_{}; 26 | std::array results_extra{}; 27 | int result_[4*height_ * width_]{ 0}; 28 | 29 | Ort::Value input_tensor_{ nullptr }; 30 | std::array input_shape_{ 1,channel, height_, width_ }; 31 | 32 | Ort::Value output_tensor_{ nullptr }; 33 | std::array output_shape_{ 1,4,height_, width_ }; 34 | 35 | OrtSession* session_ = nullptr; 36 | OrtSessionOptions* session_option; 37 | 38 | int main(int argc, char *argv[]) 39 | { 40 | QCoreApplication a(argc, argv); 41 | auto allocator_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); 42 | input_tensor_ = Ort::Value::CreateTensor(allocator_info, input_image_.data(), input_image_.size(), input_shape_.data(), input_shape_.size()); 43 | output_tensor_ = Ort::Value::CreateTensor(allocator_info, results_.data(), results_.size(), output_shape_.data(), output_shape_.size()); 44 | 45 | const char* input_names[] = { "actual_input_1" }; 46 | const char* output_names[] = { "output1" }; 47 | 48 | ORT_THROW_ON_ERROR(OrtCreateSessionOptions(&session_option)); 49 | //ORT_THROW_ON_ERROR(OrtSessionOptionsAppendExecutionProvider_Mkldnn(session_option, 1)); 50 | ORT_THROW_ON_ERROR(OrtSessionOptionsAppendExecutionProvider_Tensorrt(session_option, 0)); 51 | ORT_THROW_ON_ERROR(OrtSessionOptionsAppendExecutionProvider_CUDA(session_option, 0)); 52 | ORT_THROW_ON_ERROR(OrtCreateSession(env, "../../test_models/erfnet.onnx", session_option, &session_)); 53 | 54 | OrtValue *input_tensor_1 = input_tensor_; 55 | OrtValue *output_tensor_1 = output_tensor_; 56 | 57 | Mat img = imread("../../test_imgs/segmentation/00004.png"); 58 | const int row = height_; 59 | const int col = width_; 60 | Mat dst(row, col, CV_8UC3); 61 | Mat dst2; 62 | resize(img, dst, Size(col, row)); 63 | 64 | float* output = input_image_.data(); 65 | fill(input_image_.begin(), input_image_.end(), 0.f); 66 | Scalar rgb_mean = mean(dst); 67 | for (int c = 0;c < 3;c++) { 68 | for (int i = 0;i < row;i++) { 69 | for (int j = 0;j < col;j++) { 70 | output[c*row*col + i*col + j] = (dst.ptr(i)[j * 3 + c])/255.0; 71 | } 72 | } 73 | } 74 | double timeStart = (double)getTickCount(); 75 | for (int i = 0; i < 1000; i++) { 76 | OrtRun(session_, nullptr, input_names, &input_tensor_1, 1, output_names, 1, &output_tensor_1); 77 | } 78 | double nTime = ((double)getTickCount() - timeStart) / getTickFrequency(); 79 | cout << "running time :" << nTime << "sec\n" << endl; 80 | 81 | for (int i = 0; i < height_*width_; i++) { 82 | results_extra[0] = results_[i]; 83 | results_extra[1] = results_[i + height_ * width_]; 84 | results_extra[2] = results_[i + height_ * width_ * 2]; 85 | results_extra[3] = results_[i + height_ * width_ * 3]; 86 | result_[i] = std::distance(results_extra.begin(), std::max_element(results_extra.begin(), results_extra.end())); 87 | } 88 | int* result = result_; 89 | 90 | Mat outputimage(height_, width_, CV_8UC3, Scalar(0, 0, 0)); 91 | for (int i = 0;i < height_;i++) { 92 | for (int j = 0;j < width_;j++) { 93 | if (result[i * width_ + j] == 0) { 94 | outputimage.ptr(i)[j * 3] = 255; 95 | outputimage.ptr(i)[j * 3+1] = 0; 96 | outputimage.ptr(i)[j * 3+2] = 0; 97 | } 98 | if (result[i * width_ + j] == 1) { 99 | outputimage.ptr(i)[j * 3] = 0; 100 | outputimage.ptr(i)[j * 3 + 1] = 255; 101 | outputimage.ptr(i)[j * 3 + 2] = 0; 102 | } 103 | if (result[i * width_ + j] == 2) { 104 | outputimage.ptr(i)[j * 3] = 0; 105 | outputimage.ptr(i)[j * 3 + 1] = 0; 106 | outputimage.ptr(i)[j * 3 + 2] = 255; 107 | } 108 | if (result[i * width_ + j] == 3) { 109 | outputimage.ptr(i)[j * 3] = 255; 110 | outputimage.ptr(i)[j * 3 + 1] = 255; 111 | outputimage.ptr(i)[j * 3 + 2] = 0; 112 | } 113 | } 114 | } 115 | imwrite("4.png", outputimage); 116 | return a.exec(); 117 | } 118 | 119 | -------------------------------------------------------------------------------- /Ubuntu/onnx_erfnet/onnx_test_seg.pro: -------------------------------------------------------------------------------- 1 | QT += core 2 | QT -= gui 3 | 4 | TARGET = onnx_test 5 | CONFIG += console 6 | CONFIG -= app_bundle 7 | CONFIG += C++11 8 | 9 | TEMPLATE = app 10 | 11 | SOURCES += main.cpp 12 | 13 | INCLUDEPATH += /usr/include \ 14 | /usr/include/opencv \ 15 | /usr/include/opencv2 \ 16 | ../../include/onnxruntime 17 | 18 | LIBS += /usr/lib/x86_64-linux-gnu/libopencv_highgui.so \ 19 | /usr/lib/x86_64-linux-gnu/libopencv_core.so \ 20 | /usr/lib/x86_64-linux-gnu/libopencv_imgproc.so \ 21 | /media/usr523/000903F80002AA1E/cxy/1908/1909/onnxruntime/build/Linux/Release/libonnxruntime.so.0.5.0 22 | -------------------------------------------------------------------------------- /Ubuntu/onnx_mobilenet/main.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | 15 | using namespace cv; 16 | using namespace std; 17 | 18 | Ort::Env env{ ORT_LOGGING_LEVEL_WARNING, "test" }; 19 | 20 | static constexpr const int width_ = 224; 21 | static constexpr const int height_ = 224; 22 | static constexpr const int channel = 3; 23 | 24 | std::array input_image_{}; 25 | std::array results_{}; 26 | int result_{ 0 }; 27 | 28 | Ort::Value input_tensor_{ nullptr }; 29 | std::array input_shape_{ 1,3, width_, height_ }; 30 | 31 | Ort::Value output_tensor_{ nullptr }; 32 | std::array output_shape_{ 1, 1000 }; 33 | 34 | OrtSession* session_ = nullptr; 35 | 36 | OrtSessionOptions* session_option; 37 | 38 | int main(int argc, char *argv[]) 39 | { 40 | QCoreApplication a(argc, argv); 41 | auto allocator_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); 42 | input_tensor_ = Ort::Value::CreateTensor(allocator_info, input_image_.data(), input_image_.size(), input_shape_.data(), input_shape_.size()); 43 | output_tensor_ = Ort::Value::CreateTensor(allocator_info, results_.data(), results_.size(), output_shape_.data(), output_shape_.size()); 44 | const char* input_names[] = { "data" }; 45 | const char* output_names[] = { "mobilenetv20_output_flatten0_reshape0" }; 46 | //CreateSession(session_); 47 | ORT_THROW_ON_ERROR(OrtCreateSessionOptions(&session_option)); 48 | //ORT_THROW_ON_ERROR(OrtSessionOptionsAppendExecutionProvider_Mkldnn(session_option, 1)); 49 | ORT_THROW_ON_ERROR(OrtSessionOptionsAppendExecutionProvider_Tensorrt(session_option, 0)); 50 | ORT_THROW_ON_ERROR(OrtSessionOptionsAppendExecutionProvider_CUDA(session_option, 0)); 51 | ORT_THROW_ON_ERROR(OrtCreateSession(env, "../../test_models/mobilenetv2-1.0.onnx", session_option, &session_)); 52 | OrtValue *input_tensor_1 = input_tensor_; 53 | OrtValue *output_tensor_1 = output_tensor_; 54 | Mat img = imread("../../test_imgs/cls_001.jpg"); 55 | const int row = 224; 56 | const int col = 224; 57 | Mat dst(row, col, CV_8UC3); 58 | Mat dst2; 59 | resize(img, dst, Size(row, col)); 60 | cvtColor(dst, dst, CV_BGR2RGB); 61 | 62 | float* output = input_image_.data(); 63 | fill(input_image_.begin(), input_image_.end(), 0.f); 64 | for (int c = 0; c < 3; c++) { 65 | for (int i = 0; i < row; i++) { 66 | for (int j = 0; j < col; j++) { 67 | if (c == 0) { 68 | output[c*row*col + i * col + j] = ((dst.ptr(i)[j * 3 + c]) / 255.0 - 0.406) / 0.225; 69 | } 70 | if (c == 1) { 71 | output[c*row*col + i * col + j] = ((dst.ptr(i)[j * 3 + c]) / 255.0 - 0.456) / 0.224; 72 | } 73 | if (c == 2) { 74 | output[c*row*col + i * col + j] = ((dst.ptr(i)[j * 3 + c]) / 255.0 - 0.485) / 0.229; 75 | } 76 | } 77 | } 78 | } 79 | 80 | double timeStart = (double)getTickCount(); 81 | for (int i = 0; i < 1000; i++) { 82 | OrtRun(session_, nullptr, input_names, &input_tensor_1, 1, output_names, 1, &output_tensor_1); 83 | } 84 | double nTime = ((double)getTickCount() - timeStart) / getTickFrequency(); 85 | cout << "running time : " << nTime << "sec\n" << endl; 86 | result_ = std::distance(results_.begin(), std::max_element(results_.begin(), results_.end())); 87 | int result = result_; 88 | cout << result << endl; 89 | return a.exec(); 90 | } 91 | 92 | -------------------------------------------------------------------------------- /Ubuntu/onnx_mobilenet/onnx_test.pro: -------------------------------------------------------------------------------- 1 | QT += core 2 | QT -= gui 3 | 4 | TARGET = onnx_test 5 | CONFIG += console 6 | CONFIG -= app_bundle 7 | CONFIG += C++11 8 | 9 | TEMPLATE = app 10 | 11 | SOURCES += main.cpp 12 | 13 | INCLUDEPATH += /usr/include \ 14 | /usr/include/opencv \ 15 | /usr/include/opencv2 \ 16 | /media/usr523/000903F80002AA1E/cxy/1908/1909/onnxruntime/include/onnxruntime 17 | 18 | LIBS += /usr/lib/x86_64-linux-gnu/libopencv_highgui.so \ 19 | /usr/lib/x86_64-linux-gnu/libopencv_core.so \ 20 | /usr/lib/x86_64-linux-gnu/libopencv_imgproc.so \ 21 | /media/usr523/000903F80002AA1E/cxy/1908/1909/onnxruntime/build/Linux/Release/libonnxruntime.so.0.5.0 22 | -------------------------------------------------------------------------------- /Windows/README.md: -------------------------------------------------------------------------------- 1 | ## Start 2 | 3 | 1, Download the thirdparty denpendencies from [BaiduYun](https://pan.baidu.com/s/1cIOIHaF058pFwhve-QO2gw), and extract them to the workplace. 4 | 5 | 2, Open the sln file with Visual Studio and run the project. -------------------------------------------------------------------------------- /Windows/new_version/onnx_ssd/ort_test.sln: -------------------------------------------------------------------------------- 1 |  2 | Microsoft Visual Studio Solution File, Format Version 12.00 3 | # Visual Studio 14 4 | VisualStudioVersion = 14.0.25420.1 5 | MinimumVisualStudioVersion = 10.0.40219.1 6 | Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "ort_test", "ort_test\ort_test.vcxproj", "{B0F7CDB8-E482-467B-B5AF-9B4AF7ABB5A4}" 7 | EndProject 8 | Global 9 | GlobalSection(SolutionConfigurationPlatforms) = preSolution 10 | Debug|x64 = Debug|x64 11 | Debug|x86 = Debug|x86 12 | Release|x64 = Release|x64 13 | Release|x86 = Release|x86 14 | EndGlobalSection 15 | GlobalSection(ProjectConfigurationPlatforms) = postSolution 16 | {B0F7CDB8-E482-467B-B5AF-9B4AF7ABB5A4}.Debug|x64.ActiveCfg = Debug|x64 17 | {B0F7CDB8-E482-467B-B5AF-9B4AF7ABB5A4}.Debug|x64.Build.0 = Debug|x64 18 | {B0F7CDB8-E482-467B-B5AF-9B4AF7ABB5A4}.Debug|x86.ActiveCfg = Debug|Win32 19 | {B0F7CDB8-E482-467B-B5AF-9B4AF7ABB5A4}.Debug|x86.Build.0 = Debug|Win32 20 | {B0F7CDB8-E482-467B-B5AF-9B4AF7ABB5A4}.Release|x64.ActiveCfg = Release|x64 21 | {B0F7CDB8-E482-467B-B5AF-9B4AF7ABB5A4}.Release|x64.Build.0 = Release|x64 22 | {B0F7CDB8-E482-467B-B5AF-9B4AF7ABB5A4}.Release|x86.ActiveCfg = Release|Win32 23 | {B0F7CDB8-E482-467B-B5AF-9B4AF7ABB5A4}.Release|x86.Build.0 = Release|Win32 24 | EndGlobalSection 25 | GlobalSection(SolutionProperties) = preSolution 26 | HideSolutionNode = FALSE 27 | EndGlobalSection 28 | EndGlobal 29 | -------------------------------------------------------------------------------- /Windows/new_version/onnx_ssd/ort_test/ReadMe.txt: -------------------------------------------------------------------------------- 1 | ======================================================================== 2 | 控制台应用程序:ort_test 项目概述 3 | ======================================================================== 4 | 5 | 应用程序向导已为您创建了此 ort_test 应用程序。 6 | 7 | 本文件概要介绍组成 ort_test 应用程序的每个文件的内容。 8 | 9 | 10 | ort_test.vcxproj 11 | 这是使用应用程序向导生成的 VC++ 项目的主项目文件,其中包含生成该文件的 Visual C++ 的版本信息,以及有关使用应用程序向导选择的平台、配置和项目功能的信息。 12 | 13 | ort_test.vcxproj.filters 14 | 这是使用“应用程序向导”生成的 VC++ 项目筛选器文件。它包含有关项目文件与筛选器之间的关联信息。在 IDE 中,通过这种关联,在特定节点下以分组形式显示具有相似扩展名的文件。例如,“.cpp”文件与“源文件”筛选器关联。 15 | 16 | ort_test.cpp 17 | 这是主应用程序源文件。 18 | 19 | ///////////////////////////////////////////////////////////////////////////// 20 | 其他标准文件: 21 | 22 | StdAfx.h, StdAfx.cpp 23 | 这些文件用于生成名为 ort_test.pch 的预编译头 (PCH) 文件和名为 StdAfx.obj 的预编译类型文件。 24 | 25 | ///////////////////////////////////////////////////////////////////////////// 26 | 其他注释: 27 | 28 | 应用程序向导使用“TODO:”注释来指示应添加或自定义的源代码部分。 29 | 30 | ///////////////////////////////////////////////////////////////////////////// 31 | -------------------------------------------------------------------------------- /Windows/new_version/onnx_ssd/ort_test/onnxruntime.lib: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tenglike1997/onnxruntime-projects/7e99a75f1ce36329c4dbf2863023e7e05c73a991/Windows/new_version/onnx_ssd/ort_test/onnxruntime.lib -------------------------------------------------------------------------------- /Windows/new_version/onnx_ssd/ort_test/ort_test.vcxproj.filters: -------------------------------------------------------------------------------- 1 |  2 | 3 | 4 | 5 | {4FC737F1-C7A5-4376-A066-2A32D752A2FF} 6 | cpp;c;cc;cxx;def;odl;idl;hpj;bat;asm;asmx 7 | 8 | 9 | {93995380-89BD-4b04-88EB-625FBE52EBFB} 10 | h;hh;hpp;hxx;hm;inl;inc;xsd 11 | 12 | 13 | {67DA6AB6-F800-4c08-8B7A-83BB121AAD01} 14 | rc;ico;cur;bmp;dlg;rc2;rct;bin;rgs;gif;jpg;jpeg;jpe;resx;tiff;tif;png;wav;mfcribbon-ms 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 头文件 23 | 24 | 25 | 头文件 26 | 27 | 28 | 29 | 30 | 源文件 31 | 32 | 33 | 源文件 34 | 35 | 36 | -------------------------------------------------------------------------------- /Windows/new_version/onnx_ssd/ort_test/ort_test.vcxproj.user: -------------------------------------------------------------------------------- 1 |  2 | 3 | 4 | -------------------------------------------------------------------------------- /Windows/new_version/onnx_ssd/ort_test/stdafx.cpp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tenglike1997/onnxruntime-projects/7e99a75f1ce36329c4dbf2863023e7e05c73a991/Windows/new_version/onnx_ssd/ort_test/stdafx.cpp -------------------------------------------------------------------------------- /Windows/new_version/onnx_ssd/ort_test/stdafx.h: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tenglike1997/onnxruntime-projects/7e99a75f1ce36329c4dbf2863023e7e05c73a991/Windows/new_version/onnx_ssd/ort_test/stdafx.h -------------------------------------------------------------------------------- /Windows/new_version/onnx_ssd/ort_test/targetver.h: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tenglike1997/onnxruntime-projects/7e99a75f1ce36329c4dbf2863023e7e05c73a991/Windows/new_version/onnx_ssd/ort_test/targetver.h -------------------------------------------------------------------------------- /Windows/onnx_crnn/ort_crnn.cpp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tenglike1997/onnxruntime-projects/7e99a75f1ce36329c4dbf2863023e7e05c73a991/Windows/onnx_crnn/ort_crnn.cpp -------------------------------------------------------------------------------- /Windows/onnx_erfnet/ort_test.sln: -------------------------------------------------------------------------------- 1 |  2 | Microsoft Visual Studio Solution File, Format Version 12.00 3 | # Visual Studio 14 4 | VisualStudioVersion = 14.0.25420.1 5 | MinimumVisualStudioVersion = 10.0.40219.1 6 | Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "ort_test", "ort_test\ort_test.vcxproj", "{B0F7CDB8-E482-467B-B5AF-9B4AF7ABB5A4}" 7 | EndProject 8 | Global 9 | GlobalSection(SolutionConfigurationPlatforms) = preSolution 10 | Debug|x64 = Debug|x64 11 | Debug|x86 = Debug|x86 12 | Release|x64 = Release|x64 13 | Release|x86 = Release|x86 14 | EndGlobalSection 15 | GlobalSection(ProjectConfigurationPlatforms) = postSolution 16 | {B0F7CDB8-E482-467B-B5AF-9B4AF7ABB5A4}.Debug|x64.ActiveCfg = Debug|x64 17 | {B0F7CDB8-E482-467B-B5AF-9B4AF7ABB5A4}.Debug|x64.Build.0 = Debug|x64 18 | {B0F7CDB8-E482-467B-B5AF-9B4AF7ABB5A4}.Debug|x86.ActiveCfg = Debug|Win32 19 | {B0F7CDB8-E482-467B-B5AF-9B4AF7ABB5A4}.Debug|x86.Build.0 = Debug|Win32 20 | {B0F7CDB8-E482-467B-B5AF-9B4AF7ABB5A4}.Release|x64.ActiveCfg = Release|x64 21 | {B0F7CDB8-E482-467B-B5AF-9B4AF7ABB5A4}.Release|x64.Build.0 = Release|x64 22 | {B0F7CDB8-E482-467B-B5AF-9B4AF7ABB5A4}.Release|x86.ActiveCfg = Release|Win32 23 | {B0F7CDB8-E482-467B-B5AF-9B4AF7ABB5A4}.Release|x86.Build.0 = Release|Win32 24 | EndGlobalSection 25 | GlobalSection(SolutionProperties) = preSolution 26 | HideSolutionNode = FALSE 27 | EndGlobalSection 28 | EndGlobal 29 | -------------------------------------------------------------------------------- /Windows/onnx_erfnet/ort_test/ReadMe.txt: -------------------------------------------------------------------------------- 1 | ======================================================================== 2 | 控制台应用程序:ort_test 项目概述 3 | ======================================================================== 4 | 5 | 应用程序向导已为您创建了此 ort_test 应用程序。 6 | 7 | 本文件概要介绍组成 ort_test 应用程序的每个文件的内容。 8 | 9 | 10 | ort_test.vcxproj 11 | 这是使用应用程序向导生成的 VC++ 项目的主项目文件,其中包含生成该文件的 Visual C++ 的版本信息,以及有关使用应用程序向导选择的平台、配置和项目功能的信息。 12 | 13 | ort_test.vcxproj.filters 14 | 这是使用“应用程序向导”生成的 VC++ 项目筛选器文件。它包含有关项目文件与筛选器之间的关联信息。在 IDE 中,通过这种关联,在特定节点下以分组形式显示具有相似扩展名的文件。例如,“.cpp”文件与“源文件”筛选器关联。 15 | 16 | ort_test.cpp 17 | 这是主应用程序源文件。 18 | 19 | ///////////////////////////////////////////////////////////////////////////// 20 | 其他标准文件: 21 | 22 | StdAfx.h, StdAfx.cpp 23 | 这些文件用于生成名为 ort_test.pch 的预编译头 (PCH) 文件和名为 StdAfx.obj 的预编译类型文件。 24 | 25 | ///////////////////////////////////////////////////////////////////////////// 26 | 其他注释: 27 | 28 | 应用程序向导使用“TODO:”注释来指示应添加或自定义的源代码部分。 29 | 30 | ///////////////////////////////////////////////////////////////////////////// 31 | -------------------------------------------------------------------------------- /Windows/onnx_erfnet/ort_test/onnxruntime.lib: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tenglike1997/onnxruntime-projects/7e99a75f1ce36329c4dbf2863023e7e05c73a991/Windows/onnx_erfnet/ort_test/onnxruntime.lib -------------------------------------------------------------------------------- /Windows/onnx_erfnet/ort_test/ort_test.cpp: -------------------------------------------------------------------------------- 1 | #include "stdafx.h" 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | 18 | using namespace cv; 19 | using namespace std; 20 | 21 | 22 | #pragma comment(lib, "user32.lib") 23 | #pragma comment(lib, "gdi32.lib") 24 | #pragma comment(lib, "onnxruntime.lib") 25 | 26 | Ort::Env env{ ORT_LOGGING_LEVEL_WARNING, "test" }; 27 | 28 | static constexpr const int width_ = 640; 29 | static constexpr const int height_ = 480; 30 | static constexpr const int channel = 3; 31 | 32 | std::array input_image_{}; 33 | std::array results_{}; 34 | std::array results_extra{}; 35 | int result_[4*height_ * width_]{ 0}; 36 | 37 | Ort::Value input_tensor_{ nullptr }; 38 | std::array input_shape_{ 1,channel, height_, width_ }; 39 | 40 | Ort::Value output_tensor_{ nullptr }; 41 | std::array output_shape_{ 1,4,height_, width_ }; 42 | 43 | OrtSession* session_ = nullptr; 44 | OrtSessionOptions* session_option; 45 | 46 | int main() 47 | { 48 | auto allocator_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); 49 | input_tensor_ = Ort::Value::CreateTensor(allocator_info, input_image_.data(), input_image_.size(), input_shape_.data(), input_shape_.size()); 50 | output_tensor_ = Ort::Value::CreateTensor(allocator_info, results_.data(), results_.size(), output_shape_.data(), output_shape_.size()); 51 | 52 | const char* input_names[] = { "actual_input_1" }; 53 | const char* output_names[] = { "output1" }; 54 | 55 | ORT_THROW_ON_ERROR(OrtCreateSessionOptions(&session_option)); 56 | //ORT_THROW_ON_ERROR(OrtSessionOptionsAppendExecutionProvider_Mkldnn(session_option, 1)); 57 | ORT_THROW_ON_ERROR(OrtSessionOptionsAppendExecutionProvider_Tensorrt(session_option, 0)); 58 | ORT_THROW_ON_ERROR(OrtSessionOptionsAppendExecutionProvider_CUDA(session_option, 0)); 59 | ORT_THROW_ON_ERROR(OrtCreateSession(env, L"erfnet.onnx", session_option, &session_)); 60 | 61 | OrtValue *input_tensor_1 = input_tensor_; 62 | OrtValue *output_tensor_1 = output_tensor_; 63 | 64 | Mat img = imread("..\\..\\test_imgs\\segmentation\\00004.png"); 65 | const int row = height_; 66 | const int col = width_; 67 | Mat dst(row, col, CV_8UC3); 68 | Mat dst2; 69 | resize(img, dst, Size(col, row)); 70 | cvtColor(dst, dst, CV_BGR2RGB); 71 | 72 | float* output = input_image_.data(); 73 | fill(input_image_.begin(), input_image_.end(), 0.f); 74 | Scalar rgb_mean = mean(dst); 75 | for (int c = 0;c < 3;c++) { 76 | for (int i = 0;i < row;i++) { 77 | for (int j = 0;j < col;j++) { 78 | 79 | output[c*row*col + i*col + j] = (dst.ptr(i)[j * 3 + c])/255.0; 80 | } 81 | } 82 | } 83 | 84 | double timeStart = (double)getTickCount(); 85 | for (int i = 0; i < 1000; i++) { 86 | OrtRun(session_, nullptr, input_names, &input_tensor_1, 1, output_names, 1, &output_tensor_1); 87 | } 88 | double nTime = ((double)getTickCount() - timeStart) / getTickFrequency(); 89 | cout << "running time :" << nTime << "sec\n" << endl; 90 | 91 | for (int i = 0; i < height_*width_; i++) { 92 | results_extra[0] = results_[i]; 93 | results_extra[1] = results_[i + height_ * width_]; 94 | results_extra[2] = results_[i + height_ * width_ * 2]; 95 | results_extra[3] = results_[i + height_ * width_ * 3]; 96 | result_[i] = std::distance(results_extra.begin(), std::max_element(results_extra.begin(), results_extra.end())); 97 | } 98 | int* result = result_; 99 | 100 | Mat outputimage(height_, width_, CV_8UC3, Scalar(0, 0, 0)); 101 | for (int i = 0;i < height_;i++) { 102 | for (int j = 0;j < width_;j++) { 103 | if (result[i * width_ + j] == 0) { 104 | outputimage.ptr(i)[j * 3] = 255; 105 | outputimage.ptr(i)[j * 3+1] = 0; 106 | outputimage.ptr(i)[j * 3+2] = 0; 107 | } 108 | if (result[i * width_ + j] == 1) { 109 | outputimage.ptr(i)[j * 3] = 0; 110 | outputimage.ptr(i)[j * 3 + 1] = 255; 111 | outputimage.ptr(i)[j * 3 + 2] = 0; 112 | } 113 | if (result[i * width_ + j] == 2) { 114 | outputimage.ptr(i)[j * 3] = 0; 115 | outputimage.ptr(i)[j * 3 + 1] = 0; 116 | outputimage.ptr(i)[j * 3 + 2] = 255; 117 | } 118 | if (result[i * width_ + j] == 3) { 119 | outputimage.ptr(i)[j * 3] = 255; 120 | outputimage.ptr(i)[j * 3 + 1] = 255; 121 | outputimage.ptr(i)[j * 3 + 2] = 0; 122 | } 123 | } 124 | } 125 | imwrite("4.png", outputimage); 126 | system("pause"); 127 | return 0; 128 | } 129 | 130 | -------------------------------------------------------------------------------- /Windows/onnx_erfnet/ort_test/ort_test.vcxproj.filters: -------------------------------------------------------------------------------- 1 |  2 | 3 | 4 | 5 | {4FC737F1-C7A5-4376-A066-2A32D752A2FF} 6 | cpp;c;cc;cxx;def;odl;idl;hpj;bat;asm;asmx 7 | 8 | 9 | {93995380-89BD-4b04-88EB-625FBE52EBFB} 10 | h;hh;hpp;hxx;hm;inl;inc;xsd 11 | 12 | 13 | {67DA6AB6-F800-4c08-8B7A-83BB121AAD01} 14 | rc;ico;cur;bmp;dlg;rc2;rct;bin;rgs;gif;jpg;jpeg;jpe;resx;tiff;tif;png;wav;mfcribbon-ms 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 头文件 23 | 24 | 25 | 头文件 26 | 27 | 28 | 29 | 30 | 源文件 31 | 32 | 33 | 源文件 34 | 35 | 36 | -------------------------------------------------------------------------------- /Windows/onnx_erfnet/ort_test/ort_test.vcxproj.user: -------------------------------------------------------------------------------- 1 |  2 | 3 | 4 | -------------------------------------------------------------------------------- /Windows/onnx_erfnet/ort_test/stdafx.cpp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tenglike1997/onnxruntime-projects/7e99a75f1ce36329c4dbf2863023e7e05c73a991/Windows/onnx_erfnet/ort_test/stdafx.cpp -------------------------------------------------------------------------------- /Windows/onnx_erfnet/ort_test/stdafx.h: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tenglike1997/onnxruntime-projects/7e99a75f1ce36329c4dbf2863023e7e05c73a991/Windows/onnx_erfnet/ort_test/stdafx.h -------------------------------------------------------------------------------- /Windows/onnx_erfnet/ort_test/targetver.h: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tenglike1997/onnxruntime-projects/7e99a75f1ce36329c4dbf2863023e7e05c73a991/Windows/onnx_erfnet/ort_test/targetver.h -------------------------------------------------------------------------------- /Windows/onnx_mobilenet/ort_test.sln: -------------------------------------------------------------------------------- 1 |  2 | Microsoft Visual Studio Solution File, Format Version 12.00 3 | # Visual Studio 14 4 | VisualStudioVersion = 14.0.25420.1 5 | MinimumVisualStudioVersion = 10.0.40219.1 6 | Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "ort_test", "ort_test\ort_test.vcxproj", "{B0F7CDB8-E482-467B-B5AF-9B4AF7ABB5A4}" 7 | EndProject 8 | Global 9 | GlobalSection(SolutionConfigurationPlatforms) = preSolution 10 | Debug|x64 = Debug|x64 11 | Debug|x86 = Debug|x86 12 | Release|x64 = Release|x64 13 | Release|x86 = Release|x86 14 | EndGlobalSection 15 | GlobalSection(ProjectConfigurationPlatforms) = postSolution 16 | {B0F7CDB8-E482-467B-B5AF-9B4AF7ABB5A4}.Debug|x64.ActiveCfg = Debug|x64 17 | {B0F7CDB8-E482-467B-B5AF-9B4AF7ABB5A4}.Debug|x64.Build.0 = Debug|x64 18 | {B0F7CDB8-E482-467B-B5AF-9B4AF7ABB5A4}.Debug|x86.ActiveCfg = Debug|Win32 19 | {B0F7CDB8-E482-467B-B5AF-9B4AF7ABB5A4}.Debug|x86.Build.0 = Debug|Win32 20 | {B0F7CDB8-E482-467B-B5AF-9B4AF7ABB5A4}.Release|x64.ActiveCfg = Release|x64 21 | {B0F7CDB8-E482-467B-B5AF-9B4AF7ABB5A4}.Release|x64.Build.0 = Release|x64 22 | {B0F7CDB8-E482-467B-B5AF-9B4AF7ABB5A4}.Release|x86.ActiveCfg = Release|Win32 23 | {B0F7CDB8-E482-467B-B5AF-9B4AF7ABB5A4}.Release|x86.Build.0 = Release|Win32 24 | EndGlobalSection 25 | GlobalSection(SolutionProperties) = preSolution 26 | HideSolutionNode = FALSE 27 | EndGlobalSection 28 | EndGlobal 29 | -------------------------------------------------------------------------------- /Windows/onnx_mobilenet/ort_test/ReadMe.txt: -------------------------------------------------------------------------------- 1 | ======================================================================== 2 | 控制台应用程序:ort_test 项目概述 3 | ======================================================================== 4 | 5 | 应用程序向导已为您创建了此 ort_test 应用程序。 6 | 7 | 本文件概要介绍组成 ort_test 应用程序的每个文件的内容。 8 | 9 | 10 | ort_test.vcxproj 11 | 这是使用应用程序向导生成的 VC++ 项目的主项目文件,其中包含生成该文件的 Visual C++ 的版本信息,以及有关使用应用程序向导选择的平台、配置和项目功能的信息。 12 | 13 | ort_test.vcxproj.filters 14 | 这是使用“应用程序向导”生成的 VC++ 项目筛选器文件。它包含有关项目文件与筛选器之间的关联信息。在 IDE 中,通过这种关联,在特定节点下以分组形式显示具有相似扩展名的文件。例如,“.cpp”文件与“源文件”筛选器关联。 15 | 16 | ort_test.cpp 17 | 这是主应用程序源文件。 18 | 19 | ///////////////////////////////////////////////////////////////////////////// 20 | 其他标准文件: 21 | 22 | StdAfx.h, StdAfx.cpp 23 | 这些文件用于生成名为 ort_test.pch 的预编译头 (PCH) 文件和名为 StdAfx.obj 的预编译类型文件。 24 | 25 | ///////////////////////////////////////////////////////////////////////////// 26 | 其他注释: 27 | 28 | 应用程序向导使用“TODO:”注释来指示应添加或自定义的源代码部分。 29 | 30 | ///////////////////////////////////////////////////////////////////////////// 31 | -------------------------------------------------------------------------------- /Windows/onnx_mobilenet/ort_test/onnxruntime.lib: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tenglike1997/onnxruntime-projects/7e99a75f1ce36329c4dbf2863023e7e05c73a991/Windows/onnx_mobilenet/ort_test/onnxruntime.lib -------------------------------------------------------------------------------- /Windows/onnx_mobilenet/ort_test/ort_test.cpp: -------------------------------------------------------------------------------- 1 | #include "stdafx.h" 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | 18 | using namespace cv; 19 | using namespace std; 20 | 21 | 22 | #pragma comment(lib, "user32.lib") 23 | #pragma comment(lib, "gdi32.lib") 24 | #pragma comment(lib, "onnxruntime.lib") 25 | 26 | Ort::Env env{ ORT_LOGGING_LEVEL_WARNING, "test" }; 27 | 28 | static constexpr const int width_ = 224; 29 | static constexpr const int height_ = 224; 30 | static constexpr const int channel = 3; 31 | 32 | std::array input_image_{}; 33 | std::array results_{}; 34 | int result_{ 0 }; 35 | 36 | Ort::Value input_tensor_{ nullptr }; 37 | std::array input_shape_{ 1,3, width_, height_ }; 38 | 39 | Ort::Value output_tensor_{ nullptr }; 40 | std::array output_shape_{ 1, 1000 }; 41 | 42 | OrtSession* session_ = nullptr; 43 | OrtSessionOptions* session_option; 44 | 45 | int main() 46 | { 47 | auto allocator_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); 48 | input_tensor_ = Ort::Value::CreateTensor(allocator_info, input_image_.data(), input_image_.size(), input_shape_.data(), input_shape_.size()); 49 | output_tensor_ = Ort::Value::CreateTensor(allocator_info, results_.data(), results_.size(), output_shape_.data(), output_shape_.size()); 50 | const char* input_names[] = { "data" }; 51 | const char* output_names[] = { "mobilenetv20_output_flatten0_reshape0" }; 52 | ORT_THROW_ON_ERROR(OrtCreateSessionOptions(&session_option)); 53 | //ORT_THROW_ON_ERROR(OrtSessionOptionsAppendExecutionProvider_Mkldnn(session_option, 1)); 54 | ORT_THROW_ON_ERROR(OrtSessionOptionsAppendExecutionProvider_Tensorrt(session_option, 0)); 55 | ORT_THROW_ON_ERROR(OrtSessionOptionsAppendExecutionProvider_CUDA(session_option, 0)); 56 | ORT_THROW_ON_ERROR(OrtCreateSession(env, L"mobilenetv2-1.0.onnx", session_option, &session_)); 57 | OrtValue *input_tensor_1 = input_tensor_; 58 | OrtValue *output_tensor_1 = output_tensor_; 59 | 60 | Mat img = imread("..\\..\\test_imgs\\classification\\cls_001.jpg"); 61 | const int row = 224; 62 | const int col = 224; 63 | Mat dst(row, col, CV_8UC3); 64 | Mat dst2; 65 | resize(img, dst, Size(row, col)); 66 | cvtColor(dst, dst, CV_BGR2RGB); 67 | 68 | float* output = input_image_.data(); 69 | fill(input_image_.begin(), input_image_.end(), 0.f); 70 | for (int c = 0; c < 3; c++) { 71 | for (int i = 0; i < row; i++) { 72 | for (int j = 0; j < col; j++) { 73 | if (c == 0) { 74 | output[c*row*col + i * col + j] = ((dst.ptr(i)[j * 3 + c]) / 255.0 - 0.406) / 0.225; 75 | } 76 | if (c == 1) { 77 | output[c*row*col + i * col + j] = ((dst.ptr(i)[j * 3 + c]) / 255.0 - 0.456) / 0.224; 78 | } 79 | if (c == 2) { 80 | output[c*row*col + i * col + j] = ((dst.ptr(i)[j * 3 + c]) / 255.0 - 0.485) / 0.229; 81 | } 82 | } 83 | } 84 | } 85 | 86 | double timeStart = (double)getTickCount(); 87 | for (int i = 0; i < 1000; i++) { 88 | OrtRun(session_, nullptr, input_names, &input_tensor_1, 1, output_names, 1, &output_tensor_1); 89 | } 90 | double nTime = ((double)getTickCount() - timeStart) / getTickFrequency(); 91 | cout << "running time :" << nTime << "sec\n" << endl; 92 | 93 | result_ = std::distance(results_.begin(), std::max_element(results_.begin(), results_.end())); 94 | int result = result_; 95 | cout << result << endl; 96 | system("pause"); 97 | return 0; 98 | } 99 | 100 | -------------------------------------------------------------------------------- /Windows/onnx_mobilenet/ort_test/ort_test.vcxproj.filters: -------------------------------------------------------------------------------- 1 |  2 | 3 | 4 | 5 | {4FC737F1-C7A5-4376-A066-2A32D752A2FF} 6 | cpp;c;cc;cxx;def;odl;idl;hpj;bat;asm;asmx 7 | 8 | 9 | {93995380-89BD-4b04-88EB-625FBE52EBFB} 10 | h;hh;hpp;hxx;hm;inl;inc;xsd 11 | 12 | 13 | {67DA6AB6-F800-4c08-8B7A-83BB121AAD01} 14 | rc;ico;cur;bmp;dlg;rc2;rct;bin;rgs;gif;jpg;jpeg;jpe;resx;tiff;tif;png;wav;mfcribbon-ms 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 头文件 23 | 24 | 25 | 头文件 26 | 27 | 28 | 29 | 30 | 源文件 31 | 32 | 33 | 源文件 34 | 35 | 36 | -------------------------------------------------------------------------------- /Windows/onnx_mobilenet/ort_test/ort_test.vcxproj.user: -------------------------------------------------------------------------------- 1 |  2 | 3 | 4 | -------------------------------------------------------------------------------- /Windows/onnx_mobilenet/ort_test/stdafx.cpp: -------------------------------------------------------------------------------- 1 | // stdafx.cpp : ֻ������׼�����ļ���Դ�ļ� 2 | // ort_test.pch ����ΪԤ����ͷ 3 | // stdafx.obj ������Ԥ����������Ϣ 4 | 5 | #include "stdafx.h" 6 | 7 | // TODO: �� STDAFX.H �������κ�����ĸ���ͷ�ļ��� 8 | //�������ڴ��ļ������� 9 | -------------------------------------------------------------------------------- /Windows/onnx_mobilenet/ort_test/stdafx.h: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tenglike1997/onnxruntime-projects/7e99a75f1ce36329c4dbf2863023e7e05c73a991/Windows/onnx_mobilenet/ort_test/stdafx.h -------------------------------------------------------------------------------- /Windows/onnx_mobilenet/ort_test/targetver.h: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tenglike1997/onnxruntime-projects/7e99a75f1ce36329c4dbf2863023e7e05c73a991/Windows/onnx_mobilenet/ort_test/targetver.h -------------------------------------------------------------------------------- /Windows/onnx_style_transfer/ort_test.sln: -------------------------------------------------------------------------------- 1 |  2 | Microsoft Visual Studio Solution File, Format Version 12.00 3 | # Visual Studio 14 4 | VisualStudioVersion = 14.0.25420.1 5 | MinimumVisualStudioVersion = 10.0.40219.1 6 | Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "ort_test", "ort_test\ort_test.vcxproj", "{B0F7CDB8-E482-467B-B5AF-9B4AF7ABB5A4}" 7 | EndProject 8 | Global 9 | GlobalSection(SolutionConfigurationPlatforms) = preSolution 10 | Debug|x64 = Debug|x64 11 | Debug|x86 = Debug|x86 12 | Release|x64 = Release|x64 13 | Release|x86 = Release|x86 14 | EndGlobalSection 15 | GlobalSection(ProjectConfigurationPlatforms) = postSolution 16 | {B0F7CDB8-E482-467B-B5AF-9B4AF7ABB5A4}.Debug|x64.ActiveCfg = Debug|x64 17 | {B0F7CDB8-E482-467B-B5AF-9B4AF7ABB5A4}.Debug|x64.Build.0 = Debug|x64 18 | {B0F7CDB8-E482-467B-B5AF-9B4AF7ABB5A4}.Debug|x86.ActiveCfg = Debug|Win32 19 | {B0F7CDB8-E482-467B-B5AF-9B4AF7ABB5A4}.Debug|x86.Build.0 = Debug|Win32 20 | {B0F7CDB8-E482-467B-B5AF-9B4AF7ABB5A4}.Release|x64.ActiveCfg = Release|x64 21 | {B0F7CDB8-E482-467B-B5AF-9B4AF7ABB5A4}.Release|x64.Build.0 = Release|x64 22 | {B0F7CDB8-E482-467B-B5AF-9B4AF7ABB5A4}.Release|x86.ActiveCfg = Release|Win32 23 | {B0F7CDB8-E482-467B-B5AF-9B4AF7ABB5A4}.Release|x86.Build.0 = Release|Win32 24 | EndGlobalSection 25 | GlobalSection(SolutionProperties) = preSolution 26 | HideSolutionNode = FALSE 27 | EndGlobalSection 28 | EndGlobal 29 | -------------------------------------------------------------------------------- /Windows/onnx_style_transfer/ort_test/ReadMe.txt: -------------------------------------------------------------------------------- 1 | ======================================================================== 2 | 控制台应用程序:ort_test 项目概述 3 | ======================================================================== 4 | 5 | 应用程序向导已为您创建了此 ort_test 应用程序。 6 | 7 | 本文件概要介绍组成 ort_test 应用程序的每个文件的内容。 8 | 9 | 10 | ort_test.vcxproj 11 | 这是使用应用程序向导生成的 VC++ 项目的主项目文件,其中包含生成该文件的 Visual C++ 的版本信息,以及有关使用应用程序向导选择的平台、配置和项目功能的信息。 12 | 13 | ort_test.vcxproj.filters 14 | 这是使用“应用程序向导”生成的 VC++ 项目筛选器文件。它包含有关项目文件与筛选器之间的关联信息。在 IDE 中,通过这种关联,在特定节点下以分组形式显示具有相似扩展名的文件。例如,“.cpp”文件与“源文件”筛选器关联。 15 | 16 | ort_test.cpp 17 | 这是主应用程序源文件。 18 | 19 | ///////////////////////////////////////////////////////////////////////////// 20 | 其他标准文件: 21 | 22 | StdAfx.h, StdAfx.cpp 23 | 这些文件用于生成名为 ort_test.pch 的预编译头 (PCH) 文件和名为 StdAfx.obj 的预编译类型文件。 24 | 25 | ///////////////////////////////////////////////////////////////////////////// 26 | 其他注释: 27 | 28 | 应用程序向导使用“TODO:”注释来指示应添加或自定义的源代码部分。 29 | 30 | ///////////////////////////////////////////////////////////////////////////// 31 | -------------------------------------------------------------------------------- /Windows/onnx_style_transfer/ort_test/onnxruntime.lib: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tenglike1997/onnxruntime-projects/7e99a75f1ce36329c4dbf2863023e7e05c73a991/Windows/onnx_style_transfer/ort_test/onnxruntime.lib -------------------------------------------------------------------------------- /Windows/onnx_style_transfer/ort_test/ort_test.cpp: -------------------------------------------------------------------------------- 1 | #include "stdafx.h" 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | 18 | using namespace cv; 19 | using namespace std; 20 | 21 | 22 | #pragma comment(lib, "user32.lib") 23 | #pragma comment(lib, "gdi32.lib") 24 | #pragma comment(lib, "onnxruntime.lib") 25 | 26 | Ort::Env env{ ORT_LOGGING_LEVEL_WARNING, "test" }; 27 | 28 | static constexpr const int width_ = 224; 29 | static constexpr const int height_ = 224; 30 | static constexpr const int channel = 3; 31 | 32 | std::array input_image_{}; 33 | std::array results_{}; 34 | std::array results_extra{}; 35 | int result_[height_ * width_]{ 0 }; 36 | Mat outputimage(height_, width_, CV_8UC3, Scalar(0, 0, 0)); 37 | 38 | Ort::Value input_tensor_{ nullptr }; 39 | std::array input_shape_{ 1,channel, height_, width_ }; 40 | 41 | Ort::Value output_tensor_{ nullptr }; 42 | std::array output_shape_{ 1,3,height_, width_ }; 43 | 44 | OrtSession* session_ = nullptr; 45 | OrtSessionOptions* session_option; 46 | 47 | int main() 48 | { 49 | auto allocator_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); 50 | input_tensor_ = Ort::Value::CreateTensor(allocator_info, input_image_.data(), input_image_.size(), input_shape_.data(), input_shape_.size()); 51 | output_tensor_ = Ort::Value::CreateTensor(allocator_info, results_.data(), results_.size(), output_shape_.data(), output_shape_.size()); 52 | const char* input_names[] = { "input1" }; 53 | const char* output_names[] = { "output1" }; 54 | ORT_THROW_ON_ERROR(OrtCreateSessionOptions(&session_option)); 55 | //ORT_THROW_ON_ERROR(OrtSessionOptionsAppendExecutionProvider_Mkldnn(session_option, 1)); 56 | //ORT_THROW_ON_ERROR(OrtSessionOptionsAppendExecutionProvider_Tensorrt(session_option, 0)); 57 | ORT_THROW_ON_ERROR(OrtSessionOptionsAppendExecutionProvider_CUDA(session_option, 0)); 58 | ORT_THROW_ON_ERROR(OrtCreateSession(env, L"mosaic.onnx", session_option, &session_)); 59 | //ORT_THROW_ON_ERROR(OrtCreateSession(env, L"candy.onnx", session_option, &session_)); 60 | //ORT_THROW_ON_ERROR(OrtCreateSession(env, L"pointilism.onnx", session_option, &session_)); 61 | //ORT_THROW_ON_ERROR(OrtCreateSession(env, L"rain_princess.onnx", session_option, &session_)); 62 | //ORT_THROW_ON_ERROR(OrtCreateSession(env, L"udnie.onnx", session_option, &session_)); 63 | OrtValue *input_tensor_1 = input_tensor_; 64 | OrtValue *output_tensor_1 = output_tensor_; 65 | 66 | Mat img = imread("..\\..\\test_imgs\\style_transfer\\church.jpg"); 67 | const int row = height_; 68 | const int col = width_; 69 | Mat dst(row, col, CV_8UC3); 70 | Mat dst2; 71 | resize(img, dst, Size(col, row)); 72 | cvtColor(dst, dst, CV_BGR2RGB); 73 | 74 | float* output = input_image_.data(); 75 | fill(input_image_.begin(), input_image_.end(), 0.f); 76 | Scalar rgb_mean = mean(dst); 77 | for (int c = 0;c < 3;c++) { 78 | for (int i = 0;i < row;i++) { 79 | for (int j = 0;j < col;j++) { 80 | output[c*row*col + i*col + j] = (dst.ptr(i)[j * 3 + c]); 81 | } 82 | } 83 | } 84 | 85 | double timeStart = (double)getTickCount(); 86 | for (int i = 0; i < 1000; i++) { 87 | OrtRun(session_, nullptr, input_names, &input_tensor_1, 1, output_names, 1, &output_tensor_1); 88 | } 89 | double nTime = ((double)getTickCount() - timeStart) / getTickFrequency(); 90 | cout << "running time :" << nTime << "sec\n" << endl; 91 | 92 | for (int i = 0; i < height_*width_; i++) { 93 | results_extra[0] = results_[i]; 94 | results_extra[1] = results_[i + height_ * width_]; 95 | results_extra[2] = results_[i + height_ * width_ * 2]; 96 | outputimage.ptr(i / width_)[i%width_ * 3] = results_extra[0]; 97 | outputimage.ptr(i / width_)[i%width_ * 3 + 1] = results_extra[1]; 98 | outputimage.ptr(i / width_)[i%width_ * 3 + 2] = results_extra[2]; 99 | } 100 | 101 | dst2 = outputimage; 102 | imwrite("..\\..\\test_imgs\\style_transfer\\mosaic.png", outputimage); 103 | system("pause"); 104 | return 0; 105 | } 106 | 107 | -------------------------------------------------------------------------------- /Windows/onnx_style_transfer/ort_test/ort_test.vcxproj.filters: -------------------------------------------------------------------------------- 1 |  2 | 3 | 4 | 5 | {4FC737F1-C7A5-4376-A066-2A32D752A2FF} 6 | cpp;c;cc;cxx;def;odl;idl;hpj;bat;asm;asmx 7 | 8 | 9 | {93995380-89BD-4b04-88EB-625FBE52EBFB} 10 | h;hh;hpp;hxx;hm;inl;inc;xsd 11 | 12 | 13 | {67DA6AB6-F800-4c08-8B7A-83BB121AAD01} 14 | rc;ico;cur;bmp;dlg;rc2;rct;bin;rgs;gif;jpg;jpeg;jpe;resx;tiff;tif;png;wav;mfcribbon-ms 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 头文件 23 | 24 | 25 | 头文件 26 | 27 | 28 | 29 | 30 | 源文件 31 | 32 | 33 | 源文件 34 | 35 | 36 | -------------------------------------------------------------------------------- /Windows/onnx_style_transfer/ort_test/ort_test.vcxproj.user: -------------------------------------------------------------------------------- 1 |  2 | 3 | 4 | -------------------------------------------------------------------------------- /Windows/onnx_style_transfer/ort_test/stdafx.cpp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tenglike1997/onnxruntime-projects/7e99a75f1ce36329c4dbf2863023e7e05c73a991/Windows/onnx_style_transfer/ort_test/stdafx.cpp -------------------------------------------------------------------------------- /Windows/onnx_style_transfer/ort_test/stdafx.h: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tenglike1997/onnxruntime-projects/7e99a75f1ce36329c4dbf2863023e7e05c73a991/Windows/onnx_style_transfer/ort_test/stdafx.h -------------------------------------------------------------------------------- /Windows/onnx_style_transfer/ort_test/targetver.h: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tenglike1997/onnxruntime-projects/7e99a75f1ce36329c4dbf2863023e7e05c73a991/Windows/onnx_style_transfer/ort_test/targetver.h -------------------------------------------------------------------------------- /Windows/onnx_super_resolustion/ort_test.sln: -------------------------------------------------------------------------------- 1 |  2 | Microsoft Visual Studio Solution File, Format Version 12.00 3 | # Visual Studio 14 4 | VisualStudioVersion = 14.0.25420.1 5 | MinimumVisualStudioVersion = 10.0.40219.1 6 | Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "ort_test", "ort_test\ort_test.vcxproj", "{B0F7CDB8-E482-467B-B5AF-9B4AF7ABB5A4}" 7 | EndProject 8 | Global 9 | GlobalSection(SolutionConfigurationPlatforms) = preSolution 10 | Debug|x64 = Debug|x64 11 | Debug|x86 = Debug|x86 12 | Release|x64 = Release|x64 13 | Release|x86 = Release|x86 14 | EndGlobalSection 15 | GlobalSection(ProjectConfigurationPlatforms) = postSolution 16 | {B0F7CDB8-E482-467B-B5AF-9B4AF7ABB5A4}.Debug|x64.ActiveCfg = Debug|x64 17 | {B0F7CDB8-E482-467B-B5AF-9B4AF7ABB5A4}.Debug|x64.Build.0 = Debug|x64 18 | {B0F7CDB8-E482-467B-B5AF-9B4AF7ABB5A4}.Debug|x86.ActiveCfg = Debug|Win32 19 | {B0F7CDB8-E482-467B-B5AF-9B4AF7ABB5A4}.Debug|x86.Build.0 = Debug|Win32 20 | {B0F7CDB8-E482-467B-B5AF-9B4AF7ABB5A4}.Release|x64.ActiveCfg = Release|x64 21 | {B0F7CDB8-E482-467B-B5AF-9B4AF7ABB5A4}.Release|x64.Build.0 = Release|x64 22 | {B0F7CDB8-E482-467B-B5AF-9B4AF7ABB5A4}.Release|x86.ActiveCfg = Release|Win32 23 | {B0F7CDB8-E482-467B-B5AF-9B4AF7ABB5A4}.Release|x86.Build.0 = Release|Win32 24 | EndGlobalSection 25 | GlobalSection(SolutionProperties) = preSolution 26 | HideSolutionNode = FALSE 27 | EndGlobalSection 28 | EndGlobal 29 | -------------------------------------------------------------------------------- /Windows/onnx_super_resolustion/ort_test/ReadMe.txt: -------------------------------------------------------------------------------- 1 | ======================================================================== 2 | 控制台应用程序:ort_test 项目概述 3 | ======================================================================== 4 | 5 | 应用程序向导已为您创建了此 ort_test 应用程序。 6 | 7 | 本文件概要介绍组成 ort_test 应用程序的每个文件的内容。 8 | 9 | 10 | ort_test.vcxproj 11 | 这是使用应用程序向导生成的 VC++ 项目的主项目文件,其中包含生成该文件的 Visual C++ 的版本信息,以及有关使用应用程序向导选择的平台、配置和项目功能的信息。 12 | 13 | ort_test.vcxproj.filters 14 | 这是使用“应用程序向导”生成的 VC++ 项目筛选器文件。它包含有关项目文件与筛选器之间的关联信息。在 IDE 中,通过这种关联,在特定节点下以分组形式显示具有相似扩展名的文件。例如,“.cpp”文件与“源文件”筛选器关联。 15 | 16 | ort_test.cpp 17 | 这是主应用程序源文件。 18 | 19 | ///////////////////////////////////////////////////////////////////////////// 20 | 其他标准文件: 21 | 22 | StdAfx.h, StdAfx.cpp 23 | 这些文件用于生成名为 ort_test.pch 的预编译头 (PCH) 文件和名为 StdAfx.obj 的预编译类型文件。 24 | 25 | ///////////////////////////////////////////////////////////////////////////// 26 | 其他注释: 27 | 28 | 应用程序向导使用“TODO:”注释来指示应添加或自定义的源代码部分。 29 | 30 | ///////////////////////////////////////////////////////////////////////////// 31 | -------------------------------------------------------------------------------- /Windows/onnx_super_resolustion/ort_test/onnxruntime.lib: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tenglike1997/onnxruntime-projects/7e99a75f1ce36329c4dbf2863023e7e05c73a991/Windows/onnx_super_resolustion/ort_test/onnxruntime.lib -------------------------------------------------------------------------------- /Windows/onnx_super_resolustion/ort_test/ort_test.cpp: -------------------------------------------------------------------------------- 1 | #include "stdafx.h" 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | 18 | using namespace cv; 19 | using namespace std; 20 | 21 | 22 | #pragma comment(lib, "user32.lib") 23 | #pragma comment(lib, "gdi32.lib") 24 | #pragma comment(lib, "onnxruntime.lib") 25 | 26 | Ort::Env env{ ORT_LOGGING_LEVEL_WARNING, "test" }; 27 | 28 | static constexpr const int width_ = 224; 29 | static constexpr const int height_ = 224; 30 | static constexpr const int channel = 1; 31 | 32 | std::array input_image_{}; 33 | std::array results_{}; 34 | std::array results_extra{}; 35 | int result_[672 * 672]{ 0 }; 36 | Mat outputimageY(672, 672, CV_8UC1, Scalar(0)); 37 | 38 | Ort::Value input_tensor_{ nullptr }; 39 | std::array input_shape_{ 1,channel, height_, width_ }; 40 | 41 | Ort::Value output_tensor_{ nullptr }; 42 | std::array output_shape_{ 1,1,672, 672 }; 43 | 44 | OrtSession* session_ = nullptr; 45 | OrtSessionOptions* session_option; 46 | 47 | int main() 48 | { 49 | auto allocator_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); 50 | input_tensor_ = Ort::Value::CreateTensor(allocator_info, input_image_.data(), input_image_.size(), input_shape_.data(), input_shape_.size()); 51 | output_tensor_ = Ort::Value::CreateTensor(allocator_info, results_.data(), results_.size(), output_shape_.data(), output_shape_.size()); 52 | const char* input_names[] = { "input" }; 53 | const char* output_names[] = { "output" }; 54 | ORT_THROW_ON_ERROR(OrtCreateSessionOptions(&session_option)); 55 | //ORT_THROW_ON_ERROR(OrtSessionOptionsAppendExecutionProvider_Mkldnn(session_option, 1)); 56 | ORT_THROW_ON_ERROR(OrtSessionOptionsAppendExecutionProvider_Tensorrt(session_option, 0)); 57 | ORT_THROW_ON_ERROR(OrtSessionOptionsAppendExecutionProvider_CUDA(session_option, 0)); 58 | ORT_THROW_ON_ERROR(OrtCreateSession(env, L"super_resolution.onnx", session_option, &session_)); 59 | OrtValue *input_tensor_1 = input_tensor_; 60 | OrtValue *output_tensor_1 = output_tensor_; 61 | 62 | Mat img = imread("..\\..\\test_imgs\\super_resolution\\rawimg.jpg"); 63 | const int row = height_; 64 | const int col = width_; 65 | Mat dst(row, col, CV_8UC3); 66 | Mat dst1; 67 | 68 | resize(img, dst, Size(col, row)); 69 | imwrite("..\\..\\test_imgs\\super_resolution\\LowResolution.png", dst); 70 | 71 | Mat dstresize; 72 | resize(dst, dstresize, Size(672, 672)); 73 | imwrite("..\\..\\test_imgs\\super_resolution\\RizeResolution.png", dstresize); 74 | 75 | cvtColor(dst, dst1, COLOR_BGR2YCrCb); 76 | Mat dstycbcr=dst1; 77 | vector channels1; 78 | vector channels2; 79 | Mat imageYChannel; 80 | Mat imageCrChannel; 81 | Mat imageCbChannel; 82 | split(dstycbcr, channels1); 83 | imageYChannel = channels1.at(0); 84 | imageCrChannel = channels1.at(1); 85 | imageCbChannel = channels1.at(2); 86 | 87 | 88 | float* output = input_image_.data(); 89 | fill(input_image_.begin(), input_image_.end(), 0.f); 90 | for (int c = 0;c < 1;c++) { 91 | for (int i = 0;i < row;i++) { 92 | for (int j = 0;j < col;j++) { 93 | output[c*row*col + i*col + j] = (imageYChannel.ptr(i)[j]); 94 | } 95 | } 96 | } 97 | 98 | double timeStart = (double)getTickCount(); 99 | for (int i = 0; i < 1000; i++) { 100 | OrtRun(session_, nullptr, input_names, &input_tensor_1, 1, output_names, 1, &output_tensor_1); 101 | } 102 | double nTime = ((double)getTickCount() - timeStart) / getTickFrequency(); 103 | cout << "running time :" << nTime << "sec\n" << endl; 104 | 105 | for (int i = 0; i < 672 * 672; i++) { 106 | outputimageY.ptr(i / 672)[i % 672] = results_[i]; 107 | } 108 | 109 | Mat outputimageCr; 110 | Mat outputimageCb; 111 | Mat outputimageY1 = outputimageY; 112 | resize(imageCrChannel, outputimageCr, Size(672, 672), INTER_CUBIC); 113 | resize(imageCbChannel, outputimageCb, Size(672, 672), INTER_CUBIC); 114 | channels2.push_back(outputimageY); 115 | channels2.push_back(outputimageCr); 116 | channels2.push_back(outputimageCb); 117 | Mat outputimage; 118 | Mat outputimagefinal; 119 | merge(channels2, outputimage); 120 | cvtColor(outputimage, outputimagefinal, COLOR_YCrCb2BGR); 121 | 122 | imwrite("..\\..\\test_imgs\\super_resolution\\SuperResolution.png", outputimagefinal); 123 | system("pause"); 124 | return 0; 125 | } 126 | 127 | -------------------------------------------------------------------------------- /Windows/onnx_super_resolustion/ort_test/ort_test.vcxproj.filters: -------------------------------------------------------------------------------- 1 |  2 | 3 | 4 | 5 | {4FC737F1-C7A5-4376-A066-2A32D752A2FF} 6 | cpp;c;cc;cxx;def;odl;idl;hpj;bat;asm;asmx 7 | 8 | 9 | {93995380-89BD-4b04-88EB-625FBE52EBFB} 10 | h;hh;hpp;hxx;hm;inl;inc;xsd 11 | 12 | 13 | {67DA6AB6-F800-4c08-8B7A-83BB121AAD01} 14 | rc;ico;cur;bmp;dlg;rc2;rct;bin;rgs;gif;jpg;jpeg;jpe;resx;tiff;tif;png;wav;mfcribbon-ms 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 头文件 23 | 24 | 25 | 头文件 26 | 27 | 28 | 29 | 30 | 源文件 31 | 32 | 33 | 源文件 34 | 35 | 36 | -------------------------------------------------------------------------------- /Windows/onnx_super_resolustion/ort_test/ort_test.vcxproj.user: -------------------------------------------------------------------------------- 1 |  2 | 3 | 4 | -------------------------------------------------------------------------------- /Windows/onnx_super_resolustion/ort_test/stdafx.cpp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tenglike1997/onnxruntime-projects/7e99a75f1ce36329c4dbf2863023e7e05c73a991/Windows/onnx_super_resolustion/ort_test/stdafx.cpp -------------------------------------------------------------------------------- /Windows/onnx_super_resolustion/ort_test/stdafx.h: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tenglike1997/onnxruntime-projects/7e99a75f1ce36329c4dbf2863023e7e05c73a991/Windows/onnx_super_resolustion/ort_test/stdafx.h -------------------------------------------------------------------------------- /Windows/onnx_super_resolustion/ort_test/targetver.h: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tenglike1997/onnxruntime-projects/7e99a75f1ce36329c4dbf2863023e7e05c73a991/Windows/onnx_super_resolustion/ort_test/targetver.h -------------------------------------------------------------------------------- /Windows/onnx_tiny_yolov2/ort_test.sln: -------------------------------------------------------------------------------- 1 |  2 | Microsoft Visual Studio Solution File, Format Version 12.00 3 | # Visual Studio 14 4 | VisualStudioVersion = 14.0.25420.1 5 | MinimumVisualStudioVersion = 10.0.40219.1 6 | Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "ort_test", "ort_test\ort_test.vcxproj", "{B0F7CDB8-E482-467B-B5AF-9B4AF7ABB5A4}" 7 | EndProject 8 | Global 9 | GlobalSection(SolutionConfigurationPlatforms) = preSolution 10 | Debug|x64 = Debug|x64 11 | Debug|x86 = Debug|x86 12 | Release|x64 = Release|x64 13 | Release|x86 = Release|x86 14 | EndGlobalSection 15 | GlobalSection(ProjectConfigurationPlatforms) = postSolution 16 | {B0F7CDB8-E482-467B-B5AF-9B4AF7ABB5A4}.Debug|x64.ActiveCfg = Debug|x64 17 | {B0F7CDB8-E482-467B-B5AF-9B4AF7ABB5A4}.Debug|x64.Build.0 = Debug|x64 18 | {B0F7CDB8-E482-467B-B5AF-9B4AF7ABB5A4}.Debug|x86.ActiveCfg = Debug|Win32 19 | {B0F7CDB8-E482-467B-B5AF-9B4AF7ABB5A4}.Debug|x86.Build.0 = Debug|Win32 20 | {B0F7CDB8-E482-467B-B5AF-9B4AF7ABB5A4}.Release|x64.ActiveCfg = Release|x64 21 | {B0F7CDB8-E482-467B-B5AF-9B4AF7ABB5A4}.Release|x64.Build.0 = Release|x64 22 | {B0F7CDB8-E482-467B-B5AF-9B4AF7ABB5A4}.Release|x86.ActiveCfg = Release|Win32 23 | {B0F7CDB8-E482-467B-B5AF-9B4AF7ABB5A4}.Release|x86.Build.0 = Release|Win32 24 | EndGlobalSection 25 | GlobalSection(SolutionProperties) = preSolution 26 | HideSolutionNode = FALSE 27 | EndGlobalSection 28 | EndGlobal 29 | -------------------------------------------------------------------------------- /Windows/onnx_tiny_yolov2/ort_test/ReadMe.txt: -------------------------------------------------------------------------------- 1 | ======================================================================== 2 | 控制台应用程序:ort_test 项目概述 3 | ======================================================================== 4 | 5 | 应用程序向导已为您创建了此 ort_test 应用程序。 6 | 7 | 本文件概要介绍组成 ort_test 应用程序的每个文件的内容。 8 | 9 | 10 | ort_test.vcxproj 11 | 这是使用应用程序向导生成的 VC++ 项目的主项目文件,其中包含生成该文件的 Visual C++ 的版本信息,以及有关使用应用程序向导选择的平台、配置和项目功能的信息。 12 | 13 | ort_test.vcxproj.filters 14 | 这是使用“应用程序向导”生成的 VC++ 项目筛选器文件。它包含有关项目文件与筛选器之间的关联信息。在 IDE 中,通过这种关联,在特定节点下以分组形式显示具有相似扩展名的文件。例如,“.cpp”文件与“源文件”筛选器关联。 15 | 16 | ort_test.cpp 17 | 这是主应用程序源文件。 18 | 19 | ///////////////////////////////////////////////////////////////////////////// 20 | 其他标准文件: 21 | 22 | StdAfx.h, StdAfx.cpp 23 | 这些文件用于生成名为 ort_test.pch 的预编译头 (PCH) 文件和名为 StdAfx.obj 的预编译类型文件。 24 | 25 | ///////////////////////////////////////////////////////////////////////////// 26 | 其他注释: 27 | 28 | 应用程序向导使用“TODO:”注释来指示应添加或自定义的源代码部分。 29 | 30 | ///////////////////////////////////////////////////////////////////////////// 31 | -------------------------------------------------------------------------------- /Windows/onnx_tiny_yolov2/ort_test/onnxruntime.lib: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tenglike1997/onnxruntime-projects/7e99a75f1ce36329c4dbf2863023e7e05c73a991/Windows/onnx_tiny_yolov2/ort_test/onnxruntime.lib -------------------------------------------------------------------------------- /Windows/onnx_tiny_yolov2/ort_test/ort_test.cpp: -------------------------------------------------------------------------------- 1 | #include "stdafx.h" 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | 18 | using namespace cv; 19 | using namespace std; 20 | 21 | 22 | #pragma comment(lib, "user32.lib") 23 | #pragma comment(lib, "gdi32.lib") 24 | #pragma comment(lib, "onnxruntime.lib") 25 | 26 | Ort::Env env{ ORT_LOGGING_LEVEL_WARNING, "test" }; 27 | 28 | static constexpr const int width_ = 416; 29 | static constexpr const int height_ = 416; 30 | static constexpr const int channel = 3; 31 | 32 | static constexpr const float confidence_threshold = 0.1; 33 | static constexpr const float nms_threshold =0.3f; 34 | 35 | std::array input_image_{}; 36 | std::array results_{}; 37 | std::array results_extra{}; 38 | vectorTrue_Point; 39 | vectorType; 40 | vectorTrue_TypeIndex; 41 | float anchors[10] = { 1.08, 1.19, 3.42, 4.41, 6.63, 11.38, 9.42, 5.11, 16.62, 10.52 }; 42 | String classes[20] = { "aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor" }; 43 | Scalar colors[20] = { Scalar(255, 0, 0), Scalar(0, 255, 0),Scalar(0,0,255), 44 | Scalar(255,255,0),Scalar(255,0,255), Scalar(0,255,255), 45 | Scalar(255,255,255), Scalar(127,0,0),Scalar(0,127,0), 46 | Scalar(0,0,127),Scalar(127,127,0), Scalar(127,0,127), 47 | Scalar(0,127,127), Scalar(127,127,127),Scalar(127,255,0), 48 | Scalar(127,0,255),Scalar(127,255,255), Scalar(0,127,255), 49 | Scalar(255,127,0), Scalar(0,255,127) }; 50 | //int result_[height_ * width_*5]{ 0 }; 51 | 52 | Ort::Value input_tensor_{ nullptr }; 53 | std::array input_shape_{ 1,channel, height_, width_ }; 54 | 55 | Ort::Value output_tensor_{ nullptr }; 56 | std::array output_shape_{ 1,125,13, 13 }; 57 | 58 | OrtSession* session_ = nullptr; 59 | OrtSessionOptions* session_option; 60 | 61 | 62 | 63 | float sigmoid(float x) { 64 | return 1.0 / (1.0 + exp(-x)); 65 | } 66 | 67 | void nms( 68 | const std::vector& srcRects, 69 | std::vector& resRects, 70 | float thresh 71 | ) 72 | { 73 | resRects.clear(); 74 | 75 | const size_t size = srcRects.size(); 76 | if (!size) 77 | { 78 | return; 79 | } 80 | 81 | // Sort the bounding boxes by the bottom - right y - coordinate of the bounding box 82 | std::multimap idxs; 83 | for (size_t i = 0; i < size; ++i) 84 | { 85 | idxs.insert(std::pair(srcRects[i].br().y, i)); 86 | } 87 | 88 | // keep looping while some indexes still remain in the indexes list 89 | while (idxs.size() > 0) 90 | { 91 | // grab the last rectangle 92 | auto lastElem = --std::end(idxs); 93 | const cv::Rect& rect1 = srcRects[lastElem->second]; 94 | True_TypeIndex.push_back(lastElem->second); 95 | resRects.push_back(rect1); 96 | 97 | idxs.erase(lastElem); 98 | 99 | for (auto pos = std::begin(idxs); pos != std::end(idxs); ) 100 | { 101 | // grab the current rectangle 102 | const cv::Rect& rect2 = srcRects[pos->second]; 103 | 104 | float intArea = (rect1 & rect2).area(); 105 | float unionArea = rect1.area() + rect2.area() - intArea; 106 | float overlap = intArea / unionArea; 107 | 108 | // if there is sufficient overlap, suppress the current bounding box 109 | if (overlap > thresh) 110 | { 111 | pos = idxs.erase(pos); 112 | } 113 | else 114 | { 115 | ++pos; 116 | } 117 | } 118 | } 119 | } 120 | 121 | 122 | int main() 123 | { 124 | auto allocator_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); 125 | input_tensor_ = Ort::Value::CreateTensor(allocator_info, input_image_.data(), input_image_.size(), input_shape_.data(), input_shape_.size()); 126 | output_tensor_ = Ort::Value::CreateTensor(allocator_info, results_.data(), results_.size(), output_shape_.data(), output_shape_.size()); 127 | const char* input_names[] = { "image" }; 128 | const char* output_names[] = { "grid" }; 129 | ORT_THROW_ON_ERROR(OrtCreateSessionOptions(&session_option)); 130 | //ORT_THROW_ON_ERROR(OrtSessionOptionsAppendExecutionProvider_Mkldnn(session_option, 1)); 131 | ORT_THROW_ON_ERROR(OrtSessionOptionsAppendExecutionProvider_Tensorrt(session_option, 0)); 132 | ORT_THROW_ON_ERROR(OrtSessionOptionsAppendExecutionProvider_CUDA(session_option, 0)); 133 | ORT_THROW_ON_ERROR(OrtCreateSession(env, L"tiny_yolov2.onnx", session_option, &session_)); 134 | OrtValue *input_tensor_1 = input_tensor_; 135 | OrtValue *output_tensor_1 = output_tensor_; 136 | 137 | Mat img = imread("..\\..\\test_imgs\\detection\\000001.jpg"); 138 | const int row = height_; 139 | const int col = width_; 140 | Mat dst(row, col, CV_8UC3); 141 | Mat dst2; 142 | resize(img, dst, Size(col, row)); 143 | cvtColor(dst,dst,CV_BGR2RGB); 144 | 145 | float* output = input_image_.data(); 146 | fill(input_image_.begin(), input_image_.end(), 0.f); 147 | Scalar rgb_mean = mean(dst); 148 | for (int c = 0;c < 3;c++) { 149 | for (int i = 0;i < row;i++) { 150 | for (int j = 0;j < col;j++) { 151 | output[c*row*col + i*col + j] = (dst.ptr(i)[j * 3 + c]); 152 | } 153 | } 154 | } 155 | 156 | double timeStart = (double)getTickCount(); 157 | for (int i = 0; i < 1000; i++) { 158 | OrtRun(session_, nullptr, input_names, &input_tensor_1, 1, output_names, 1, &output_tensor_1); 159 | } 160 | double nTime = ((double)getTickCount() - timeStart) / getTickFrequency(); 161 | cout << "running time :" << nTime << "sec\n" << endl; 162 | 163 | Vec4i true_point; 164 | int result_type; 165 | for (int i = 0; i < 13 * 13; i++) { 166 | for (int j = 0; j < 5; j++) { 167 | float sum = 0; 168 | for (int k = 0; k < 20; k++) { 169 | results_extra[k] = results_[i + 13 * 13 * (25 * j + k + 5)]; 170 | sum += exp(results_extra[k]); 171 | } 172 | result_type = std::distance(results_extra.begin(), std::max_element(results_extra.begin(), results_extra.end())); 173 | float probability = exp(results_extra[result_type]) / sum; 174 | if (sigmoid(results_[i + 13 * 13 * (25 * j + 4)])*probability >= confidence_threshold) { 175 | true_point[0] = (sigmoid(results_[i + 13 * 13 * (25 * j)]) + i % 13)*32.0; 176 | true_point[1] = (sigmoid(results_[i + 13 * 13 * (25 * j + 1)]) + i / 13)*32.0; 177 | true_point[2] = exp(results_[i + 13 * 13 * (25 * j + 2)])*anchors[2 * j] * 32.0; 178 | true_point[3] = exp(results_[i + 13 * 13 * (25 * j + 3)])*anchors[2 * j + 1] * 32.0; 179 | True_Point.push_back(true_point); 180 | Type.push_back(result_type); 181 | } 182 | } 183 | } 184 | vectorsrcRects; 185 | vectorresRects; 186 | Rect rect; 187 | for (int i = 0;i < True_Point.size();i++) { 188 | rect = Rect(True_Point[i][0] - True_Point[i][2] / 2.0, True_Point[i][1] - True_Point[i][3] / 2.0, True_Point[i][2], True_Point[i][3]); 189 | srcRects.push_back(rect); 190 | } 191 | nms(srcRects, resRects, nms_threshold); 192 | for (int i = 0;i < resRects.size();i++) { 193 | rectangle(dst, resRects[i], Scalar(colors[Type[True_TypeIndex[i]]]), 3, 1, 0); 194 | cout << Type[True_TypeIndex[i]] << endl; 195 | putText(dst,classes[Type[True_TypeIndex[i]]],Point(resRects[i].x+5,resRects[i].y+13), FONT_HERSHEY_COMPLEX,0.5,Scalar(colors[Type[True_TypeIndex[i]]]),1,8); 196 | } 197 | imwrite("..\\..\\test_imgs\\detection\\test.jpg", dst); 198 | system("pause"); 199 | return 0; 200 | } 201 | 202 | -------------------------------------------------------------------------------- /Windows/onnx_tiny_yolov2/ort_test/ort_test.vcxproj.filters: -------------------------------------------------------------------------------- 1 |  2 | 3 | 4 | 5 | {4FC737F1-C7A5-4376-A066-2A32D752A2FF} 6 | cpp;c;cc;cxx;def;odl;idl;hpj;bat;asm;asmx 7 | 8 | 9 | {93995380-89BD-4b04-88EB-625FBE52EBFB} 10 | h;hh;hpp;hxx;hm;inl;inc;xsd 11 | 12 | 13 | {67DA6AB6-F800-4c08-8B7A-83BB121AAD01} 14 | rc;ico;cur;bmp;dlg;rc2;rct;bin;rgs;gif;jpg;jpeg;jpe;resx;tiff;tif;png;wav;mfcribbon-ms 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 头文件 23 | 24 | 25 | 头文件 26 | 27 | 28 | 29 | 30 | 源文件 31 | 32 | 33 | 源文件 34 | 35 | 36 | -------------------------------------------------------------------------------- /Windows/onnx_tiny_yolov2/ort_test/ort_test.vcxproj.user: -------------------------------------------------------------------------------- 1 |  2 | 3 | 4 | -------------------------------------------------------------------------------- /Windows/onnx_tiny_yolov2/ort_test/stdafx.cpp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tenglike1997/onnxruntime-projects/7e99a75f1ce36329c4dbf2863023e7e05c73a991/Windows/onnx_tiny_yolov2/ort_test/stdafx.cpp -------------------------------------------------------------------------------- /Windows/onnx_tiny_yolov2/ort_test/stdafx.h: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tenglike1997/onnxruntime-projects/7e99a75f1ce36329c4dbf2863023e7e05c73a991/Windows/onnx_tiny_yolov2/ort_test/stdafx.h -------------------------------------------------------------------------------- /Windows/onnx_tiny_yolov2/ort_test/targetver.h: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tenglike1997/onnxruntime-projects/7e99a75f1ce36329c4dbf2863023e7e05c73a991/Windows/onnx_tiny_yolov2/ort_test/targetver.h -------------------------------------------------------------------------------- /include/onnxruntime/core/common/code_location.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | 4 | #pragma once 5 | 6 | #include 7 | #include 8 | #include 9 | 10 | namespace onnxruntime { 11 | /** 12 | CodeLocation captures information on where in the source code a message came from. 13 | */ 14 | struct CodeLocation { 15 | /** 16 | @param file_path Usually the value of __FILE__ 17 | @param line Usually the value of __LINE__ 18 | @param func Usually the value of __PRETTY_FUNCTION__ or __FUNCTION__ 19 | */ 20 | CodeLocation(const char* file_path, const int line, const char* func) 21 | : file_and_path{file_path}, line_num{line}, function{func} { 22 | } 23 | 24 | /** 25 | @param file_path Usually the value of __FILE__ 26 | @param line Usually the value of __LINE__ 27 | @param func Usually the value of __PRETTY_FUNCTION__ or __FUNCTION__ 28 | @param stacktrace Stacktrace from source of message. 29 | */ 30 | CodeLocation(const char* file_path, const int line, const char* func, const std::vector& stacktrace) 31 | : file_and_path{file_path}, line_num{line}, function{func}, stacktrace(stacktrace) { 32 | } 33 | 34 | std::string FileNoPath() const { 35 | // assuming we always have work to do, so not trying to avoid creating a new string if 36 | // no path was removed. 37 | return file_and_path.substr(file_and_path.find_last_of("/\\") + 1); 38 | } 39 | 40 | enum Format { 41 | kFilename, 42 | kFilenameAndPath 43 | }; 44 | 45 | std::string ToString(Format format = Format::kFilename) const { 46 | std::ostringstream out; 47 | out << (format == Format::kFilename ? FileNoPath() : file_and_path) << ":" << line_num << " " << function; 48 | return out.str(); 49 | } 50 | 51 | const std::string file_and_path; 52 | const int line_num; 53 | const std::string function; 54 | const std::vector stacktrace; 55 | }; 56 | 57 | } // namespace onnxruntime 58 | -------------------------------------------------------------------------------- /include/onnxruntime/core/common/const_pointer_container.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | 4 | #pragma once 5 | 6 | #include 7 | 8 | namespace onnxruntime { 9 | /** 10 | Container has T* entries. e.g. std::vector, and this class provides const access to those 11 | via iterators and direct access, as the standard behavior only makes the pointer constant, 12 | and not what is pointed too. i.e. you get a const pointer to T not a pointer to const T without this wrapper. 13 | See https://stackoverflow.com/questions/8017036/understanding-const-iterator-with-pointers 14 | */ 15 | template 16 | class ConstPointerContainer { 17 | public: 18 | using T = typename std::remove_pointer::type; 19 | 20 | class ConstIterator { 21 | public: 22 | using const_iterator = typename Container::const_iterator; 23 | using iterator_category = std::input_iterator_tag; 24 | using value_type = T*; 25 | using difference_type = std::ptrdiff_t; 26 | using pointer = T**; 27 | using reference = T*&; 28 | 29 | /** Construct iterator for container that will return const T* entries.*/ 30 | explicit ConstIterator(const_iterator position) noexcept : current_{position}, item_{nullptr} {} 31 | ConstIterator(const ConstIterator& other) = default; 32 | ConstIterator& operator=(const ConstIterator& other) = default; 33 | 34 | bool operator==(const ConstIterator& other) const noexcept { return current_ == other.current_; } 35 | bool operator!=(const ConstIterator& other) const noexcept { return current_ != other.current_; } 36 | 37 | ConstIterator& operator++() { 38 | ++current_; 39 | return *this; 40 | } 41 | 42 | ConstIterator operator++(int) { 43 | ConstIterator tmp{*this}; 44 | ++(*this); 45 | return tmp; 46 | } 47 | 48 | const T*& operator*() const { 49 | item_ = *current_; 50 | return item_; 51 | } 52 | 53 | const T** operator->() const { return &(operator*()); }; 54 | 55 | private: 56 | const_iterator current_; 57 | mutable const T* item_; 58 | }; 59 | 60 | /** 61 | Construct wrapper class that will provide const access to the pointers in a container of non-const pointers. 62 | @param data Container with non-const pointers. e.g. std::vector 63 | */ 64 | explicit ConstPointerContainer(const Container& data) noexcept : data_(data) {} 65 | 66 | size_t size() const noexcept { return data_.size(); } 67 | bool empty() const noexcept { return data_.empty(); } 68 | 69 | ConstIterator cbegin() const noexcept { return ConstIterator(data_.cbegin()); } 70 | ConstIterator cend() const noexcept { return ConstIterator(data_.cend()); } 71 | 72 | ConstIterator begin() const noexcept { return ConstIterator(data_.cbegin()); } 73 | ConstIterator end() const noexcept { return ConstIterator(data_.cend()); } 74 | 75 | const T* operator[](size_t index) const { return data_[index]; } 76 | 77 | const T* at(size_t index) const { 78 | ORT_ENFORCE(index < data_.size()); 79 | return data_[index]; 80 | } 81 | 82 | private: 83 | const Container& data_; 84 | }; 85 | } // namespace onnxruntime 86 | -------------------------------------------------------------------------------- /include/onnxruntime/core/common/exceptions.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | 4 | #pragma once 5 | 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | 13 | #include "core/common/common.h" 14 | #include "core/common/code_location.h" 15 | 16 | namespace onnxruntime { 17 | 18 | class NotImplementedException : public std::logic_error { 19 | public: 20 | explicit NotImplementedException(const char* _Message = "Function not yet implemented") noexcept : std::logic_error(_Message){}; 21 | explicit NotImplementedException(const std::string& _Message = "Function not yet implemented") noexcept : std::logic_error(_Message){}; 22 | }; 23 | 24 | class TypeMismatchException : public std::logic_error { 25 | public: 26 | TypeMismatchException() noexcept : logic_error("Type mismatch"){}; 27 | }; 28 | 29 | class OnnxRuntimeException : public std::exception { 30 | public: 31 | OnnxRuntimeException(const CodeLocation& location, const std::string& msg) noexcept 32 | : OnnxRuntimeException(location, nullptr, msg) { 33 | } 34 | 35 | /** 36 | Create a new exception that captures the location it was thrown from. 37 | @param location Location in the source code the exception is being thrown from 38 | @param failed_condition Optional string containing the condition that failed. 39 | e.g. "tensor.Size() == input.Size()". May be nullptr. 40 | @param msg Message containing additional information about the exception cause. 41 | */ 42 | OnnxRuntimeException(const CodeLocation& location, const char* failed_condition, const std::string& msg) 43 | : location_{location} { 44 | std::ostringstream ss; 45 | 46 | ss << location.ToString(CodeLocation::kFilenameAndPath); // output full path in case just the filename is ambiguous 47 | if (failed_condition != nullptr) { 48 | ss << " " << failed_condition << " was false."; 49 | } 50 | 51 | ss << " " << msg << "\n"; 52 | if (!location.stacktrace.empty()) { 53 | ss << "Stacktrace:\n"; 54 | // skip the first entry in the stacktrace as we have that information from location.ToString() 55 | std::copy(++location.stacktrace.begin(), location.stacktrace.end(), std::ostream_iterator(ss, "\n")); 56 | } 57 | 58 | what_ = ss.str(); 59 | } 60 | 61 | const char* what() const noexcept override { 62 | return what_.c_str(); 63 | } 64 | 65 | private: 66 | const CodeLocation location_; 67 | const std::vector stacktrace_; 68 | std::string what_; 69 | }; 70 | 71 | } // namespace onnxruntime 72 | -------------------------------------------------------------------------------- /include/onnxruntime/core/common/logging/capture.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | 4 | #pragma once 5 | 6 | #include 7 | #include 8 | #include "core/common/common.h" 9 | #include "core/common/code_location.h" 10 | #include "core/common/logging/severity.h" 11 | 12 | namespace onnxruntime { 13 | namespace logging { 14 | 15 | class Logger; 16 | enum class DataType; 17 | 18 | /** 19 | Class to capture the details of a log message. 20 | */ 21 | class Capture { 22 | public: 23 | /** 24 | Initializes a new instance of the Capture class. 25 | @param logger The logger. 26 | @param severity The severity. 27 | @param category The category. 28 | @param dataType Type of the data. 29 | @param location The file location the log message is coming from. 30 | */ 31 | Capture(const Logger& logger, logging::Severity severity, const char* category, 32 | logging::DataType dataType, const CodeLocation& location) 33 | : logger_{&logger}, severity_{severity}, category_{category}, data_type_{dataType}, location_{location} { 34 | } 35 | 36 | /** 37 | The stream that can capture the message via operator<<. 38 | @returns Output stream. 39 | */ 40 | std::ostream& Stream() noexcept { 41 | return stream_; 42 | } 43 | 44 | #ifdef _MSC_VER 45 | // add SAL annotation for printf format string. requires Code Analysis to run to validate usage. 46 | #define msvc_printf_check _Printf_format_string_ 47 | #define __attribute__(x) // Disable for MSVC. Supported by GCC and CLang. 48 | #else 49 | #define msvc_printf_check 50 | #endif 51 | 52 | /** 53 | Captures a printf style log message. 54 | @param name="format">The printf format. 55 | @param name="">Arguments to the printf format if needed. 56 | @remarks 57 | A maximum of 2K of output will be captured currently. 58 | Non-static method, so 'this' is implicit first arg, and we use format(printf(2,3) 59 | */ 60 | void CapturePrintf(msvc_printf_check const char* format, ...) __attribute__((format(printf, 2, 3))); 61 | 62 | /** 63 | Process a printf style log message. 64 | @param format The printf format. 65 | @param ... Arguments to the printf format if needed. 66 | @remarks 67 | A maximum of 2K of output will be captured currently. 68 | Note: As va_list is 'char *', we have to disambiguate this from CapturePrintf 69 | so that something like "One string: %s", "the string" does not consider "the string" 70 | to be the va_list. 71 | */ 72 | void ProcessPrintf(msvc_printf_check const char* format, va_list args); 73 | 74 | logging::Severity Severity() const noexcept { 75 | return severity_; 76 | } 77 | 78 | char SeverityPrefix() const noexcept { 79 | // Carefully setup so severity_ is a valid index 80 | GSL_SUPPRESS(bounds .2) { 81 | return logging::SEVERITY_PREFIX[static_cast(severity_)]; 82 | } 83 | } 84 | 85 | const char* Category() const noexcept { 86 | return category_; 87 | } 88 | 89 | logging::DataType DataType() const noexcept { 90 | return data_type_; 91 | } 92 | 93 | const CodeLocation& Location() const noexcept { 94 | return location_; 95 | } 96 | 97 | std::string Message() const noexcept { 98 | return stream_.str(); 99 | } 100 | 101 | ~Capture(); 102 | 103 | private: 104 | ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Capture); 105 | 106 | const Logger* logger_; 107 | const logging::Severity severity_; 108 | const char* category_; 109 | const logging::DataType data_type_; 110 | const CodeLocation location_; 111 | 112 | std::ostringstream stream_; 113 | }; 114 | } // namespace logging 115 | } // namespace onnxruntime 116 | -------------------------------------------------------------------------------- /include/onnxruntime/core/common/logging/isink.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | 4 | #pragma once 5 | 6 | #include 7 | 8 | #include "core/common/logging/logging.h" 9 | 10 | namespace onnxruntime { 11 | namespace logging { 12 | class ISink { 13 | public: 14 | ISink() = default; 15 | 16 | /** 17 | Sends the message to the sink. 18 | @param timestamp The timestamp. 19 | @param logger_id The logger identifier. 20 | @param message The captured message. 21 | */ 22 | void Send(const Timestamp& timestamp, const std::string& logger_id, const Capture& message) { 23 | SendImpl(timestamp, logger_id, message); 24 | } 25 | 26 | /** 27 | Sends a Profiling Event Record to the sink. 28 | @param Profiling Event Record 29 | */ 30 | virtual void SendProfileEvent(profiling::EventRecord&) const {}; 31 | 32 | virtual ~ISink() = default; 33 | 34 | private: 35 | // Make Code Analysis happy by disabling all for now. Enable as needed. 36 | ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(ISink); 37 | 38 | virtual void SendImpl(const Timestamp& timestamp, const std::string& logger_id, const Capture& message) = 0; 39 | }; 40 | } // namespace logging 41 | } // namespace onnxruntime 42 | -------------------------------------------------------------------------------- /include/onnxruntime/core/common/logging/severity.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | 4 | #pragma once 5 | 6 | namespace onnxruntime { 7 | namespace logging { 8 | // mild violation of naming convention. the 'k' lets us use token concatenation in the macro 9 | // ::onnxruntime::Logging::Severity::k##severity. It's not legal to have ::onnxruntime::Logging::Severity::##severity 10 | // the uppercase makes the LOG macro usage look as expected for passing an enum value as it will be LOGS(logger, ERROR) 11 | enum class Severity { 12 | kVERBOSE = 0, 13 | kINFO = 1, 14 | kWARNING = 2, 15 | kERROR = 3, 16 | kFATAL = 4 17 | }; 18 | 19 | constexpr const char* SEVERITY_PREFIX = "VIWEF"; 20 | 21 | } // namespace logging 22 | } // namespace onnxruntime 23 | -------------------------------------------------------------------------------- /include/onnxruntime/core/common/ml_status.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | 4 | #pragma once 5 | 6 | #include 7 | 8 | namespace onnxruntime { 9 | 10 | enum class MLStatus : uint32_t { 11 | OK = 0, 12 | FAIL = 1, 13 | INVALID_ARGUMENT = 2, 14 | NO_SUCHFILE = 3, 15 | NO_MODEL = 4, 16 | ENGINE_ERROR = 5, 17 | RUNTIME_EXCEPTION = 6, 18 | INVALID_PROTOBUF = 7, 19 | MODEL_LOADED = 8, 20 | NOT_IMPLEMENTED = 9, 21 | INVALID_GRAPH = 10, 22 | SHAPE_INFERENCE_NOT_REGISTERED = 11, 23 | REQUIREMENT_NOT_REGISTERED = 12 24 | }; 25 | 26 | inline const char* MLStatusToString(MLStatus status) noexcept { 27 | switch (status) { 28 | case MLStatus::OK: 29 | return "SUCCESS"; 30 | case MLStatus::INVALID_ARGUMENT: 31 | return "INVALID_ARGUMENT"; 32 | case MLStatus::NO_SUCHFILE: 33 | return "NO_SUCHFILE"; 34 | case MLStatus::NO_MODEL: 35 | return "NO_MODEL"; 36 | case MLStatus::ENGINE_ERROR: 37 | return "ENGINE_ERROR"; 38 | case MLStatus::RUNTIME_EXCEPTION: 39 | return "RUNTIME_EXCEPTION"; 40 | case MLStatus::INVALID_PROTOBUF: 41 | return "INVALID_PROTOBUF"; 42 | case MLStatus::MODEL_LOADED: 43 | return "MODEL_LOADED"; 44 | case MLStatus::NOT_IMPLEMENTED: 45 | return "NOT_IMPLEMENTED"; 46 | case MLStatus::INVALID_GRAPH: 47 | return "INVALID_GRAPH"; 48 | case MLStatus::SHAPE_INFERENCE_NOT_REGISTERED: 49 | return "SHAPE_INFERENCE_NOT_REGISTERED"; 50 | case MLStatus::REQUIREMENT_NOT_REGISTERED: 51 | return "REQUIREMENT_NOT_REGISTERED"; 52 | default: 53 | return "GENERAL ERROR"; 54 | } 55 | } 56 | 57 | } // namespace onnxruntime 58 | -------------------------------------------------------------------------------- /include/onnxruntime/core/common/status.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 | Licensed under the Apache License, Version 2.0 (the "License"); 3 | you may not use this file except in compliance with the License. 4 | You may obtain a copy of the License at 5 | http://www.apache.org/licenses/LICENSE-2.0 6 | Unless required by applicable law or agreed to in writing, software 7 | distributed under the License is distributed on an "AS IS" BASIS, 8 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | See the License for the specific language governing permissions and 10 | limitations under the License. 11 | ==============================================================================*/ 12 | // Modifications Copyright (c) Microsoft. 13 | 14 | #pragma once 15 | 16 | #include 17 | #include 18 | #include 19 | #include "core/common/ml_status.h" 20 | 21 | namespace onnxruntime { 22 | namespace common { 23 | 24 | enum StatusCategory { 25 | NONE = 0, 26 | SYSTEM = 1, 27 | ONNXRUNTIME = 2, 28 | }; 29 | 30 | /** 31 | Error code for ONNXRuntime. 32 | */ 33 | enum StatusCode { 34 | OK = static_cast(MLStatus::OK), 35 | FAIL = static_cast(MLStatus::FAIL), 36 | INVALID_ARGUMENT = static_cast(MLStatus::INVALID_ARGUMENT), 37 | NO_SUCHFILE = static_cast(MLStatus::NO_SUCHFILE), 38 | NO_MODEL = static_cast(MLStatus::NO_MODEL), 39 | ENGINE_ERROR = static_cast(MLStatus::ENGINE_ERROR), 40 | RUNTIME_EXCEPTION = static_cast(MLStatus::RUNTIME_EXCEPTION), 41 | INVALID_PROTOBUF = static_cast(MLStatus::INVALID_PROTOBUF), 42 | MODEL_LOADED = static_cast(MLStatus::MODEL_LOADED), 43 | NOT_IMPLEMENTED = static_cast(MLStatus::NOT_IMPLEMENTED), 44 | INVALID_GRAPH = static_cast(MLStatus::INVALID_GRAPH), 45 | SHAPE_INFERENCE_NOT_REGISTERED = static_cast(MLStatus::SHAPE_INFERENCE_NOT_REGISTERED), 46 | REQUIREMENT_NOT_REGISTERED = static_cast(MLStatus::REQUIREMENT_NOT_REGISTERED), 47 | }; 48 | 49 | class Status { 50 | public: 51 | Status() noexcept = default; 52 | 53 | Status(StatusCategory category, int code, const std::string& msg); 54 | 55 | Status(StatusCategory category, int code, const char* msg); 56 | 57 | Status(StatusCategory category, int code); 58 | 59 | Status(const Status& other) 60 | : state_((other.state_ == nullptr) ? nullptr : std::make_unique(*other.state_)) {} 61 | 62 | Status& operator=(const Status& other) { 63 | if (state_ != other.state_) { 64 | if (other.state_ == nullptr) { 65 | state_.reset(); 66 | } else { 67 | state_ = std::make_unique(*other.state_); 68 | } 69 | } 70 | return *this; 71 | } 72 | 73 | Status(Status&& other) = default; 74 | Status& operator=(Status&& other) = default; 75 | ~Status() = default; 76 | 77 | bool IsOK() const { 78 | return (state_ == nullptr); 79 | } 80 | 81 | int Code() const noexcept; 82 | 83 | StatusCategory Category() const noexcept; 84 | 85 | const std::string& ErrorMessage() const noexcept; 86 | 87 | std::string ToString() const; 88 | 89 | bool operator==(const Status& other) const { 90 | return (this->state_ == other.state_) || (ToString() == other.ToString()); 91 | } 92 | 93 | bool operator!=(const Status& other) const { 94 | return !(*this == other); 95 | } 96 | 97 | static Status OK() { 98 | return Status(); 99 | } 100 | 101 | private: 102 | static const std::string& EmptyString() noexcept; 103 | 104 | struct State { 105 | State(StatusCategory cat0, int code0, const std::string& msg0) 106 | : category(cat0), code(code0), msg(msg0) {} 107 | 108 | State(StatusCategory cat0, int code0, const char* msg0) 109 | : category(cat0), code(code0), msg(msg0) {} 110 | 111 | const StatusCategory category; 112 | const int code; 113 | const std::string msg; 114 | }; 115 | 116 | // As long as Code() is OK, state_ == nullptr. 117 | std::unique_ptr state_; 118 | }; 119 | 120 | inline std::ostream& operator<<(std::ostream& out, const Status& status) { 121 | return out << status.ToString(); 122 | } 123 | 124 | } // namespace common 125 | } // namespace onnxruntime 126 | -------------------------------------------------------------------------------- /include/onnxruntime/core/framework/alloc_kind.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | 4 | #pragma once 5 | #include 6 | 7 | namespace onnxruntime { 8 | // The ml-Values fall into the following categories with respect to their 9 | // memory management: 10 | // - inference inputs: owned (allocated and freed) by caller, and is by 11 | // default read-only by the runtime. 12 | // - inference outputs: allocated by runtime, ownership transferred to 13 | // caller. TODO: Make sure this semantics is clear in InferenceSession API. 14 | // - weights (constant tensors): can be allocated once (statically), and 15 | // reused by all inference calls within an InferenceSession. 16 | // - tensor values: The lifetimes of these tensor-values are statically 17 | // determined, which is used for memory reuse/sharing optimizations. The 18 | // runtime allocates/frees these values at the right time (as determined 19 | // by the static allocation plan). Note that this is simplified since we 20 | // do not try to optimize for "slice" like ops, where we may be able to 21 | // conditionally reuse memory/data in some cases but not others. 22 | // Generalizing this is future work. 23 | 24 | enum class AllocKind { 25 | kAllocate = 0, 26 | kReuse = 1, 27 | kPreExisting = 2, 28 | kAllocateStatically = 3, 29 | kAllocateOutput = 4, 30 | kShare = 5 31 | }; 32 | 33 | std::ostream& operator<<(std::ostream& out, AllocKind alloc_kind); 34 | } // namespace onnxruntime 35 | -------------------------------------------------------------------------------- /include/onnxruntime/core/framework/customregistry.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | 4 | #pragma once 5 | 6 | #include "core/common/status.h" 7 | #include "core/common/logging/logging.h" 8 | #include "core/graph/schema_registry.h" 9 | #include "core/framework/op_kernel.h" 10 | #include "core/framework/kernel_def_builder.h" 11 | #include "core/framework/kernel_registry.h" 12 | 13 | namespace onnxruntime { 14 | 15 | /** 16 | Represents a registry that contains both custom kernels and custom schemas. 17 | */ 18 | class CustomRegistry final { 19 | public: 20 | CustomRegistry() : 21 | kernel_registry_(std::make_shared()), 22 | opschema_registry_(std::make_shared()) {} 23 | 24 | /** 25 | * Register a kernel definition together with kernel factory method to this session. 26 | * If any conflict happened between registered kernel def and built-in kernel def, 27 | * registered kernel will have higher priority. 28 | * Call this before invoking Initialize(). 29 | * @return OK if success. 30 | */ 31 | common::Status RegisterCustomKernel(KernelDefBuilder& kernel_def_builder, const KernelCreateFn& kernel_creator); 32 | 33 | common::Status RegisterCustomKernel(KernelCreateInfo&); 34 | 35 | common::Status RegisterOpSet(std::vector& schemas, const std::string& domain, 36 | int baseline_opset_version, int opset_version); 37 | 38 | const std::shared_ptr& GetKernelRegistry(); 39 | 40 | const std::shared_ptr& GetOpschemaRegistry(); 41 | 42 | private: 43 | ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(CustomRegistry); 44 | std::shared_ptr kernel_registry_; 45 | std::shared_ptr opschema_registry_; 46 | 47 | }; 48 | 49 | } // namespace onnxruntime 50 | -------------------------------------------------------------------------------- /include/onnxruntime/core/framework/execution_provider.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | 4 | #pragma once 5 | 6 | #include 7 | #include "gsl/pointers" 8 | 9 | #include "core/common/status.h" 10 | #include "core/framework/tensor.h" 11 | #include "core/framework/func_api.h" 12 | #include "core/framework/data_transfer.h" 13 | 14 | namespace onnxruntime { 15 | class GraphViewer; 16 | class Node; 17 | } // namespace onnxruntime 18 | namespace onnxruntime { 19 | 20 | struct ComputeCapability; 21 | class KernelRegistry; 22 | class KernelRegistryManager; 23 | 24 | /** 25 | Logical device representation. 26 | */ 27 | typedef std::map AllocatorMap; 28 | 29 | // if we are export the fused function to dll, the function will still in the same binary as lotus 30 | // use std function to give execution provider some chance to capture some state. 31 | using CreateFunctionStateFunc = std::function; 32 | using ComputeFunc = std::function; 33 | using DestroyFunctionStateFunc = std::function; 34 | 35 | struct NodeComputeInfo { 36 | CreateFunctionStateFunc create_state_func; 37 | ComputeFunc compute_func; 38 | DestroyFunctionStateFunc release_state_func; 39 | }; 40 | 41 | class IExecutionProvider { 42 | protected: 43 | IExecutionProvider(const std::string& type) : type_{type} {} 44 | 45 | public: 46 | virtual ~IExecutionProvider() = default; 47 | 48 | /** 49 | Get all IAllocators for <*this> execution provider. 50 | */ 51 | const std::vector>& GetAllocators() const { 52 | return allocator_list_; 53 | } 54 | 55 | /** 56 | * Get an allocator with specified device id and MemType. Return nullptr if it doesn't exist 57 | */ 58 | virtual AllocatorPtr GetAllocator(int id, OrtMemType mem_type) const; 59 | 60 | /** 61 | * Returns a data transfer object that implements methods to copy to and 62 | * from this device. 63 | * If no copy is required for the successful operation of this provider, 64 | * return a nullptr. 65 | */ 66 | virtual std::unique_ptr GetDataTransfer() const { 67 | return nullptr; 68 | } 69 | 70 | /** 71 | Get execution provider's capability for the specified . 72 | Return a bunch of IndexedSubGraphs <*this> execution provider can run if 73 | the sub-graph contains only one node or can fuse to run if the sub-graph 74 | contains more than one node. The node indexes contained in sub-graphs may 75 | have overlap, and it's ONNXRuntime's responsibility to do the partition 76 | and decide whether a node will be assigned to <*this> execution provider. 77 | */ 78 | virtual std::vector> 79 | GetCapability(const onnxruntime::GraphViewer& graph_viewer, 80 | const std::vector& kernel_registries) const; 81 | 82 | /** 83 | Get kernel registry per execution provider type. 84 | The KernelRegistry share pointer returned is shared across sessions. 85 | 86 | NOTE: this is a tricky but final solution to achieve following goals, 87 | 1. The execution provider type based kernel registry should be shared 88 | across sessions. 89 | Only one copy of this kind of kernel registry exists in ONNXRuntime 90 | with multiple sessions/models. 91 | 2. Adding an execution provider into ONNXRuntime does not need to touch ONNXRuntime 92 | frameowrk/session code. 93 | 3. onnxruntime (framework/session) does not depend on any specific 94 | execution provider lib. 95 | */ 96 | virtual std::shared_ptr GetKernelRegistry() const; 97 | 98 | /** 99 | Returns an opaque handle whose exact type varies based on the provider 100 | and is interpreted accordingly by the corresponding kernel implementation. 101 | For Direct3D operator kernels, this may return an IUnknown supporting 102 | QueryInterface to ID3D12GraphicsCommandList1. 103 | */ 104 | virtual const void* GetExecutionHandle() const noexcept { 105 | return nullptr; 106 | } 107 | 108 | /** 109 | @return type of the execution provider; should match that set in the node 110 | through the SetExecutionProvider API. Example valid return values are: 111 | kCpuExecutionProvider, kCudaExecutionProvider 112 | */ 113 | const std::string& Type() const { return type_; } 114 | 115 | /** 116 | Blocks until the device has completed all preceding requested tasks. 117 | Currently this is primarily used by the IOBinding object to ensure that all 118 | inputs have been copied to the device before execution begins. 119 | */ 120 | virtual common::Status Sync() const; 121 | 122 | /** 123 | Called when InferenceSession::Run started 124 | NOTE that due to async execution in provider, the actual work of previous 125 | Run may not be finished on device This function should be regarded as the 126 | point after which a new Run would start to submit commands from CPU 127 | */ 128 | virtual common::Status OnRunStart(); 129 | 130 | /** 131 | Called when InferenceSession::Run ended 132 | NOTE that due to async execution in provider, the actual work of this Run 133 | may not be finished on device This function should be regarded as the point 134 | that all commands of current Run has been submmited by CPU 135 | */ 136 | virtual common::Status OnRunEnd(); 137 | 138 | void InsertAllocator(AllocatorPtr allocator); 139 | 140 | /** 141 | Given a list of fused_node, return create_state/compute/release_state func for each node. 142 | */ 143 | virtual common::Status Compile(const std::vector& fused_node, 144 | std::vector& node_compute_funcs); 145 | 146 | /** 147 | Given a list of fused_node, return a dll that expose functions for each node. 148 | For each node, there should be three symbols: 149 | Create_State_${node_name} 150 | Compute_${node_name} 151 | Release_State_${node_name} 152 | */ 153 | virtual common::Status Compile(const std::vector& fused_node, 154 | std::string& dll_path); 155 | 156 | private: 157 | const std::string type_; 158 | AllocatorMap allocators_; 159 | 160 | // convenience list of the allocators so GetAllocatorList doesn't have to build a new vector each time 161 | // contains the same instances as allocators_ 162 | std::vector> allocator_list_; 163 | }; 164 | } // namespace onnxruntime 165 | -------------------------------------------------------------------------------- /include/onnxruntime/core/framework/fence.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | 4 | #pragma once 5 | 6 | #include "core/common/common.h" 7 | #include "core/graph/basic_types.h" 8 | 9 | namespace onnxruntime { 10 | 11 | /* 12 | We use a simple fence mechanism for async compute. Assumptions in this fence mechanism: 13 | * Execution provider command queues, which execute in the same order of submit 14 | * No fence needed for kernels within one execution provider command queue 15 | * Fence is used to synchronize between command queues, and execution providers 16 | 17 | Fence usage: 18 | 1. Fence object would be created by allocation planer for input/output when KernelDef::ExecQueueId() is not zero 19 | 2. If fence object exists, executor would call BeforeUsingAs* prior to kernel::Compute(), and AfterUsedAs* afterwards 20 | */ 21 | class IFence { 22 | public: 23 | virtual ~IFence() = default; 24 | 25 | /** 26 | Called by executor before OrtValue is used as input in a compute kernel in provider_type and exec queue_id 27 | This should wait in the specified provider's exec queue for previous write to OrtValue to finish 28 | */ 29 | virtual void BeforeUsingAsInput(onnxruntime::ProviderType provider_type, int queue_id) = 0; 30 | 31 | /** 32 | Called by executor before OrtValue is used as output in a compute kernel in provider_type and exec queue_id 33 | This should wait in the specified provider's exec queue for previous read to OrtValue to finish 34 | */ 35 | virtual void BeforeUsingAsOutput(onnxruntime::ProviderType provider_type, int queue_id) = 0; 36 | 37 | /** 38 | Called by executor after OrtValue is used as input in a compute kernel in provider_type and exec queue_id 39 | This should update the read fence of the MLValue 40 | */ 41 | virtual void AfterUsedAsInput(int queue_id) = 0; 42 | 43 | /** 44 | Called by executor after OrtValue is used as output in a compute kernel in provider_type and exec queue_id 45 | This should update the write fence of the MLValue 46 | */ 47 | virtual void AfterUsedAsOutput(int queue_id) = 0; 48 | 49 | /** 50 | Called by executor before release OrtValue to see whether async data read is finished or not. This is non-blocking. 51 | */ 52 | virtual bool CanRelease() = 0; 53 | }; 54 | using Fence_t = IFence*; 55 | using FencePtr = std::shared_ptr; 56 | 57 | } // namespace onnxruntime 58 | -------------------------------------------------------------------------------- /include/onnxruntime/core/framework/framework_common.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | 4 | #pragma once 5 | 6 | #include 7 | #include 8 | #include 9 | #include "run_options.h" 10 | 11 | namespace onnxruntime { // forward declarations 12 | class Model; 13 | class GraphTransformer; 14 | class NodeArg; 15 | } // namespace onnxruntime 16 | 17 | namespace onnxruntime { 18 | using InputDefList = std::vector; 19 | using OutputDefList = std::vector; 20 | 21 | using NameMLValMap = std::unordered_map; 22 | } // namespace onnxruntime 23 | -------------------------------------------------------------------------------- /include/onnxruntime/core/framework/func_api.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "core/common/common.h" 3 | 4 | namespace onnxruntime { 5 | 6 | // AllocateFunc(void* handle, size_t alignment, size_t size) 7 | using AllocateFunc = void* (*)(void*, size_t, size_t); 8 | using DestroyFunc = void (*)(void*, void*); 9 | using AllocatorHandle = void*; 10 | 11 | typedef struct { 12 | //right now we only include allocation for host memory 13 | AllocateFunc allocate_func; 14 | DestroyFunc release_func; 15 | AllocatorHandle allocator_handle; 16 | const char* node_name; 17 | } ComputeContext; 18 | 19 | using FunctionState = void*; 20 | // take the ComputeContext, and create a function state. 21 | using CreateFunctionStateC = int (*)(ComputeContext*, FunctionState*); 22 | // pass in the function state and input/output tensors, perform compute and return status 23 | using ComputeFuncC = Status (*)(FunctionState, const OrtCustomOpApi*, OrtKernelContext*); 24 | // release the function state. 25 | using DestroyFunctionStateC = void (*)(FunctionState); 26 | } // namespace onnxruntime 27 | -------------------------------------------------------------------------------- /include/onnxruntime/core/framework/kernel_registry.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | 4 | #pragma once 5 | 6 | #include "core/framework/op_kernel.h" 7 | 8 | namespace onnxruntime { 9 | /** 10 | * Each provider has a KernelRegistry. Often, the KernelRegistry only belongs to that specific provider. 11 | * 12 | */ 13 | class KernelRegistry { 14 | public: 15 | KernelRegistry() = default; 16 | 17 | // Register a kernel with kernel definition and function to create the kernel. 18 | Status Register(KernelDefBuilder& kernel_def_builder, 19 | const KernelCreateFn& kernel_creator); 20 | 21 | Status Register(KernelCreateInfo&& create_info); 22 | 23 | // factory functions should always return a unique_ptr for maximum flexibility 24 | // for its clients unless the factory is managing the lifecycle of the pointer 25 | // itself. 26 | // TODO(Task:132) Make usage of unique_ptr/shared_ptr as out param consistent 27 | Status TryCreateKernel(const onnxruntime::Node& node, 28 | const IExecutionProvider& execution_provider, 29 | const std::unordered_map& constant_initialized_tensors, 30 | const OrtValueNameIdxMap& mlvalue_name_idx_map, 31 | const FuncManager& funcs_mgr, 32 | const DataTransferManager& data_transfer_mgr, 33 | std::unique_ptr& op_kernel) const; 34 | 35 | // Check if an execution provider can create kernel for a node and return 36 | // the kernel if so 37 | const KernelCreateInfo* TryFindKernel(const onnxruntime::Node& node, 38 | onnxruntime::ProviderType exec_provider) const; 39 | 40 | bool IsEmpty() const { return kernel_creator_fn_map_.empty(); } 41 | 42 | #ifdef onnxruntime_PYBIND_EXPORT_OPSCHEMA 43 | // This is used by the opkernel doc generator to enlist all registered operators for a given provider's opkernel 44 | const KernelCreateMap& GetKernelCreateMap() const 45 | { 46 | return kernel_creator_fn_map_; 47 | } 48 | #endif 49 | 50 | private: 51 | // Check whether the types of inputs/outputs of the given node match the extra 52 | // type-constraints of the given kernel. This serves two purposes: first, to 53 | // select the right kernel implementation based on the types of the arguments 54 | // when we have multiple kernels, e.g., Clip and Clip; second, to 55 | // accommodate (and check) mapping of ONNX (specification) type to the onnxruntime 56 | // implementation type (e.g., if we want to implement ONNX's float16 as a regular 57 | // float in onnxruntime). (The second, however, requires a globally uniform mapping.) 58 | // 59 | // Note that this is not intended for type-checking the node against the ONNX 60 | // type specification of the corresponding op, which is done before this check. 61 | // 62 | // if this function is called before graph partition, then node.provider is not set. 63 | // In this case, kernel_def.provider must equal to exec_provider 64 | // otherwise, kernel_def.provider must equal to node.provider. exec_provider is ignored. 65 | static bool VerifyKernelDef(const onnxruntime::Node& node, 66 | const KernelDef& kernel_def, 67 | std::string& error_str, 68 | onnxruntime::ProviderType exec_provider = ""); 69 | 70 | // Kernel create function map from op name to kernel creation info. 71 | KernelCreateMap kernel_creator_fn_map_; 72 | }; 73 | } // namespace onnxruntime 74 | -------------------------------------------------------------------------------- /include/onnxruntime/core/framework/ml_value.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | 4 | #pragma once 5 | 6 | #include 7 | #include "core/common/common.h" 8 | #include "core/common/exceptions.h" 9 | #include "core/framework/allocator.h" 10 | #include "core/framework/data_types.h" 11 | #include "core/framework/tensor.h" 12 | 13 | /** 14 | Represents both tensors and non-tensors. 15 | */ 16 | struct OrtValue { 17 | public: 18 | OrtValue() : data_(nullptr) {} 19 | virtual ~OrtValue() = default; 20 | 21 | OrtValue(void* pData, onnxruntime::MLDataType type, onnxruntime::DeleteFunc deleter) { 22 | Init(pData, type, deleter); 23 | } 24 | 25 | void Init(void* pData, onnxruntime::MLDataType type, onnxruntime::DeleteFunc deleter) { 26 | data_.reset(pData, deleter); 27 | type_ = type; 28 | } 29 | 30 | bool IsAllocated() const { 31 | return data_ && type_; 32 | } 33 | 34 | template 35 | const T& Get() const { 36 | ORT_ENFORCE(onnxruntime::DataTypeImpl::GetType() == type_, onnxruntime::DataTypeImpl::GetType(), " != ", type_); 37 | return *static_cast(data_.get()); 38 | } 39 | 40 | template 41 | T* GetMutable() { 42 | ORT_ENFORCE(onnxruntime::DataTypeImpl::GetType() == type_, onnxruntime::DataTypeImpl::GetType(), " != ", type_); 43 | return static_cast(data_.get()); 44 | } 45 | 46 | bool IsTensor() const noexcept { 47 | return onnxruntime::DataTypeImpl::GetType() == type_; 48 | } 49 | 50 | onnxruntime::MLDataType Type() const { 51 | return type_; 52 | } 53 | 54 | onnxruntime::Fence_t Fence() const { 55 | return fence_.get(); 56 | } 57 | 58 | void SetFence(onnxruntime::FencePtr fence) { 59 | fence_ = fence; 60 | } 61 | 62 | void ShareFenceWith(OrtValue& v) { 63 | fence_ = v.fence_; 64 | } 65 | 66 | private: 67 | std::shared_ptr data_; 68 | onnxruntime::MLDataType type_{nullptr}; 69 | onnxruntime::FencePtr fence_; 70 | }; 71 | 72 | //TODO: remove the following line 73 | #define MLValue OrtValue 74 | -------------------------------------------------------------------------------- /include/onnxruntime/core/framework/op_kernel_info.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | 4 | #pragma once 5 | 6 | #include "core/framework/execution_provider.h" 7 | #include "core/framework/kernel_def_builder.h" 8 | #include "core/framework/ml_value.h" 9 | #include "core/framework/op_node_proto_helper.h" 10 | #include "core/graph/graph_viewer.h" 11 | #include "gsl/span" 12 | #include "gsl/gsl_util" 13 | 14 | namespace onnxruntime { 15 | 16 | class OrtValueNameIdxMap; 17 | class FuncManager; 18 | class DataTransferManager; 19 | 20 | // A very light-weight class, which works as an aggregated 21 | // view of all data needed for constructing a Kernel instance. 22 | // NOTE: it does not own/hold any objects. 23 | class OpKernelInfo : public OpNodeProtoHelper { 24 | public: 25 | explicit OpKernelInfo(const onnxruntime::Node& node, 26 | const KernelDef& kernel_def, 27 | const IExecutionProvider& execution_provider, 28 | const std::unordered_map& constant_initialized_tensors, 29 | const OrtValueNameIdxMap& mlvalue_name_idx_map, 30 | const FuncManager& funcs_mgr, 31 | const DataTransferManager& data_transfer_mgr); 32 | 33 | OpKernelInfo(const OpKernelInfo& other); 34 | 35 | const OrtMemoryInfo& GetMemoryInfo(int device_id, OrtMemType mem_type) const; 36 | 37 | AllocatorPtr GetAllocator(int device_id, OrtMemType mem_type) const; 38 | 39 | const KernelDef& GetKernelDef() const; 40 | 41 | const IExecutionProvider* GetExecutionProvider() const noexcept; 42 | 43 | const DataTransferManager& GetDataTransferManager() const noexcept; 44 | 45 | const onnxruntime::Node& node() const noexcept; 46 | 47 | bool TryGetConstantInput(int input_index, const Tensor** constant_input_value) const; 48 | 49 | common::Status GetFusedFuncs(ComputeFunc* compute, 50 | CreateFunctionStateFunc* create, 51 | DestroyFunctionStateFunc* release) const; 52 | 53 | private: 54 | ORT_DISALLOW_MOVE(OpKernelInfo); 55 | ORT_DISALLOW_ASSIGNMENT(OpKernelInfo); 56 | 57 | const onnxruntime::Node& node_; 58 | const KernelDef& kernel_def_; 59 | // For non cpu/cuda case, this pointer should be set so that function kernel 60 | // will delegate kernel compute call to compute call. 61 | gsl::not_null execution_provider_; 62 | const std::unordered_map& constant_initialized_tensors_; 63 | const OrtValueNameIdxMap& ort_value_name_idx_map_; 64 | const FuncManager& funcs_mgr_; 65 | const DataTransferManager& data_transfer_mgr_; 66 | ProtoHelperNodeContext proto_helper_context_; 67 | }; 68 | 69 | } // namespace onnxruntime 70 | -------------------------------------------------------------------------------- /include/onnxruntime/core/framework/op_node_proto_helper.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | 4 | #pragma once 5 | 6 | #include "core/common/status.h" 7 | #include "core/graph/graph_viewer.h" 8 | #include "gsl/span" 9 | 10 | #ifdef __has_attribute 11 | #define ORT_HAVE_ATTRIBUTE(x) __has_attribute(x) 12 | #else 13 | #define ORT_HAVE_ATTRIBUTE(x) 0 14 | #endif 15 | 16 | #if ORT_HAVE_ATTRIBUTE(nodiscard) 17 | #define MUST_USE_RESULT [[nodiscard]] 18 | #elif defined(__clang__) && ORT_HAVE_ATTRIBUTE(warn_unused_result) 19 | #define MUST_USE_RESULT __attribute__((warn_unused_result)) 20 | #else 21 | #define MUST_USE_RESULT 22 | #endif 23 | 24 | class IMLOpKernel; 25 | 26 | namespace onnxruntime { 27 | 28 | /** 29 | A set of wrappers with common signatures for use with both OpKernelInfo 30 | (as its base class) and InferenceContext. Used by ABI kernels for both 31 | shape / type inference and kernel construction 32 | */ 33 | template 34 | class OpNodeProtoHelper { 35 | public: 36 | explicit OpNodeProtoHelper(const Impl_t* impl) : impl_(impl) {} 37 | 38 | /** 39 | Get a single attribute 40 | Call this function for a required attribute or when a default value for an optional attribute is specified in the op schema 41 | */ 42 | template 43 | MUST_USE_RESULT Status GetAttr(const std::string& name, T* value) const; 44 | 45 | /** 46 | Get a single attribute 47 | Call this function only when a default value for an optional attribute isn't specified in the op schema 48 | */ 49 | template 50 | T GetAttrOrDefault(const std::string& name, const T& default_value) const { 51 | T tmp; 52 | return GetAttr(name, &tmp).IsOK() ? tmp : default_value; 53 | } 54 | 55 | /** 56 | Get a single attribute 57 | Call this function only when a default value for an optional attribute isn't specified in the op schema 58 | */ 59 | template 60 | void GetAttrOrDefault(const std::string& name, T* value, const T& default_value) const { 61 | if (!GetAttr(name, value).IsOK()) 62 | *value = default_value; 63 | } 64 | 65 | /** 66 | Get repeated attributes 67 | Call this function only when a default value for an optional attribute isn't specified in the op schema 68 | */ 69 | template 70 | MUST_USE_RESULT std::vector GetAttrsOrDefault(const std::string& name, const std::vector& default_value = std::vector{}) const { 71 | std::vector tmp; 72 | return GetAttrs(name, tmp).IsOK() ? tmp : default_value; 73 | } 74 | 75 | /** 76 | Get repeated attributes 77 | */ 78 | template 79 | MUST_USE_RESULT Status GetAttrs(const std::string& name, std::vector& values) const; 80 | 81 | template 82 | MUST_USE_RESULT Status GetAttrs(const std::string& name, gsl::span values) const; 83 | 84 | uint32_t GetPrimitiveAttrElementCount(ONNX_NAMESPACE::AttributeProto_AttributeType type, 85 | const std::string& name) const noexcept; 86 | 87 | bool HasPrimitiveAttribute(ONNX_NAMESPACE::AttributeProto_AttributeType type, 88 | const std::string& name) const noexcept; 89 | 90 | uint32_t GetInputCount() const { 91 | return gsl::narrow_cast(impl_->getNumInputs()); 92 | } 93 | 94 | uint32_t GetOutputCount() const { 95 | return gsl::narrow_cast(impl_->getNumOutputs()); 96 | } 97 | 98 | const ONNX_NAMESPACE::TypeProto* GetInputType(size_t index) const { 99 | return impl_->getInputType(index); 100 | } 101 | 102 | const ONNX_NAMESPACE::TypeProto* GetOutputType(size_t index) const { 103 | // Work around lack of a const method from the onnx InferenceContext interface 104 | return const_cast(impl_)->getOutputType(index); 105 | } 106 | 107 | // Try to query an attribute, returning nullptr if it doesn't exist 108 | const ONNX_NAMESPACE::AttributeProto* TryGetAttribute(const std::string& name) const { 109 | return impl_->getAttribute(name); 110 | } 111 | 112 | const ONNX_NAMESPACE::AttributeProto* GetAttribute(const std::string& name) const { 113 | const ONNX_NAMESPACE::AttributeProto* attr = TryGetAttribute(name); 114 | ORT_ENFORCE(attr != nullptr); 115 | return attr; 116 | } 117 | 118 | private: 119 | OpNodeProtoHelper() = delete; 120 | const Impl_t* impl_ = nullptr; 121 | }; 122 | 123 | // The methods on the following class are called by OpNodeProtoHelper, implementing 124 | // the same signatures as InferenceContext other than const-ness. 125 | class ProtoHelperNodeContext { 126 | public: 127 | explicit ProtoHelperNodeContext(const onnxruntime::Node& node) : node_(node) {} 128 | ProtoHelperNodeContext() = delete; 129 | 130 | const ONNX_NAMESPACE::AttributeProto* getAttribute(const std::string& name) const; 131 | size_t getNumInputs() const; 132 | const ONNX_NAMESPACE::TypeProto* getInputType(size_t index) const; 133 | size_t getNumOutputs() const; 134 | const ONNX_NAMESPACE::TypeProto* getOutputType(size_t index) const; 135 | 136 | private: 137 | const onnxruntime::Node& node_; 138 | }; 139 | 140 | } // namespace onnxruntime 141 | -------------------------------------------------------------------------------- /include/onnxruntime/core/framework/run_options.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | 4 | #pragma once 5 | 6 | #include 7 | #include 8 | #include "core/session/onnxruntime_c_api.h" 9 | 10 | /** 11 | * Configuration information for a Run call. 12 | */ 13 | struct OrtRunOptions { 14 | /// Log severity. See https://github.com/microsoft/onnxruntime/blob/master/include/onnxruntime/core/common/logging/severity.h 15 | /// Default = -1 (use the log severity from the InferenceSession that the Run is for). 16 | int run_log_severity_level = -1; 17 | int run_log_verbosity_level = 0; ///< VLOG level if debug build and run_log_severity_level is 0 (VERBOSE). 18 | std::string run_tag; ///< A tag for the Run() calls using this. 19 | 20 | // Set to 'true' to ensure the termination of all the outstanding Run() calls 21 | // that use this OrtRunOptions instance. Some of the outstanding Run() calls may 22 | // be forced to terminate with an error status. 23 | bool terminate = false; 24 | 25 | OrtRunOptions() = default; 26 | ~OrtRunOptions() = default; 27 | 28 | // Disable copy, move and assignment. we don't want accidental copies, to 29 | // ensure that the instance provided to the Run() call never changes and the 30 | // terminate mechanism will work. 31 | OrtRunOptions(const OrtRunOptions&) = delete; 32 | OrtRunOptions(OrtRunOptions&&) = delete; 33 | OrtRunOptions& operator=(const OrtRunOptions&) = delete; 34 | OrtRunOptions& operator=(OrtRunOptions&&) = delete; 35 | }; 36 | 37 | namespace onnxruntime { 38 | using RunOptions = OrtRunOptions; 39 | } 40 | -------------------------------------------------------------------------------- /include/onnxruntime/core/framework/sparse_tensor.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | // Copyright (c) Microsoft Corporation. All rights reserved. 4 | // Licensed under the MIT License. 5 | 6 | #include "core/framework/data_types.h" 7 | #include "core/framework/tensor_shape.h" 8 | #include "core/framework/tensor.h" 9 | 10 | using namespace onnxruntime::common; 11 | 12 | namespace onnxruntime { 13 | 14 | /** 15 | * @brief This class implements SparseTensor. 16 | * We represent a SparseTensor as a triple . "values" and "indices" themselves 17 | * are implemented as Tensors. 18 | * We follow the Tensor design for memory ownership/management: a sparse-tensor does not own the "value" 19 | * or "indices" tensors. 20 | */ 21 | 22 | class SparseTensor final { 23 | public: 24 | SparseTensor(MLDataType elt_type, 25 | const TensorShape& shape, 26 | size_t nnz, 27 | void* values_data, 28 | void* indices_data, 29 | const OrtMemoryInfo& memory_info); 30 | 31 | SparseTensor(MLDataType elt_type, 32 | const TensorShape& shape, 33 | size_t nnz, 34 | std::shared_ptr allocator); 35 | 36 | ~SparseTensor() = default; 37 | 38 | // For now, disallow all copy, assignment, and move. 39 | ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(SparseTensor); 40 | 41 | // Returns the number of entries in the values tensor (aka "NNZ" or "number of nonzero values") 42 | size_t NumValues() const { return values_.Shape().Size(); } 43 | 44 | const Tensor& Values() const { 45 | return values_; 46 | } 47 | 48 | const Tensor& Indices() const { 49 | return indices_; 50 | } 51 | 52 | const TensorShape& Shape() const { 53 | return shape_; 54 | } 55 | 56 | Tensor& MutableValues() { 57 | return values_; 58 | } 59 | 60 | Tensor& MutableIndices() { 61 | return indices_; 62 | } 63 | 64 | //TensorShape& MutableShape() { 65 | // return shape_; 66 | //} 67 | 68 | private: 69 | Tensor values_; 70 | Tensor indices_; 71 | TensorShape shape_; // The shape of corresponding dense-tensor. 72 | }; 73 | 74 | } // namespace onnxruntime 75 | -------------------------------------------------------------------------------- /include/onnxruntime/core/framework/tensor.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | 4 | #pragma once 5 | 6 | #include 7 | #include 8 | #include 9 | 10 | #include "gsl/span" 11 | #include "core/common/common.h" 12 | #include "core/framework/allocator.h" 13 | #include "core/framework/data_types.h" 14 | #include "core/framework/tensor_shape.h" 15 | #include "onnxruntime_config.h" 16 | 17 | namespace onnxruntime { 18 | 19 | // TODO: Do we need this class or is IAllocator::MakeUniquePtr sufficient/better 20 | class BufferDeleter { 21 | public: 22 | BufferDeleter() : alloc_(nullptr) {} 23 | BufferDeleter(AllocatorPtr alloc) 24 | : alloc_(alloc) {} 25 | 26 | void operator()(void* p) const { 27 | if (alloc_) 28 | alloc_->Free(p); 29 | } 30 | 31 | private: 32 | // TODO: we may need consider the lifetime of alloc carefully 33 | // The alloc_ here is the allocator that used to allocate the buffer 34 | // And need go with the unique_ptr together. If it is using our internal 35 | // allocator, it is ok as our allocators are global managed. But if it 36 | // is provide by user, user need to be very careful about it. 37 | // A weak_ptr may be a choice to reduce the impact, but that require to 38 | // change our current allocator mgr to use shared_ptr. Will revisit it 39 | // later. 40 | AllocatorPtr alloc_; 41 | }; 42 | 43 | using BufferUniquePtr = std::unique_ptr; 44 | using BufferNakedPtr = void*; 45 | //TODO:ensure dtype_!=nullptr 46 | #ifdef __GNUC__ 47 | #pragma GCC diagnostic push 48 | #ifdef HAS_NULL_DEREFERENCE 49 | #pragma GCC diagnostic ignored "-Wnull-dereference" 50 | #endif 51 | #endif 52 | /* 53 | We want to keep tensor as simple as possible, it is just a placeholder 54 | for a piece of memory, with additional shape information. 55 | Memory is owned and managed by Executor / Workspace, so Tensor just uses 56 | it, and won't do any allocation / release. 57 | */ 58 | class Tensor final { 59 | public: 60 | /** 61 | * Create tensor with given type, shape, pre-allocate memory and allocator info. 62 | * This function won't check if the preallocated buffer(p_data) has enough room for the shape. 63 | * \param data A preallocated buffer. Can be NULL if the shape is empty. 64 | * Tensor does not own the data and will not delete it 65 | * \param alloc Where the buffer('data') was allocated from 66 | */ 67 | Tensor(MLDataType p_type, const TensorShape& shape, void* p_data, const OrtMemoryInfo& alloc, 68 | int64_t offset = 0); 69 | 70 | /** 71 | * Deprecated. The orginal design is this Tensor class won't do any allocation / release. 72 | * However, this function will allocate the buffer for the shape, and do placement new if p_type is string tensor. 73 | */ 74 | Tensor(MLDataType p_type, const TensorShape& shape, std::shared_ptr allocator, int64_t offset = 0); 75 | 76 | ~Tensor(); 77 | 78 | //Move is allowed 79 | ORT_DISALLOW_COPY_AND_ASSIGNMENT(Tensor); 80 | 81 | Tensor(Tensor&& other) noexcept; 82 | 83 | Tensor& operator=(Tensor&& other) noexcept; 84 | 85 | /** 86 | Returns the data type. 87 | */ 88 | MLDataType DataType() const { return dtype_; } 89 | 90 | /** 91 | Returns the shape of the tensor. 92 | */ 93 | const TensorShape& Shape() const noexcept { return shape_; } 94 | 95 | /** 96 | Returns the location of the tensor's memory 97 | */ 98 | const OrtMemoryInfo& Location() const { return alloc_info_; } 99 | 100 | /** 101 | May return nullptr if tensor size is zero 102 | */ 103 | template 104 | T* MutableData() { 105 | // Type check 106 | ORT_ENFORCE(DataTypeImpl::GetType() == dtype_, "Tensor type mismatch. ", 107 | DataTypeImpl::GetType(), "!=", dtype_); 108 | return reinterpret_cast(static_cast(p_data_) + byte_offset_); 109 | } 110 | 111 | /** 112 | May return nullptr if tensor size is zero 113 | */ 114 | template 115 | gsl::span MutableDataAsSpan() { 116 | // Type check 117 | ORT_ENFORCE(DataTypeImpl::GetType() == dtype_, "Tensor type mismatch. ", 118 | DataTypeImpl::GetType(), "!=", dtype_); 119 | T* data = reinterpret_cast(static_cast(p_data_) + byte_offset_); 120 | return gsl::make_span(data, shape_.Size()); 121 | } 122 | 123 | template 124 | const T* Data() const { 125 | // Type check 126 | ORT_ENFORCE(DataTypeImpl::GetType() == dtype_, "Tensor type mismatch. ", 127 | DataTypeImpl::GetType(), "!=", dtype_); 128 | return reinterpret_cast(static_cast(p_data_) + byte_offset_); 129 | } 130 | 131 | template 132 | gsl::span DataAsSpan() const { 133 | // Type check 134 | ORT_ENFORCE(DataTypeImpl::GetType() == dtype_, "Tensor type mismatch. ", 135 | DataTypeImpl::GetType(), "!=", dtype_); 136 | const T* data = reinterpret_cast(static_cast(p_data_) + byte_offset_); 137 | return gsl::make_span(data, shape_.Size()); 138 | } 139 | 140 | void* MutableDataRaw(MLDataType type) { 141 | ORT_ENFORCE(type == dtype_, "Tensor type mismatch.", type, "!=", dtype_); 142 | return p_data_; 143 | } 144 | 145 | const void* DataRaw(MLDataType type) const { 146 | ORT_ENFORCE(type == dtype_, "Tensor type mismatch.", type, "!=", dtype_); 147 | return p_data_; 148 | } 149 | 150 | void* MutableDataRaw() noexcept { 151 | return p_data_; 152 | } 153 | 154 | const void* DataRaw() const noexcept { 155 | return p_data_; 156 | } 157 | 158 | /** 159 | * Resizes the tensor without touching underlying storage. 160 | * This requires the total size of the tensor to remains constant. 161 | * @warning this function is NOT thread-safe. 162 | */ 163 | inline void Reshape(const TensorShape& new_shape) { 164 | ORT_ENFORCE(shape_.Size() == new_shape.Size(), 165 | "Tensor size (" + std::to_string(shape_.Size()) + 166 | ") != new size (" + std::to_string(new_shape.Size()) + ")"); 167 | shape_ = new_shape; 168 | } 169 | 170 | /** 171 | The number of bytes of data. 172 | */ 173 | size_t SizeInBytes() const { 174 | size_t ret; 175 | int64_t l = shape_.Size(); 176 | if (l >= static_cast(std::numeric_limits::max())) { 177 | ORT_THROW("tensor size overflow"); 178 | } 179 | if (!IAllocator::CalcMemSizeForArray(static_cast(shape_.Size()), dtype_->Size(), &ret)) { 180 | ORT_THROW("tensor size overflow"); 181 | } 182 | return ret; 183 | } 184 | 185 | // More API methods. 186 | private: 187 | void Init(MLDataType p_type, 188 | const TensorShape& shape, 189 | void* p_raw_data, 190 | AllocatorPtr deleter, 191 | int64_t offset = 0); 192 | 193 | void ReleaseBuffer(); 194 | 195 | void* p_data_; 196 | /** 197 | if buffer_deleter_ is null, it means tensor does not own the buffer. 198 | otherwise tensor will use the deleter to release the buffer when 199 | tensor is released. 200 | */ 201 | AllocatorPtr buffer_deleter_; 202 | 203 | TensorShape shape_; 204 | MLDataType dtype_; 205 | OrtMemoryInfo alloc_info_; 206 | int64_t byte_offset_; 207 | }; 208 | #ifdef __GNUC__ 209 | #pragma GCC diagnostic pop 210 | #endif 211 | } // namespace onnxruntime 212 | -------------------------------------------------------------------------------- /include/onnxruntime/core/framework/tensor_shape.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | 4 | #pragma once 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include "onnxruntime_config.h" 11 | 12 | namespace ONNX_NAMESPACE { 13 | class TensorShapeProto; 14 | } 15 | 16 | namespace onnxruntime { 17 | #ifdef __GNUC__ 18 | #pragma GCC diagnostic push 19 | #ifdef HAS_NULL_DEREFERENCE 20 | #pragma GCC diagnostic ignored "-Wnull-dereference" 21 | #endif 22 | #endif 23 | class TensorShape : private std::vector { 24 | // TODO - Use a custom STL allocator to avoid heap allocations in the common case. 25 | // We use negative numbers for unknown symbolic dimension. Each negative 26 | // number represents a unique symbolic dimension. 27 | // Private inheritance is used to prevent ambiguity of element versus dimension size 28 | public: 29 | TensorShape() = default; 30 | 31 | TensorShape(const TensorShape& /*other*/) = default; 32 | TensorShape& operator=(const TensorShape& /*other*/) = default; 33 | 34 | TensorShape(TensorShape&& /*other*/) = default; 35 | TensorShape& operator=(TensorShape&& /*other*/) = default; 36 | 37 | TensorShape(const std::vector& dims) : std::vector(dims) {} 38 | 39 | TensorShape(std::vector&& dims) : std::vector(std::move(dims)) {} 40 | 41 | TensorShape(const std::initializer_list& dims) : std::vector(dims) {} 42 | 43 | TensorShape(const int64_t* dimension_sizes, size_t dimension_count); 44 | 45 | TensorShape(const std::vector& dims, size_t start, size_t end); 46 | 47 | /** 48 | Return the dimension specified by . 49 | */ 50 | const int64_t& operator[](size_t idx) const { 51 | return std::vector::operator[](static_cast(idx)); 52 | } 53 | 54 | int64_t& operator[](size_t idx) { 55 | return std::vector::operator[](static_cast(idx)); 56 | } 57 | 58 | bool operator==(const TensorShape& other) const noexcept { 59 | auto thisVector = static_cast*>(this); 60 | auto otherVector = static_cast*>(&other); 61 | return *thisVector == *otherVector; 62 | } 63 | 64 | bool operator!=(const TensorShape& other) const noexcept { 65 | return !(*this == other); 66 | } 67 | 68 | size_t NumDimensions() const noexcept { 69 | return size(); 70 | } 71 | 72 | /** 73 | Copy dims into an array with given size 74 | */ 75 | void CopyDims(int64_t* dims, size_t num_dims) const { 76 | memcpy(dims, data(), sizeof(value_type) * std::min(num_dims, NumDimensions())); 77 | } 78 | 79 | /** 80 | Return underlying vector representation. 81 | */ 82 | const std::vector& GetDims() const { return *this; } 83 | 84 | /** 85 | * Return the total number of elements. Returns 1 for an empty (rank 0) TensorShape. 86 | * 87 | * May return -1 88 | */ 89 | int64_t Size() const; 90 | 91 | /** 92 | Return the total number of elements up to the specified dimension. 93 | If the dimension interval is empty (dimension == 0), return 1. 94 | @param dimension Return size up to this dimension. Value must be between 0 and this->NumDimensions(), inclusive. 95 | */ 96 | int64_t SizeToDimension(size_t dimension) const; 97 | 98 | /** 99 | Return the total number of elements from the specified dimension to the end of the tensor shape. 100 | If the dimension interval is empty (dimension == this->NumDimensions()), return 1. 101 | @param dimension Return size from this dimension to the end. Value must be between 0 and this->NumDimensions(), 102 | inclusive. 103 | */ 104 | int64_t SizeFromDimension(size_t dimension) const; 105 | 106 | /** 107 | Return a new TensorShape of the dimensions from dimstart to dimend. 108 | */ 109 | TensorShape Slice(size_t dimstart, size_t dimend) const; 110 | 111 | /** 112 | Return a new TensorShape of the dimensions from dimstart to end. 113 | */ 114 | TensorShape Slice(size_t dimstart) const; 115 | 116 | /** 117 | output dimensions nicely formatted 118 | */ 119 | std::string ToString() const; 120 | 121 | /** 122 | Calculate size between start and end. 123 | Assumes start and end are between 0 and this->NumDimensions(), inclusive, and that 124 | start < end. 125 | */ 126 | int64_t SizeHelper(size_t start, size_t end) const; 127 | 128 | /** 129 | empty shape or 1D shape (1) is regarded as scalar tensor 130 | */ 131 | bool IsScalar() const { 132 | size_t len = size(); 133 | return len == 0 || (len == 1 && operator[](0) == 1); 134 | } 135 | 136 | static const TensorShape& ReinterpretBaseType(const std::vector& dimensions) { 137 | static_assert(sizeof(TensorShape) == sizeof(std::vector), "Size of TensorShape prevents safe casting from vector"); 138 | return *static_cast(&dimensions); 139 | } 140 | }; 141 | #ifdef __GNUC__ 142 | #pragma GCC diagnostic pop 143 | #endif 144 | // operator<< to nicely output to a stream 145 | std::ostream& operator<<(std::ostream& out, const ::onnxruntime::TensorShape& shape); 146 | 147 | std::ostream& operator<<(std::ostream& out, const ONNX_NAMESPACE::TensorShapeProto& shape_proto); 148 | 149 | } // namespace onnxruntime 150 | -------------------------------------------------------------------------------- /include/onnxruntime/core/graph/basic_types.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | namespace ONNX_NAMESPACE { 11 | class ValueInfoProto; 12 | class TensorProto; 13 | class TypeProto; 14 | class AttributeProto; 15 | } // namespace ONNX_NAMESPACE 16 | 17 | namespace onnxruntime { 18 | using NodeIndex = size_t; 19 | using Version = int64_t; 20 | using NodeArgInfo = ONNX_NAMESPACE::ValueInfoProto; 21 | using InitializedTensorSet = std::unordered_map; 22 | using ArgNameToTypeMap = std::unordered_map; 23 | using ProviderType = const std::string&; 24 | // TODO - Evaluate switching the types below to support transparent comparators and enable 25 | // lookups based on gsl::cstring_span<> and std::string_view. This would reduces allocations 26 | // converting to std::string, but requires conversion to std::map> 27 | // instead of std::unordered_map]>. 28 | 29 | using NodeAttributes = std::unordered_map; 30 | class IOnnxRuntimeOpSchemaCollection; 31 | using IOnnxRuntimeOpSchemaCollectionPtr = std::shared_ptr; 32 | } // namespace onnxruntime 33 | 34 | namespace onnxruntime { 35 | class OpKernel; 36 | class OpKernelInfo; 37 | 38 | using KernelCreateFn = std::function; 39 | } // namespace onnxruntime 40 | -------------------------------------------------------------------------------- /include/onnxruntime/core/graph/constants.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | 4 | #pragma once 5 | 6 | #include 7 | #include 8 | #include 9 | 10 | #include "core/common/common.h" 11 | 12 | namespace onnxruntime { 13 | constexpr const char* kNoOp = "NoOp"; 14 | constexpr const char* kConstant = "Constant"; 15 | constexpr const char* kFunctionOp = "_kFunctionOp"; 16 | constexpr const char* kConstantValue = "value"; 17 | constexpr const char* kOnnxDomain = ""; 18 | constexpr const char* kOnnxDomainAlias = "ai.onnx"; 19 | constexpr const char* kMLDomain = "ai.onnx.ml"; 20 | constexpr const char* kMSDomain = "com.microsoft"; 21 | constexpr const char* kMSNchwcDomain = "com.microsoft.nchwc"; 22 | constexpr const char* kMSAutoMLDomain = "com.microsoft.automl"; 23 | constexpr const char* kNGraphDomain = "com.intel.ai"; 24 | constexpr const char* kCpuExecutionProvider = "CPUExecutionProvider"; 25 | constexpr const char* kCudaExecutionProvider = "CUDAExecutionProvider"; 26 | constexpr const char* kMklDnnExecutionProvider = "MKLDNNExecutionProvider"; 27 | constexpr const char* kNGraphExecutionProvider = "NGRAPHExecutionProvider"; 28 | constexpr const char* kOpenVINOExecutionProvider = "OpenVINOExecutionProvider"; 29 | constexpr const char* kNupharExecutionProvider = "NupharExecutionProvider"; 30 | constexpr const char* kBrainSliceExecutionProvider = "BrainSliceExecutionProvider"; 31 | constexpr const char* kTensorrtExecutionProvider = "TensorrtExecutionProvider"; 32 | constexpr const char* kNnapiExecutionProvider = "NnapiExecutionProvider"; 33 | } // namespace onnxruntime 34 | 35 | -------------------------------------------------------------------------------- /include/onnxruntime/core/graph/function.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | 4 | #pragma once 5 | 6 | #include "core/common/common.h" 7 | #include "core/graph/indexed_sub_graph.h" 8 | 9 | namespace onnxruntime { 10 | class Graph; 11 | class Node; 12 | } // namespace onnxruntime 13 | 14 | namespace onnxruntime { 15 | 16 | /** 17 | @class Function 18 | Class representing a Function. 19 | */ 20 | class Function { 21 | public: 22 | virtual ~Function() = default; 23 | 24 | /** Gets the OpSchema for the Function. */ 25 | virtual const ONNX_NAMESPACE::OpSchema& OpSchema() const = 0; 26 | 27 | /** Gets the Graph instance for the Function body subgraph. */ 28 | virtual const onnxruntime::Graph& Body() const = 0; 29 | 30 | /** Gets the IndexedSubGraph for the Function. */ 31 | virtual const IndexedSubGraph& GetIndexedSubGraph() const = 0; 32 | }; 33 | 34 | /** 35 | Create a new Function instance. 36 | @param graph The graph containing the Function. 37 | @param customized_func the IndexedSubGraph to use for the Function. 38 | */ 39 | std::unique_ptr MakeFunction(const onnxruntime::Graph& graph, 40 | std::unique_ptr customized_func); 41 | } // namespace onnxruntime 42 | -------------------------------------------------------------------------------- /include/onnxruntime/core/graph/graph_nodes.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | 4 | #pragma once 5 | 6 | #include 7 | #include 8 | #include 9 | 10 | namespace onnxruntime { 11 | 12 | class Node; 13 | 14 | /** 15 | Class to filter out null entries from either a vector of unique_ptr or a vector of [const] Node* and 16 | provide an iterator interface that returns [const] Node& for the valid entries. 17 | */ 18 | template 19 | class ValidNodes { 20 | public: 21 | template 22 | class NodeIterator; 23 | 24 | /** 25 | Construct a ValidNodes instance to provide iteration over all valid nodes in the TNodesCollection 26 | @param[in] nodes Nodes to iterate, skipping invalid entries. 27 | */ 28 | explicit ValidNodes(TNodesContainer& nodes) noexcept : nodes_(nodes) {} 29 | 30 | using ConstNodeIterator = NodeIterator; 31 | using MutableNodeIterator = NodeIterator; 32 | 33 | ConstNodeIterator cbegin() const noexcept { 34 | return {nodes_.cbegin(), nodes_.cend()}; 35 | } 36 | 37 | ConstNodeIterator cend() const noexcept { 38 | return {nodes_.cend(), nodes_.cend()}; 39 | } 40 | 41 | ConstNodeIterator begin() const noexcept { 42 | return cbegin(); 43 | } 44 | 45 | ConstNodeIterator end() const noexcept { 46 | return cend(); 47 | } 48 | 49 | MutableNodeIterator begin() noexcept { 50 | return {nodes_.begin(), nodes_.end()}; 51 | } 52 | 53 | MutableNodeIterator end() noexcept { 54 | return {nodes_.end(), nodes_.end()}; 55 | } 56 | 57 | bool empty() const noexcept { return nodes_.empty(); } 58 | 59 | /** 60 | @class NodeIterator 61 | Iterator to provide const and non-const access to valid Node instances in a Graph. 62 | @remarks Skips invalid nodes. 63 | */ 64 | template 65 | class NodeIterator { 66 | // get the type being returned by the iterator. can't use TIterator::value_type as that is always non-const 67 | using IterType = typename std::remove_reference::reference>::type; 68 | // and determine what we will return based on its constness 69 | using T = typename std::conditional::value, 70 | const Node, // return const Node if this is a const iterator 71 | Node>::type; // else return Node 72 | 73 | public: 74 | using iterator_category = std::input_iterator_tag; 75 | using value_type = T; 76 | using difference_type = typename TIterator::difference_type; 77 | using pointer = T*; 78 | using reference = T&; 79 | using const_reference = std::add_const_t; 80 | 81 | /** Construct a NodeInterator and move to the first valid node. */ 82 | NodeIterator(const TIterator current, const TIterator end) noexcept : current_{current}, end_{end} { 83 | // skip to next valid node, stopping at end if none are found 84 | while (current_ < end && *current_ == nullptr) { 85 | ++current_; 86 | } 87 | } 88 | 89 | bool operator==(const NodeIterator& other) const noexcept { 90 | return (current_ == other.current_); 91 | } 92 | 93 | bool operator!=(const NodeIterator& other) const noexcept { 94 | return (current_ != other.current_); 95 | } 96 | 97 | void operator++() { 98 | if (current_ < end_) { 99 | while (++current_ != end_) { 100 | if (*current_ != nullptr) break; 101 | } 102 | } 103 | } 104 | 105 | NodeIterator operator++(int) { 106 | NodeIterator tmp{*this}; 107 | ++(*this); 108 | 109 | return tmp; 110 | } 111 | 112 | /** Return the current Node&. This will be const if the iterator was returned from a const GraphNodes instance. */ 113 | reference operator*() { 114 | // if iterator is valid we always have a non-nullptr node 115 | // if this is a nullptr we're at end_ and this shouldn't be being called 116 | return **current_; 117 | } 118 | 119 | pointer operator->() { 120 | return current_->get(); 121 | } 122 | 123 | private: 124 | TIterator current_; 125 | TIterator end_; 126 | }; 127 | 128 | private: 129 | TNodesContainer& nodes_; 130 | }; 131 | 132 | /** 133 | Class that provides iteration over all valid nodes in the Graph. 134 | */ 135 | class GraphNodes : public ValidNodes>> { 136 | public: 137 | GraphNodes(std::vector>& nodes) : ValidNodes(nodes) {} 138 | }; 139 | 140 | } // namespace onnxruntime 141 | -------------------------------------------------------------------------------- /include/onnxruntime/core/graph/graph_viewer.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | 4 | #pragma once 5 | 6 | #include "core/graph/graph.h" 7 | 8 | namespace onnxruntime { 9 | class Function; 10 | struct IndexedSubGraph; 11 | } // namespace onnxruntime 12 | 13 | namespace onnxruntime { 14 | 15 | /** 16 | @class GraphViewer 17 | Class that provides a read-only view of the Graph. 18 | @remarks If the underlying Graph is changed, GetNodesInTopologicalOrder and GetRootNodes may become invalid. 19 | */ 20 | class GraphViewer { 21 | public: 22 | /** 23 | Construct a GraphViewer from the provided Graph instance. 24 | */ 25 | explicit GraphViewer(const Graph& graph); 26 | 27 | /** Gets the Graph name. */ 28 | const std::string& Name() const noexcept; 29 | 30 | /** Gets the Graph description. */ 31 | const std::string& Description() const noexcept; 32 | 33 | /** 34 | Gets a tensor created from an initializer. 35 | @param tensor_name The tensor name 36 | @param[out] value Sets the pointer to the TensorProto if found, or nullptr if not. 37 | @returns True if found. False if not. 38 | */ 39 | bool GetInitializedTensor(const std::string& tensor_name, const ONNX_NAMESPACE::TensorProto*& value) const; 40 | 41 | /** Returns true if an initializer value can be overridden by a graph input with the same name. */ 42 | bool CanOverrideInitializer() const noexcept; 43 | 44 | /** 45 | Gets the Graph inputs, excluding initializers. 46 | @returns Collection of NodeArg pointers for the graph inputs, excluding inputs that have matching initializers. 47 | @remarks No nullptr values in the returned collection. The order will be the same as in the GraphProto. 48 | */ 49 | const std::vector& GetInputs() const noexcept; 50 | 51 | /** 52 | Gets the Graph inputs, including any initializers. 53 | @returns Collection of NodeArg pointers for all the graph inputs. 54 | @remarks No nullptr values in the returned collection. The order will be the same as in the GraphProto. 55 | */ 56 | const std::vector& GetInputsIncludingInitializers() const noexcept; 57 | 58 | /** 59 | Gets the Graph outputs. 60 | @returns Collection of NodeArg pointers for all the graph outputs. 61 | @remarks No nullptr values in the returned collection. The order will be the same as in the GraphProto. 62 | */ 63 | const std::vector& GetOutputs() const noexcept; 64 | 65 | /** Gets all ValueInfo NodeArg instances in the Graph. */ 66 | const std::vector& GetValueInfo() const noexcept; 67 | 68 | /** 69 | Gets the Node instance at the specified index. 70 | @param node_index Index to retrieve Node from. 71 | @remarks May return nullptr if index no longer points to a valid node due to the node being freed. 72 | */ 73 | const Node* GetNode(NodeIndex node_index) const; 74 | 75 | /** Gets an iterator over all the valid Nodes in the Graph. */ 76 | const GraphNodes& Nodes() const noexcept; 77 | 78 | /** Gets the number of valid nodes in the Graph. */ 79 | int NumberOfNodes() const noexcept; 80 | 81 | /** Gets the maximum NodeIndex value used by Nodes in the Graph. */ 82 | int MaxNodeIndex() const noexcept; 83 | 84 | /** Gets the NodeIndex values for the Graph nodes, sorted into topological order. */ 85 | const std::vector& GetNodesInTopologicalOrder() const; 86 | 87 | /** 88 | Gets the NodeIndex values for the root nodes in the Graph. 89 | The root nodes are the topmost nodes in the Graph that receive inputs from the Graph inputs 90 | and no other nodes in the Graph. 91 | */ 92 | const std::vector& GetRootNodes() const; 93 | 94 | /** Gets all tensors created from initializers. */ 95 | const InitializedTensorSet& GetAllInitializedTensors() const noexcept; 96 | 97 | /** 98 | Gets the NodeArg instance for the given name. 99 | @returns A NodeArg if found, a nullptr if not. 100 | */ 101 | const NodeArg* GetNodeArg(const std::string& name) const; 102 | 103 | /** Gets the map of operator domains to their opset versions. */ 104 | const std::unordered_map& DomainToVersionMap() const noexcept { 105 | return graph_->DomainToVersionMap(); 106 | } 107 | 108 | /** Checks if this is a Subgraph */ 109 | bool IsSubgraph() const; 110 | 111 | /** 112 | returns true if 'name' is an initializer, and is constant and cannot be overridden at runtime. 113 | @param check_outer_scope If true and the 'graph_' is a subgraph, check parent graph/s for 'name' if not found in 'graph_'. 114 | */ 115 | bool IsConstantInitializer(const std::string& name, bool check_outer_scope) const; 116 | 117 | /** Get the Node containing this Graph if IsSubgraph is true. Returns nullptr otherwise. */ 118 | const Node* ParentNode() const noexcept { return graph_->ParentNode(); } 119 | 120 | private: 121 | ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(GraphViewer); 122 | 123 | const Graph* graph_; 124 | 125 | // The NodeIndex values of the graph nodes sorted in topological order. 126 | std::vector nodes_in_topological_order_; 127 | // Graph root nodes. 128 | std::vector root_nodes_; 129 | }; 130 | } // namespace onnxruntime 131 | -------------------------------------------------------------------------------- /include/onnxruntime/core/graph/indexed_sub_graph.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | 4 | #pragma once 5 | 6 | #include 7 | 8 | #include "core/graph/basic_types.h" 9 | #include "core/graph/onnx_protobuf.h" 10 | 11 | namespace onnxruntime { 12 | 13 | class OpKernel; 14 | class OpKernelInfo; 15 | 16 | /** 17 | @class IndexedSubGraph 18 | 19 | Class containing information about a subgraph of Nodes from a Graph. 20 | It contains a NodeIndex array of the Nodes covered by the subgraph, 21 | and the meta definition needed for representing this subgraph as a FunctionProto, 22 | which could be serialized/saved to a model file. 23 | */ 24 | struct IndexedSubGraph { 25 | struct MetaDef { 26 | std::string name; ///< Name of customized SubGraph/FunctionProto 27 | std::string domain; ///< Domain of customized SubGraph/FunctionProto 28 | int since_version; ///< Since version of customized SubGraph/FunctionProto. 29 | 30 | ONNX_NAMESPACE::OperatorStatus status; ///< Status of customized SubGraph/FunctionProto. 31 | 32 | std::vector inputs; ///< Inputs of customized SubGraph/FunctionProto. 33 | std::vector outputs; ///< Outputs of customized SubGraph/FunctionProto. 34 | NodeAttributes attributes; ///< Attributes of customized SubGraph/FunctionProto. 35 | 36 | std::string doc_string; ///< Doc string of customized SubGraph/FunctionProto. 37 | }; 38 | 39 | /** Nodes covered by this subgraph. The NodeIndex values are from the parent Graph.*/ 40 | std::vector nodes; 41 | 42 | /** Set the meta definition needed to represent this subgraph as a FunctionProto 43 | It's needed IF AND ONLY IF there are multiple indexes contained in #nodes. */ 44 | void SetMetaDef(std::unique_ptr& meta_def_) { 45 | meta_def = std::move(meta_def_); 46 | } 47 | 48 | /** Gets the meta definition needed to represent this subgraph as a FunctionProto. 49 | @returns MetaDef instance if it has been set. nullptr if not. */ 50 | const MetaDef* GetMetaDef() const { 51 | return meta_def.get(); 52 | } 53 | 54 | private: 55 | // subgraph meta definition. 56 | std::unique_ptr meta_def; 57 | }; 58 | 59 | } // namespace onnxruntime 60 | -------------------------------------------------------------------------------- /include/onnxruntime/core/graph/node_arg.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | 4 | #pragma once 5 | 6 | #include "core/graph/onnx_protobuf.h" 7 | 8 | namespace onnxruntime { 9 | 10 | // Node argument definition, for both input and output, 11 | // including arg name, arg type (contains both type and shape). 12 | // 13 | // Design Question: in my opinion, shape should not be part of type. 14 | // We may align the protobuf design with our operator registry interface, 15 | // which has type specified for each operator, but no shape. Well, shape 16 | // should be inferred with a separate shape inference function given 17 | // input shapes, or input tensor data sometimes. 18 | // With shape as part of type (current protobuf design), 19 | // 1) we'll have to split the "TypeProto" into type and shape in this internal 20 | // representation interface so that it could be easily used when doing type 21 | // inference and matching with operator registry. 22 | // 2) SetType should be always called before SetShape, otherwise, SetShape() 23 | // will fail. Because shape is located in a TypeProto. 24 | // Thoughts? 25 | // 26 | 27 | /** 28 | @class NodeArg 29 | Class representing a data type that is input or output for a Node, including the shape if it is a Tensor. 30 | */ 31 | class NodeArg { 32 | public: 33 | /** 34 | Construct a new NodeArg. 35 | @param name The name to use. 36 | @param p_arg_type Optional TypeProto specifying type and shape information. 37 | */ 38 | NodeArg(const std::string& name, 39 | const ONNX_NAMESPACE::TypeProto* p_arg_type); 40 | 41 | NodeArg(NodeArg&& other) = default; 42 | 43 | /** Gets the name. */ 44 | const std::string& Name() const noexcept; 45 | 46 | /** Gets the data type. */ 47 | ONNX_NAMESPACE::DataType Type() const noexcept; 48 | 49 | /** Gets the TypeProto 50 | @returns TypeProto if type is set. nullptr otherwise. */ 51 | const ONNX_NAMESPACE::TypeProto* TypeAsProto() const noexcept; 52 | 53 | /** Gets the shape if NodeArg is for a Tensor. 54 | @returns TensorShapeProto if shape is set. nullptr if there's no shape specified. */ 55 | const ONNX_NAMESPACE::TensorShapeProto* Shape() const; 56 | 57 | /** Sets the shape. 58 | @remarks Shape can only be set if the TypeProto was provided to the ctor, or #SetType has been called, 59 | as the shape information is stored as part of TypeProto. */ 60 | void SetShape(const ONNX_NAMESPACE::TensorShapeProto& shape); 61 | 62 | /** Validate and merge type [and shape] info from input_type. 63 | @returns Success unless there is existing type or shape info that can't be cleanly updated. */ 64 | common::Status UpdateTypeAndShape(const ONNX_NAMESPACE::TypeProto& input_type); 65 | 66 | /** Validate and merge type [and shape] info from node_arg. 67 | @returns Success unless there is existing type or shape info that can't be cleanly updated. */ 68 | common::Status UpdateTypeAndShape(const NodeArg& node_arg); 69 | 70 | /** Gets this NodeArg as a ValueInfoProto. */ 71 | const NodeArgInfo& ToProto() const noexcept { return node_arg_info_; } 72 | 73 | /** Gets a flag indicating whether this NodeArg exists or not. 74 | Optional inputs are allowed in ONNX and an empty #Name represents a non-existent input argument. */ 75 | bool Exists() const noexcept; 76 | 77 | private: 78 | ORT_DISALLOW_COPY_AND_ASSIGNMENT(NodeArg); 79 | friend class Graph; 80 | 81 | void SetType(ONNX_NAMESPACE::DataType p_type); 82 | void SetType(const ONNX_NAMESPACE::TypeProto& type_proto); 83 | 84 | NodeArg& operator=(NodeArg&& other) = delete; 85 | 86 | // Node arg PType. 87 | ONNX_NAMESPACE::DataType type_; 88 | 89 | // Node arg name, type and shape. 90 | NodeArgInfo node_arg_info_; 91 | 92 | // Flag indicates whether <*this> node arg exists or not. 93 | bool exists_; 94 | }; 95 | } // namespace onnxruntime 96 | -------------------------------------------------------------------------------- /include/onnxruntime/core/graph/onnx_protobuf.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | 4 | #pragma once 5 | //TODO(): delete this file from public interface 6 | #ifdef __GNUC__ 7 | #pragma GCC diagnostic push 8 | #pragma GCC diagnostic ignored "-Wignored-qualifiers" 9 | #pragma GCC diagnostic ignored "-Wunused-parameter" 10 | #else 11 | #pragma warning(push) 12 | #pragma warning(disable : 4018) /*'expression' : signed/unsigned mismatch */ 13 | #pragma warning(disable : 4065) /*switch statement contains 'default' but no 'case' labels*/ 14 | #pragma warning(disable : 4100) 15 | #pragma warning(disable : 4146) /*unary minus operator applied to unsigned type, result still unsigned*/ 16 | #pragma warning(disable : 4244) /*'conversion' conversion from 'type1' to 'type2', possible loss of data*/ 17 | #pragma warning(disable : 4251) /*'identifier' : class 'type' needs to have dll-interface to be used by clients of class 'type2'*/ 18 | #pragma warning(disable : 4267) /*'var' : conversion from 'size_t' to 'type', possible loss of data*/ 19 | #pragma warning(disable : 4305) /*'identifier' : truncation from 'type1' to 'type2'*/ 20 | #pragma warning(disable : 4307) /*'operator' : integral constant overflow*/ 21 | #pragma warning(disable : 4309) /*'conversion' : truncation of constant value*/ 22 | #pragma warning(disable : 4334) /*'operator' : result of 32-bit shift implicitly converted to 64 bits (was 64-bit shift intended?)*/ 23 | #pragma warning(disable : 4355) /*'this' : used in base member initializer list*/ 24 | #pragma warning(disable : 4506) /*no definition for inline function 'function'*/ 25 | #pragma warning(disable : 4800) /*'type' : forcing value to bool 'true' or 'false' (performance warning)*/ 26 | #pragma warning(disable : 4996) /*The compiler encountered a deprecated declaration.*/ 27 | #endif 28 | #include "onnx/defs/schema.h" 29 | #include "onnx/onnx_pb.h" 30 | #ifdef __GNUC__ 31 | #pragma GCC diagnostic pop 32 | #else 33 | #pragma warning(pop) 34 | #endif 35 | -------------------------------------------------------------------------------- /include/onnxruntime/core/graph/schema_registry.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | 4 | #pragma once 5 | #include "core/graph/constants.h" 6 | #include "core/common/common.h" 7 | #include "core/common/status.h" 8 | #include "core/platform/ort_mutex.h" 9 | 10 | #ifdef __GNUC__ 11 | #pragma GCC diagnostic push 12 | #pragma GCC diagnostic ignored "-Wignored-qualifiers" 13 | #pragma GCC diagnostic ignored "-Wunused-parameter" 14 | #endif 15 | #include "onnx/defs/schema.h" 16 | #ifdef __GNUC__ 17 | #pragma GCC diagnostic pop 18 | #endif 19 | #include 20 | #include 21 | #include "sstream" 22 | 23 | namespace onnxruntime { 24 | using OpName_Domain_Version_Schema_Map = std::unordered_map< 25 | std::string, 26 | std::unordered_map>>; 27 | 28 | /** 29 | @struct SchemaRegistryVersion 30 | onnxruntime schema registry is a supplement to the built-in ONNX schema. 31 | Every schema registry represent a collection of schema deltas from baseline_opset_version to opset_version 32 | */ 33 | struct SchemaRegistryVersion { 34 | int baseline_opset_version; 35 | int opset_version; 36 | }; 37 | 38 | using DomainToVersionMap = std::unordered_map; 39 | using DomainToVersionRangeMap = std::unordered_map; 40 | 41 | class IOnnxRuntimeOpSchemaCollection : public ONNX_NAMESPACE::ISchemaRegistry { 42 | public: 43 | virtual DomainToVersionMap GetLatestOpsetVersions(bool is_onnx_only) const = 0; 44 | 45 | using ISchemaRegistry::GetSchema; 46 | 47 | const ONNX_NAMESPACE::OpSchema* GetSchema(const std::string& key, const int maxInclusiveVersion, 48 | const std::string& domain) const final { 49 | const ONNX_NAMESPACE::OpSchema* latest_schema = nullptr; 50 | int earliest_opset_where_unchanged = std::numeric_limits::max(); 51 | GetSchemaAndHistory(key, maxInclusiveVersion, domain, &latest_schema, &earliest_opset_where_unchanged); 52 | 53 | assert(latest_schema == nullptr || (latest_schema->SinceVersion() <= maxInclusiveVersion && 54 | earliest_opset_where_unchanged == latest_schema->SinceVersion())); 55 | 56 | return latest_schema; 57 | } 58 | 59 | virtual void GetSchemaAndHistory( 60 | const std::string& key, 61 | int maxInclusiveVersion, 62 | const std::string& domain, 63 | const ONNX_NAMESPACE::OpSchema** latest_schema, 64 | int* earliest_opset_where_unchanged) const = 0; 65 | }; 66 | 67 | /** 68 | @class OnnxRuntimeOpSchemaRegistry 69 | 70 | OnnxRuntimeOpSchemaRegistry is used to provide supplement for built-in ONNX schemas. 71 | Each OnnxRuntimeOpSchemaRegistry must register complete opsets delta from a baseline version to max opset version. 72 | (Please notice that baseline opsets are not include in the delta) 73 | 74 | For example, ONNXRuntime is build with ONNX 1.2 which is at opset7, to use ONNX opset8 and opset9, 75 | user could create a OnnxRuntimeOpSchemaRegistry and config it as {baseline_opset_version = 7, opset_version = 9} 76 | it means this OnnxRuntimeOpSchemaRegistry contains the complete delta from opset7 to opset9. 77 | */ 78 | class OnnxRuntimeOpSchemaRegistry : public IOnnxRuntimeOpSchemaCollection { 79 | public: 80 | OnnxRuntimeOpSchemaRegistry() = default; 81 | 82 | common::Status SetBaselineAndOpsetVersionForDomain( 83 | const std::string& domain, 84 | int baseline_opset_version, 85 | int opset_version); 86 | 87 | DomainToVersionMap GetLatestOpsetVersions(bool is_onnx_only) const override; 88 | 89 | // OnnxRuntimeOpSchemaRegistry must register complete delta for a opset. 90 | common::Status RegisterOpSet( 91 | std::vector& schemas, 92 | const std::string& domain, 93 | int baseline_opset_version, 94 | int opset_version); 95 | 96 | using IOnnxRuntimeOpSchemaCollection::GetSchema; 97 | 98 | void GetSchemaAndHistory(const std::string& key, int maxInclusiveVersion, const std::string& domain, 99 | const ONNX_NAMESPACE::OpSchema** latest_schema, 100 | int* earliest_opset_where_unchanged) const override; 101 | 102 | bool empty() const { 103 | return map_.empty(); 104 | } 105 | 106 | private: 107 | common::Status RegisterOpSchema(ONNX_NAMESPACE::OpSchema&& op_schema); 108 | 109 | common::Status RegisterOpSchemaInternal(ONNX_NAMESPACE::OpSchema&& op_schema); 110 | 111 | OrtMutex mutex_; 112 | 113 | OpName_Domain_Version_Schema_Map map_; 114 | DomainToVersionRangeMap domain_version_range_map_; 115 | }; 116 | 117 | /** 118 | @class SchemaRegistryManager 119 | 120 | SchemaRegistryManager provides a view based on built-in ONNX schema and a list of 121 | OnnxRuntimeOpSchemaRegistry as supplement. 122 | 123 | The user needs to make sure the customized schema registry is valid, otherwise the behavior is undefined. 124 | 125 | @todo We may add more consistency checks later. 126 | */ 127 | class SchemaRegistryManager : public onnxruntime::IOnnxRuntimeOpSchemaCollection { 128 | public: 129 | /** 130 | Register a new schema registry instance. 131 | @remarks The schema registry priority is the reverse of registration order. i.e. the last registry added will be 132 | searched first for a matching OpSchema. 133 | */ 134 | void RegisterRegistry(std::shared_ptr registry); 135 | 136 | /** Gets the latest opset versions. 137 | @param is_onnx_only If true, return the latest ONNX schemas. If false, return the latest schemas for all domains. 138 | */ 139 | DomainToVersionMap GetLatestOpsetVersions(bool is_onnx_only) const override; 140 | 141 | /** 142 | Gets the OpSchema and its history. 143 | Searches custom schema registries starting with the last one added. \ 144 | If the OpSchema is not found the default ONNX schema registry is searched. 145 | 146 | @param key Operator type. 147 | @param max_inclusive_version Maximum opset version allowed, inclusive. 148 | @param domain The domain of the operator. 149 | @param[out] latest_schema Returns the latest OpSchema if found. nullptr otherwise. 150 | @param[out] earliest_opset_where_unchanged The earliest opset version preceding max_inclusive_version where the 151 | operator is known to be unchanged. 152 | */ 153 | void GetSchemaAndHistory(const std::string& key, int max_inclusive_version, const std::string& domain, 154 | const ONNX_NAMESPACE::OpSchema** latest_schema, 155 | int* earliest_opset_where_unchanged) const override; 156 | 157 | private: 158 | std::deque> registries; 159 | }; 160 | 161 | } // namespace onnxruntime 162 | -------------------------------------------------------------------------------- /include/onnxruntime/core/optimizer/graph_transformer.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | 4 | #pragma once 5 | #include 6 | #include 7 | 8 | #include "core/common/common.h" 9 | #include "core/graph/graph_viewer.h" 10 | #include "core/optimizer/graph_transformer_level.h" 11 | 12 | namespace onnxruntime { 13 | 14 | /** 15 | @class GraphTransformer 16 | 17 | The interface for in-place transformation of a Graph. 18 | */ 19 | class GraphTransformer { 20 | public: 21 | GraphTransformer(const std::string& name, const std::unordered_set& compatible_execution_providers = {}) 22 | : name_(name), compatible_provider_types_(compatible_execution_providers) { 23 | } 24 | 25 | virtual ~GraphTransformer() = default; 26 | 27 | /** Gets the name of this graph transformer. */ 28 | const std::string& Name() const noexcept { 29 | return name_; 30 | } 31 | 32 | const std::unordered_set& GetCompatibleExecutionProviders() const noexcept { 33 | return compatible_provider_types_; 34 | } 35 | 36 | /** Apply the in-place transformation defined by this transformer to the provided Graph instance. 37 | @param[out] modified Set to true if the Graph was modified. 38 | @returns Status with success or error information. 39 | */ 40 | common::Status Apply(Graph& graph, bool& modified) const; 41 | 42 | protected: 43 | /** Helper method to call ApplyImpl on any subgraphs in the Node. */ 44 | common::Status Recurse(Node& node, bool& modified, int graph_level) const { 45 | int subgraph_level = ++graph_level; 46 | for (auto& entry : node.GetAttributeNameToMutableSubgraphMap()) { 47 | auto& subgraph = *entry.second; 48 | ORT_RETURN_IF_ERROR(ApplyImpl(subgraph, modified, subgraph_level)); 49 | } 50 | 51 | return Status::OK(); 52 | } 53 | 54 | private: 55 | ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(GraphTransformer); 56 | 57 | // Apply the transform to the graph. 58 | // graph_level is 0 for the main graph, and is incremented when descending into the subgraph of a node. 59 | // You MUST call Recurse for all valid Nodes in the graph to ensure any subgraphs in control flow nodes 60 | // (Scan/If/Loop) are processed as well. 61 | // You should avoid calling Graph::Resolve in ApplyImpl unless you are 100% sure it's required. In most cases 62 | // the call to Graph::Resolve in Apply prior to ApplyImpl being called, and after ApplyImpl fore the main graph 63 | // completes (if 'modified' is true) should suffice. 64 | virtual common::Status ApplyImpl(Graph& graph, bool& modified, int graph_level) const = 0; 65 | 66 | const std::string name_; 67 | const std::unordered_set compatible_provider_types_; 68 | }; 69 | } // namespace onnxruntime 70 | -------------------------------------------------------------------------------- /include/onnxruntime/core/optimizer/graph_transformer_level.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | 4 | #pragma once 5 | 6 | #include "core/common/common.h" 7 | 8 | namespace onnxruntime { 9 | 10 | enum class TransformerLevel : int { 11 | Default = 0, 12 | Level1, 13 | Level2, 14 | Level3, 15 | // Convenience enum to always get the max available value. 16 | // This way when we add more levels code which iterates over this enum does not need to change. 17 | MaxTransformerLevel 18 | }; 19 | 20 | } // namespace onnxruntime 21 | -------------------------------------------------------------------------------- /include/onnxruntime/core/optimizer/graph_transformer_utils.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | 4 | #pragma once 5 | 6 | #include "core/optimizer/graph_transformer.h" 7 | #include "core/optimizer/rule_based_graph_transformer.h" 8 | #include "core/optimizer/rewrite_rule.h" 9 | 10 | namespace onnxruntime { 11 | 12 | namespace transformer_utils { 13 | 14 | /** Generates all predefined rules for this level. 15 | If rules_to_enable is not empty, it returns the intersection of predefined rules and rules_to_enable. 16 | TODO: This is visible for testing at the moment, but we should rather make it private. */ 17 | std::vector> GenerateRewriteRules(TransformerLevel level, 18 | const std::vector& rules_to_enable = {}); 19 | 20 | /** Generates all predefined (both rule-based and non-rule-based) transformers for this level. 21 | If transformers_and_rules_to_enable is not empty, it returns the intersection between the predefined transformers/rules 22 | and the transformers_and_rules_to_enable. */ 23 | std::vector> GenerateTransformers(TransformerLevel level, 24 | const std::vector& rules_and_transformers_to_enable = {}); 25 | 26 | /** Given a TransformerLevel, this method generates a name for the rule-based graph transformer of that level. */ 27 | std::string GenerateRuleBasedTransformerName(TransformerLevel level); 28 | 29 | } // namespace transformer_utils 30 | } // namespace onnxruntime 31 | -------------------------------------------------------------------------------- /include/onnxruntime/core/optimizer/rewrite_rule.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | 4 | #pragma once 5 | 6 | #include "core/common/common.h" 7 | #include "core/graph/graph_viewer.h" 8 | 9 | namespace onnxruntime { 10 | 11 | /** 12 | @class RewriteRule 13 | 14 | The base class for a rewrite rule. A rewrite rule represents a semantics-preserving transformation of a 15 | computation graph. It can be used to represent, for example, the elimination of operators that serve as 16 | no-ops (e.g., dropout during inference), as well as inlining of "function" definitions or the dual operation 17 | of replacing a complex expression by an equivalent function-call). Unlike the more general GraphTransformer, 18 | a rewrite rule is a more local transformation that is triggered on a particular node of the graph. 19 | 20 | Each rule has a set of conditions and a body. The conditions have to be satisfied for the body of the rule 21 | to be triggered. Therefore, when creating a new rewrite rule, two main functions have to be implemented: 22 | - SatisfyCondition defines the condition checks. It is advisable to add the more selective checks first, 23 | because those will lead to discarding fast rules that cannot be applied on a node. 24 | - Apply is the actual body of the rule that will be executed if SatisfyCondition returns true for a particular 25 | node. Note that additional, more complex checks can be included in the Apply if putting them in the 26 | SatisfyCondition would lead to duplicate work (e.g., when we make a check on a Node attribute but we need 27 | that attribute to execute the rule too). 28 | In general, simple fast checks are a better fit for SatisfyCondition, whereas more complex ones can be added 29 | in the Apply. 30 | 31 | In order to avoid evaluating the SatisfyCondition for each rule and each node of the graph, each rewrite rule 32 | should specify the target op types for which a rule will be evaluated, by overriding the TargetOpTypes() function. 33 | If the op type of a node is not included in the target op types of a rule, that rule would not be considered at all. 34 | If the list of op types is left empty, that rule will be triggered for every op type. 35 | */ 36 | class RewriteRule { 37 | public: 38 | /** 39 | @class RewriteRuleEffect 40 | 41 | Class used to indicate the effect of rule application on a graph's node. 42 | */ 43 | enum class RewriteRuleEffect : uint8_t { 44 | kNone, // The rewrite rule has not modified the graph. 45 | kUpdatedCurrentNode, // The rewrite rule updated (but did not remove) the node on which it was triggered. 46 | kRemovedCurrentNode, // The rewrite rule removed the node on which it was triggered. 47 | kModifiedRestOfGraph // The rewrite rule modified nodes other than the one it was triggered on. 48 | }; 49 | 50 | RewriteRule(const std::string& name) : name_(name) {} 51 | 52 | virtual ~RewriteRule() = default; 53 | 54 | /** Gets the name of this rewrite rule. */ 55 | const std::string& Name() const noexcept { 56 | return name_; 57 | } 58 | 59 | /** Returns the node op types for which this rule will be triggered. If the op type of a node is not included in the 60 | target op types of a rule, that rule would not be considered at all. Returning an empty list indicates that we 61 | will attempt to trigger the rule for every op type. */ 62 | virtual std::vector TargetOpTypes() const noexcept = 0; 63 | 64 | /** Checks if the condition of the rule is satisfied, and if so applies the body of the rule. 65 | @param[in] graph The Graph. 66 | @param[in] node The Node to apply the rewrite to. 67 | @param[out] rule_effect Enum to indicate if and how the graph was modified as a result of the rule application. 68 | @returns Status indicating success or providing error information */ 69 | common::Status CheckConditionAndApply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const { 70 | return SatisfyCondition(graph, node) ? Apply(graph, node, rule_effect) : Status::OK(); 71 | } 72 | 73 | private: 74 | ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(RewriteRule); 75 | 76 | const std::string name_; 77 | 78 | /** Checks if the Node of the given Graph satisfies the conditions of this rule. The body of the rule will be 79 | evaluated if this condition function returns true. This can include a more complex pattern matching (conditions 80 | on the ascending or descending nodes of the node for which this rule was triggered) or some other properties 81 | of the nodes. */ 82 | virtual bool SatisfyCondition(const Graph& graph, const Node& node) const = 0; 83 | 84 | /** This is the actual body of the rule that performs the graph transformation. The transformation happens in-place. 85 | The return-value of node may be different from the input-value due to rewriting. 86 | The value of "rule_effect" indicates whether and how the graph was modified by the rule. */ 87 | virtual common::Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const = 0; 88 | }; 89 | } // namespace onnxruntime 90 | -------------------------------------------------------------------------------- /include/onnxruntime/core/optimizer/rule_based_graph_transformer.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | 4 | #pragma once 5 | 6 | #include "core/common/common.h" 7 | #include "core/graph/graph_viewer.h" 8 | #include "core/optimizer/graph_transformer.h" 9 | #include "core/optimizer/rewrite_rule.h" 10 | 11 | namespace onnxruntime { 12 | 13 | /** 14 | @class RuleBasedGraphTransformer 15 | 16 | Rule-based graph transformer that provides an API to register rewrite rules 17 | and an API to apply all applicable rules to a Graph. 18 | 19 | Represents an IGraphTransformer determined by a set of rewrite rules. 20 | The transformer will apply all the rewrite rules iteratively as determined by the underlying rewriting strategy. 21 | Several rewriting-strategies are possible when traversing the graph and applying rewrite rules, 22 | each with different trade offs. At the moment, we define one that performs top-down traversal of nodes. 23 | 24 | @TODO: Is a bottom-up traversal more efficient? 25 | @TODO: Is it worth adding the max number of passes a rule should be applied for? 26 | @TODO: We need to define a contract about whether a rewrite rule is allowed to leave 27 | the graph in an inconsistent state (this will determine when and where we will be 28 | calling Graph::resolve(). 29 | */ 30 | class RuleBasedGraphTransformer : public GraphTransformer { 31 | public: 32 | RuleBasedGraphTransformer(const std::string& name, 33 | const std::unordered_set& compatible_execution_providers = {}) 34 | : GraphTransformer(name, compatible_execution_providers) {} 35 | 36 | /** Registers a rewrite rule in this transformer. */ 37 | Status Register(std::unique_ptr rule); 38 | 39 | /** Gets the list of registered rewrite rules that will be triggered on nodes with the given op type 40 | by this rule-based transformer. 41 | @returns a pointer to the vector containing all the registered rewrite rules. */ 42 | const std::vector>* GetRewriteRulesForOpType(const std::string& op_type) const { 43 | auto rules = op_type_to_rules_.find(op_type); 44 | return (rules != op_type_to_rules_.cend()) ? &rules->second : nullptr; 45 | } 46 | 47 | /** Gets the rewrite rules that are evaluated on all nodes irrespective of their op type. 48 | @returns a pointer to the vector containing all such rewrite rules or nullptr if no such rule. */ 49 | const std::vector>* GetAnyOpRewriteRules() const { 50 | return &any_op_type_rules_; 51 | } 52 | 53 | /** Returns the total number of rules that are registered in this transformer. */ 54 | size_t RulesCount() const; 55 | 56 | protected: 57 | /** Applies the given set of rewrite rules on the Node of this Graph. 58 | @param[in] graph The Graph. 59 | @param[in] node The Node to apply the rules to. 60 | @param[in] rules The vector of RewriteRules that will be applied to the Node. 61 | @param[out] rule_effect Enum that indicates whether and how the graph was modified as a result of 62 | applying rules on this node. 63 | @returns Status indicating success or providing error information. */ 64 | common::Status ApplyRulesOnNode(Graph& graph, Node& node, 65 | const std::vector>& rules, 66 | RewriteRule::RewriteRuleEffect& rule_effect) const; 67 | 68 | private: 69 | using RuleEffect = RewriteRule::RewriteRuleEffect; 70 | 71 | // The list of unique pointers for all rules (so that rules can be registered for several op types). 72 | std::vector> rules_; 73 | // Map that associates a node's op type with the vector of rules that are registered to be triggered for that node. 74 | std::unordered_map>> op_type_to_rules_; 75 | // Rules that will be evaluated regardless of the op type of the node. 76 | std::vector> any_op_type_rules_; 77 | 78 | // Performs a single top-down traversal of the graph and applies all registered rules. 79 | common::Status ApplyImpl(Graph& graph, bool& modified, int graph_level) const override; 80 | }; 81 | 82 | } // namespace onnxruntime 83 | -------------------------------------------------------------------------------- /include/onnxruntime/core/platform/ort_mutex.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | 4 | #pragma once 5 | #ifdef _WIN32 6 | #include 7 | #include 8 | namespace onnxruntime { 9 | using OrtMutex = std::mutex; 10 | using OrtCondVar = std::condition_variable; 11 | } // namespace onnxruntime 12 | #else 13 | #ifdef USE_NSYNC 14 | #include "nsync.h" 15 | #include //for unique_lock 16 | #include //for cv_status 17 | #else 18 | #include 19 | #include 20 | #include 21 | #include 22 | #include 23 | #endif 24 | namespace onnxruntime { 25 | 26 | class OrtMutex { 27 | #ifdef USE_NSYNC 28 | nsync::nsync_mu data_ = NSYNC_MU_INIT; 29 | #else 30 | pthread_mutex_t data_ = PTHREAD_MUTEX_INITIALIZER; 31 | #endif 32 | 33 | public: 34 | constexpr OrtMutex() = default; 35 | #ifdef USE_NSYNC 36 | ~OrtMutex() = default; 37 | #else 38 | ~OrtMutex(); 39 | #endif 40 | 41 | OrtMutex(const OrtMutex&) = delete; 42 | OrtMutex& operator=(const OrtMutex&) = delete; 43 | 44 | void lock(); 45 | bool try_lock() noexcept; 46 | void unlock() noexcept; 47 | 48 | #ifdef USE_NSYNC 49 | using native_handle_type = nsync::nsync_mu*; 50 | #else 51 | using native_handle_type = pthread_mutex_t*; 52 | #endif 53 | native_handle_type native_handle() { return &data_; } 54 | }; 55 | 56 | class OrtCondVar { 57 | #ifdef USE_NSYNC 58 | nsync::nsync_cv native_cv_object = NSYNC_CV_INIT; 59 | #else 60 | pthread_cond_t native_cv_object = PTHREAD_COND_INITIALIZER; 61 | #endif 62 | public: 63 | constexpr OrtCondVar() noexcept = default; 64 | 65 | #ifdef USE_NSYNC 66 | ~OrtCondVar() = default; 67 | #else 68 | ~OrtCondVar(); 69 | #endif 70 | 71 | OrtCondVar(const OrtCondVar&) = delete; 72 | OrtCondVar& operator=(const OrtCondVar&) = delete; 73 | 74 | void notify_one() noexcept; 75 | void notify_all() noexcept; 76 | 77 | void wait(std::unique_lock& __lk); 78 | template 79 | void wait(std::unique_lock& __lk, _Predicate __pred); 80 | 81 | /** 82 | * returns cv_status::timeout if the wait terminates when Rel_time has elapsed. Otherwise, the method returns cv_status::no_timeout. 83 | * @param cond_mutex A unique_lock object. 84 | * @param rel_time A chrono::duration object that specifies the amount of time before the thread wakes up. 85 | * @return returns cv_status::timeout if the wait terminates when Rel_time has elapsed. Otherwise, the method returns cv_status::no_timeout 86 | */ 87 | template 88 | std::cv_status 89 | wait_for(std::unique_lock& cond_mutex, 90 | const std::chrono::duration& rel_time); 91 | #ifdef USE_NSYNC 92 | using native_handle_type = nsync::nsync_cv*; 93 | #else 94 | using native_handle_type = pthread_cond_t*; 95 | #endif 96 | 97 | native_handle_type native_handle() { return &native_cv_object; } 98 | 99 | private: 100 | void timed_wait_impl(std::unique_lock& __lk, 101 | std::chrono::time_point); 102 | }; 103 | 104 | template 105 | void OrtCondVar::wait(std::unique_lock& __lk, _Predicate __pred) { 106 | while (!__pred()) 107 | wait(__lk); 108 | } 109 | 110 | template 111 | std::cv_status 112 | OrtCondVar::wait_for(std::unique_lock& cond_mutex, 113 | const std::chrono::duration& rel_time) { 114 | //TODO: is it possible to use nsync_from_time_point_ ? 115 | using namespace std::chrono; 116 | if (rel_time <= duration::zero()) 117 | return std::cv_status::timeout; 118 | using SystemTimePointFloat = time_point >; 119 | using SystemTimePoint = time_point; 120 | SystemTimePointFloat max_time = SystemTimePoint::max(); 121 | steady_clock::time_point steady_now = steady_clock::now(); 122 | system_clock::time_point system_now = system_clock::now(); 123 | if (max_time - rel_time > system_now) { 124 | nanoseconds remain = duration_cast(rel_time); 125 | if (remain < rel_time) 126 | ++remain; 127 | timed_wait_impl(cond_mutex, system_now + remain); 128 | } else 129 | timed_wait_impl(cond_mutex, SystemTimePoint::max()); 130 | return steady_clock::now() - steady_now < rel_time ? std::cv_status::no_timeout : std::cv_status::timeout; 131 | } 132 | }; // namespace onnxruntime 133 | #endif -------------------------------------------------------------------------------- /include/onnxruntime/core/platform/threadpool.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | 4 | #pragma once 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | #if defined(__GNUC__) 11 | #pragma GCC diagnostic push 12 | #pragma GCC diagnostic ignored "-Wunused-parameter" 13 | #else 14 | #pragma warning(push) 15 | #pragma warning(disable : 4267) 16 | #endif 17 | #include 18 | #if defined(__GNUC__) 19 | #pragma GCC diagnostic pop 20 | #else 21 | #pragma warning(pop) 22 | #endif 23 | 24 | namespace onnxruntime { 25 | 26 | namespace concurrency { 27 | 28 | /** 29 | * Generic class for instantiating thread pools. 30 | * Don't put any object of this type into a global variable in a Win32 DLL. 31 | */ 32 | class ThreadPool { 33 | public: 34 | /* 35 | Initializes a thread pool given the current environment. 36 | */ 37 | ThreadPool(const std::string& name, int num_threads); 38 | 39 | /* 40 | Enqueue a unit of work. 41 | */ 42 | void Schedule(std::function fn); 43 | 44 | /* 45 | Schedule work in the interval [0, total). 46 | */ 47 | void ParallelFor(int32_t total, std::function fn); 48 | 49 | /* 50 | Schedule work in the interval [first, last]. 51 | */ 52 | void ParallelForRange(int64_t first, int64_t last, std::function fn); 53 | 54 | // This is not supported until the latest Eigen 55 | // void SetStealPartitions(const std::vector>& partitions); 56 | 57 | int NumThreads() const; 58 | 59 | int CurrentThreadId() const; 60 | 61 | Eigen::ThreadPool& GetHandler() { return impl_; } 62 | 63 | private: 64 | Eigen::ThreadPool impl_; 65 | }; 66 | 67 | } // namespace concurrency 68 | } // namespace onnxruntime 69 | -------------------------------------------------------------------------------- /include/onnxruntime/core/providers/cpu/cpu_provider_factory.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | 4 | #include "onnxruntime_c_api.h" 5 | 6 | #ifdef __cplusplus 7 | extern "C" { 8 | #endif 9 | 10 | /** 11 | * \param use_arena zero: false. non-zero: true. 12 | */ 13 | ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_CPU, _In_ OrtSessionOptions* options, int use_arena) 14 | ORT_ALL_ARGS_NONNULL; 15 | 16 | #ifdef __cplusplus 17 | } 18 | #endif 19 | -------------------------------------------------------------------------------- /include/onnxruntime/core/providers/cuda/cuda_provider_factory.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | 4 | #include "onnxruntime_c_api.h" 5 | 6 | #ifdef __cplusplus 7 | extern "C" { 8 | #endif 9 | 10 | /** 11 | * \param device_id cuda device id, starts from zero. 12 | */ 13 | ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_CUDA, _In_ OrtSessionOptions* options, int device_id); 14 | 15 | #ifdef __cplusplus 16 | } 17 | #endif 18 | -------------------------------------------------------------------------------- /include/onnxruntime/core/providers/mkldnn/mkldnn_provider_factory.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | 4 | #include "onnxruntime_c_api.h" 5 | 6 | #ifdef __cplusplus 7 | extern "C" { 8 | #endif 9 | 10 | /** 11 | * \param use_arena zero: false. non-zero: true. 12 | */ 13 | ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_Mkldnn, _In_ OrtSessionOptions* options, int use_arena); 14 | 15 | #ifdef __cplusplus 16 | } 17 | #endif 18 | -------------------------------------------------------------------------------- /include/onnxruntime/core/providers/ngraph/ngraph_provider_factory.h: -------------------------------------------------------------------------------- 1 | // Copyright(C) 2019 Intel Corporation 2 | // Licensed under the MIT License 3 | 4 | #include "onnxruntime_c_api.h" 5 | 6 | #ifdef __cplusplus 7 | extern "C" { 8 | #endif 9 | 10 | /** 11 | * \param use_arena zero: false. non-zero: true. 12 | */ 13 | ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_NGraph, _In_ OrtSessionOptions* options, _In_ const char* ng_backend_type); 14 | 15 | #ifdef __cplusplus 16 | } 17 | #endif 18 | -------------------------------------------------------------------------------- /include/onnxruntime/core/providers/nnapi/nnapi_provider_factory.h: -------------------------------------------------------------------------------- 1 | // Copyright 2019 JD.com Inc. JD AI 2 | 3 | #include "onnxruntime_c_api.h" 4 | 5 | #ifdef __cplusplus 6 | extern "C" { 7 | #endif 8 | 9 | ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_Nnapi, _In_ OrtSessionOptions* options); 10 | 11 | #ifdef __cplusplus 12 | } 13 | #endif 14 | 15 | 16 | -------------------------------------------------------------------------------- /include/onnxruntime/core/providers/nuphar/nuphar_provider_factory.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | #pragma once 4 | #include "core/session/onnxruntime_c_api.h" 5 | 6 | #ifdef __cplusplus 7 | extern "C" { 8 | #endif 9 | /** 10 | * \param device_id nuphar device id, starts from zero. 11 | * \param target_str TVM target string. 12 | */ 13 | ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_Nuphar, _In_ OrtSessionOptions* options, int allow_unaligned_buffers, _In_ const char* settings_str); 14 | 15 | #ifdef __cplusplus 16 | } 17 | #endif 18 | -------------------------------------------------------------------------------- /include/onnxruntime/core/providers/openvino/openvino_provider_factory.h: -------------------------------------------------------------------------------- 1 | // Copyright(C) 2019 Intel Corporation 2 | // Licensed under the MIT License 3 | 4 | #include "onnxruntime_c_api.h" 5 | 6 | #ifdef __cplusplus 7 | extern "C" { 8 | #endif 9 | 10 | /** 11 | * \param device_id cuda device id, starts from zero. 12 | */ 13 | ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_OpenVINO, 14 | _In_ OrtSessionOptions* options, const char* device_id); 15 | 16 | #ifdef __cplusplus 17 | } 18 | #endif 19 | -------------------------------------------------------------------------------- /include/onnxruntime/core/providers/providers.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | 4 | namespace onnxruntime { 5 | class IExecutionProvider; 6 | 7 | struct IExecutionProviderFactory { 8 | virtual ~IExecutionProviderFactory() = default; 9 | virtual std::unique_ptr CreateProvider() = 0; 10 | }; 11 | } // namespace onnxruntime 12 | -------------------------------------------------------------------------------- /include/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | 4 | #include "onnxruntime_c_api.h" 5 | 6 | #ifdef __cplusplus 7 | extern "C" { 8 | #endif 9 | 10 | ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_Tensorrt, _In_ OrtSessionOptions* options, int device_id); 11 | 12 | #ifdef __cplusplus 13 | } 14 | #endif 15 | 16 | -------------------------------------------------------------------------------- /include/onnxruntime/core/session/automl_data_containers.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | 4 | // This header contains shared definitions for reading and writing 5 | // via C/C++ API Opaque data types that are registered within ORT 6 | #pragma once 7 | 8 | #include 9 | 10 | #ifdef __cplusplus 11 | extern "C" { 12 | #endif 13 | // This structure is used to initialize and read 14 | // OrtValue of opaque(com.microsoft.automl,DateTimeFeaturizer_TimePoint) 15 | struct DateTimeFeaturizerTimePointData { 16 | int32_t year; 17 | uint8_t month; 18 | uint8_t day; 19 | uint8_t hour; 20 | uint8_t minute; 21 | uint8_t second; 22 | uint8_t dayOfWeek; 23 | uint16_t dayOfYear; 24 | uint8_t quarterOfYear; 25 | uint8_t weekOfMonth; 26 | }; 27 | 28 | #ifdef __cplusplus 29 | } // extern "C" 30 | #endif 31 | -------------------------------------------------------------------------------- /include/onnxruntime/core/session/environment.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | 4 | #pragma once 5 | 6 | #include 7 | #include 8 | #include "core/common/common.h" 9 | #include "core/common/status.h" 10 | 11 | namespace onnxruntime { 12 | /** 13 | Provides the runtime environment for onnxruntime. 14 | Create one instance for the duration of execution. 15 | */ 16 | class Environment { 17 | public: 18 | /** 19 | Create and initialize the runtime environment. 20 | */ 21 | static Status Create(std::unique_ptr& environment); 22 | 23 | /** 24 | This function will call ::google::protobuf::ShutdownProtobufLibrary 25 | */ 26 | ~Environment(); 27 | 28 | /** 29 | Returns whether any runtime environment instance has been initialized. 30 | */ 31 | static bool IsInitialized() { return is_initialized_; } 32 | 33 | private: 34 | ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Environment); 35 | 36 | Environment() = default; 37 | Status Initialize(); 38 | 39 | static std::atomic is_initialized_; 40 | }; 41 | } // namespace onnxruntime 42 | -------------------------------------------------------------------------------- /test_imgs/classification/cls_001.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tenglike1997/onnxruntime-projects/7e99a75f1ce36329c4dbf2863023e7e05c73a991/test_imgs/classification/cls_001.jpg -------------------------------------------------------------------------------- /test_imgs/classification/cls_002.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tenglike1997/onnxruntime-projects/7e99a75f1ce36329c4dbf2863023e7e05c73a991/test_imgs/classification/cls_002.jpg -------------------------------------------------------------------------------- /test_imgs/detection/000001.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tenglike1997/onnxruntime-projects/7e99a75f1ce36329c4dbf2863023e7e05c73a991/test_imgs/detection/000001.jpg -------------------------------------------------------------------------------- /test_imgs/detection/000002.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tenglike1997/onnxruntime-projects/7e99a75f1ce36329c4dbf2863023e7e05c73a991/test_imgs/detection/000002.jpg -------------------------------------------------------------------------------- /test_imgs/detection/000003.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tenglike1997/onnxruntime-projects/7e99a75f1ce36329c4dbf2863023e7e05c73a991/test_imgs/detection/000003.jpg -------------------------------------------------------------------------------- /test_imgs/detection/000004.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tenglike1997/onnxruntime-projects/7e99a75f1ce36329c4dbf2863023e7e05c73a991/test_imgs/detection/000004.jpg -------------------------------------------------------------------------------- /test_imgs/detection/000005.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tenglike1997/onnxruntime-projects/7e99a75f1ce36329c4dbf2863023e7e05c73a991/test_imgs/detection/000005.jpg -------------------------------------------------------------------------------- /test_imgs/detection/000006.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tenglike1997/onnxruntime-projects/7e99a75f1ce36329c4dbf2863023e7e05c73a991/test_imgs/detection/000006.jpg -------------------------------------------------------------------------------- /test_imgs/detection/000007.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tenglike1997/onnxruntime-projects/7e99a75f1ce36329c4dbf2863023e7e05c73a991/test_imgs/detection/000007.jpg -------------------------------------------------------------------------------- /test_imgs/detection/000008.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tenglike1997/onnxruntime-projects/7e99a75f1ce36329c4dbf2863023e7e05c73a991/test_imgs/detection/000008.jpg -------------------------------------------------------------------------------- /test_imgs/detection/000009.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tenglike1997/onnxruntime-projects/7e99a75f1ce36329c4dbf2863023e7e05c73a991/test_imgs/detection/000009.jpg -------------------------------------------------------------------------------- /test_imgs/detection/000010.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tenglike1997/onnxruntime-projects/7e99a75f1ce36329c4dbf2863023e7e05c73a991/test_imgs/detection/000010.jpg -------------------------------------------------------------------------------- /test_imgs/detection/000011.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tenglike1997/onnxruntime-projects/7e99a75f1ce36329c4dbf2863023e7e05c73a991/test_imgs/detection/000011.jpg -------------------------------------------------------------------------------- /test_imgs/detection/000012.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tenglike1997/onnxruntime-projects/7e99a75f1ce36329c4dbf2863023e7e05c73a991/test_imgs/detection/000012.jpg -------------------------------------------------------------------------------- /test_imgs/detection/000013.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tenglike1997/onnxruntime-projects/7e99a75f1ce36329c4dbf2863023e7e05c73a991/test_imgs/detection/000013.jpg -------------------------------------------------------------------------------- /test_imgs/detection/000014.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tenglike1997/onnxruntime-projects/7e99a75f1ce36329c4dbf2863023e7e05c73a991/test_imgs/detection/000014.jpg -------------------------------------------------------------------------------- /test_imgs/detection/000015.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tenglike1997/onnxruntime-projects/7e99a75f1ce36329c4dbf2863023e7e05c73a991/test_imgs/detection/000015.jpg -------------------------------------------------------------------------------- /test_imgs/detection/000016.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tenglike1997/onnxruntime-projects/7e99a75f1ce36329c4dbf2863023e7e05c73a991/test_imgs/detection/000016.jpg -------------------------------------------------------------------------------- /test_imgs/detection/000017.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tenglike1997/onnxruntime-projects/7e99a75f1ce36329c4dbf2863023e7e05c73a991/test_imgs/detection/000017.jpg -------------------------------------------------------------------------------- /test_imgs/detection/000018.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tenglike1997/onnxruntime-projects/7e99a75f1ce36329c4dbf2863023e7e05c73a991/test_imgs/detection/000018.jpg -------------------------------------------------------------------------------- /test_imgs/segmentation/00000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tenglike1997/onnxruntime-projects/7e99a75f1ce36329c4dbf2863023e7e05c73a991/test_imgs/segmentation/00000.png -------------------------------------------------------------------------------- /test_imgs/segmentation/00001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tenglike1997/onnxruntime-projects/7e99a75f1ce36329c4dbf2863023e7e05c73a991/test_imgs/segmentation/00001.png -------------------------------------------------------------------------------- /test_imgs/segmentation/00002.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tenglike1997/onnxruntime-projects/7e99a75f1ce36329c4dbf2863023e7e05c73a991/test_imgs/segmentation/00002.png -------------------------------------------------------------------------------- /test_imgs/segmentation/00003.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tenglike1997/onnxruntime-projects/7e99a75f1ce36329c4dbf2863023e7e05c73a991/test_imgs/segmentation/00003.png -------------------------------------------------------------------------------- /test_imgs/segmentation/00004.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tenglike1997/onnxruntime-projects/7e99a75f1ce36329c4dbf2863023e7e05c73a991/test_imgs/segmentation/00004.png -------------------------------------------------------------------------------- /test_imgs/segmentation/00005.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tenglike1997/onnxruntime-projects/7e99a75f1ce36329c4dbf2863023e7e05c73a991/test_imgs/segmentation/00005.png -------------------------------------------------------------------------------- /test_imgs/segmentation/00006.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tenglike1997/onnxruntime-projects/7e99a75f1ce36329c4dbf2863023e7e05c73a991/test_imgs/segmentation/00006.png -------------------------------------------------------------------------------- /test_imgs/segmentation/00007.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tenglike1997/onnxruntime-projects/7e99a75f1ce36329c4dbf2863023e7e05c73a991/test_imgs/segmentation/00007.png -------------------------------------------------------------------------------- /test_imgs/segmentation/00008.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tenglike1997/onnxruntime-projects/7e99a75f1ce36329c4dbf2863023e7e05c73a991/test_imgs/segmentation/00008.png -------------------------------------------------------------------------------- /test_imgs/segmentation/00009.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tenglike1997/onnxruntime-projects/7e99a75f1ce36329c4dbf2863023e7e05c73a991/test_imgs/segmentation/00009.png -------------------------------------------------------------------------------- /test_imgs/segmentation/00010.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tenglike1997/onnxruntime-projects/7e99a75f1ce36329c4dbf2863023e7e05c73a991/test_imgs/segmentation/00010.png -------------------------------------------------------------------------------- /test_imgs/style_transfer/church.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tenglike1997/onnxruntime-projects/7e99a75f1ce36329c4dbf2863023e7e05c73a991/test_imgs/style_transfer/church.jpg -------------------------------------------------------------------------------- /test_imgs/super_resolution/LowResolution.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tenglike1997/onnxruntime-projects/7e99a75f1ce36329c4dbf2863023e7e05c73a991/test_imgs/super_resolution/LowResolution.png -------------------------------------------------------------------------------- /test_imgs/super_resolution/RizeResolution.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tenglike1997/onnxruntime-projects/7e99a75f1ce36329c4dbf2863023e7e05c73a991/test_imgs/super_resolution/RizeResolution.png -------------------------------------------------------------------------------- /test_imgs/super_resolution/SuperResolution.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tenglike1997/onnxruntime-projects/7e99a75f1ce36329c4dbf2863023e7e05c73a991/test_imgs/super_resolution/SuperResolution.png -------------------------------------------------------------------------------- /test_imgs/super_resolution/rawimg.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tenglike1997/onnxruntime-projects/7e99a75f1ce36329c4dbf2863023e7e05c73a991/test_imgs/super_resolution/rawimg.jpg -------------------------------------------------------------------------------- /test_models/candy.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tenglike1997/onnxruntime-projects/7e99a75f1ce36329c4dbf2863023e7e05c73a991/test_models/candy.onnx -------------------------------------------------------------------------------- /test_models/erfnet.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tenglike1997/onnxruntime-projects/7e99a75f1ce36329c4dbf2863023e7e05c73a991/test_models/erfnet.onnx -------------------------------------------------------------------------------- /test_models/mobilenetv2-1.0.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tenglike1997/onnxruntime-projects/7e99a75f1ce36329c4dbf2863023e7e05c73a991/test_models/mobilenetv2-1.0.onnx -------------------------------------------------------------------------------- /test_models/mosaic.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tenglike1997/onnxruntime-projects/7e99a75f1ce36329c4dbf2863023e7e05c73a991/test_models/mosaic.onnx -------------------------------------------------------------------------------- /test_models/pointilism.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tenglike1997/onnxruntime-projects/7e99a75f1ce36329c4dbf2863023e7e05c73a991/test_models/pointilism.onnx -------------------------------------------------------------------------------- /test_models/rain_princess.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tenglike1997/onnxruntime-projects/7e99a75f1ce36329c4dbf2863023e7e05c73a991/test_models/rain_princess.onnx -------------------------------------------------------------------------------- /test_models/super_resolution.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tenglike1997/onnxruntime-projects/7e99a75f1ce36329c4dbf2863023e7e05c73a991/test_models/super_resolution.onnx -------------------------------------------------------------------------------- /test_models/tiny_yolov2.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tenglike1997/onnxruntime-projects/7e99a75f1ce36329c4dbf2863023e7e05c73a991/test_models/tiny_yolov2.onnx -------------------------------------------------------------------------------- /test_models/udnie.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tenglike1997/onnxruntime-projects/7e99a75f1ce36329c4dbf2863023e7e05c73a991/test_models/udnie.onnx --------------------------------------------------------------------------------