├── requirements.txt ├── .gitmodules ├── .gitignore ├── convert-model-to-ort.sh ├── include ├── cpu_provider_factory.h ├── provider_options.h ├── onnxruntime_run_options_config_keys.h ├── coreml_provider_factory.h ├── onnxruntime_session_options_config_keys.h ├── onnxruntime_float16.h └── onnxruntime_cxx_inline.h ├── ltcg_patch_for_windows.patch ├── LICENSE └── README.md /requirements.txt: -------------------------------------------------------------------------------- 1 | #tensorflow==2.8 2 | #tf2onnx 3 | onnxruntime==1.16.3 4 | onnx==1.13.1 5 | bin2c -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "onnxruntime"] 2 | path = onnxruntime 3 | url = git@github.com:microsoft/onnxruntime.git 4 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | venv 2 | .vscode/settings.json 3 | *.ort 4 | *.onnx 5 | model/* 6 | *.config 7 | *.pyc 8 | libs 9 | onnxruntime.xcframework 10 | *.a 11 | 12 | -------------------------------------------------------------------------------- /convert-model-to-ort.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #python -m tf2onnx.convert --saved-model model --output model.onnx --opset 13 4 | python -m onnxruntime.tools.convert_onnx_models_to_ort $1 --enable_type_reduction 5 | rm -R ./model/ 6 | mkdir -p ./model 7 | # python verify_model.py 8 | python -m bin2c -o ./model/model.ort model.ort -------------------------------------------------------------------------------- /include/cpu_provider_factory.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | 4 | #include "onnxruntime_c_api.h" 5 | 6 | #ifdef __cplusplus 7 | extern "C" { 8 | #endif 9 | 10 | /** 11 | * \param use_arena zero: false. non-zero: true. 12 | */ 13 | ORT_EXPORT 14 | ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_CPU, _In_ OrtSessionOptions* options, int use_arena) 15 | ORT_ALL_ARGS_NONNULL; 16 | 17 | #ifdef __cplusplus 18 | } 19 | #endif 20 | -------------------------------------------------------------------------------- /include/provider_options.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | 4 | #pragma once 5 | 6 | #include 7 | #include 8 | #include 9 | 10 | namespace onnxruntime { 11 | 12 | // data types for execution provider options 13 | 14 | using ProviderOptions = std::unordered_map; 15 | using ProviderOptionsVector = std::vector; 16 | using ProviderOptionsMap = std::unordered_map; 17 | 18 | } // namespace onnxruntime 19 | -------------------------------------------------------------------------------- /ltcg_patch_for_windows.patch: -------------------------------------------------------------------------------- 1 | diff --git a/cmake/adjust_global_compile_flags.cmake b/cmake/adjust_global_compile_flags.cmake 2 | index 68522a7dda..9040c9b372 100644 3 | --- a/cmake/adjust_global_compile_flags.cmake 4 | +++ b/cmake/adjust_global_compile_flags.cmake 5 | @@ -74,10 +74,7 @@ if (onnxruntime_MINIMAL_BUILD) 6 | endif() 7 | 8 | if (MSVC) 9 | - # turn on LTO (which adds some compiler flags and turns on LTCG) unless it's a Debug build to minimize binary size 10 | - if (NOT CMAKE_BUILD_TYPE STREQUAL "Debug") 11 | - set(onnxruntime_ENABLE_LTO ON) 12 | - endif() 13 | + set(onnxruntime_ENABLE_LTO OFF) 14 | 15 | # undocumented internal flag to allow analysis of a minimal build binary size 16 | if (ADD_DEBUG_INFO_TO_MINIMAL_BUILD) 17 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Oli Larkin 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /include/onnxruntime_run_options_config_keys.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | 4 | #pragma once 5 | 6 | /* 7 | * This file defines RunOptions Config Keys and format of the Config Values. 8 | * 9 | * The Naming Convention for a RunOptions Config Key, 10 | * "[Area][.[SubArea1].[SubArea2]...].[Keyname]" 11 | * Such as "ep.cuda.use_arena" 12 | * The Config Key cannot be empty 13 | * The maximum length of the Config Key is 128 14 | * 15 | * The string format of a RunOptions Config Value is defined individually for each Config. 16 | * The maximum length of the Config Value is 1024 17 | */ 18 | 19 | // Key for enabling shrinkages of user listed device memory arenas. 20 | // Expects a list of semi-colon separated key value pairs separated by colon in the following format: 21 | // "device_0:device_id_0;device_1:device_id_1" 22 | // No white-spaces allowed in the provided list string. 23 | // Currently, the only supported devices are : "cpu", "gpu" (case sensitive). 24 | // If "cpu" is included in the list, DisableCpuMemArena() API must not be called (i.e.) arena for cpu should be enabled. 25 | // Example usage: "cpu:0;gpu:0" (or) "gpu:0" 26 | // By default, the value for this key is empty (i.e.) no memory arenas are shrunk 27 | static const char* const kOrtRunOptionsConfigEnableMemoryArenaShrinkage = "memory.enable_memory_arena_shrinkage"; 28 | 29 | // Set to '1' to not synchronize execution providers with CPU at the end of session run. 30 | // Per default it will be set to '0' 31 | // Taking CUDA EP as an example, it omit triggering cudaStreamSynchronize on the compute stream. 32 | static const char* const kOrtRunOptionsConfigDisableSynchronizeExecutionProviders = "disable_synchronize_execution_providers"; 33 | -------------------------------------------------------------------------------- /include/coreml_provider_factory.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | #pragma once 4 | 5 | #include "onnxruntime_c_api.h" 6 | 7 | // COREMLFlags are bool options we want to set for CoreML EP 8 | // This enum is defined as bit flags, and cannot have negative value 9 | // To generate an uint32_t coreml_flags for using with OrtSessionOptionsAppendExecutionProvider_CoreML below, 10 | // uint32_t coreml_flags = 0; 11 | // coreml_flags |= COREML_FLAG_USE_CPU_ONLY; 12 | enum COREMLFlags { 13 | COREML_FLAG_USE_NONE = 0x000, 14 | 15 | // Using CPU only in CoreML EP, this may decrease the perf but will provide 16 | // reference output value without precision loss, which is useful for validation 17 | COREML_FLAG_USE_CPU_ONLY = 0x001, 18 | 19 | // Enable CoreML EP on subgraph 20 | COREML_FLAG_ENABLE_ON_SUBGRAPH = 0x002, 21 | 22 | // By default CoreML Execution provider will be enabled for all compatible Apple devices 23 | // Enable this option will only enable CoreML EP for Apple devices with ANE (Apple Neural Engine) 24 | // Please note, enable this option does not guarantee the entire model to be executed using ANE only 25 | COREML_FLAG_ONLY_ENABLE_DEVICE_WITH_ANE = 0x004, 26 | 27 | // Only allow CoreML EP to take nodes with inputs with static shapes. By default it will also allow inputs with 28 | // dynamic shapes. However, the performance may be negatively impacted if inputs have dynamic shapes. 29 | COREML_FLAG_ONLY_ALLOW_STATIC_INPUT_SHAPES = 0x008, 30 | 31 | // Keep COREML_FLAG_LAST at the end of the enum definition 32 | // And assign the last COREMLFlag to it 33 | COREML_FLAG_LAST = COREML_FLAG_ONLY_ALLOW_STATIC_INPUT_SHAPES, 34 | }; 35 | 36 | #ifdef __cplusplus 37 | extern "C" { 38 | #endif 39 | 40 | ORT_EXPORT ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_CoreML, 41 | _In_ OrtSessionOptions* options, uint32_t coreml_flags); 42 | 43 | #ifdef __cplusplus 44 | } 45 | #endif 46 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ONNX Runtime static library builder 2 | 3 | Converts an [ONNX](https://onnx.ai) model to ORT format and serializes it to C++ source code, generate custom slimmed ONNX Runtime static libs & xcframework for apple platforms. 4 | 5 | The goal here is to create a flexible but tiny inference engine for a specific model e.g. [iPlug2 example](https://github.com/olilarkin/iPlug2OnnxRuntime). 6 | 7 | *NOTE: due to risk of ODR violations, global/static variable conflicts, and dependency symbol clashes with DAW Hosts that use ORT themselves - think hard before you use this in an audio plugin!* 8 | 9 | The scripts here are configured to create a minimal ORT binary using only the CPU provider. If you want to experiment with GPU inference, Core ML etc, you will have to modify. 10 | 11 | ## Requirements: 12 | 13 | CMake v2.6+ 14 | 15 | ## Instructions: 16 | 17 | 1. Checkout ONNX Runtime submodule `$ git submodule update --init` 18 | 19 | 2. Create a [virtual environment](https://packaging.python.org/tutorials/installing-packages/#creating-virtual-environments) and activate it 20 | 21 | windows 22 | ```bash 23 | $ py -3 -m venv venv 24 | $ source ./venv/Scripts/activate` 25 | ``` 26 | 27 | mac/linux 28 | ```bash 29 | $ python3 -m venv venv 30 | $ source ./venv/bin/activate` 31 | ``` 32 | 33 | 3. Install dependencies `$ pip install -r requirements.txt` 34 | 35 | 4. Run `$ ./convert-model-to-ort.sh model.onnx` 36 | This converts the .onnx file to .ort and produces a .config file which slims the onnxruntime library build in the next step. 37 | It also serializes the .ort format model to C++ source code, which can be used to bake the model into your app binary. If the model 38 | is large this might not be a great solution, and it might be better to locate the .ort file at runtime. 39 | 40 | 6. Build customized onnx runtime static libraries 41 | 42 | ```mac 43 | $ ./build-mac.sh 44 | ``` 45 | 46 | ```ios 47 | $ ./build-ios.sh 48 | $ ./build-ios-simulator.sh 49 | ``` 50 | 51 | ```xcframework build 52 | $ ./build-xcframework.sh 53 | ``` 54 | 55 | **Note:** windows static lib builds can get very large due to the LTO/LTCG settings in onnxruntime. 56 | You can turn that off by applying the change in ltcg_patch_for_windows.patch to the onnxruntime repo. 57 | Due to different MSVC runtimes for Debug and Release builds, we need to build two binaries for windows. 58 | 59 | ```windows 60 | $ ./build-win.sh 61 | ``` 62 | -------------------------------------------------------------------------------- /include/onnxruntime_session_options_config_keys.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | 4 | #pragma once 5 | 6 | /* 7 | * This file defines SessionOptions Config Keys and format of the Config Values. 8 | * 9 | * The Naming Convention for a SessionOptions Config Key, 10 | * "[Area][.[SubArea1].[SubArea2]...].[Keyname]" 11 | * Such as "ep.cuda.use_arena" 12 | * The Config Key cannot be empty 13 | * The maximum length of the Config Key is 128 14 | * 15 | * The string format of a SessionOptions Config Value is defined individually for each Config. 16 | * The maximum length of the Config Value is 1024 17 | */ 18 | 19 | // Key for disable PrePacking, 20 | // If the config value is set to "1" then the prepacking is disabled, otherwise prepacking is enabled (default value) 21 | static const char* const kOrtSessionOptionsConfigDisablePrepacking = "session.disable_prepacking"; 22 | 23 | // A value of "1" means allocators registered in the env will be used. "0" means the allocators created in the session 24 | // will be used. Use this to override the usage of env allocators on a per session level. 25 | static const char* const kOrtSessionOptionsConfigUseEnvAllocators = "session.use_env_allocators"; 26 | 27 | // Set to 'ORT' (case sensitive) to load an ORT format model. 28 | // If unset, model type will default to ONNX unless inferred from filename ('.ort' == ORT format) or bytes to be ORT 29 | static const char* const kOrtSessionOptionsConfigLoadModelFormat = "session.load_model_format"; 30 | 31 | // Set to 'ORT' (case sensitive) to save optimized model in ORT format when SessionOptions.optimized_model_path is set. 32 | // If unset, format will default to ONNX unless optimized_model_filepath ends in '.ort'. 33 | static const char* const kOrtSessionOptionsConfigSaveModelFormat = "session.save_model_format"; 34 | 35 | // If a value is "1", flush-to-zero and denormal-as-zero are applied. The default is "0". 36 | // When multiple sessions are created, a main thread doesn't override changes from succeeding session options, 37 | // but threads in session thread pools follow option changes. 38 | // When ORT runs with OpenMP, the same rule is applied, i.e. the first session option to flush-to-zero and 39 | // denormal-as-zero is only applied to global OpenMP thread pool, which doesn't support per-session thread pool. 40 | // Note that an alternative way not using this option at runtime is to train and export a model without denormals 41 | // and that's recommended because turning this option on may hurt model accuracy. 42 | static const char* const kOrtSessionOptionsConfigSetDenormalAsZero = "session.set_denormal_as_zero"; 43 | 44 | // It controls to run quantization model in QDQ (QuantizelinearDeQuantizelinear) format or not. 45 | // "0": enable. ORT does fusion logic for QDQ format. 46 | // "1": disable. ORT doesn't do fusion logic for QDQ format. 47 | // Its default value is "0" unless the DirectML execution provider is registered, in which case it defaults to "1". 48 | static const char* const kOrtSessionOptionsDisableQuantQDQ = "session.disable_quant_qdq"; 49 | 50 | // It controls whether to enable Double QDQ remover and Identical Children Consolidation 51 | // "0": not to disable. ORT does remove the middle 2 Nodes from a Q->(QD->Q)->QD pairs 52 | // "1": disable. ORT doesn't remove the middle 2 Nodes from a Q->(QD->Q)->QD pairs 53 | // Its default value is "0" 54 | static const char* const kOrtSessionOptionsDisableDoubleQDQRemover = "session.disable_double_qdq_remover"; 55 | 56 | // If set to "1", enables the removal of QuantizeLinear/DequantizeLinear node pairs once all QDQ handling has been 57 | // completed. e.g. If after all QDQ handling has completed and we have -> FloatOp -> Q -> DQ -> FloatOp -> the 58 | // Q -> DQ could potentially be removed. This will provide a performance benefit by avoiding going from float to 59 | // 8-bit and back to float, but could impact accuracy. The impact on accuracy will be model specific and depend on 60 | // other factors like whether the model was created using Quantization Aware Training or Post Training Quantization. 61 | // As such, it's best to test to determine if enabling this works well for your scenario. 62 | // The default value is "0" 63 | // Available since version 1.11. 64 | static const char* const kOrtSessionOptionsEnableQuantQDQCleanup = "session.enable_quant_qdq_cleanup"; 65 | 66 | // Enable or disable gelu approximation in graph optimization. "0": disable; "1": enable. The default is "0". 67 | // GeluApproximation has side effects which may change the inference results. It is disabled by default due to this. 68 | static const char* const kOrtSessionOptionsEnableGeluApproximation = "optimization.enable_gelu_approximation"; 69 | 70 | // This setting controls whether to enable AheadOfTime function inlining. 71 | // AOT function inlining examines the graph and attempts to inline as many locally defined functions in the model 72 | // as possible with the help of enabled execution providers. 73 | // This can reduce the number of function calls and improve performance because it is done before 74 | // Level1 optimizers and constant folding. However, under some circumstances, when the EPs are not available, 75 | // one can disable the AOT inlining, produce an optimized model and postpone AOT until run time. 76 | // "0": enable; "1": disable. 77 | // Its default value is "0". 78 | static const char* const kOrtSessionOptionsDisableAheadOfTimeFunctionInlining = "session.disable_aot_function_inlining"; 79 | 80 | #ifdef ENABLE_TRAINING 81 | // Specifies a list of op types for memory footprint reduction. 82 | // The value should be a ","-delimited list of pair of 83 | // . 84 | // For example, "Gelu+Cast+:1:0,Dropout+:1:1". 85 | // A valid "subgraph string" should be one subgraph representation output by ORT graph transformations. 86 | // "optimization strategy" currently has valid values: 0 - disabled, 1 - recompute. 87 | // "number of subgraph to apply" is used to control how many subgraphs to apply optimization, to avoid "oversaving" 88 | // the memory. 89 | static const char* const kOrtSessionOptionsMemoryOptimizerEnabler = "optimization.memory_optimizer_config"; 90 | 91 | // Specifies the level for detecting subgraphs for memory footprint reduction. 92 | // The value should be an integer. The default value is 0. 93 | static const char* const kOrtSessionOptionsMemoryOptimizerProbeLevel = "optimization.enable_memory_probe_recompute_level"; 94 | #endif 95 | 96 | // Enable or disable using device allocator for allocating initialized tensor memory. "1": enable; "0": disable. The default is "0". 97 | // Using device allocators means the memory allocation is made using malloc/new. 98 | static const char* const kOrtSessionOptionsUseDeviceAllocatorForInitializers = "session.use_device_allocator_for_initializers"; 99 | 100 | // Configure whether to allow the inter_op/intra_op threads spinning a number of times before blocking 101 | // "0": thread will block if found no job to run 102 | // "1": default, thread will spin a number of times before blocking 103 | static const char* const kOrtSessionOptionsConfigAllowInterOpSpinning = "session.inter_op.allow_spinning"; 104 | static const char* const kOrtSessionOptionsConfigAllowIntraOpSpinning = "session.intra_op.allow_spinning"; 105 | 106 | // Key for using model bytes directly for ORT format 107 | // If a session is created using an input byte array contains the ORT format model data, 108 | // By default we will copy the model bytes at the time of session creation to ensure the model bytes 109 | // buffer is valid. 110 | // Setting this option to "1" will disable copy the model bytes, and use the model bytes directly. The caller 111 | // has to guarantee that the model bytes are valid until the ORT session using the model bytes is destroyed. 112 | static const char* const kOrtSessionOptionsConfigUseORTModelBytesDirectly = "session.use_ort_model_bytes_directly"; 113 | 114 | /// 115 | /// Key for using the ORT format model flatbuffer bytes directly for initializers. 116 | /// This avoids copying the bytes and reduces peak memory usage during model loading and initialization. 117 | /// Requires `session.use_ort_model_bytes_directly` to be true. 118 | /// If set, the flatbuffer bytes provided when creating the InferenceSession MUST remain valid for the entire 119 | /// duration of the InferenceSession. 120 | /// 121 | static const char* const kOrtSessionOptionsConfigUseORTModelBytesForInitializers = 122 | "session.use_ort_model_bytes_for_initializers"; 123 | 124 | // This should only be specified when exporting an ORT format model for use on a different platform. 125 | // If the ORT format model will be used on ARM platforms set to "1". For other platforms set to "0" 126 | // Available since version 1.11. 127 | static const char* const kOrtSessionOptionsQDQIsInt8Allowed = "session.qdqisint8allowed"; 128 | 129 | // x64 SSE4.1/AVX2/AVX512(with no VNNI) has overflow problem with quantizied matrix multiplication with U8S8. 130 | // To avoid this we need to use slower U8U8 matrix multiplication instead. This option, if 131 | // turned on, use slower U8U8 matrix multiplications. Only effective with AVX2 or AVX512 132 | // platforms. 133 | static const char* const kOrtSessionOptionsAvx2PrecisionMode = "session.x64quantprecision"; 134 | 135 | // Specifies how minimal build graph optimizations are handled in a full build. 136 | // These optimizations are at the extended level or higher. 137 | // Possible values and their effects are: 138 | // "save": Save runtime optimizations when saving an ORT format model. 139 | // "apply": Only apply optimizations available in a minimal build. 140 | // ""/: Apply optimizations available in a full build. 141 | // Available since version 1.11. 142 | static const char* const kOrtSessionOptionsConfigMinimalBuildOptimizations = 143 | "optimization.minimal_build_optimizations"; 144 | 145 | // Note: The options specific to an EP should be specified prior to appending that EP to the session options object in 146 | // order for them to take effect. 147 | 148 | // Specifies a list of stop op types. Nodes of a type in the stop op types and nodes downstream from them will not be 149 | // run by the NNAPI EP. 150 | // The value should be a ","-delimited list of op types. For example, "Add,Sub". 151 | // If not specified, the default set of stop ops is used. To specify an empty stop ops types list and disable stop op 152 | // exclusion, set the value to "". 153 | static const char* const kOrtSessionOptionsConfigNnapiEpPartitioningStopOps = "ep.nnapi.partitioning_stop_ops"; 154 | 155 | // Enabling dynamic block-sizing for multithreading. 156 | // With a positive value, thread pool will split a task of N iterations to blocks of size starting from: 157 | // N / (num_of_threads * dynamic_block_base) 158 | // As execution progresses, the size will decrease according to the diminishing residual of N, 159 | // meaning the task will be distributed in smaller granularity for better parallelism. 160 | // For some models, it helps to reduce the variance of E2E inference latency and boost performance. 161 | // The feature will not function by default, specify any positive integer, e.g. "4", to enable it. 162 | // Available since version 1.11. 163 | static const char* const kOrtSessionOptionsConfigDynamicBlockBase = "session.dynamic_block_base"; 164 | 165 | // This option allows to decrease CPU usage between infrequent 166 | // requests and forces any TP threads spinning stop immediately when the last of 167 | // concurrent Run() call returns. 168 | // Spinning is restarted on the next Run() call. 169 | // Applies only to internal thread-pools 170 | static const char* const kOrtSessionOptionsConfigForceSpinningStop = "session.force_spinning_stop"; 171 | 172 | // "1": all inconsistencies encountered during shape and type inference 173 | // will result in failures. 174 | // "0": in some cases warnings will be logged but processing will continue. The default. 175 | // May be useful to expose bugs in models. 176 | static const char* const kOrtSessionOptionsConfigStrictShapeTypeInference = "session.strict_shape_type_inference"; 177 | 178 | // "1": every model using a more recent opset than the latest released one will fail 179 | // "0": the model may or may not work if onnxruntime cannot find an implementation, this option 180 | // is used for development purpose. 181 | static const char* const kOrtSessionOptionsConfigStrictAllowReleasedOpsetsOnly = "session.allow_released_opsets_only"; 182 | 183 | // The file saves configuration for partitioning node among logic streams 184 | static const char* const kNodePartitionConfigFile = "session.node_partition_config_file"; 185 | 186 | // This Option allows setting affinities for intra op threads. 187 | // Affinity string follows format: 188 | // logical_processor_id,logical_processor_id;logical_processor_id,logical_processor_id 189 | // Semicolon isolates configurations among threads, while comma split processors where ith thread expected to attach to. 190 | // e.g.1,2,3;4,5 191 | // specifies affinities for two threads, with the 1st thread attach to the 1st, 2nd, and 3rd processor, and 2nd thread to the 4th and 5th. 192 | // To ease the configuration, an "interval" is also allowed: 193 | // e.g. 1-8;8-16;17-24 194 | // orders that the 1st thread runs on first eight processors, 2nd thread runs on next eight processors, and so forth. 195 | // Note: 196 | // 1. Once set, the number of thread affinities must equal to intra_op_num_threads - 1, since ort does not set affinity on the main thread which 197 | // is started and managed by the calling app; 198 | // 2. For windows, ort will infer the group id from a logical processor id, for example, assuming there are two groups with each has 64 logical processors, 199 | // an id of 64 will be inferred as the last processor of the 1st group, while 65 will be interpreted as the 1st processor of the second group. 200 | // Hence 64-65 is an invalid configuration, because a windows thread cannot be attached to processors across group boundary. 201 | static const char* const kOrtSessionOptionsConfigIntraOpThreadAffinities = "session.intra_op_thread_affinities"; 202 | 203 | // This option will dump out the model to assist debugging any issues with layout transformation, 204 | // and is primarily intended for developer usage. It is only relevant if an execution provider that requests 205 | // NHWC layout is enabled such as NNAPI, XNNPACK or QNN. 206 | // 207 | // Default is off. Set to "1" to enable. 208 | // 209 | // If modified by layout transformation the model will be dumped after these steps: 210 | // 1) insertion of the layout transformation Transpose nodes 211 | // 2) after those are optimized using the transpose optimizer, 212 | // 3) after the L1 transformers are applied to the updated graph. 213 | // The model will be saved to filename post_layout_transform_step_.onnx. 214 | static const char* const kDebugLayoutTransformation = "session.debug_layout_transformation"; 215 | 216 | // Graph nodes that are not supported by the execution providers (EPs) explicitly added to the session are 217 | // assigned (i.e., "fallback") to the CPU EP by default. 218 | // 219 | // This option allows the user to disable the fallback of unsupported graph nodes to the CPU EP. 220 | // If this option is set to "1", session creation will fail if the execution providers other than the CPU EP cannot 221 | // fully support all of the nodes in the graph. 222 | // 223 | // It is invalid to set this option and explicitly add the CPU EP to the session. In this case, session creation 224 | // will also fail with an error. 225 | // 226 | // Option values: 227 | // - "0": CPU EP fallback is not disabled. [DEFAULT] 228 | // - "1": CPU EP fallback is disabled. 229 | static const char* const kOrtSessionOptionsDisableCPUEPFallback = "session.disable_cpu_ep_fallback"; 230 | 231 | // Use this config when serializing a large model after optimization to specify an external initializers file 232 | static const char* const kOrtSessionOptionsOptimizedModelExternalInitializersFileName = 233 | "session.optimized_model_external_initializers_file_name"; 234 | 235 | // Use this config to control the minimum size of the initializer when externalizing it during serialization 236 | static const char* const kOrtSessionOptionsOptimizedModelExternalInitializersMinSizeInBytes = 237 | "session.optimized_model_external_initializers_min_size_in_bytes"; 238 | -------------------------------------------------------------------------------- /include/onnxruntime_float16.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | 4 | #pragma once 5 | 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | namespace onnxruntime_float16 { 12 | 13 | namespace detail { 14 | 15 | enum class endian { 16 | #if defined(_WIN32) 17 | little = 0, 18 | big = 1, 19 | native = little, 20 | #elif defined(__GNUC__) || defined(__clang__) 21 | little = __ORDER_LITTLE_ENDIAN__, 22 | big = __ORDER_BIG_ENDIAN__, 23 | native = __BYTE_ORDER__, 24 | #else 25 | #error onnxruntime_float16::detail::endian is not implemented in this environment. 26 | #endif 27 | }; 28 | 29 | static_assert( 30 | endian::native == endian::little || endian::native == endian::big, 31 | "Only little-endian or big-endian native byte orders are supported."); 32 | 33 | } // namespace detail 34 | 35 | /// 36 | /// Shared implementation between public and internal classes. CRTP pattern. 37 | /// 38 | template 39 | struct Float16Impl { 40 | protected: 41 | /// 42 | /// Converts from float to uint16_t float16 representation 43 | /// 44 | /// 45 | /// 46 | constexpr static uint16_t ToUint16Impl(float v) noexcept; 47 | 48 | /// 49 | /// Converts float16 to float 50 | /// 51 | /// float representation of float16 value 52 | float ToFloatImpl() const noexcept; 53 | 54 | /// 55 | /// Creates an instance that represents absolute value. 56 | /// 57 | /// Absolute value 58 | uint16_t AbsImpl() const noexcept { 59 | return static_cast(val & ~kSignMask); 60 | } 61 | 62 | /// 63 | /// Creates a new instance with the sign flipped. 64 | /// 65 | /// Flipped sign instance 66 | uint16_t NegateImpl() const noexcept { 67 | return IsNaN() ? val : static_cast(val ^ kSignMask); 68 | } 69 | 70 | public: 71 | // uint16_t special values 72 | static constexpr uint16_t kSignMask = 0x8000U; 73 | static constexpr uint16_t kBiasedExponentMask = 0x7C00U; 74 | static constexpr uint16_t kPositiveInfinityBits = 0x7C00U; 75 | static constexpr uint16_t kNegativeInfinityBits = 0xFC00U; 76 | static constexpr uint16_t kPositiveQNaNBits = 0x7E00U; 77 | static constexpr uint16_t kNegativeQNaNBits = 0xFE00U; 78 | static constexpr uint16_t kEpsilonBits = 0x4170U; 79 | static constexpr uint16_t kMinValueBits = 0xFBFFU; // Minimum normal number 80 | static constexpr uint16_t kMaxValueBits = 0x7BFFU; // Largest normal number 81 | static constexpr uint16_t kOneBits = 0x3C00U; 82 | static constexpr uint16_t kMinusOneBits = 0xBC00U; 83 | 84 | uint16_t val{0}; 85 | 86 | Float16Impl() = default; 87 | 88 | /// 89 | /// Checks if the value is negative 90 | /// 91 | /// true if negative 92 | bool IsNegative() const noexcept { 93 | return static_cast(val) < 0; 94 | } 95 | 96 | /// 97 | /// Tests if the value is NaN 98 | /// 99 | /// true if NaN 100 | bool IsNaN() const noexcept { 101 | return AbsImpl() > kPositiveInfinityBits; 102 | } 103 | 104 | /// 105 | /// Tests if the value is finite 106 | /// 107 | /// true if finite 108 | bool IsFinite() const noexcept { 109 | return AbsImpl() < kPositiveInfinityBits; 110 | } 111 | 112 | /// 113 | /// Tests if the value represents positive infinity. 114 | /// 115 | /// true if positive infinity 116 | bool IsPositiveInfinity() const noexcept { 117 | return val == kPositiveInfinityBits; 118 | } 119 | 120 | /// 121 | /// Tests if the value represents negative infinity 122 | /// 123 | /// true if negative infinity 124 | bool IsNegativeInfinity() const noexcept { 125 | return val == kNegativeInfinityBits; 126 | } 127 | 128 | /// 129 | /// Tests if the value is either positive or negative infinity. 130 | /// 131 | /// True if absolute value is infinity 132 | bool IsInfinity() const noexcept { 133 | return AbsImpl() == kPositiveInfinityBits; 134 | } 135 | 136 | /// 137 | /// Tests if the value is NaN or zero. Useful for comparisons. 138 | /// 139 | /// True if NaN or zero. 140 | bool IsNaNOrZero() const noexcept { 141 | auto abs = AbsImpl(); 142 | return (abs == 0 || abs > kPositiveInfinityBits); 143 | } 144 | 145 | /// 146 | /// Tests if the value is normal (not zero, subnormal, infinite, or NaN). 147 | /// 148 | /// True if so 149 | bool IsNormal() const noexcept { 150 | auto abs = AbsImpl(); 151 | return (abs < kPositiveInfinityBits) // is finite 152 | && (abs != 0) // is not zero 153 | && ((abs & kBiasedExponentMask) != 0); // is not subnormal (has a non-zero exponent) 154 | } 155 | 156 | /// 157 | /// Tests if the value is subnormal (denormal). 158 | /// 159 | /// True if so 160 | bool IsSubnormal() const noexcept { 161 | auto abs = AbsImpl(); 162 | return (abs < kPositiveInfinityBits) // is finite 163 | && (abs != 0) // is not zero 164 | && ((abs & kBiasedExponentMask) == 0); // is subnormal (has a zero exponent) 165 | } 166 | 167 | /// 168 | /// Creates an instance that represents absolute value. 169 | /// 170 | /// Absolute value 171 | Derived Abs() const noexcept { return Derived::FromBits(AbsImpl()); } 172 | 173 | /// 174 | /// Creates a new instance with the sign flipped. 175 | /// 176 | /// Flipped sign instance 177 | Derived Negate() const noexcept { return Derived::FromBits(NegateImpl()); } 178 | 179 | /// 180 | /// IEEE defines that positive and negative zero are equal, this gives us a quick equality check 181 | /// for two values by or'ing the private bits together and stripping the sign. They are both zero, 182 | /// and therefore equivalent, if the resulting value is still zero. 183 | /// 184 | /// first value 185 | /// second value 186 | /// True if both arguments represent zero 187 | static bool AreZero(const Float16Impl& lhs, const Float16Impl& rhs) noexcept { 188 | return static_cast((lhs.val | rhs.val) & ~kSignMask) == 0; 189 | } 190 | 191 | bool operator==(const Float16Impl& rhs) const noexcept { 192 | if (IsNaN() || rhs.IsNaN()) { 193 | // IEEE defines that NaN is not equal to anything, including itself. 194 | return false; 195 | } 196 | return val == rhs.val; 197 | } 198 | 199 | bool operator!=(const Float16Impl& rhs) const noexcept { return !(*this == rhs); } 200 | 201 | bool operator<(const Float16Impl& rhs) const noexcept { 202 | if (IsNaN() || rhs.IsNaN()) { 203 | // IEEE defines that NaN is unordered with respect to everything, including itself. 204 | return false; 205 | } 206 | 207 | const bool left_is_negative = IsNegative(); 208 | if (left_is_negative != rhs.IsNegative()) { 209 | // When the signs of left and right differ, we know that left is less than right if it is 210 | // the negative value. The exception to this is if both values are zero, in which case IEEE 211 | // says they should be equal, even if the signs differ. 212 | return left_is_negative && !AreZero(*this, rhs); 213 | } 214 | return (val != rhs.val) && ((val < rhs.val) ^ left_is_negative); 215 | } 216 | }; 217 | 218 | // The following Float16_t conversions are based on the code from 219 | // Eigen library. 220 | 221 | // The conversion routines are Copyright (c) Fabian Giesen, 2016. 222 | // The original license follows: 223 | // 224 | // Copyright (c) Fabian Giesen, 2016 225 | // All rights reserved. 226 | // Redistribution and use in source and binary forms, with or without 227 | // modification, are permitted. 228 | // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 229 | // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 230 | // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 231 | // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 232 | // HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 233 | // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 234 | // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 235 | // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 236 | // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 237 | // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 238 | // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 239 | 240 | namespace detail { 241 | union float32_bits { 242 | unsigned int u; 243 | float f; 244 | }; 245 | } // namespace detail 246 | 247 | template 248 | inline constexpr uint16_t Float16Impl::ToUint16Impl(float v) noexcept { 249 | detail::float32_bits f{}; 250 | f.f = v; 251 | 252 | constexpr detail::float32_bits f32infty = {255 << 23}; 253 | constexpr detail::float32_bits f16max = {(127 + 16) << 23}; 254 | constexpr detail::float32_bits denorm_magic = {((127 - 15) + (23 - 10) + 1) << 23}; 255 | constexpr unsigned int sign_mask = 0x80000000u; 256 | uint16_t val = static_cast(0x0u); 257 | 258 | unsigned int sign = f.u & sign_mask; 259 | f.u ^= sign; 260 | 261 | // NOTE all the integer compares in this function can be safely 262 | // compiled into signed compares since all operands are below 263 | // 0x80000000. Important if you want fast straight SSE2 code 264 | // (since there's no unsigned PCMPGTD). 265 | 266 | if (f.u >= f16max.u) { // result is Inf or NaN (all exponent bits set) 267 | val = (f.u > f32infty.u) ? 0x7e00 : 0x7c00; // NaN->qNaN and Inf->Inf 268 | } else { // (De)normalized number or zero 269 | if (f.u < (113 << 23)) { // resulting FP16 is subnormal or zero 270 | // use a magic value to align our 10 mantissa bits at the bottom of 271 | // the float. as long as FP addition is round-to-nearest-even this 272 | // just works. 273 | f.f += denorm_magic.f; 274 | 275 | // and one integer subtract of the bias later, we have our final float! 276 | val = static_cast(f.u - denorm_magic.u); 277 | } else { 278 | unsigned int mant_odd = (f.u >> 13) & 1; // resulting mantissa is odd 279 | 280 | // update exponent, rounding bias part 1 281 | // Equivalent to `f.u += ((unsigned int)(15 - 127) << 23) + 0xfff`, but 282 | // without arithmetic overflow. 283 | f.u += 0xc8000fffU; 284 | // rounding bias part 2 285 | f.u += mant_odd; 286 | // take the bits! 287 | val = static_cast(f.u >> 13); 288 | } 289 | } 290 | 291 | val |= static_cast(sign >> 16); 292 | return val; 293 | } 294 | 295 | template 296 | inline float Float16Impl::ToFloatImpl() const noexcept { 297 | constexpr detail::float32_bits magic = {113 << 23}; 298 | constexpr unsigned int shifted_exp = 0x7c00 << 13; // exponent mask after shift 299 | detail::float32_bits o{}; 300 | 301 | o.u = (val & 0x7fff) << 13; // exponent/mantissa bits 302 | unsigned int exp = shifted_exp & o.u; // just the exponent 303 | o.u += (127 - 15) << 23; // exponent adjust 304 | 305 | // handle exponent special cases 306 | if (exp == shifted_exp) { // Inf/NaN? 307 | o.u += (128 - 16) << 23; // extra exp adjust 308 | } else if (exp == 0) { // Zero/Denormal? 309 | o.u += 1 << 23; // extra exp adjust 310 | o.f -= magic.f; // re-normalize 311 | } 312 | 313 | // Attempt to workaround the Internal Compiler Error on ARM64 314 | // for bitwise | operator, including std::bitset 315 | #if (defined _MSC_VER) && (defined _M_ARM || defined _M_ARM64 || defined _M_ARM64EC) 316 | if (IsNegative()) { 317 | return -o.f; 318 | } 319 | #else 320 | // original code: 321 | o.u |= (val & 0x8000U) << 16U; // sign bit 322 | #endif 323 | return o.f; 324 | } 325 | 326 | /// Shared implementation between public and internal classes. CRTP pattern. 327 | template 328 | struct BFloat16Impl { 329 | protected: 330 | /// 331 | /// Converts from float to uint16_t float16 representation 332 | /// 333 | /// 334 | /// 335 | static uint16_t ToUint16Impl(float v) noexcept; 336 | 337 | /// 338 | /// Converts bfloat16 to float 339 | /// 340 | /// float representation of bfloat16 value 341 | float ToFloatImpl() const noexcept; 342 | 343 | /// 344 | /// Creates an instance that represents absolute value. 345 | /// 346 | /// Absolute value 347 | uint16_t AbsImpl() const noexcept { 348 | return static_cast(val & ~kSignMask); 349 | } 350 | 351 | /// 352 | /// Creates a new instance with the sign flipped. 353 | /// 354 | /// Flipped sign instance 355 | uint16_t NegateImpl() const noexcept { 356 | return IsNaN() ? val : static_cast(val ^ kSignMask); 357 | } 358 | 359 | public: 360 | // uint16_t special values 361 | static constexpr uint16_t kSignMask = 0x8000U; 362 | static constexpr uint16_t kBiasedExponentMask = 0x7F80U; 363 | static constexpr uint16_t kPositiveInfinityBits = 0x7F80U; 364 | static constexpr uint16_t kNegativeInfinityBits = 0xFF80U; 365 | static constexpr uint16_t kPositiveQNaNBits = 0x7FC1U; 366 | static constexpr uint16_t kNegativeQNaNBits = 0xFFC1U; 367 | static constexpr uint16_t kSignaling_NaNBits = 0x7F80U; 368 | static constexpr uint16_t kEpsilonBits = 0x0080U; 369 | static constexpr uint16_t kMinValueBits = 0xFF7FU; 370 | static constexpr uint16_t kMaxValueBits = 0x7F7FU; 371 | static constexpr uint16_t kRoundToNearest = 0x7FFFU; 372 | static constexpr uint16_t kOneBits = 0x3F80U; 373 | static constexpr uint16_t kMinusOneBits = 0xBF80U; 374 | 375 | uint16_t val{0}; 376 | 377 | BFloat16Impl() = default; 378 | 379 | /// 380 | /// Checks if the value is negative 381 | /// 382 | /// true if negative 383 | bool IsNegative() const noexcept { 384 | return static_cast(val) < 0; 385 | } 386 | 387 | /// 388 | /// Tests if the value is NaN 389 | /// 390 | /// true if NaN 391 | bool IsNaN() const noexcept { 392 | return AbsImpl() > kPositiveInfinityBits; 393 | } 394 | 395 | /// 396 | /// Tests if the value is finite 397 | /// 398 | /// true if finite 399 | bool IsFinite() const noexcept { 400 | return AbsImpl() < kPositiveInfinityBits; 401 | } 402 | 403 | /// 404 | /// Tests if the value represents positive infinity. 405 | /// 406 | /// true if positive infinity 407 | bool IsPositiveInfinity() const noexcept { 408 | return val == kPositiveInfinityBits; 409 | } 410 | 411 | /// 412 | /// Tests if the value represents negative infinity 413 | /// 414 | /// true if negative infinity 415 | bool IsNegativeInfinity() const noexcept { 416 | return val == kNegativeInfinityBits; 417 | } 418 | 419 | /// 420 | /// Tests if the value is either positive or negative infinity. 421 | /// 422 | /// True if absolute value is infinity 423 | bool IsInfinity() const noexcept { 424 | return AbsImpl() == kPositiveInfinityBits; 425 | } 426 | 427 | /// 428 | /// Tests if the value is NaN or zero. Useful for comparisons. 429 | /// 430 | /// True if NaN or zero. 431 | bool IsNaNOrZero() const noexcept { 432 | auto abs = AbsImpl(); 433 | return (abs == 0 || abs > kPositiveInfinityBits); 434 | } 435 | 436 | /// 437 | /// Tests if the value is normal (not zero, subnormal, infinite, or NaN). 438 | /// 439 | /// True if so 440 | bool IsNormal() const noexcept { 441 | auto abs = AbsImpl(); 442 | return (abs < kPositiveInfinityBits) // is finite 443 | && (abs != 0) // is not zero 444 | && ((abs & kBiasedExponentMask) != 0); // is not subnormal (has a non-zero exponent) 445 | } 446 | 447 | /// 448 | /// Tests if the value is subnormal (denormal). 449 | /// 450 | /// True if so 451 | bool IsSubnormal() const noexcept { 452 | auto abs = AbsImpl(); 453 | return (abs < kPositiveInfinityBits) // is finite 454 | && (abs != 0) // is not zero 455 | && ((abs & kBiasedExponentMask) == 0); // is subnormal (has a zero exponent) 456 | } 457 | 458 | /// 459 | /// Creates an instance that represents absolute value. 460 | /// 461 | /// Absolute value 462 | Derived Abs() const noexcept { return Derived::FromBits(AbsImpl()); } 463 | 464 | /// 465 | /// Creates a new instance with the sign flipped. 466 | /// 467 | /// Flipped sign instance 468 | Derived Negate() const noexcept { return Derived::FromBits(NegateImpl()); } 469 | 470 | /// 471 | /// IEEE defines that positive and negative zero are equal, this gives us a quick equality check 472 | /// for two values by or'ing the private bits together and stripping the sign. They are both zero, 473 | /// and therefore equivalent, if the resulting value is still zero. 474 | /// 475 | /// first value 476 | /// second value 477 | /// True if both arguments represent zero 478 | static bool AreZero(const BFloat16Impl& lhs, const BFloat16Impl& rhs) noexcept { 479 | // IEEE defines that positive and negative zero are equal, this gives us a quick equality check 480 | // for two values by or'ing the private bits together and stripping the sign. They are both zero, 481 | // and therefore equivalent, if the resulting value is still zero. 482 | return static_cast((lhs.val | rhs.val) & ~kSignMask) == 0; 483 | } 484 | }; 485 | 486 | template 487 | inline uint16_t BFloat16Impl::ToUint16Impl(float v) noexcept { 488 | uint16_t result; 489 | if (std::isnan(v)) { 490 | result = kPositiveQNaNBits; 491 | } else { 492 | auto get_msb_half = [](float fl) { 493 | uint16_t result; 494 | #ifdef __cpp_if_constexpr 495 | if constexpr (detail::endian::native == detail::endian::little) { 496 | #else 497 | if (detail::endian::native == detail::endian::little) { 498 | #endif 499 | std::memcpy(&result, reinterpret_cast(&fl) + sizeof(uint16_t), sizeof(uint16_t)); 500 | } else { 501 | std::memcpy(&result, &fl, sizeof(uint16_t)); 502 | } 503 | return result; 504 | }; 505 | 506 | uint16_t upper_bits = get_msb_half(v); 507 | union { 508 | uint32_t U32; 509 | float F32; 510 | }; 511 | F32 = v; 512 | U32 += (upper_bits & 1) + kRoundToNearest; 513 | result = get_msb_half(F32); 514 | } 515 | return result; 516 | } 517 | 518 | template 519 | inline float BFloat16Impl::ToFloatImpl() const noexcept { 520 | if (IsNaN()) { 521 | return std::numeric_limits::quiet_NaN(); 522 | } 523 | float result; 524 | char* const first = reinterpret_cast(&result); 525 | char* const second = first + sizeof(uint16_t); 526 | #ifdef __cpp_if_constexpr 527 | if constexpr (detail::endian::native == detail::endian::little) { 528 | #else 529 | if (detail::endian::native == detail::endian::little) { 530 | #endif 531 | std::memset(first, 0, sizeof(uint16_t)); 532 | std::memcpy(second, &val, sizeof(uint16_t)); 533 | } else { 534 | std::memcpy(first, &val, sizeof(uint16_t)); 535 | std::memset(second, 0, sizeof(uint16_t)); 536 | } 537 | return result; 538 | } 539 | 540 | } // namespace onnxruntime_float16 541 | -------------------------------------------------------------------------------- /include/onnxruntime_cxx_inline.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | 4 | // Do not include this file directly. Please include "onnxruntime_cxx_api.h" instead. 5 | // If interested in trying out features of the new experimental C++ API, include "experimental_onnxruntime_cxx_api.h" instead. 6 | // 7 | // These are the inline implementations of the C++ header APIs. They're in this separate file as to not clutter 8 | // the main C++ file with implementation details. 9 | 10 | #include 11 | 12 | namespace Ort { 13 | 14 | namespace detail { 15 | inline void ThrowStatus(const Status& st) { 16 | std::string error_message = st.GetErrorMessage(); 17 | OrtErrorCode error_code = st.GetErrorCode(); 18 | ORT_CXX_API_THROW(std::move(error_message), error_code); 19 | } 20 | } // namespace detail 21 | 22 | inline void ThrowOnError(OrtStatus* ort_status) { 23 | if (ort_status) { 24 | Ort::Status st(ort_status); 25 | detail::ThrowStatus(st); 26 | } 27 | } 28 | 29 | inline void ThrowOnError(const Status& st) { 30 | if (st) { 31 | detail::ThrowStatus(st); 32 | } 33 | } 34 | 35 | inline Status::Status(OrtStatus* status) noexcept : Base{status} { 36 | } 37 | 38 | inline Status::Status(const std::exception& e) noexcept { 39 | p_ = GetApi().CreateStatus(ORT_FAIL, e.what()); 40 | } 41 | 42 | inline Status::Status(const Exception& e) noexcept { 43 | p_ = GetApi().CreateStatus(e.GetOrtErrorCode(), e.what()); 44 | } 45 | 46 | inline Status::Status(const char* message, OrtErrorCode code) noexcept { 47 | p_ = GetApi().CreateStatus(code, message); 48 | } 49 | 50 | inline std::string Status::GetErrorMessage() const { 51 | std::string message(GetApi().GetErrorMessage(p_)); 52 | return message; 53 | } 54 | 55 | inline OrtErrorCode Status::GetErrorCode() const { 56 | return GetApi().GetErrorCode(p_); 57 | } 58 | 59 | inline bool Status::IsOK() const noexcept { 60 | return (p_ == nullptr); 61 | } 62 | 63 | // This template converts a C++ type into it's ONNXTensorElementDataType 64 | template 65 | struct TypeToTensorType; 66 | template <> 67 | struct TypeToTensorType { 68 | static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; 69 | }; 70 | template <> 71 | struct TypeToTensorType { 72 | static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16; 73 | }; 74 | template <> 75 | struct TypeToTensorType { 76 | static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16; 77 | }; 78 | template <> 79 | struct TypeToTensorType { 80 | static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE; 81 | }; 82 | template <> 83 | struct TypeToTensorType { 84 | static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8; 85 | }; 86 | template <> 87 | struct TypeToTensorType { 88 | static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16; 89 | }; 90 | template <> 91 | struct TypeToTensorType { 92 | static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32; 93 | }; 94 | template <> 95 | struct TypeToTensorType { 96 | static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; 97 | }; 98 | template <> 99 | struct TypeToTensorType { 100 | static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8; 101 | }; 102 | template <> 103 | struct TypeToTensorType { 104 | static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16; 105 | }; 106 | template <> 107 | struct TypeToTensorType { 108 | static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32; 109 | }; 110 | template <> 111 | struct TypeToTensorType { 112 | static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64; 113 | }; 114 | template <> 115 | struct TypeToTensorType { 116 | static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL; 117 | }; 118 | 119 | template <> 120 | struct TypeToTensorType { 121 | static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN; 122 | }; 123 | template <> 124 | struct TypeToTensorType { 125 | static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FNUZ; 126 | }; 127 | template <> 128 | struct TypeToTensorType { 129 | static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2; 130 | }; 131 | template <> 132 | struct TypeToTensorType { 133 | static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ; 134 | }; 135 | 136 | inline bool BFloat16_t::operator==(const BFloat16_t& rhs) const noexcept { 137 | if (IsNaN() || rhs.IsNaN()) { 138 | // IEEE defines that NaN is not equal to anything, including itself. 139 | return false; 140 | } 141 | return val == rhs.val; 142 | } 143 | 144 | inline bool BFloat16_t::operator<(const BFloat16_t& rhs) const noexcept { 145 | if (IsNaN() || rhs.IsNaN()) { 146 | // IEEE defines that NaN is unordered with respect to everything, including itself. 147 | return false; 148 | } 149 | 150 | const bool left_is_negative = IsNegative(); 151 | if (left_is_negative != rhs.IsNegative()) { 152 | // When the signs of left and right differ, we know that left is less than right if it is 153 | // the negative value. The exception to this is if both values are zero, in which case IEEE 154 | // says they should be equal, even if the signs differ. 155 | return left_is_negative && !AreZero(*this, rhs); 156 | } 157 | return (val != rhs.val) && ((val < rhs.val) ^ left_is_negative); 158 | } 159 | 160 | inline MemoryAllocation::MemoryAllocation(OrtAllocator* allocator, void* p, size_t size) 161 | : allocator_(allocator), p_(p), size_(size) { 162 | } 163 | 164 | inline MemoryAllocation::~MemoryAllocation() { 165 | if (p_ != nullptr) { 166 | // We do not throw out of destructor 167 | auto ret = GetApi().AllocatorFree(allocator_, p_); 168 | static_cast(ret); 169 | } 170 | } 171 | 172 | inline MemoryAllocation::MemoryAllocation(MemoryAllocation&& o) noexcept : allocator_(nullptr), p_(nullptr), size_(0) { 173 | *this = std::move(o); 174 | } 175 | 176 | inline MemoryAllocation& MemoryAllocation::operator=(MemoryAllocation&& o) noexcept { 177 | OrtAllocator* alloc = nullptr; 178 | void* p = nullptr; 179 | size_t sz = 0; 180 | 181 | // Swap out this 182 | std::swap(alloc, allocator_); 183 | std::swap(p, p_); 184 | std::swap(sz, size_); 185 | 186 | // Swap with incoming 187 | std::swap(allocator_, o.allocator_); 188 | std::swap(p_, o.p_); 189 | std::swap(size_, o.size_); 190 | 191 | // Destroy this instance if needed 192 | MemoryAllocation this_alloc(alloc, p, sz); 193 | return *this; 194 | } 195 | 196 | namespace detail { 197 | 198 | template 199 | inline void* AllocatorImpl::Alloc(size_t size) { 200 | void* out; 201 | ThrowOnError(GetApi().AllocatorAlloc(this->p_, size, &out)); 202 | return out; 203 | } 204 | 205 | template 206 | inline MemoryAllocation AllocatorImpl::GetAllocation(size_t size) { 207 | void* out; 208 | ThrowOnError(GetApi().AllocatorAlloc(this->p_, size, &out)); 209 | MemoryAllocation result(this->p_, out, size); 210 | return result; 211 | } 212 | 213 | template 214 | inline void AllocatorImpl::Free(void* p) { 215 | ThrowOnError(GetApi().AllocatorFree(this->p_, p)); 216 | } 217 | 218 | template 219 | inline ConstMemoryInfo AllocatorImpl::GetInfo() const { 220 | const OrtMemoryInfo* out; 221 | ThrowOnError(GetApi().AllocatorGetInfo(this->p_, &out)); 222 | return ConstMemoryInfo{out}; 223 | } 224 | 225 | } // namespace detail 226 | 227 | inline AllocatorWithDefaultOptions::AllocatorWithDefaultOptions() { 228 | ThrowOnError(GetApi().GetAllocatorWithDefaultOptions(&this->p_)); 229 | } 230 | 231 | inline Allocator::Allocator(const Session& sess, const OrtMemoryInfo* mem_info) { 232 | ThrowOnError(GetApi().CreateAllocator(sess, mem_info, &this->p_)); 233 | } 234 | 235 | namespace detail { 236 | 237 | template 238 | inline std::string MemoryInfoImpl::GetAllocatorName() const { 239 | const char* name = nullptr; 240 | ThrowOnError(GetApi().MemoryInfoGetName(this->p_, &name)); 241 | return std::string(name); 242 | } 243 | 244 | template 245 | inline OrtAllocatorType MemoryInfoImpl::GetAllocatorType() const { 246 | OrtAllocatorType type; 247 | ThrowOnError(GetApi().MemoryInfoGetType(this->p_, &type)); 248 | return type; 249 | } 250 | 251 | template 252 | inline int MemoryInfoImpl::GetDeviceId() const { 253 | int id = 0; 254 | ThrowOnError(GetApi().MemoryInfoGetId(this->p_, &id)); 255 | return id; 256 | } 257 | 258 | template 259 | inline OrtMemoryInfoDeviceType MemoryInfoImpl::GetDeviceType() const { 260 | OrtMemoryInfoDeviceType type; 261 | GetApi().MemoryInfoGetDeviceType(this->p_, &type); 262 | return type; 263 | } 264 | 265 | template 266 | inline OrtMemType MemoryInfoImpl::GetMemoryType() const { 267 | OrtMemType type; 268 | ThrowOnError(GetApi().MemoryInfoGetMemType(this->p_, &type)); 269 | return type; 270 | } 271 | 272 | template 273 | template 274 | inline bool MemoryInfoImpl::operator==(const MemoryInfoImpl& o) const { 275 | int comp_result = 0; 276 | ThrowOnError(Ort::GetApi().CompareMemoryInfo(this->p_, o, &comp_result)); 277 | return comp_result == 0; 278 | } 279 | 280 | } // namespace detail 281 | 282 | inline MemoryInfo MemoryInfo::CreateCpu(OrtAllocatorType type, OrtMemType mem_type) { 283 | OrtMemoryInfo* p; 284 | ThrowOnError(GetApi().CreateCpuMemoryInfo(type, mem_type, &p)); 285 | return MemoryInfo(p); 286 | } 287 | 288 | inline MemoryInfo::MemoryInfo(const char* name, OrtAllocatorType type, int id, OrtMemType mem_type) { 289 | ThrowOnError(GetApi().CreateMemoryInfo(name, type, id, mem_type, &this->p_)); 290 | } 291 | 292 | namespace detail { 293 | template 294 | inline std::vector ConstIoBindingImpl::GetOutputNames() const { 295 | AllocatorWithDefaultOptions allocator; 296 | return binding_utils::GetOutputNamesHelper(this->p_, allocator); 297 | } 298 | 299 | template 300 | inline std::vector ConstIoBindingImpl::GetOutputNames(OrtAllocator* allocator) const { 301 | return binding_utils::GetOutputNamesHelper(this->p_, allocator); 302 | } 303 | 304 | template 305 | inline std::vector ConstIoBindingImpl::GetOutputValues() const { 306 | AllocatorWithDefaultOptions allocator; 307 | return binding_utils::GetOutputValuesHelper(this->p_, allocator); 308 | } 309 | 310 | template 311 | inline std::vector ConstIoBindingImpl::GetOutputValues(OrtAllocator* allocator) const { 312 | return binding_utils::GetOutputValuesHelper(this->p_, allocator); 313 | } 314 | 315 | template 316 | inline void IoBindingImpl::BindInput(const char* name, const Value& value) { 317 | ThrowOnError(GetApi().BindInput(this->p_, name, value)); 318 | } 319 | 320 | template 321 | inline void IoBindingImpl::BindOutput(const char* name, const Value& value) { 322 | ThrowOnError(GetApi().BindOutput(this->p_, name, value)); 323 | } 324 | 325 | template 326 | inline void IoBindingImpl::BindOutput(const char* name, const OrtMemoryInfo* mem_info) { 327 | ThrowOnError(GetApi().BindOutputToDevice(this->p_, name, mem_info)); 328 | } 329 | 330 | template 331 | inline void IoBindingImpl::ClearBoundInputs() { 332 | GetApi().ClearBoundInputs(this->p_); 333 | } 334 | 335 | template 336 | inline void IoBindingImpl::ClearBoundOutputs() { 337 | GetApi().ClearBoundOutputs(this->p_); 338 | } 339 | 340 | template 341 | inline void IoBindingImpl::SynchronizeInputs() { 342 | ThrowOnError(GetApi().SynchronizeBoundInputs(this->p_)); 343 | } 344 | 345 | template 346 | inline void IoBindingImpl::SynchronizeOutputs() { 347 | ThrowOnError(GetApi().SynchronizeBoundOutputs(this->p_)); 348 | } 349 | 350 | namespace binding_utils { 351 | inline std::vector GetOutputNamesHelper(const OrtIoBinding* binding, OrtAllocator* allocator) { 352 | std::vector result; 353 | auto free_fn = detail::AllocatedFree(allocator); 354 | using Ptr = std::unique_ptr; 355 | 356 | char* buffer = nullptr; 357 | size_t* lengths = nullptr; 358 | size_t count = 0; 359 | ThrowOnError(GetApi().GetBoundOutputNames(binding, allocator, &buffer, &lengths, &count)); 360 | 361 | if (count == 0) { 362 | return result; 363 | } 364 | 365 | Ptr buffer_g(buffer, free_fn); 366 | Ptr lengths_g(lengths, free_fn); 367 | 368 | result.reserve(count); 369 | for (size_t i = 0; i < count; ++i) { 370 | auto sz = *lengths; 371 | result.emplace_back(buffer, sz); 372 | buffer += sz; 373 | ++lengths; 374 | } 375 | return result; 376 | } 377 | 378 | inline std::vector GetOutputValuesHelper(const OrtIoBinding* binding, OrtAllocator* allocator) { 379 | std::vector result; 380 | size_t owned = 0; 381 | size_t output_count = 0; 382 | // Lambda to release the buffer when no longer needed and 383 | // make sure that we destroy all instances on exception 384 | auto free_fn = [&owned, &output_count, allocator](OrtValue** buffer) { 385 | if (buffer) { 386 | while (owned < output_count) { 387 | auto* p = buffer + owned++; 388 | GetApi().ReleaseValue(*p); 389 | } 390 | allocator->Free(allocator, buffer); 391 | } 392 | }; 393 | using Ptr = std::unique_ptr; 394 | 395 | OrtValue** output_buffer = nullptr; 396 | ThrowOnError(GetApi().GetBoundOutputValues(binding, allocator, &output_buffer, &output_count)); 397 | if (output_count == 0) { 398 | return result; 399 | } 400 | 401 | Ptr buffer_g(output_buffer, free_fn); 402 | 403 | result.reserve(output_count); 404 | for (size_t i = 0; i < output_count; ++i) { 405 | result.emplace_back(output_buffer[i]); 406 | ++owned; 407 | } 408 | return result; 409 | } 410 | 411 | } // namespace binding_utils 412 | } // namespace detail 413 | 414 | inline IoBinding::IoBinding(Session& session) { 415 | ThrowOnError(GetApi().CreateIoBinding(session, &this->p_)); 416 | } 417 | 418 | inline ArenaCfg::ArenaCfg(size_t max_mem, int arena_extend_strategy, int initial_chunk_size_bytes, int max_dead_bytes_per_chunk) { 419 | ThrowOnError(GetApi().CreateArenaCfg(max_mem, arena_extend_strategy, initial_chunk_size_bytes, max_dead_bytes_per_chunk, &p_)); 420 | } 421 | 422 | inline ThreadingOptions::ThreadingOptions() { 423 | ThrowOnError(GetApi().CreateThreadingOptions(&p_)); 424 | } 425 | 426 | inline ThreadingOptions& ThreadingOptions::SetGlobalIntraOpNumThreads(int intra_op_num_threads) { 427 | ThrowOnError(GetApi().SetGlobalIntraOpNumThreads(p_, intra_op_num_threads)); 428 | return *this; 429 | } 430 | 431 | inline ThreadingOptions& ThreadingOptions::SetGlobalInterOpNumThreads(int inter_op_num_threads) { 432 | ThrowOnError(GetApi().SetGlobalInterOpNumThreads(p_, inter_op_num_threads)); 433 | return *this; 434 | } 435 | 436 | inline ThreadingOptions& ThreadingOptions::SetGlobalSpinControl(int allow_spinning) { 437 | ThrowOnError(GetApi().SetGlobalSpinControl(p_, allow_spinning)); 438 | return *this; 439 | } 440 | 441 | inline ThreadingOptions& ThreadingOptions::SetGlobalDenormalAsZero() { 442 | ThrowOnError(GetApi().SetGlobalDenormalAsZero(p_)); 443 | return *this; 444 | } 445 | 446 | inline ThreadingOptions& ThreadingOptions::SetGlobalCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn) { 447 | ThrowOnError(GetApi().SetGlobalCustomCreateThreadFn(p_, ort_custom_create_thread_fn)); 448 | return *this; 449 | } 450 | 451 | inline ThreadingOptions& ThreadingOptions::SetGlobalCustomThreadCreationOptions(void* ort_custom_thread_creation_options) { 452 | ThrowOnError(GetApi().SetGlobalCustomThreadCreationOptions(p_, ort_custom_thread_creation_options)); 453 | return *this; 454 | } 455 | 456 | inline ThreadingOptions& ThreadingOptions::SetGlobalCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn) { 457 | ThrowOnError(GetApi().SetGlobalCustomJoinThreadFn(p_, ort_custom_join_thread_fn)); 458 | return *this; 459 | } 460 | 461 | inline Env::Env(OrtLoggingLevel logging_level, _In_ const char* logid) { 462 | ThrowOnError(GetApi().CreateEnv(logging_level, logid, &p_)); 463 | if (strcmp(logid, "onnxruntime-node") == 0) { 464 | ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_NODEJS)); 465 | } else { 466 | ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_CPLUSPLUS)); 467 | } 468 | } 469 | 470 | inline Env::Env(OrtLoggingLevel logging_level, const char* logid, OrtLoggingFunction logging_function, void* logger_param) { 471 | ThrowOnError(GetApi().CreateEnvWithCustomLogger(logging_function, logger_param, logging_level, logid, &p_)); 472 | if (strcmp(logid, "onnxruntime-node") == 0) { 473 | ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_NODEJS)); 474 | } else { 475 | ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_CPLUSPLUS)); 476 | } 477 | } 478 | 479 | inline Env::Env(const OrtThreadingOptions* tp_options, OrtLoggingLevel logging_level, _In_ const char* logid) { 480 | ThrowOnError(GetApi().CreateEnvWithGlobalThreadPools(logging_level, logid, tp_options, &p_)); 481 | if (strcmp(logid, "onnxruntime-node") == 0) { 482 | ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_NODEJS)); 483 | } else { 484 | ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_CPLUSPLUS)); 485 | } 486 | } 487 | 488 | inline Env::Env(const OrtThreadingOptions* tp_options, OrtLoggingFunction logging_function, void* logger_param, 489 | OrtLoggingLevel logging_level, _In_ const char* logid) { 490 | ThrowOnError(GetApi().CreateEnvWithCustomLoggerAndGlobalThreadPools(logging_function, logger_param, logging_level, logid, tp_options, &p_)); 491 | if (strcmp(logid, "onnxruntime-node") == 0) { 492 | ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_NODEJS)); 493 | } else { 494 | ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_CPLUSPLUS)); 495 | } 496 | } 497 | 498 | inline Env& Env::EnableTelemetryEvents() { 499 | ThrowOnError(GetApi().EnableTelemetryEvents(p_)); 500 | return *this; 501 | } 502 | 503 | inline Env& Env::DisableTelemetryEvents() { 504 | ThrowOnError(GetApi().DisableTelemetryEvents(p_)); 505 | return *this; 506 | } 507 | 508 | inline Env& Env::UpdateEnvWithCustomLogLevel(OrtLoggingLevel log_severity_level) { 509 | ThrowOnError(GetApi().UpdateEnvWithCustomLogLevel(p_, log_severity_level)); 510 | return *this; 511 | } 512 | 513 | inline Env& Env::CreateAndRegisterAllocator(const OrtMemoryInfo* mem_info, const OrtArenaCfg* arena_cfg) { 514 | ThrowOnError(GetApi().CreateAndRegisterAllocator(p_, mem_info, arena_cfg)); 515 | return *this; 516 | } 517 | 518 | inline Env& Env::CreateAndRegisterAllocatorV2(const std::string& provider_type, const OrtMemoryInfo* mem_info, const std::unordered_map& options, const OrtArenaCfg* arena_cfg) { 519 | std::vector keys, values; 520 | auto num_entries = options.size(); 521 | if (num_entries > 0) { 522 | keys.reserve(num_entries); 523 | values.reserve(num_entries); 524 | for (const auto& entry : options) { 525 | keys.push_back(entry.first.c_str()); 526 | values.push_back(entry.second.c_str()); 527 | } 528 | } 529 | ThrowOnError(GetApi().CreateAndRegisterAllocatorV2(p_, provider_type.c_str(), mem_info, arena_cfg, keys.data(), values.data(), num_entries)); 530 | return *this; 531 | } 532 | 533 | inline CustomOpDomain::CustomOpDomain(const char* domain) { 534 | ThrowOnError(GetApi().CreateCustomOpDomain(domain, &p_)); 535 | } 536 | 537 | inline void CustomOpDomain::Add(const OrtCustomOp* op) { 538 | ThrowOnError(GetApi().CustomOpDomain_Add(p_, op)); 539 | } 540 | 541 | inline RunOptions::RunOptions() { 542 | ThrowOnError(GetApi().CreateRunOptions(&p_)); 543 | } 544 | 545 | inline RunOptions& RunOptions::SetRunLogVerbosityLevel(int level) { 546 | ThrowOnError(GetApi().RunOptionsSetRunLogVerbosityLevel(p_, level)); 547 | return *this; 548 | } 549 | 550 | inline RunOptions& RunOptions::SetRunLogSeverityLevel(int level) { 551 | ThrowOnError(GetApi().RunOptionsSetRunLogSeverityLevel(p_, level)); 552 | return *this; 553 | } 554 | 555 | inline int RunOptions::GetRunLogVerbosityLevel() const { 556 | int out; 557 | ThrowOnError(GetApi().RunOptionsGetRunLogVerbosityLevel(p_, &out)); 558 | return out; 559 | } 560 | 561 | inline int RunOptions::GetRunLogSeverityLevel() const { 562 | int out; 563 | ThrowOnError(GetApi().RunOptionsGetRunLogSeverityLevel(p_, &out)); 564 | return out; 565 | } 566 | 567 | inline RunOptions& RunOptions::SetRunTag(const char* run_tag) { 568 | ThrowOnError(GetApi().RunOptionsSetRunTag(p_, run_tag)); 569 | return *this; 570 | } 571 | 572 | inline const char* RunOptions::GetRunTag() const { 573 | const char* out; 574 | ThrowOnError(GetApi().RunOptionsGetRunTag(p_, &out)); 575 | return out; 576 | } 577 | 578 | inline RunOptions& RunOptions::AddConfigEntry(const char* config_key, const char* config_value) { 579 | ThrowOnError(GetApi().AddRunConfigEntry(p_, config_key, config_value)); 580 | return *this; 581 | } 582 | 583 | inline RunOptions& RunOptions::SetTerminate() { 584 | ThrowOnError(GetApi().RunOptionsSetTerminate(p_)); 585 | return *this; 586 | } 587 | 588 | inline RunOptions& RunOptions::UnsetTerminate() { 589 | ThrowOnError(GetApi().RunOptionsUnsetTerminate(p_)); 590 | return *this; 591 | } 592 | 593 | namespace detail { 594 | 595 | template 596 | inline Ort::SessionOptions ConstSessionOptionsImpl::Clone() const { 597 | OrtSessionOptions* out; 598 | ThrowOnError(GetApi().CloneSessionOptions(this->p_, &out)); 599 | return SessionOptions{out}; 600 | } 601 | 602 | template 603 | inline std::string ConstSessionOptionsImpl::GetConfigEntry(const char* config_key) const { 604 | size_t size = 0; 605 | // Feed nullptr for the data buffer to query the true size of the string value 606 | Ort::ThrowOnError(GetApi().GetSessionConfigEntry(this->p_, config_key, nullptr, &size)); 607 | 608 | std::string out; 609 | out.resize(size); 610 | Ort::ThrowOnError(GetApi().GetSessionConfigEntry(this->p_, config_key, &out[0], &size)); 611 | out.resize(size - 1); // remove the terminating character '\0' 612 | 613 | return out; 614 | } 615 | 616 | template 617 | inline bool ConstSessionOptionsImpl::HasConfigEntry(const char* config_key) const { 618 | int out = 0; 619 | Ort::ThrowOnError(GetApi().HasSessionConfigEntry(this->p_, config_key, &out)); 620 | return static_cast(out); 621 | } 622 | 623 | template 624 | inline std::string ConstSessionOptionsImpl::GetConfigEntryOrDefault(const char* config_key, const std::string& def) { 625 | if (!this->HasConfigEntry(config_key)) { 626 | return def; 627 | } 628 | 629 | return this->GetConfigEntry(config_key); 630 | } 631 | 632 | template 633 | inline SessionOptionsImpl& SessionOptionsImpl::SetIntraOpNumThreads(int intra_op_num_threads) { 634 | ThrowOnError(GetApi().SetIntraOpNumThreads(this->p_, intra_op_num_threads)); 635 | return *this; 636 | } 637 | 638 | template 639 | inline SessionOptionsImpl& SessionOptionsImpl::SetInterOpNumThreads(int inter_op_num_threads) { 640 | ThrowOnError(GetApi().SetInterOpNumThreads(this->p_, inter_op_num_threads)); 641 | return *this; 642 | } 643 | 644 | template 645 | inline SessionOptionsImpl& SessionOptionsImpl::SetGraphOptimizationLevel(GraphOptimizationLevel graph_optimization_level) { 646 | ThrowOnError(GetApi().SetSessionGraphOptimizationLevel(this->p_, graph_optimization_level)); 647 | return *this; 648 | } 649 | 650 | template 651 | inline SessionOptionsImpl& SessionOptionsImpl::SetOptimizedModelFilePath(const ORTCHAR_T* optimized_model_filepath) { 652 | ThrowOnError(GetApi().SetOptimizedModelFilePath(this->p_, optimized_model_filepath)); 653 | return *this; 654 | } 655 | 656 | template 657 | inline SessionOptionsImpl& SessionOptionsImpl::EnableProfiling(const ORTCHAR_T* profile_file_prefix) { 658 | ThrowOnError(GetApi().EnableProfiling(this->p_, profile_file_prefix)); 659 | return *this; 660 | } 661 | 662 | template 663 | inline SessionOptionsImpl& SessionOptionsImpl::DisableProfiling() { 664 | ThrowOnError(GetApi().DisableProfiling(this->p_)); 665 | return *this; 666 | } 667 | 668 | template 669 | inline SessionOptionsImpl& SessionOptionsImpl::EnableOrtCustomOps() { 670 | ThrowOnError(GetApi().EnableOrtCustomOps(this->p_)); 671 | return *this; 672 | } 673 | 674 | template 675 | inline SessionOptionsImpl& SessionOptionsImpl::EnableMemPattern() { 676 | ThrowOnError(GetApi().EnableMemPattern(this->p_)); 677 | return *this; 678 | } 679 | 680 | template 681 | inline SessionOptionsImpl& SessionOptionsImpl::DisableMemPattern() { 682 | ThrowOnError(GetApi().DisableMemPattern(this->p_)); 683 | return *this; 684 | } 685 | 686 | template 687 | inline SessionOptionsImpl& SessionOptionsImpl::EnableCpuMemArena() { 688 | ThrowOnError(GetApi().EnableCpuMemArena(this->p_)); 689 | return *this; 690 | } 691 | 692 | template 693 | inline SessionOptionsImpl& SessionOptionsImpl::DisableCpuMemArena() { 694 | ThrowOnError(GetApi().DisableCpuMemArena(this->p_)); 695 | return *this; 696 | } 697 | 698 | template 699 | inline SessionOptionsImpl& SessionOptionsImpl::SetExecutionMode(ExecutionMode execution_mode) { 700 | ThrowOnError(GetApi().SetSessionExecutionMode(this->p_, execution_mode)); 701 | return *this; 702 | } 703 | 704 | template 705 | inline SessionOptionsImpl& SessionOptionsImpl::SetLogId(const char* logid) { 706 | ThrowOnError(GetApi().SetSessionLogId(this->p_, logid)); 707 | return *this; 708 | } 709 | 710 | template 711 | inline SessionOptionsImpl& SessionOptionsImpl::SetLogSeverityLevel(int level) { 712 | ThrowOnError(GetApi().SetSessionLogSeverityLevel(this->p_, level)); 713 | return *this; 714 | } 715 | 716 | template 717 | inline SessionOptionsImpl& SessionOptionsImpl::Add(OrtCustomOpDomain* custom_op_domain) { 718 | ThrowOnError(GetApi().AddCustomOpDomain(this->p_, custom_op_domain)); 719 | return *this; 720 | } 721 | 722 | template 723 | inline SessionOptionsImpl& SessionOptionsImpl::AddConfigEntry(const char* config_key, const char* config_value) { 724 | ThrowOnError(GetApi().AddSessionConfigEntry(this->p_, config_key, config_value)); 725 | return *this; 726 | } 727 | 728 | template 729 | inline SessionOptionsImpl& SessionOptionsImpl::AddInitializer(const char* name, const OrtValue* ort_val) { 730 | ThrowOnError(GetApi().AddInitializer(this->p_, name, ort_val)); 731 | return *this; 732 | } 733 | 734 | template 735 | inline SessionOptionsImpl& SessionOptionsImpl::DisablePerSessionThreads() { 736 | ThrowOnError(GetApi().DisablePerSessionThreads(this->p_)); 737 | return *this; 738 | } 739 | 740 | template 741 | inline SessionOptionsImpl& SessionOptionsImpl::AddExternalInitializers(const std::vector& names, 742 | const std::vector& ort_values) { 743 | const size_t inputs_num = names.size(); 744 | if (inputs_num != ort_values.size()) { 745 | ORT_CXX_API_THROW("Expecting names and ort_values to have the same length", ORT_INVALID_ARGUMENT); 746 | } 747 | std::vector names_ptr; 748 | std::vector ort_values_ptrs; 749 | names_ptr.reserve(inputs_num); 750 | ort_values_ptrs.reserve(inputs_num); 751 | for (size_t i = 0; i < inputs_num; ++i) { 752 | names_ptr.push_back(names[i].c_str()); 753 | ort_values_ptrs.push_back(ort_values[i]); 754 | } 755 | ThrowOnError(GetApi().AddExternalInitializers(this->p_, names_ptr.data(), ort_values_ptrs.data(), inputs_num)); 756 | return *this; 757 | } 758 | 759 | template 760 | inline SessionOptionsImpl& SessionOptionsImpl::AppendExecutionProvider_CUDA(const OrtCUDAProviderOptions& provider_options) { 761 | ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_CUDA(this->p_, &provider_options)); 762 | return *this; 763 | } 764 | 765 | template 766 | inline SessionOptionsImpl& SessionOptionsImpl::AppendExecutionProvider_CUDA_V2(const OrtCUDAProviderOptionsV2& provider_options) { 767 | ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_CUDA_V2(this->p_, &provider_options)); 768 | return *this; 769 | } 770 | 771 | template 772 | inline SessionOptionsImpl& SessionOptionsImpl::AppendExecutionProvider_ROCM(const OrtROCMProviderOptions& provider_options) { 773 | ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_ROCM(this->p_, &provider_options)); 774 | return *this; 775 | } 776 | 777 | template 778 | inline SessionOptionsImpl& SessionOptionsImpl::AppendExecutionProvider_TensorRT(const OrtTensorRTProviderOptions& provider_options) { 779 | ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_TensorRT(this->p_, &provider_options)); 780 | return *this; 781 | } 782 | 783 | template 784 | inline SessionOptionsImpl& SessionOptionsImpl::AppendExecutionProvider_TensorRT_V2(const OrtTensorRTProviderOptionsV2& provider_options) { 785 | ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_TensorRT_V2(this->p_, &provider_options)); 786 | return *this; 787 | } 788 | 789 | template 790 | inline SessionOptionsImpl& SessionOptionsImpl::AppendExecutionProvider_MIGraphX(const OrtMIGraphXProviderOptions& provider_options) { 791 | ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_MIGraphX(this->p_, &provider_options)); 792 | return *this; 793 | } 794 | 795 | template 796 | inline SessionOptionsImpl& SessionOptionsImpl::AppendExecutionProvider_CANN(const OrtCANNProviderOptions& provider_options) { 797 | ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_CANN(this->p_, &provider_options)); 798 | return *this; 799 | } 800 | 801 | template 802 | inline SessionOptionsImpl& SessionOptionsImpl::AppendExecutionProvider_Dnnl(const OrtDnnlProviderOptions& provider_options) { 803 | ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_Dnnl(this->p_, &provider_options)); 804 | return *this; 805 | } 806 | 807 | template 808 | inline SessionOptionsImpl& SessionOptionsImpl::AppendExecutionProvider( 809 | const std::string& provider_name, 810 | const std::unordered_map& provider_options) { 811 | auto num_entries = provider_options.size(); 812 | std::vector keys, values; 813 | if (num_entries > 0) { 814 | keys.reserve(num_entries); 815 | values.reserve(num_entries); 816 | 817 | for (const auto& entry : provider_options) { 818 | keys.push_back(entry.first.c_str()); 819 | values.push_back(entry.second.c_str()); 820 | } 821 | } 822 | 823 | ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider(this->p_, provider_name.c_str(), 824 | keys.data(), values.data(), num_entries)); 825 | 826 | return *this; 827 | } 828 | 829 | template 830 | inline SessionOptionsImpl& SessionOptionsImpl::SetCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn) { 831 | ThrowOnError(GetApi().SessionOptionsSetCustomCreateThreadFn(this->p_, ort_custom_create_thread_fn)); 832 | return *this; 833 | } 834 | 835 | template 836 | inline SessionOptionsImpl& SessionOptionsImpl::SetCustomThreadCreationOptions(void* ort_custom_thread_creation_options) { 837 | ThrowOnError(GetApi().SessionOptionsSetCustomThreadCreationOptions(this->p_, ort_custom_thread_creation_options)); 838 | return *this; 839 | } 840 | 841 | template 842 | inline SessionOptionsImpl& SessionOptionsImpl::SetCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn) { 843 | ThrowOnError(GetApi().SessionOptionsSetCustomJoinThreadFn(this->p_, ort_custom_join_thread_fn)); 844 | return *this; 845 | } 846 | 847 | template 848 | inline SessionOptionsImpl& SessionOptionsImpl::AppendExecutionProvider_OpenVINO(const OrtOpenVINOProviderOptions& provider_options) { 849 | ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_OpenVINO(this->p_, &provider_options)); 850 | return *this; 851 | } 852 | 853 | template 854 | inline SessionOptionsImpl& SessionOptionsImpl::RegisterCustomOpsLibrary(const ORTCHAR_T* library_name, 855 | const CustomOpConfigs& custom_op_configs) { 856 | // Add custom op config entries before registering the custom op library. Otherwise, the config entries _may_ be ignored by 857 | // the custom op library. 858 | for (const auto& config_iter : custom_op_configs.GetFlattenedConfigs()) { 859 | AddConfigEntry(config_iter.first.c_str(), config_iter.second.c_str()); 860 | } 861 | 862 | ThrowOnError(GetApi().RegisterCustomOpsLibrary_V2(this->p_, library_name)); 863 | return *this; 864 | } 865 | 866 | template 867 | inline SessionOptionsImpl& SessionOptionsImpl::RegisterCustomOpsUsingFunction(const char* registration_function_name) { 868 | ThrowOnError(GetApi().RegisterCustomOpsUsingFunction(this->p_, registration_function_name)); 869 | return *this; 870 | } 871 | 872 | /// Session 873 | template 874 | inline size_t ConstSessionImpl::GetInputCount() const { 875 | size_t out; 876 | ThrowOnError(GetApi().SessionGetInputCount(this->p_, &out)); 877 | return out; 878 | } 879 | 880 | template 881 | inline size_t ConstSessionImpl::GetOutputCount() const { 882 | size_t out; 883 | ThrowOnError(GetApi().SessionGetOutputCount(this->p_, &out)); 884 | return out; 885 | } 886 | 887 | template 888 | inline size_t ConstSessionImpl::GetOverridableInitializerCount() const { 889 | size_t out; 890 | ThrowOnError(GetApi().SessionGetOverridableInitializerCount(this->p_, &out)); 891 | return out; 892 | } 893 | 894 | template 895 | inline AllocatedStringPtr ConstSessionImpl::GetInputNameAllocated(size_t index, OrtAllocator* allocator) const { 896 | char* out; 897 | ThrowOnError(GetApi().SessionGetInputName(this->p_, index, allocator, &out)); 898 | return AllocatedStringPtr(out, detail::AllocatedFree(allocator)); 899 | } 900 | 901 | template 902 | inline AllocatedStringPtr ConstSessionImpl::GetOutputNameAllocated(size_t index, OrtAllocator* allocator) const { 903 | char* out; 904 | ThrowOnError(GetApi().SessionGetOutputName(this->p_, index, allocator, &out)); 905 | return AllocatedStringPtr(out, detail::AllocatedFree(allocator)); 906 | } 907 | 908 | template 909 | inline AllocatedStringPtr ConstSessionImpl::GetOverridableInitializerNameAllocated(size_t index, OrtAllocator* allocator) const { 910 | char* out; 911 | ThrowOnError(GetApi().SessionGetOverridableInitializerName(this->p_, index, allocator, &out)); 912 | return AllocatedStringPtr(out, detail::AllocatedFree(allocator)); 913 | } 914 | 915 | template 916 | inline uint64_t ConstSessionImpl::GetProfilingStartTimeNs() const { 917 | uint64_t out; 918 | ThrowOnError(GetApi().SessionGetProfilingStartTimeNs(this->p_, &out)); 919 | return out; 920 | } 921 | 922 | template 923 | inline ModelMetadata ConstSessionImpl::GetModelMetadata() const { 924 | OrtModelMetadata* out; 925 | ThrowOnError(GetApi().SessionGetModelMetadata(this->p_, &out)); 926 | return ModelMetadata{out}; 927 | } 928 | 929 | template 930 | inline TypeInfo ConstSessionImpl::GetInputTypeInfo(size_t index) const { 931 | OrtTypeInfo* out; 932 | ThrowOnError(GetApi().SessionGetInputTypeInfo(this->p_, index, &out)); 933 | return TypeInfo{out}; 934 | } 935 | 936 | template 937 | inline TypeInfo ConstSessionImpl::GetOutputTypeInfo(size_t index) const { 938 | OrtTypeInfo* out; 939 | ThrowOnError(GetApi().SessionGetOutputTypeInfo(this->p_, index, &out)); 940 | return TypeInfo{out}; 941 | } 942 | 943 | template 944 | inline TypeInfo ConstSessionImpl::GetOverridableInitializerTypeInfo(size_t index) const { 945 | OrtTypeInfo* out; 946 | ThrowOnError(GetApi().SessionGetOverridableInitializerTypeInfo(this->p_, index, &out)); 947 | return TypeInfo{out}; 948 | } 949 | 950 | template 951 | inline std::vector SessionImpl::Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count, 952 | const char* const* output_names, size_t output_count) { 953 | std::vector output_values; 954 | output_values.reserve(output_count); 955 | for (size_t i = 0; i < output_count; i++) 956 | output_values.emplace_back(nullptr); 957 | Run(run_options, input_names, input_values, input_count, output_names, output_values.data(), output_count); 958 | return output_values; 959 | } 960 | 961 | template 962 | inline void SessionImpl::Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count, 963 | const char* const* output_names, Value* output_values, size_t output_count) { 964 | static_assert(sizeof(Value) == sizeof(OrtValue*), "Value is really just an array of OrtValue* in memory, so we can reinterpret_cast safely"); 965 | auto ort_input_values = reinterpret_cast(input_values); 966 | auto ort_output_values = reinterpret_cast(output_values); 967 | ThrowOnError(GetApi().Run(this->p_, run_options, input_names, ort_input_values, input_count, output_names, output_count, ort_output_values)); 968 | } 969 | 970 | template 971 | inline void SessionImpl::Run(const RunOptions& run_options, const IoBinding& io_binding) { 972 | ThrowOnError(GetApi().RunWithBinding(this->p_, run_options, io_binding)); 973 | } 974 | 975 | template 976 | inline void SessionImpl::RunAsync(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count, 977 | const char* const* output_names, Value* output_values, size_t output_count, RunAsyncCallbackFn callback, void* user_data) { 978 | auto ort_input_values = reinterpret_cast(input_values); 979 | auto ort_output_values = reinterpret_cast(output_values); 980 | ThrowOnError(GetApi().RunAsync(this->p_, run_options, input_names, 981 | ort_input_values, input_count, output_names, output_count, 982 | ort_output_values, callback, user_data)); 983 | } 984 | 985 | template 986 | inline AllocatedStringPtr SessionImpl::EndProfilingAllocated(OrtAllocator* allocator) { 987 | char* out = nullptr; 988 | ThrowOnError(GetApi().SessionEndProfiling(this->p_, allocator, &out)); 989 | return AllocatedStringPtr(out, detail::AllocatedFree(allocator)); 990 | } 991 | 992 | } // namespace detail 993 | 994 | inline SessionOptions::SessionOptions() { 995 | ThrowOnError(GetApi().CreateSessionOptions(&this->p_)); 996 | } 997 | 998 | /// CustomOpConfigs 999 | inline std::string detail::MakeCustomOpConfigEntryKey(const char* custom_op_name, const char* config) { 1000 | std::string config_key = "custom_op."; 1001 | 1002 | config_key += custom_op_name; 1003 | config_key += "."; 1004 | config_key += config; 1005 | 1006 | return config_key; 1007 | } 1008 | 1009 | inline CustomOpConfigs& CustomOpConfigs::AddConfig(const char* custom_op_name, const char* config_key, const char* config_value) { 1010 | const std::string full_flat_key = detail::MakeCustomOpConfigEntryKey(custom_op_name, config_key); 1011 | flat_configs_[full_flat_key] = config_value; 1012 | return *this; 1013 | } 1014 | 1015 | inline const std::unordered_map& CustomOpConfigs::GetFlattenedConfigs() const { 1016 | return flat_configs_; 1017 | } 1018 | 1019 | inline Session::Session(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options) { 1020 | ThrowOnError(GetApi().CreateSession(env, model_path, options, &this->p_)); 1021 | } 1022 | 1023 | inline Session::Session(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options, 1024 | OrtPrepackedWeightsContainer* prepacked_weights_container) { 1025 | ThrowOnError(GetApi().CreateSessionWithPrepackedWeightsContainer(env, model_path, options, prepacked_weights_container, &this->p_)); 1026 | } 1027 | 1028 | inline Session::Session(const Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options) { 1029 | ThrowOnError(GetApi().CreateSessionFromArray(env, model_data, model_data_length, options, &this->p_)); 1030 | } 1031 | 1032 | inline Session::Session(const Env& env, const void* model_data, size_t model_data_length, 1033 | const SessionOptions& options, OrtPrepackedWeightsContainer* prepacked_weights_container) { 1034 | ThrowOnError(GetApi().CreateSessionFromArrayWithPrepackedWeightsContainer(env, model_data, model_data_length, options, 1035 | prepacked_weights_container, &this->p_)); 1036 | } 1037 | 1038 | inline AllocatedStringPtr ModelMetadata::GetProducerNameAllocated(OrtAllocator* allocator) const { 1039 | char* out; 1040 | ThrowOnError(GetApi().ModelMetadataGetProducerName(p_, allocator, &out)); 1041 | return AllocatedStringPtr(out, detail::AllocatedFree(allocator)); 1042 | } 1043 | 1044 | inline AllocatedStringPtr ModelMetadata::GetGraphNameAllocated(OrtAllocator* allocator) const { 1045 | char* out; 1046 | ThrowOnError(GetApi().ModelMetadataGetGraphName(p_, allocator, &out)); 1047 | return AllocatedStringPtr(out, detail::AllocatedFree(allocator)); 1048 | } 1049 | 1050 | inline AllocatedStringPtr ModelMetadata::GetDomainAllocated(OrtAllocator* allocator) const { 1051 | char* out; 1052 | ThrowOnError(GetApi().ModelMetadataGetDomain(p_, allocator, &out)); 1053 | return AllocatedStringPtr(out, detail::AllocatedFree(allocator)); 1054 | } 1055 | 1056 | inline AllocatedStringPtr Ort::ModelMetadata::GetDescriptionAllocated(OrtAllocator* allocator) const { 1057 | char* out; 1058 | ThrowOnError(GetApi().ModelMetadataGetDescription(p_, allocator, &out)); 1059 | return AllocatedStringPtr(out, detail::AllocatedFree(allocator)); 1060 | } 1061 | 1062 | inline AllocatedStringPtr ModelMetadata::GetGraphDescriptionAllocated(OrtAllocator* allocator) const { 1063 | char* out; 1064 | ThrowOnError(GetApi().ModelMetadataGetGraphDescription(p_, allocator, &out)); 1065 | return AllocatedStringPtr(out, detail::AllocatedFree(allocator)); 1066 | } 1067 | 1068 | inline AllocatedStringPtr ModelMetadata::LookupCustomMetadataMapAllocated(const char* key, OrtAllocator* allocator) const { 1069 | char* out; 1070 | ThrowOnError(GetApi().ModelMetadataLookupCustomMetadataMap(p_, allocator, key, &out)); 1071 | return AllocatedStringPtr(out, detail::AllocatedFree(allocator)); 1072 | } 1073 | 1074 | inline std::vector ModelMetadata::GetCustomMetadataMapKeysAllocated(OrtAllocator* allocator) const { 1075 | auto deletor = detail::AllocatedFree(allocator); 1076 | std::vector result; 1077 | 1078 | char** out = nullptr; 1079 | int64_t num_keys = 0; 1080 | ThrowOnError(GetApi().ModelMetadataGetCustomMetadataMapKeys(p_, allocator, &out, &num_keys)); 1081 | if (num_keys <= 0) { 1082 | return result; 1083 | } 1084 | 1085 | // array of pointers will be freed 1086 | std::unique_ptr array_guard(out, deletor); 1087 | // reserve may throw 1088 | auto strings_deletor = [&deletor, num_keys](char** out) { for(int64_t i = 0; i < num_keys; ++i) deletor(out[i]); }; 1089 | std::unique_ptr strings_guard(out, strings_deletor); 1090 | result.reserve(static_cast(num_keys)); 1091 | strings_guard.release(); 1092 | for (int64_t i = 0; i < num_keys; ++i) { 1093 | result.push_back(AllocatedStringPtr(out[i], deletor)); 1094 | } 1095 | 1096 | return result; 1097 | } 1098 | 1099 | inline int64_t ModelMetadata::GetVersion() const { 1100 | int64_t out; 1101 | ThrowOnError(GetApi().ModelMetadataGetVersion(p_, &out)); 1102 | return out; 1103 | } 1104 | 1105 | namespace detail { 1106 | 1107 | template 1108 | inline ONNXTensorElementDataType TensorTypeAndShapeInfoImpl::GetElementType() const { 1109 | ONNXTensorElementDataType out; 1110 | ThrowOnError(GetApi().GetTensorElementType(this->p_, &out)); 1111 | return out; 1112 | } 1113 | 1114 | template 1115 | inline size_t TensorTypeAndShapeInfoImpl::GetElementCount() const { 1116 | size_t out; 1117 | ThrowOnError(GetApi().GetTensorShapeElementCount(this->p_, &out)); 1118 | return static_cast(out); 1119 | } 1120 | 1121 | template 1122 | inline size_t TensorTypeAndShapeInfoImpl::GetDimensionsCount() const { 1123 | size_t out; 1124 | ThrowOnError(GetApi().GetDimensionsCount(this->p_, &out)); 1125 | return out; 1126 | } 1127 | 1128 | template 1129 | inline void TensorTypeAndShapeInfoImpl::GetDimensions(int64_t* values, size_t values_count) const { 1130 | ThrowOnError(GetApi().GetDimensions(this->p_, values, values_count)); 1131 | } 1132 | 1133 | template 1134 | inline void TensorTypeAndShapeInfoImpl::GetSymbolicDimensions(const char** values, size_t values_count) const { 1135 | ThrowOnError(GetApi().GetSymbolicDimensions(this->p_, values, values_count)); 1136 | } 1137 | 1138 | template 1139 | inline std::vector TensorTypeAndShapeInfoImpl::GetShape() const { 1140 | std::vector out(GetDimensionsCount(), 0); 1141 | ThrowOnError(GetApi().GetDimensions(this->p_, out.data(), out.size())); 1142 | return out; 1143 | } 1144 | 1145 | template 1146 | inline ConstTensorTypeAndShapeInfo TypeInfoImpl::GetTensorTypeAndShapeInfo() const { 1147 | const OrtTensorTypeAndShapeInfo* out; 1148 | ThrowOnError(GetApi().CastTypeInfoToTensorInfo(this->p_, &out)); 1149 | return ConstTensorTypeAndShapeInfo{out}; 1150 | } 1151 | 1152 | template 1153 | inline ConstSequenceTypeInfo TypeInfoImpl::GetSequenceTypeInfo() const { 1154 | const OrtSequenceTypeInfo* out; 1155 | ThrowOnError(GetApi().CastTypeInfoToSequenceTypeInfo(this->p_, &out)); 1156 | return ConstSequenceTypeInfo{out}; 1157 | } 1158 | 1159 | template 1160 | inline ConstMapTypeInfo TypeInfoImpl::GetMapTypeInfo() const { 1161 | const OrtMapTypeInfo* out; 1162 | ThrowOnError(GetApi().CastTypeInfoToMapTypeInfo(this->p_, &out)); 1163 | return ConstMapTypeInfo{out}; 1164 | } 1165 | 1166 | template 1167 | inline ONNXType TypeInfoImpl::GetONNXType() const { 1168 | ONNXType out; 1169 | ThrowOnError(GetApi().GetOnnxTypeFromTypeInfo(this->p_, &out)); 1170 | return out; 1171 | } 1172 | 1173 | template 1174 | inline TypeInfo SequenceTypeInfoImpl::GetSequenceElementType() const { 1175 | OrtTypeInfo* output; 1176 | ThrowOnError(GetApi().GetSequenceElementType(this->p_, &output)); 1177 | return TypeInfo{output}; 1178 | } 1179 | 1180 | template 1181 | inline TypeInfo OptionalTypeInfoImpl::GetOptionalElementType() const { 1182 | OrtTypeInfo* info; 1183 | ThrowOnError(GetApi().GetOptionalContainedTypeInfo(this->p_, &info)); 1184 | return TypeInfo{info}; 1185 | } 1186 | 1187 | template 1188 | inline ONNXTensorElementDataType MapTypeInfoImpl::GetMapKeyType() const { 1189 | ONNXTensorElementDataType out; 1190 | ThrowOnError(GetApi().GetMapKeyType(this->p_, &out)); 1191 | return out; 1192 | } 1193 | 1194 | template 1195 | inline TypeInfo MapTypeInfoImpl::GetMapValueType() const { 1196 | OrtTypeInfo* output; 1197 | ThrowOnError(GetApi().GetMapValueType(this->p_, &output)); 1198 | return TypeInfo{output}; 1199 | } 1200 | 1201 | template 1202 | inline ConstOptionalTypeInfo TypeInfoImpl::GetOptionalTypeInfo() const { 1203 | const OrtOptionalTypeInfo* info; 1204 | ThrowOnError(GetApi().CastTypeInfoToOptionalTypeInfo(this->p_, &info)); 1205 | return ConstOptionalTypeInfo{info}; 1206 | } 1207 | 1208 | } // namespace detail 1209 | 1210 | namespace detail { 1211 | 1212 | template 1213 | template 1214 | inline void ConstValueImpl::GetOpaqueData(const char* domain, const char* type_name, R& out) const { 1215 | ThrowOnError(GetApi().GetOpaqueValue(domain, type_name, this->p_, &out, sizeof(R))); 1216 | } 1217 | 1218 | template 1219 | inline bool ConstValueImpl::IsTensor() const { 1220 | int out; 1221 | ThrowOnError(GetApi().IsTensor(this->p_, &out)); 1222 | return out != 0; 1223 | } 1224 | 1225 | template 1226 | inline bool ConstValueImpl::HasValue() const { 1227 | int out; 1228 | ThrowOnError(GetApi().HasValue(this->p_, &out)); 1229 | return out != 0; 1230 | } 1231 | 1232 | template 1233 | inline size_t ConstValueImpl::GetCount() const { 1234 | size_t out; 1235 | ThrowOnError(GetApi().GetValueCount(this->p_, &out)); 1236 | return out; 1237 | } 1238 | 1239 | template 1240 | inline Value ConstValueImpl::GetValue(int index, OrtAllocator* allocator) const { 1241 | OrtValue* out; 1242 | ThrowOnError(GetApi().GetValue(this->p_, index, allocator, &out)); 1243 | return Value{out}; 1244 | } 1245 | 1246 | template 1247 | inline size_t ConstValueImpl::GetStringTensorDataLength() const { 1248 | size_t out; 1249 | ThrowOnError(GetApi().GetStringTensorDataLength(this->p_, &out)); 1250 | return out; 1251 | } 1252 | 1253 | template 1254 | inline size_t ConstValueImpl::GetStringTensorElementLength(size_t element_index) const { 1255 | size_t out; 1256 | ThrowOnError(GetApi().GetStringTensorElementLength(this->p_, element_index, &out)); 1257 | return out; 1258 | } 1259 | 1260 | template 1261 | template 1262 | inline const R* ConstValueImpl::GetTensorData() const { 1263 | R* out; 1264 | ThrowOnError(GetApi().GetTensorMutableData(const_cast(this->p_), (void**)&out)); 1265 | return out; 1266 | } 1267 | 1268 | template 1269 | inline const void* ConstValueImpl::GetTensorRawData() const { 1270 | void* out; 1271 | ThrowOnError(GetApi().GetTensorMutableData(const_cast(this->p_), &out)); 1272 | return out; 1273 | } 1274 | 1275 | template 1276 | inline TypeInfo ConstValueImpl::GetTypeInfo() const { 1277 | OrtTypeInfo* output; 1278 | ThrowOnError(GetApi().GetTypeInfo(this->p_, &output)); 1279 | return TypeInfo{output}; 1280 | } 1281 | 1282 | template 1283 | inline TensorTypeAndShapeInfo ConstValueImpl::GetTensorTypeAndShapeInfo() const { 1284 | OrtTensorTypeAndShapeInfo* output; 1285 | ThrowOnError(GetApi().GetTensorTypeAndShape(this->p_, &output)); 1286 | return TensorTypeAndShapeInfo{output}; 1287 | } 1288 | 1289 | template 1290 | inline ConstMemoryInfo ConstValueImpl::GetTensorMemoryInfo() const { 1291 | const OrtMemoryInfo* mem_info; 1292 | ThrowOnError(GetApi().GetTensorMemoryInfo(this->p_, &mem_info)); 1293 | return ConstMemoryInfo(mem_info); 1294 | } 1295 | 1296 | template 1297 | inline void ConstValueImpl::GetStringTensorElement(size_t buffer_length, size_t element_index, void* buffer) const { 1298 | ThrowOnError(GetApi().GetStringTensorElement(this->p_, buffer_length, element_index, buffer)); 1299 | } 1300 | 1301 | template 1302 | inline std::string ConstValueImpl::GetStringTensorElement(size_t element_index) const { 1303 | size_t buffer_length; 1304 | ThrowOnError(GetApi().GetStringTensorElementLength(this->p_, element_index, &buffer_length)); 1305 | 1306 | std::string s; 1307 | s.resize(buffer_length); 1308 | ThrowOnError(GetApi().GetStringTensorElement(this->p_, buffer_length, element_index, &s[0])); 1309 | return s; 1310 | } 1311 | 1312 | template 1313 | inline void ConstValueImpl::GetStringTensorContent(void* buffer, size_t buffer_length, size_t* offsets, size_t offsets_count) const { 1314 | ThrowOnError(GetApi().GetStringTensorContent(this->p_, buffer, buffer_length, offsets, offsets_count)); 1315 | } 1316 | 1317 | #if !defined(DISABLE_SPARSE_TENSORS) 1318 | template 1319 | inline OrtSparseFormat ConstValueImpl::GetSparseFormat() const { 1320 | OrtSparseFormat format; 1321 | ThrowOnError(GetApi().GetSparseTensorFormat(this->p_, &format)); 1322 | return format; 1323 | } 1324 | 1325 | template 1326 | inline TensorTypeAndShapeInfo ConstValueImpl::GetSparseTensorValuesTypeAndShapeInfo() const { 1327 | OrtTensorTypeAndShapeInfo* output; 1328 | ThrowOnError(GetApi().GetSparseTensorValuesTypeAndShape(this->p_, &output)); 1329 | return TensorTypeAndShapeInfo{output}; 1330 | } 1331 | 1332 | template 1333 | inline TensorTypeAndShapeInfo ConstValueImpl::GetSparseTensorIndicesTypeShapeInfo(OrtSparseIndicesFormat indices_format) const { 1334 | OrtTensorTypeAndShapeInfo* output; 1335 | ThrowOnError(GetApi().GetSparseTensorIndicesTypeShape(this->p_, indices_format, &output)); 1336 | return TensorTypeAndShapeInfo{output}; 1337 | } 1338 | 1339 | template 1340 | template 1341 | inline const R* ConstValueImpl::GetSparseTensorIndicesData(OrtSparseIndicesFormat indices_format, size_t& num_indices) const { 1342 | const void* out; 1343 | ThrowOnError(GetApi().GetSparseTensorIndices(this->p_, indices_format, &num_indices, &out)); 1344 | return reinterpret_cast(out); 1345 | } 1346 | 1347 | template 1348 | inline bool ConstValueImpl::IsSparseTensor() const { 1349 | int out; 1350 | ThrowOnError(GetApi().IsSparseTensor(this->p_, &out)); 1351 | return out != 0; 1352 | } 1353 | 1354 | template 1355 | template 1356 | inline const R* ConstValueImpl::GetSparseTensorValues() const { 1357 | const void* out; 1358 | ThrowOnError(GetApi().GetSparseTensorValues(this->p_, &out)); 1359 | return reinterpret_cast(out); 1360 | } 1361 | 1362 | #endif 1363 | 1364 | template 1365 | void ValueImpl::FillStringTensor(const char* const* s, size_t s_len) { 1366 | ThrowOnError(GetApi().FillStringTensor(this->p_, s, s_len)); 1367 | } 1368 | 1369 | template 1370 | void ValueImpl::FillStringTensorElement(const char* s, size_t index) { 1371 | ThrowOnError(GetApi().FillStringTensorElement(this->p_, s, index)); 1372 | } 1373 | 1374 | template 1375 | inline char* ValueImpl::GetResizedStringTensorElementBuffer(size_t index, size_t buffer_length) { 1376 | char* result; 1377 | ThrowOnError(GetApi().GetResizedStringTensorElementBuffer(this->p_, index, buffer_length, &result)); 1378 | return result; 1379 | } 1380 | 1381 | template 1382 | void* ValueImpl::GetTensorMutableRawData() { 1383 | void* out; 1384 | ThrowOnError(GetApi().GetTensorMutableData(this->p_, &out)); 1385 | return out; 1386 | } 1387 | 1388 | template 1389 | template 1390 | R* ValueImpl::GetTensorMutableData() { 1391 | R* out; 1392 | ThrowOnError(GetApi().GetTensorMutableData(this->p_, (void**)&out)); 1393 | return out; 1394 | } 1395 | 1396 | template 1397 | template 1398 | R& ValueImpl::At(const std::vector& location) { 1399 | static_assert(!std::is_same::value, "this api does not support std::string"); 1400 | R* out; 1401 | ThrowOnError(GetApi().TensorAt(this->p_, location.data(), location.size(), (void**)&out)); 1402 | return *out; 1403 | } 1404 | 1405 | #if !defined(DISABLE_SPARSE_TENSORS) 1406 | template 1407 | void ValueImpl::UseCooIndices(int64_t* indices_data, size_t indices_num) { 1408 | ThrowOnError(GetApi().UseCooIndices(this->p_, indices_data, indices_num)); 1409 | } 1410 | 1411 | template 1412 | void ValueImpl::UseCsrIndices(int64_t* inner_data, size_t inner_num, int64_t* outer_data, size_t outer_num) { 1413 | ThrowOnError(GetApi().UseCsrIndices(this->p_, inner_data, inner_num, outer_data, outer_num)); 1414 | } 1415 | 1416 | template 1417 | void ValueImpl::UseBlockSparseIndices(const Shape& indices_shape, int32_t* indices_data) { 1418 | ThrowOnError(GetApi().UseBlockSparseIndices(this->p_, indices_shape.shape, indices_shape.shape_len, indices_data)); 1419 | } 1420 | 1421 | template 1422 | void ValueImpl::FillSparseTensorCoo(const OrtMemoryInfo* mem_info, const OrtSparseValuesParam& values_param, 1423 | const int64_t* indices_data, size_t indices_num) { 1424 | ThrowOnError(GetApi().FillSparseTensorCoo(this->p_, mem_info, values_param.values_shape, 1425 | values_param.values_shape_len, values_param.data.p_data, 1426 | indices_data, indices_num)); 1427 | } 1428 | 1429 | template 1430 | void ValueImpl::FillSparseTensorCsr(const OrtMemoryInfo* data_mem_info, 1431 | const OrtSparseValuesParam& values, 1432 | const int64_t* inner_indices_data, size_t inner_indices_num, 1433 | const int64_t* outer_indices_data, size_t outer_indices_num) { 1434 | ThrowOnError(GetApi().FillSparseTensorCsr(this->p_, data_mem_info, values.values_shape, values.values_shape_len, values.data.p_data, 1435 | inner_indices_data, inner_indices_num, 1436 | outer_indices_data, outer_indices_num)); 1437 | } 1438 | 1439 | template 1440 | void ValueImpl::FillSparseTensorBlockSparse(const OrtMemoryInfo* data_mem_info, 1441 | const OrtSparseValuesParam& values, 1442 | const Shape& indices_shape, 1443 | const int32_t* indices_data) { 1444 | ThrowOnError(GetApi().FillSparseTensorBlockSparse(this->p_, data_mem_info, values.values_shape, values.values_shape_len, values.data.p_data, 1445 | indices_shape.shape, indices_shape.shape_len, 1446 | indices_data)); 1447 | } 1448 | 1449 | #endif // !defined(DISABLE_SPARSE_TENSORS) 1450 | 1451 | } // namespace detail 1452 | 1453 | template 1454 | inline Value Value::CreateTensor(const OrtMemoryInfo* info, T* p_data, size_t p_data_element_count, const int64_t* shape, size_t shape_len) { 1455 | return CreateTensor(info, p_data, p_data_element_count * sizeof(T), shape, shape_len, TypeToTensorType::type); 1456 | } 1457 | 1458 | inline Value Value::CreateTensor(const OrtMemoryInfo* info, void* p_data, size_t p_data_byte_count, const int64_t* shape, size_t shape_len, 1459 | ONNXTensorElementDataType type) { 1460 | OrtValue* out; 1461 | ThrowOnError(GetApi().CreateTensorWithDataAsOrtValue(info, p_data, p_data_byte_count, shape, shape_len, type, &out)); 1462 | return Value{out}; 1463 | } 1464 | 1465 | template 1466 | inline Value Value::CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len) { 1467 | return CreateTensor(allocator, shape, shape_len, TypeToTensorType::type); 1468 | } 1469 | 1470 | inline Value Value::CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type) { 1471 | OrtValue* out; 1472 | ThrowOnError(GetApi().CreateTensorAsOrtValue(allocator, shape, shape_len, type, &out)); 1473 | return Value{out}; 1474 | } 1475 | 1476 | #if !defined(DISABLE_SPARSE_TENSORS) 1477 | 1478 | template 1479 | inline Value Value::CreateSparseTensor(const OrtMemoryInfo* info, T* p_data, const Shape& dense_shape, 1480 | const Shape& values_shape) { 1481 | return CreateSparseTensor(info, p_data, dense_shape, values_shape, TypeToTensorType::type); 1482 | } 1483 | 1484 | inline Value Value::CreateSparseTensor(const OrtMemoryInfo* info, void* p_data, const Shape& dense_shape, 1485 | const Shape& values_shape, ONNXTensorElementDataType type) { 1486 | OrtValue* out; 1487 | ThrowOnError(GetApi().CreateSparseTensorWithValuesAsOrtValue(info, p_data, dense_shape.shape, dense_shape.shape_len, 1488 | values_shape.shape, values_shape.shape_len, type, &out)); 1489 | return Value{out}; 1490 | } 1491 | 1492 | template 1493 | inline Value Value::CreateSparseTensor(OrtAllocator* allocator, const Shape& dense_shape) { 1494 | return CreateSparseTensor(allocator, dense_shape, TypeToTensorType::type); 1495 | } 1496 | 1497 | inline Value Value::CreateSparseTensor(OrtAllocator* allocator, const Shape& dense_shape, 1498 | ONNXTensorElementDataType type) { 1499 | OrtValue* out; 1500 | ThrowOnError(GetApi().CreateSparseTensorAsOrtValue(allocator, dense_shape.shape, dense_shape.shape_len, type, &out)); 1501 | return Value{out}; 1502 | } 1503 | #endif // !defined(DISABLE_SPARSE_TENSORS) 1504 | 1505 | inline Value Value::CreateMap(const Value& keys, const Value& values) { 1506 | OrtValue* out; 1507 | const OrtValue* inputs[2] = {keys, values}; 1508 | ThrowOnError(GetApi().CreateValue(inputs, 2, ONNX_TYPE_MAP, &out)); 1509 | return Value{out}; 1510 | } 1511 | 1512 | inline Value Value::CreateSequence(const std::vector& values) { 1513 | OrtValue* out; 1514 | std::vector values_ort{values.data(), values.data() + values.size()}; 1515 | ThrowOnError(GetApi().CreateValue(values_ort.data(), values_ort.size(), ONNX_TYPE_SEQUENCE, &out)); 1516 | return Value{out}; 1517 | } 1518 | 1519 | template 1520 | inline Value Value::CreateOpaque(const char* domain, const char* type_name, const T& data_container) { 1521 | OrtValue* out; 1522 | ThrowOnError(GetApi().CreateOpaqueValue(domain, type_name, &data_container, sizeof(T), &out)); 1523 | return Value{out}; 1524 | } 1525 | 1526 | // 1527 | // Custom OP Inlines 1528 | // 1529 | inline Logger::Logger(const OrtLogger* logger) : logger_(logger) { 1530 | Ort::ThrowOnError(GetApi().Logger_GetLoggingSeverityLevel(this->logger_, &this->cached_severity_level_)); 1531 | } 1532 | 1533 | inline OrtLoggingLevel Logger::GetLoggingSeverityLevel() const noexcept { 1534 | return cached_severity_level_; 1535 | } 1536 | 1537 | inline Status Logger::LogMessage(OrtLoggingLevel log_severity_level, const ORTCHAR_T* file_path, int line_number, 1538 | const char* func_name, const char* message) const noexcept { 1539 | OrtStatus* status = GetApi().Logger_LogMessage(logger_, log_severity_level, message, file_path, line_number, 1540 | func_name); 1541 | return Status{status}; 1542 | } 1543 | 1544 | // Disable warnings about the format string not being a literal (-Wformat-nonliteral and -Wformat-security) 1545 | // for gcc and clang. The alternative is to use actual C-style variadic parameters and apply 1546 | // __attribute__(format(printf...)), which does not work with variadic templates. 1547 | #if defined(__GNUC__) 1548 | #pragma GCC diagnostic push 1549 | #pragma GCC diagnostic ignored "-Wformat-nonliteral" 1550 | #pragma GCC diagnostic ignored "-Wformat-security" 1551 | #elif defined(__clang__) 1552 | #pragma clang diagnostic push 1553 | #pragma clang diagnostic ignored "-Wformat-nonliteral" 1554 | #pragma clang diagnostic ignored "-Wformat-security" 1555 | #endif 1556 | template 1557 | inline Status Logger::LogFormattedMessage(OrtLoggingLevel log_severity_level, const ORTCHAR_T* file_path, 1558 | int line_number, const char* func_name, const char* format, 1559 | Args&&... args) const noexcept { 1560 | int msg_len = std::snprintf(nullptr, 0U, format, std::forward(args)...); 1561 | 1562 | if (msg_len < 0) { // Formatting error 1563 | return Status("Failed to log message due to formatting error", OrtErrorCode::ORT_FAIL); 1564 | } 1565 | 1566 | OrtStatus* status = nullptr; 1567 | const size_t buffer_size = static_cast(msg_len) + 1U; 1568 | 1569 | constexpr size_t kStackBufferSize = 1024; 1570 | 1571 | if (buffer_size < kStackBufferSize) { 1572 | char buffer[kStackBufferSize]; 1573 | snprintf(buffer, kStackBufferSize, format, std::forward(args)...); 1574 | status = GetApi().Logger_LogMessage(logger_, log_severity_level, buffer, file_path, line_number, func_name); 1575 | } else { 1576 | // std::make_unique is only supported starting at C++14. 1577 | #if (__cplusplus >= 201402L) || (_MSC_VER >= 1900) 1578 | auto buffer = std::make_unique(buffer_size); 1579 | #else 1580 | std::unique_ptr buffer(new char[buffer_size]); 1581 | #endif 1582 | std::snprintf(buffer.get(), buffer_size, format, std::forward(args)...); 1583 | status = GetApi().Logger_LogMessage(logger_, log_severity_level, buffer.get(), file_path, line_number, func_name); 1584 | } 1585 | 1586 | return Status{status}; 1587 | } 1588 | // Re-enable -Wformat-nonliteral and -Wformat-security 1589 | #if defined(__GNUC__) 1590 | #pragma GCC diagnostic pop 1591 | #elif defined(__clang__) 1592 | #pragma clang diagnostic pop 1593 | #endif 1594 | 1595 | inline KernelContext::KernelContext(OrtKernelContext* context) : ctx_(context) { 1596 | } 1597 | 1598 | inline size_t KernelContext::GetInputCount() const { 1599 | size_t out = 0; 1600 | Ort::ThrowOnError(GetApi().KernelContext_GetInputCount(ctx_, &out)); 1601 | return out; 1602 | } 1603 | 1604 | inline size_t KernelContext::GetOutputCount() const { 1605 | size_t out = 0; 1606 | Ort::ThrowOnError(GetApi().KernelContext_GetOutputCount(ctx_, &out)); 1607 | return out; 1608 | } 1609 | 1610 | inline ConstValue KernelContext::GetInput(size_t index) const { 1611 | const OrtValue* out = nullptr; 1612 | Ort::ThrowOnError(GetApi().KernelContext_GetInput(ctx_, index, &out)); 1613 | return ConstValue{out}; 1614 | } 1615 | 1616 | inline UnownedValue KernelContext::GetOutput(size_t index, const int64_t* dim_values, size_t dim_count) const { 1617 | OrtValue* out = nullptr; 1618 | Ort::ThrowOnError(GetApi().KernelContext_GetOutput(ctx_, index, dim_values, dim_count, &out)); 1619 | return UnownedValue(out); 1620 | } 1621 | 1622 | inline UnownedValue KernelContext::GetOutput(size_t index, const std::vector& dims) const { 1623 | OrtValue* out = nullptr; 1624 | Ort::ThrowOnError(GetApi().KernelContext_GetOutput(ctx_, index, dims.data(), dims.size(), &out)); 1625 | return UnownedValue(out); 1626 | } 1627 | 1628 | inline void* KernelContext::GetGPUComputeStream() const { 1629 | void* out = nullptr; 1630 | Ort::ThrowOnError(GetApi().KernelContext_GetGPUComputeStream(ctx_, &out)); 1631 | return out; 1632 | } 1633 | 1634 | inline OrtAllocator* KernelContext::GetAllocator(const OrtMemoryInfo& memory_info) const { 1635 | OrtAllocator* out = nullptr; 1636 | Ort::ThrowOnError(GetApi().KernelContext_GetAllocator(ctx_, &memory_info, &out)); 1637 | return out; 1638 | } 1639 | 1640 | inline Logger KernelContext::GetLogger() const { 1641 | const OrtLogger* out = nullptr; 1642 | ThrowOnError(GetApi().KernelContext_GetLogger(this->ctx_, &out)); 1643 | return Logger{out}; 1644 | } 1645 | 1646 | inline OpAttr::OpAttr(const char* name, const void* data, int len, OrtOpAttrType type) { 1647 | Ort::ThrowOnError(GetApi().CreateOpAttr(name, data, len, type, &p_)); 1648 | } 1649 | 1650 | namespace detail { 1651 | template 1652 | inline KernelInfo KernelInfoImpl::Copy() const { 1653 | OrtKernelInfo* info_copy = nullptr; 1654 | Ort::ThrowOnError(GetApi().CopyKernelInfo(this->p_, &info_copy)); 1655 | return KernelInfo{info_copy}; 1656 | } 1657 | 1658 | template 1659 | inline size_t KernelInfoImpl::GetInputCount() const { 1660 | size_t out = 0; 1661 | ThrowOnError(GetApi().KernelInfo_GetInputCount(this->p_, &out)); 1662 | return out; 1663 | } 1664 | 1665 | template 1666 | inline size_t KernelInfoImpl::GetOutputCount() const { 1667 | size_t out = 0; 1668 | ThrowOnError(GetApi().KernelInfo_GetOutputCount(this->p_, &out)); 1669 | return out; 1670 | } 1671 | 1672 | template 1673 | inline std::string KernelInfoImpl::GetInputName(size_t index) const { 1674 | size_t size = 0; 1675 | 1676 | // Feed nullptr for the data buffer to query the true size of the string value 1677 | Ort::ThrowOnError(GetApi().KernelInfo_GetInputName(this->p_, index, nullptr, &size)); 1678 | 1679 | std::string out; 1680 | out.resize(size); 1681 | Ort::ThrowOnError(GetApi().KernelInfo_GetInputName(this->p_, index, &out[0], &size)); 1682 | out.resize(size - 1); // remove the terminating character '\0' 1683 | 1684 | return out; 1685 | } 1686 | 1687 | template 1688 | inline std::string KernelInfoImpl::GetOutputName(size_t index) const { 1689 | size_t size = 0; 1690 | 1691 | // Feed nullptr for the data buffer to query the true size of the string value 1692 | Ort::ThrowOnError(GetApi().KernelInfo_GetOutputName(this->p_, index, nullptr, &size)); 1693 | 1694 | std::string out; 1695 | out.resize(size); 1696 | Ort::ThrowOnError(GetApi().KernelInfo_GetOutputName(this->p_, index, &out[0], &size)); 1697 | out.resize(size - 1); // remove the terminating character '\0' 1698 | 1699 | return out; 1700 | } 1701 | 1702 | template 1703 | inline TypeInfo KernelInfoImpl::GetInputTypeInfo(size_t index) const { 1704 | OrtTypeInfo* out = nullptr; 1705 | ThrowOnError(GetApi().KernelInfo_GetInputTypeInfo(this->p_, index, &out)); 1706 | return TypeInfo{out}; 1707 | } 1708 | 1709 | template 1710 | inline TypeInfo KernelInfoImpl::GetOutputTypeInfo(size_t index) const { 1711 | OrtTypeInfo* out = nullptr; 1712 | ThrowOnError(GetApi().KernelInfo_GetOutputTypeInfo(this->p_, index, &out)); 1713 | return TypeInfo{out}; 1714 | } 1715 | 1716 | template 1717 | inline Value KernelInfoImpl::GetTensorAttribute(const char* name, OrtAllocator* allocator) const { 1718 | OrtValue* out = nullptr; 1719 | ThrowOnError(GetApi().KernelInfoGetAttribute_tensor(this->p_, name, allocator, &out)); 1720 | return Value{out}; 1721 | } 1722 | 1723 | template 1724 | inline ConstValue KernelInfoImpl::GetTensorConstantInput(size_t index, int* is_constant) const { 1725 | const OrtValue* out = nullptr; 1726 | ThrowOnError(GetApi().KernelInfoGetConstantInput_tensor(this->p_, index, is_constant, &out)); 1727 | return ConstValue{out}; 1728 | } 1729 | 1730 | template 1731 | inline std::string KernelInfoImpl::GetNodeName() const { 1732 | size_t size = 0; 1733 | 1734 | // Feed nullptr for the data buffer to query the true size of the string value 1735 | Ort::ThrowOnError(GetApi().KernelInfo_GetNodeName(this->p_, nullptr, &size)); 1736 | 1737 | std::string out; 1738 | out.resize(size); 1739 | Ort::ThrowOnError(GetApi().KernelInfo_GetNodeName(this->p_, &out[0], &size)); 1740 | out.resize(size - 1); // remove the terminating character '\0' 1741 | 1742 | return out; 1743 | } 1744 | 1745 | template 1746 | inline Logger KernelInfoImpl::GetLogger() const { 1747 | const OrtLogger* out = nullptr; 1748 | ThrowOnError(GetApi().KernelInfo_GetLogger(this->p_, &out)); 1749 | return Logger{out}; 1750 | } 1751 | 1752 | inline void attr_utils::GetAttr(const OrtKernelInfo* p, const char* name, float& out) { 1753 | Ort::ThrowOnError(GetApi().KernelInfoGetAttribute_float(p, name, &out)); 1754 | } 1755 | 1756 | inline void attr_utils::GetAttr(const OrtKernelInfo* p, const char* name, int64_t& out) { 1757 | Ort::ThrowOnError(GetApi().KernelInfoGetAttribute_int64(p, name, &out)); 1758 | } 1759 | 1760 | inline void attr_utils::GetAttr(const OrtKernelInfo* p, const char* name, std::string& result) { 1761 | size_t size = 0; 1762 | // Feed nullptr for the data buffer to query the true size of the string attribute 1763 | Ort::ThrowOnError(GetApi().KernelInfoGetAttribute_string(p, name, nullptr, &size)); 1764 | 1765 | std::string out; 1766 | out.resize(size); 1767 | Ort::ThrowOnError(GetApi().KernelInfoGetAttribute_string(p, name, &out[0], &size)); 1768 | out.resize(size - 1); // remove the terminating character '\0' 1769 | out.swap(result); 1770 | } 1771 | 1772 | inline void attr_utils::GetAttrs(const OrtKernelInfo* p, const char* name, std::vector& result) { 1773 | size_t size = 0; 1774 | // Feed nullptr for the data buffer to query the true size of the attribute 1775 | Ort::ThrowOnError(GetApi().KernelInfoGetAttributeArray_float(p, name, nullptr, &size)); 1776 | 1777 | std::vector out; 1778 | out.resize(size); 1779 | Ort::ThrowOnError(GetApi().KernelInfoGetAttributeArray_float(p, name, out.data(), &size)); 1780 | out.swap(result); 1781 | } 1782 | 1783 | inline void attr_utils::GetAttrs(const OrtKernelInfo* p, const char* name, std::vector& result) { 1784 | size_t size = 0; 1785 | 1786 | // Feed nullptr for the data buffer to query the true size of the attribute 1787 | Ort::ThrowOnError(GetApi().KernelInfoGetAttributeArray_int64(p, name, nullptr, &size)); 1788 | 1789 | std::vector out; 1790 | out.resize(size); 1791 | Ort::ThrowOnError(GetApi().KernelInfoGetAttributeArray_int64(p, name, out.data(), &size)); 1792 | out.swap(result); 1793 | } 1794 | } // namespace detail 1795 | 1796 | inline KernelInfo::KernelInfo(OrtKernelInfo* info) : detail::KernelInfoImpl{info} {} 1797 | 1798 | inline Op::Op(OrtOp* p) : Base(p) {} 1799 | 1800 | inline Op Op::Create(const OrtKernelInfo* info, const char* op_name, const char* domain, int version, 1801 | const char** type_constraint_names, 1802 | const ONNXTensorElementDataType* type_constraint_values, 1803 | size_t type_constraint_count, 1804 | const OpAttr* attr_values, size_t attr_count, 1805 | size_t input_count, size_t output_count) { 1806 | static_assert(sizeof(OpAttr) == sizeof(OrtOpAttr*), 1807 | "OpAttr's is expected to be just an array of OrtOpAttr in memory so we can reinterpret safely"); 1808 | auto attr_input_values = reinterpret_cast(attr_values); 1809 | OrtOp* op; 1810 | Ort::ThrowOnError(GetApi().CreateOp(info, op_name, domain, version, type_constraint_names, type_constraint_values, 1811 | static_cast(type_constraint_count), 1812 | attr_input_values, 1813 | static_cast(attr_count), 1814 | static_cast(input_count), 1815 | static_cast(output_count), &op)); 1816 | return Op{op}; 1817 | } 1818 | 1819 | inline void Op::Invoke(const OrtKernelContext* context, 1820 | const Value* input_values, 1821 | size_t input_count, 1822 | Value* output_values, 1823 | size_t output_count) { 1824 | static_assert(sizeof(Value) == sizeof(OrtValue*), 1825 | "Value is really just an array of OrtValue* in memory, so we can reinterpret_cast safely"); 1826 | auto ort_input_values = reinterpret_cast(input_values); 1827 | auto ort_output_values = reinterpret_cast(output_values); 1828 | Ort::ThrowOnError(GetApi().InvokeOp(context, p_, ort_input_values, static_cast(input_count), 1829 | ort_output_values, static_cast(output_count))); 1830 | } 1831 | 1832 | inline void Op::Invoke(const OrtKernelContext* context, 1833 | const OrtValue* const* input_values, 1834 | size_t input_count, 1835 | OrtValue* const* output_values, 1836 | size_t output_count) { 1837 | Ort::ThrowOnError(GetApi().InvokeOp(context, p_, input_values, static_cast(input_count), 1838 | output_values, static_cast(output_count))); 1839 | } 1840 | 1841 | inline std::string GetVersionString() { 1842 | return OrtGetApiBase()->GetVersionString(); 1843 | } 1844 | 1845 | inline std::string GetBuildInfoString() { 1846 | return GetApi().GetBuildInfoString(); 1847 | } 1848 | 1849 | inline std::vector GetAvailableProviders() { 1850 | char** providers; 1851 | int len; 1852 | 1853 | auto release_fn = [&len](char** providers) { 1854 | // This should always return nullptr. 1855 | ThrowOnError(GetApi().ReleaseAvailableProviders(providers, len)); 1856 | }; 1857 | 1858 | ThrowOnError(GetApi().GetAvailableProviders(&providers, &len)); 1859 | std::unique_ptr guard(providers, release_fn); 1860 | std::vector available_providers; 1861 | available_providers.reserve(static_cast(len)); 1862 | for (int i = 0; i < len; ++i) { 1863 | available_providers.emplace_back(providers[i]); 1864 | } 1865 | return available_providers; 1866 | } 1867 | 1868 | template 1869 | void CustomOpBase::GetSessionConfigs(std::unordered_map& out, 1870 | ConstSessionOptions options) const { 1871 | const TOp* derived = static_cast(this); 1872 | std::vector keys = derived->GetSessionConfigKeys(); 1873 | 1874 | out.reserve(keys.size()); 1875 | 1876 | std::string config_entry_key = detail::MakeCustomOpConfigEntryKey(derived->GetName(), ""); 1877 | const size_t prefix_size = config_entry_key.length(); 1878 | 1879 | for (const auto& key : keys) { 1880 | config_entry_key.resize(prefix_size); 1881 | config_entry_key.append(key); 1882 | out[key] = options.GetConfigEntryOrDefault(config_entry_key.c_str(), ""); 1883 | } 1884 | } 1885 | 1886 | } // namespace Ort 1887 | --------------------------------------------------------------------------------