├── extensions ├── include │ ├── OnnxRuntimeExtensions.h │ └── OrtExt.h └── OrtExt.mm ├── objectivec ├── test │ ├── testdata │ │ ├── identity_string.ort │ │ ├── single_add.basic.ort │ │ ├── single_add.onnx │ │ ├── gen_models.sh │ │ └── single_add_gen.py │ ├── assert_arc_enabled.mm │ ├── ort_training_utils_test.mm │ ├── test_utils.h │ ├── ort_env_test.mm │ ├── test_utils.mm │ ├── assertion_utils.h │ ├── ort_value_test.mm │ ├── ort_checkpoint_test.mm │ ├── ort_session_test.mm │ └── ort_training_session_test.mm ├── ReadMe.md ├── docs │ ├── main_page.md │ ├── jazzy_config.yaml │ └── readme.md ├── assert_arc_enabled.mm ├── format_objc.sh ├── ort_env_internal.h ├── ort_checkpoint_internal.h ├── include │ ├── onnxruntime_training.h │ ├── onnxruntime.h │ ├── ort_custom_op_registration.h │ ├── ort_env.h │ ├── ort_xnnpack_execution_provider.h │ ├── ort_enums.h │ ├── ort_coreml_execution_provider.h │ ├── ort_value.h │ ├── ort_checkpoint.h │ ├── ort_training_session.h │ └── ort_session.h ├── ort_training_session_internal.h ├── ort_session_internal.h ├── ort_enums_internal.h ├── ort_xnnpack_execution_provider.mm ├── cxx_utils.h ├── ort_env.mm ├── ort_value_internal.h ├── error_utils.mm ├── error_utils.h ├── cxx_api.h ├── ort_coreml_execution_provider.mm ├── cxx_utils.mm ├── ort_checkpoint.mm ├── ort_enums.mm ├── ort_training_session.mm ├── ort_value.mm └── ort_session.mm ├── swift ├── OnnxRuntimeBindingsTests │ ├── Resources │ │ └── single_add.basic.ort │ └── SwiftOnnxRuntimeBindingsTests.swift └── OnnxRuntimeExtensionsTests │ ├── Resources │ └── decode_image.onnx │ └── SwiftOnnxRuntimeExtensionsTests.swift ├── .config └── tsaoptions.json ├── .pipelines ├── templates │ ├── use-xcode-version.yml │ ├── component-governance-component-detection-steps.yml │ └── main.yml ├── mac-ios-spm-release-validation-pipeline.yml ├── official.yml ├── mac-ios-spm-dev-validation-pipeline.yml └── mac-ios-spm-extensions-validation-pipeline.yml ├── CODE_OF_CONDUCT.md ├── LICENSE ├── SUPPORT.md ├── README.md ├── SECURITY.md ├── release_instructions.md ├── .github └── workflows │ └── codeql.yml ├── Package.swift └── .gitignore /extensions/include/OnnxRuntimeExtensions.h: -------------------------------------------------------------------------------- 1 | // This is the umbrella header 2 | 3 | #include "OrtExt.h" -------------------------------------------------------------------------------- /objectivec/test/testdata/identity_string.ort: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/onnxruntime-swift-package-manager/HEAD/objectivec/test/testdata/identity_string.ort -------------------------------------------------------------------------------- /objectivec/ReadMe.md: -------------------------------------------------------------------------------- 1 | NOTE: Flat directory structure to work with both the Objective-C build and the Swift Package Manager build which is done 2 | via ../Package.swift 3 | -------------------------------------------------------------------------------- /objectivec/test/testdata/single_add.basic.ort: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/onnxruntime-swift-package-manager/HEAD/objectivec/test/testdata/single_add.basic.ort -------------------------------------------------------------------------------- /objectivec/docs/main_page.md: -------------------------------------------------------------------------------- 1 | # ONNX Runtime Objective-C API 2 | 3 | ONNX Runtime provides an Objective-C API. 4 | 5 | It can be used from Objective-C/C++ or Swift with a bridging header. 6 | -------------------------------------------------------------------------------- /swift/OnnxRuntimeBindingsTests/Resources/single_add.basic.ort: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/onnxruntime-swift-package-manager/HEAD/swift/OnnxRuntimeBindingsTests/Resources/single_add.basic.ort -------------------------------------------------------------------------------- /swift/OnnxRuntimeExtensionsTests/Resources/decode_image.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/onnxruntime-swift-package-manager/HEAD/swift/OnnxRuntimeExtensionsTests/Resources/decode_image.onnx -------------------------------------------------------------------------------- /objectivec/test/testdata/single_add.onnx: -------------------------------------------------------------------------------- 1 | :S 2 |  3 | A 4 | BCAdd"Add SingleAddZ 5 | A 6 | 7 |  8 | Z 9 | B 10 | 11 |  12 | b 13 | C 14 | 15 |  16 | B 17 |  -------------------------------------------------------------------------------- /objectivec/assert_arc_enabled.mm: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | 4 | static_assert(__has_feature(objc_arc), "Objective-C ARC must be enabled."); 5 | -------------------------------------------------------------------------------- /objectivec/test/assert_arc_enabled.mm: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | 4 | static_assert(__has_feature(objc_arc), "Objective-C ARC must be enabled."); 5 | -------------------------------------------------------------------------------- /objectivec/format_objc.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # formats Objective-C/C++ code 4 | 5 | set -e 6 | 7 | SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" 8 | 9 | clang-format -i $(find ${SCRIPT_DIR} -name "*.h" -o -name "*.m" -o -name "*.mm") 10 | -------------------------------------------------------------------------------- /objectivec/ort_env_internal.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | 4 | #import "ort_env.h" 5 | 6 | #import "cxx_api.h" 7 | 8 | NS_ASSUME_NONNULL_BEGIN 9 | 10 | @interface ORTEnv () 11 | 12 | - (Ort::Env&)CXXAPIOrtEnv; 13 | 14 | @end 15 | 16 | NS_ASSUME_NONNULL_END 17 | -------------------------------------------------------------------------------- /.config/tsaoptions.json: -------------------------------------------------------------------------------- 1 | { 2 | "notificationAliases": ["onnxruntime-org@microsoft.com"], 3 | "areaPath": "ONNX Runtime\\Engineering Systems", 4 | "codebaseName": "onnxruntime-swift-package-manager", 5 | "instanceUrl": "https://aiinfra.visualstudio.com/", 6 | "projectName": "ONNX Runtime", 7 | "ignoreBranchName": false, 8 | "template": "AIINFRA_TSA" 9 | } 10 | -------------------------------------------------------------------------------- /objectivec/ort_checkpoint_internal.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | 4 | #import "ort_checkpoint.h" 5 | 6 | #import "cxx_api.h" 7 | 8 | NS_ASSUME_NONNULL_BEGIN 9 | 10 | @interface ORTCheckpoint () 11 | 12 | - (Ort::CheckpointState&)CXXAPIOrtCheckpoint; 13 | 14 | @end 15 | 16 | NS_ASSUME_NONNULL_END 17 | -------------------------------------------------------------------------------- /objectivec/test/testdata/gen_models.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | 5 | # Get directory this script is in 6 | DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" 7 | 8 | cd ${DIR} 9 | 10 | python3 ./single_add_gen.py 11 | 12 | ORT_CONVERT_ONNX_MODELS_TO_ORT_OPTIMIZATION_LEVEL=basic python3 -m onnxruntime.tools.convert_onnx_models_to_ort --optimization_style=Fixed . 13 | -------------------------------------------------------------------------------- /objectivec/include/onnxruntime_training.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | 4 | // this header contains the entire ONNX Runtime training Objective-C API 5 | // the headers below can also be imported individually 6 | 7 | #import "onnxruntime.h" 8 | #import "ort_checkpoint.h" 9 | #import "ort_training_session.h" 10 | -------------------------------------------------------------------------------- /extensions/OrtExt.mm: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | 4 | #import "OrtExt.h" 5 | 6 | #include "onnxruntime_extensions/onnxruntime_extensions.h" 7 | 8 | @implementation OrtExt 9 | 10 | + (nonnull ORTCAPIRegisterCustomOpsFnPtr)getRegisterCustomOpsFunctionPointer { 11 | return RegisterCustomOps; 12 | } 13 | 14 | @end -------------------------------------------------------------------------------- /objectivec/ort_training_session_internal.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | 4 | #import "ort_training_session.h" 5 | 6 | #import "cxx_api.h" 7 | 8 | NS_ASSUME_NONNULL_BEGIN 9 | 10 | @interface ORTTrainingSession () 11 | 12 | - (Ort::TrainingSession&)CXXAPIOrtTrainingSession; 13 | 14 | @end 15 | 16 | NS_ASSUME_NONNULL_END 17 | -------------------------------------------------------------------------------- /.pipelines/templates/use-xcode-version.yml: -------------------------------------------------------------------------------- 1 | # Specify use of a specific Xcode version. 2 | 3 | parameters: 4 | - name: xcodeVersion 5 | type: string 6 | default: "14.3.1" 7 | 8 | steps: 9 | - bash: | 10 | set -e -x 11 | XCODE_DEVELOPER_DIR="/Applications/Xcode_${{ parameters.xcodeVersion }}.app/Contents/Developer" 12 | sudo xcode-select --switch "${XCODE_DEVELOPER_DIR}" 13 | 14 | displayName: Use Xcode ${{ parameters.xcodeVersion }} -------------------------------------------------------------------------------- /objectivec/ort_session_internal.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | 4 | #import "ort_session.h" 5 | 6 | #import "cxx_api.h" 7 | 8 | NS_ASSUME_NONNULL_BEGIN 9 | 10 | @interface ORTSessionOptions () 11 | 12 | - (Ort::SessionOptions&)CXXAPIOrtSessionOptions; 13 | 14 | @end 15 | 16 | @interface ORTRunOptions () 17 | 18 | - (Ort::RunOptions&)CXXAPIOrtRunOptions; 19 | 20 | @end 21 | 22 | NS_ASSUME_NONNULL_END 23 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Microsoft Open Source Code of Conduct 2 | 3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 4 | 5 | Resources: 6 | 7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/) 8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) 9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns 10 | -------------------------------------------------------------------------------- /objectivec/include/onnxruntime.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | 4 | // this header contains the entire ONNX Runtime Objective-C API 5 | // the headers below can also be imported individually 6 | 7 | #import "ort_coreml_execution_provider.h" 8 | #import "ort_custom_op_registration.h" 9 | #import "ort_enums.h" 10 | #import "ort_env.h" 11 | #import "ort_session.h" 12 | #import "ort_value.h" 13 | #import "ort_xnnpack_execution_provider.h" 14 | -------------------------------------------------------------------------------- /objectivec/docs/jazzy_config.yaml: -------------------------------------------------------------------------------- 1 | module: ONNX Runtime Objective-C API 2 | author: ONNX Runtime Authors 3 | author_url: https://www.onnxruntime.ai 4 | github_url: https://github.com/microsoft/onnxruntime 5 | 6 | objc: true 7 | # Specify the training header as umbrella_header so that every public header is included. 8 | # The training header is a superset of the inference-only header. 9 | umbrella_header: ../include/onnxruntime_training.h 10 | framework_root: .. 11 | 12 | readme: ./main_page.md 13 | 14 | hide_documentation_coverage: true 15 | undocumented_text: "" 16 | -------------------------------------------------------------------------------- /objectivec/test/ort_training_utils_test.mm: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | 4 | #import 5 | #import "ort_training_session.h" 6 | 7 | NS_ASSUME_NONNULL_BEGIN 8 | 9 | @interface ORTTrainingUtilsTest : XCTestCase 10 | @end 11 | 12 | @implementation ORTTrainingUtilsTest 13 | 14 | - (void)setUp { 15 | [super setUp]; 16 | 17 | self.continueAfterFailure = NO; 18 | } 19 | 20 | - (void)testSetSeed { 21 | ORTSetSeed(2718); 22 | } 23 | 24 | @end 25 | 26 | NS_ASSUME_NONNULL_END 27 | -------------------------------------------------------------------------------- /objectivec/test/test_utils.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | 4 | #import 5 | #import 6 | 7 | #import "ort_session.h" 8 | #import "ort_env.h" 9 | #import "ort_value.h" 10 | 11 | NS_ASSUME_NONNULL_BEGIN 12 | 13 | namespace test_utils { 14 | 15 | NSString* _Nullable createTemporaryDirectory(XCTestCase* testCase); 16 | 17 | NSArray* getFloatArrayFromData(NSData* data); 18 | 19 | } // namespace test_utils 20 | 21 | NS_ASSUME_NONNULL_END 22 | -------------------------------------------------------------------------------- /objectivec/docs/readme.md: -------------------------------------------------------------------------------- 1 | # Objective-C API Documentation 2 | 3 | The API should be documented with comments in the [public header files](../include). 4 | 5 | ## Documentation Generation 6 | 7 | The [Jazzy](https://github.com/realm/jazzy) tool is used to generate documentation from the code. 8 | 9 | To generate documentation, from the repo root, run: 10 | 11 | ```bash 12 | jazzy --config objectivec/docs/jazzy_config.yaml --output 13 | ``` 14 | 15 | The generated documentation website files will be in ``. 16 | 17 | [main_page.md](./main_page.md) contains content for the main page of the generated documentation website. 18 | -------------------------------------------------------------------------------- /objectivec/test/testdata/single_add_gen.py: -------------------------------------------------------------------------------- 1 | import onnx 2 | from onnx import TensorProto, helper 3 | 4 | graph = helper.make_graph( 5 | [ # nodes 6 | helper.make_node("Add", ["A", "B"], ["C"], "Add"), 7 | ], 8 | "SingleAdd", # name 9 | [ # inputs 10 | helper.make_tensor_value_info("A", TensorProto.FLOAT, [1]), 11 | helper.make_tensor_value_info("B", TensorProto.FLOAT, [1]), 12 | ], 13 | [ # outputs 14 | helper.make_tensor_value_info("C", TensorProto.FLOAT, [1]), 15 | ], 16 | ) 17 | 18 | model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 12)]) 19 | onnx.save(model, r"single_add.onnx") 20 | -------------------------------------------------------------------------------- /objectivec/ort_enums_internal.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | 4 | #import "ort_enums.h" 5 | 6 | #import "cxx_api.h" 7 | 8 | OrtLoggingLevel PublicToCAPILoggingLevel(ORTLoggingLevel logging_level); 9 | 10 | ORTValueType CAPIToPublicValueType(ONNXType capi_type); 11 | 12 | ONNXTensorElementDataType PublicToCAPITensorElementType(ORTTensorElementDataType type); 13 | ORTTensorElementDataType CAPIToPublicTensorElementType(ONNXTensorElementDataType capi_type); 14 | 15 | size_t SizeOfCAPITensorElementType(ONNXTensorElementDataType capi_type); 16 | 17 | GraphOptimizationLevel PublicToCAPIGraphOptimizationLevel(ORTGraphOptimizationLevel opt_level); 18 | -------------------------------------------------------------------------------- /objectivec/test/ort_env_test.mm: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | 4 | #import 5 | 6 | #import "ort_env.h" 7 | 8 | #import "test/assertion_utils.h" 9 | 10 | NS_ASSUME_NONNULL_BEGIN 11 | 12 | @interface ORTEnvTest : XCTestCase 13 | @end 14 | 15 | @implementation ORTEnvTest 16 | - (void)testGetOrtVersion { 17 | NSString* ver = ORTVersion(); 18 | XCTAssertNotNil(ver); 19 | } 20 | 21 | - (void)testInitOk { 22 | NSError* err = nil; 23 | ORTEnv* env = [[ORTEnv alloc] initWithLoggingLevel:ORTLoggingLevelWarning 24 | error:&err]; 25 | ORTAssertNullableResultSuccessful(env, err); 26 | } 27 | 28 | @end 29 | 30 | NS_ASSUME_NONNULL_END 31 | -------------------------------------------------------------------------------- /extensions/include/OrtExt.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | 4 | #import 5 | 6 | @interface OrtExt : NSObject 7 | 8 | typedef struct OrtStatus* (*ORTCAPIRegisterCustomOpsFnPtr)(struct OrtSessionOptions* /*options*/, 9 | const struct OrtApiBase* /*api*/); 10 | 11 | // Note: This returns the address of `RegisterCustomOps` function. At swift 12 | // level, user can call this function to get the RegisterCustomOpsFnPtr 13 | // and use the function pointer to register custom ops. See 14 | // SwiftOnnxRuntimeExtensionsTests.swift for an example usage. 15 | + (nonnull ORTCAPIRegisterCustomOpsFnPtr)getRegisterCustomOpsFunctionPointer; 16 | 17 | @end -------------------------------------------------------------------------------- /.pipelines/templates/component-governance-component-detection-steps.yml: -------------------------------------------------------------------------------- 1 | # component detection for component governance checks 2 | parameters: 3 | - name: condition 4 | type: string 5 | default: 'succeeded' # could be 'ci_only', 'always', 'succeeded' 6 | 7 | steps: 8 | - ${{ if eq(variables['System.TeamProject'], 'Lotus') }}: 9 | - task: ms.vss-governance-buildtask.governance-build-task-component-detection.ComponentGovernanceComponentDetection@0 10 | displayName: 'Component Detection' 11 | condition: 12 | or(or(and(eq('${{parameters.condition}}', 'ci_only'), and(succeeded(), in(variables['Build.Reason'], 'IndividualCI', 'BatchedCI', 'Scheduled'))), 13 | and(eq('${{parameters.condition}}', 'always'), always())), 14 | and(eq('${{parameters.condition}}', 'succeeded'), succeeded())) 15 | -------------------------------------------------------------------------------- /.pipelines/mac-ios-spm-release-validation-pipeline.yml: -------------------------------------------------------------------------------- 1 | jobs: 2 | - job: j 3 | displayName: "Test with released ORT native pod" 4 | 5 | pool: 6 | vmImage: "macOS-13" 7 | 8 | timeoutInMinutes: 60 9 | 10 | steps: 11 | - template: templates/use-xcode-version.yml 12 | 13 | # Note: Running xcodebuild test on `onnxruntime-Package` scheme will perform swift tests for both OnnxRuntimeBindings 14 | # and OnnxRuntimeExtensions targets. 15 | - script: | 16 | set -e -x 17 | xcodebuild test -scheme onnxruntime-Package -destination 'platform=iOS Simulator,name=iPhone 14' 18 | xcodebuild test -scheme onnxruntime-Package -destination 'platform=macosx' 19 | workingDirectory: "$(Build.SourcesDirectory)" 20 | displayName: "Test Package.swift usage" 21 | 22 | - template: templates/component-governance-component-detection-steps.yml 23 | parameters: 24 | condition: 'succeeded' -------------------------------------------------------------------------------- /objectivec/ort_xnnpack_execution_provider.mm: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | 4 | #import "ort_xnnpack_execution_provider.h" 5 | 6 | #import "cxx_api.h" 7 | #import "error_utils.h" 8 | #import "ort_session_internal.h" 9 | 10 | NS_ASSUME_NONNULL_BEGIN 11 | 12 | @implementation ORTXnnpackExecutionProviderOptions 13 | 14 | @end 15 | 16 | @implementation ORTSessionOptions (ORTSessionOptionsXnnpackEP) 17 | 18 | - (BOOL)appendXnnpackExecutionProviderWithOptions:(ORTXnnpackExecutionProviderOptions*)options 19 | error:(NSError**)error { 20 | try { 21 | NSDictionary* provider_options = @{ 22 | @"intra_op_num_threads" : [NSString stringWithFormat:@"%d", options.intra_op_num_threads] 23 | }; 24 | return [self appendExecutionProvider:@"XNNPACK" providerOptions:provider_options error:error]; 25 | } 26 | ORT_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error); 27 | } 28 | 29 | @end 30 | 31 | NS_ASSUME_NONNULL_END 32 | -------------------------------------------------------------------------------- /objectivec/cxx_utils.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | 4 | #import 5 | 6 | #include 7 | #include 8 | #include 9 | 10 | #import "cxx_api.h" 11 | 12 | NS_ASSUME_NONNULL_BEGIN 13 | @class ORTValue; 14 | 15 | namespace utils { 16 | 17 | NSString* toNSString(const std::string& str); 18 | NSString* _Nullable toNullableNSString(const std::optional& str); 19 | 20 | std::string toStdString(NSString* str); 21 | std::optional toStdOptionalString(NSString* _Nullable str); 22 | 23 | std::vector toStdStringVector(NSArray* strs); 24 | NSArray* toNSStringNSArray(const std::vector& strs); 25 | 26 | NSArray* _Nullable wrapUnownedCAPIOrtValues(const std::vector& values, NSError** error); 27 | 28 | std::vector getWrappedCAPIOrtValues(NSArray* values); 29 | 30 | } // namespace utils 31 | 32 | NS_ASSUME_NONNULL_END 33 | -------------------------------------------------------------------------------- /objectivec/ort_env.mm: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | 4 | #import "ort_env_internal.h" 5 | 6 | #include 7 | 8 | #import "cxx_api.h" 9 | 10 | #import "error_utils.h" 11 | #import "ort_enums_internal.h" 12 | 13 | NS_ASSUME_NONNULL_BEGIN 14 | 15 | NSString* _Nullable ORTVersion(void) { 16 | return [NSString stringWithUTF8String:OrtGetApiBase()->GetVersionString()]; 17 | } 18 | 19 | @implementation ORTEnv { 20 | std::optional _env; 21 | } 22 | 23 | - (nullable instancetype)initWithLoggingLevel:(ORTLoggingLevel)loggingLevel 24 | error:(NSError**)error { 25 | if ((self = [super init]) == nil) { 26 | return nil; 27 | } 28 | 29 | try { 30 | const auto CAPILoggingLevel = PublicToCAPILoggingLevel(loggingLevel); 31 | _env = Ort::Env{CAPILoggingLevel}; 32 | return self; 33 | } 34 | ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error) 35 | } 36 | 37 | - (Ort::Env&)CXXAPIOrtEnv { 38 | return *_env; 39 | } 40 | 41 | @end 42 | 43 | NS_ASSUME_NONNULL_END 44 | -------------------------------------------------------------------------------- /objectivec/ort_value_internal.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | 4 | #import "ort_value.h" 5 | 6 | #import "cxx_api.h" 7 | 8 | NS_ASSUME_NONNULL_BEGIN 9 | 10 | @interface ORTValue () 11 | 12 | /** 13 | * Creates a value from an existing C++ API Ort::Value and takes ownership from it. 14 | * Note: Ownership is guaranteed to be transferred on success but not otherwise. 15 | * 16 | * @param existingCXXAPIOrtValue The existing C++ API Ort::Value. 17 | * @param externalTensorData Any external tensor data referenced by `existingCXXAPIOrtValue`. 18 | * @param error Optional error information set if an error occurs. 19 | * @return The instance, or nil if an error occurs. 20 | */ 21 | - (nullable instancetype)initWithCXXAPIOrtValue:(Ort::Value&&)existingCXXAPIOrtValue 22 | externalTensorData:(nullable NSMutableData*)externalTensorData 23 | error:(NSError**)error NS_DESIGNATED_INITIALIZER; 24 | 25 | - (Ort::Value&)CXXAPIOrtValue; 26 | 27 | @end 28 | 29 | NS_ASSUME_NONNULL_END 30 | -------------------------------------------------------------------------------- /objectivec/include/ort_custom_op_registration.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | 4 | /** C API type forward declaration. */ 5 | struct OrtStatus; 6 | 7 | /** C API type forward declaration. */ 8 | struct OrtApiBase; 9 | 10 | /** C API type forward declaration. */ 11 | struct OrtSessionOptions; 12 | 13 | /** 14 | * Pointer to a custom op registration function that uses the ONNX Runtime C API. 15 | * 16 | * The signature is defined in the ONNX Runtime C API: 17 | * https://github.com/microsoft/onnxruntime/blob/67f4cd54fab321d83e4a75a40efeee95a6a17079/include/onnxruntime/core/session/onnxruntime_c_api.h#L697 18 | * 19 | * This is a low-level type intended for interoperating with libraries which provide such a function for custom op 20 | * registration, such as [ONNX Runtime Extensions](https://github.com/microsoft/onnxruntime-extensions). 21 | */ 22 | typedef struct OrtStatus* (*ORTCAPIRegisterCustomOpsFnPtr)(struct OrtSessionOptions* /*options*/, 23 | const struct OrtApiBase* /*api*/); 24 | -------------------------------------------------------------------------------- /swift/OnnxRuntimeExtensionsTests/SwiftOnnxRuntimeExtensionsTests.swift: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | 4 | import XCTest 5 | import Foundation 6 | @testable import OnnxRuntimeBindings 7 | @testable import OnnxRuntimeExtensions 8 | 9 | final class SwiftOnnxRuntimeExtensionsTests: XCTestCase { 10 | 11 | let modelPath: String = Bundle.module.url(forResource: "decode_image", withExtension: "onnx")!.path 12 | 13 | func testCreateSessionWithCustomOps() throws { 14 | let env = try ORTEnv(loggingLevel: ORTLoggingLevel.verbose) 15 | let options = try ORTSessionOptions() 16 | try options.setLogSeverityLevel(ORTLoggingLevel.verbose) 17 | try options.setIntraOpNumThreads(1) 18 | 19 | // Register Custom Ops library using function pointer 20 | let ortCustomOpsFnPtr = OrtExt.getRegisterCustomOpsFunctionPointer() 21 | try options.registerCustomOps(functionPointer: ortCustomOpsFnPtr) 22 | 23 | // Create the ORTSession 24 | _ = try ORTSession(env: env, modelPath: modelPath, sessionOptions: options) 25 | } 26 | } -------------------------------------------------------------------------------- /.pipelines/official.yml: -------------------------------------------------------------------------------- 1 | trigger: none 2 | 3 | # The `resources` specify the location and version of the 1ES PT. 4 | resources: 5 | pipelines: 6 | - pipeline: 'pod' 7 | project: 'Lotus' 8 | source: 'onnxruntime-ios-packaging-pipeline' 9 | repositories: 10 | - repository: 1esPipelines 11 | type: git 12 | name: 1ESPipelineTemplates/1ESPipelineTemplates 13 | ref: refs/tags/release 14 | 15 | extends: 16 | template: v1/1ES.Official.PipelineTemplate.yml@1esPipelines 17 | parameters: 18 | settings: 19 | networkIsolationPolicy: Preferred, GitHub 20 | pool: 21 | name: Azure Pipelines 22 | image: "macOS-14" 23 | os: macOS 24 | sdl: 25 | sourceAnalysisPool: 26 | name: onnxruntime-Win-CPU-2022 27 | os: windows 28 | policheck: 29 | enabled: true 30 | credscan: 31 | enabled: true 32 | codeql: 33 | sourceLanguages: python 34 | tsa: 35 | enabled: true 36 | configFile: '$(Build.SourcesDirectory)\.config\tsaoptions.json' 37 | stages: 38 | - stage: Stage 39 | jobs: 40 | - template: .pipelines/templates/main.yml@self 41 | -------------------------------------------------------------------------------- /.pipelines/mac-ios-spm-dev-validation-pipeline.yml: -------------------------------------------------------------------------------- 1 | trigger: none 2 | 3 | # The `resources` specify the location and version of the 1ES PT. 4 | resources: 5 | pipelines: 6 | - pipeline: 'pod' 7 | project: 'Lotus' 8 | source: 'onnxruntime-ios-packaging-pipeline' 9 | repositories: 10 | - repository: 1esPipelines 11 | type: git 12 | name: 1ESPipelineTemplates/1ESPipelineTemplates 13 | ref: refs/tags/release 14 | 15 | extends: 16 | template: v1/1ES.Unofficial.PipelineTemplate.yml@1esPipelines 17 | parameters: 18 | settings: 19 | networkIsolationPolicy: Preferred, GitHub 20 | pool: 21 | name: Azure Pipelines 22 | image: "macOS-14" 23 | os: macOS 24 | sdl: 25 | sourceAnalysisPool: 26 | name: onnxruntime-Win-CPU-2022 27 | os: windows 28 | policheck: 29 | enabled: true 30 | credscan: 31 | enabled: true 32 | codeql: 33 | sourceLanguages: python 34 | tsa: 35 | enabled: false 36 | configFile: '$(Build.SourcesDirectory)\.config\tsaoptions.json' 37 | stages: 38 | - stage: Stage 39 | jobs: 40 | - template: .pipelines/templates/main.yml@self 41 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Microsoft Corporation. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE 22 | -------------------------------------------------------------------------------- /objectivec/include/ort_env.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | 4 | #import 5 | 6 | #import "ort_enums.h" 7 | 8 | NS_ASSUME_NONNULL_BEGIN 9 | 10 | #ifdef __cplusplus 11 | extern "C" { 12 | #endif 13 | 14 | /** 15 | * Gets the ORT version string in format major.minor.patch. 16 | * 17 | * Available since 1.15. 18 | */ 19 | NSString* _Nullable ORTVersion(void); 20 | 21 | #ifdef __cplusplus 22 | } 23 | #endif 24 | 25 | /** 26 | * The ORT environment. 27 | * It maintains shared state including the default logger. 28 | * 29 | * @note One ORTEnv should be created before and destroyed after other ORT API usage. 30 | */ 31 | @interface ORTEnv : NSObject 32 | 33 | - (instancetype)init NS_UNAVAILABLE; 34 | 35 | /** 36 | * Creates an ORT Environment. 37 | * 38 | * @param loggingLevel The environment logging level. 39 | * @param error Optional error information set if an error occurs. 40 | * @return The instance, or nil if an error occurs. 41 | */ 42 | - (nullable instancetype)initWithLoggingLevel:(ORTLoggingLevel)loggingLevel 43 | error:(NSError**)error NS_DESIGNATED_INITIALIZER; 44 | 45 | @end 46 | 47 | NS_ASSUME_NONNULL_END 48 | -------------------------------------------------------------------------------- /objectivec/include/ort_xnnpack_execution_provider.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | 4 | #import 5 | 6 | #import "ort_session.h" 7 | 8 | NS_ASSUME_NONNULL_BEGIN 9 | 10 | /** 11 | * Options for configuring the Xnnpack execution provider. 12 | */ 13 | @interface ORTXnnpackExecutionProviderOptions : NSObject 14 | 15 | /** 16 | * How many threads used for the Xnnpack execution provider. 17 | */ 18 | @property int intra_op_num_threads; 19 | 20 | @end 21 | 22 | @interface ORTSessionOptions (ORTSessionOptionsXnnpackEP) 23 | 24 | /** 25 | * Available since 1.14. 26 | * Enables the Xnnpack execution provider in the session configuration options. 27 | * It is appended to the execution provider list which is ordered by 28 | * decreasing priority. 29 | * 30 | * @param options The Xnnpack execution provider configuration options. 31 | * @param error Optional error information set if an error occurs. 32 | * @return Whether the provider was enabled successfully. 33 | */ 34 | - (BOOL)appendXnnpackExecutionProviderWithOptions:(ORTXnnpackExecutionProviderOptions*)options 35 | error:(NSError**)error; 36 | 37 | @end 38 | 39 | NS_ASSUME_NONNULL_END 40 | -------------------------------------------------------------------------------- /SUPPORT.md: -------------------------------------------------------------------------------- 1 | # TODO: The maintainer of this repo has not yet edited this file 2 | 3 | **REPO OWNER**: Do you want Customer Service & Support (CSS) support for this product/project? 4 | 5 | - **No CSS support:** Fill out this template with information about how to file issues and get help. 6 | - **Yes CSS support:** Fill out an intake form at [aka.ms/onboardsupport](https://aka.ms/onboardsupport). CSS will work with/help you to determine next steps. 7 | - **Not sure?** Fill out an intake as though the answer were "Yes". CSS will help you decide. 8 | 9 | *Then remove this first heading from this SUPPORT.MD file before publishing your repo.* 10 | 11 | # Support 12 | 13 | ## How to file issues and get help 14 | 15 | This project uses GitHub Issues to track bugs and feature requests. Please search the existing 16 | issues before filing new issues to avoid duplicates. For new issues, file your bug or 17 | feature request as a new Issue. 18 | 19 | For help and questions about using this project, please **REPO MAINTAINER: INSERT INSTRUCTIONS HERE 20 | FOR HOW TO ENGAGE REPO OWNERS OR COMMUNITY FOR HELP. COULD BE A STACK OVERFLOW TAG OR OTHER 21 | CHANNEL. WHERE WILL YOU HELP PEOPLE?**. 22 | 23 | ## Microsoft Support Policy 24 | 25 | Support for this **PROJECT or PRODUCT** is limited to the resources listed above. 26 | -------------------------------------------------------------------------------- /objectivec/error_utils.mm: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | 4 | #import "error_utils.h" 5 | 6 | NS_ASSUME_NONNULL_BEGIN 7 | 8 | static NSString* const kOrtErrorDomain = @"onnxruntime"; 9 | 10 | void ORTSaveCodeAndDescriptionToError(int code, const char* descriptionCstr, NSError** error) { 11 | if (!error) return; 12 | 13 | NSString* description = [NSString stringWithCString:descriptionCstr 14 | encoding:NSASCIIStringEncoding]; 15 | 16 | *error = [NSError errorWithDomain:kOrtErrorDomain 17 | code:code 18 | userInfo:@{NSLocalizedDescriptionKey : description}]; 19 | } 20 | 21 | void ORTSaveCodeAndDescriptionToError(int code, NSString* description, NSError** error) { 22 | if (!error) return; 23 | 24 | *error = [NSError errorWithDomain:kOrtErrorDomain 25 | code:code 26 | userInfo:@{NSLocalizedDescriptionKey : description}]; 27 | } 28 | 29 | void ORTSaveOrtExceptionToError(const Ort::Exception& e, NSError** error) { 30 | ORTSaveCodeAndDescriptionToError(e.GetOrtErrorCode(), e.what(), error); 31 | } 32 | 33 | void ORTSaveExceptionToError(const std::exception& e, NSError** error) { 34 | ORTSaveCodeAndDescriptionToError(ORT_RUNTIME_EXCEPTION, e.what(), error); 35 | } 36 | 37 | NS_ASSUME_NONNULL_END 38 | -------------------------------------------------------------------------------- /objectivec/error_utils.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | 4 | #import 5 | 6 | #include 7 | 8 | #import "cxx_api.h" 9 | 10 | NS_ASSUME_NONNULL_BEGIN 11 | 12 | void ORTSaveCodeAndDescriptionToError(int code, const char* description, NSError** error); 13 | void ORTSaveCodeAndDescriptionToError(int code, NSString* description, NSError** error); 14 | void ORTSaveOrtExceptionToError(const Ort::Exception& e, NSError** error); 15 | void ORTSaveExceptionToError(const std::exception& e, NSError** error); 16 | 17 | // helper macros to catch and handle C++ exceptions 18 | #define ORT_OBJC_API_IMPL_CATCH(error, failure_return_value) \ 19 | catch (const Ort::Exception& e) { \ 20 | ORTSaveOrtExceptionToError(e, (error)); \ 21 | return (failure_return_value); \ 22 | } \ 23 | catch (const std::exception& e) { \ 24 | ORTSaveExceptionToError(e, (error)); \ 25 | return (failure_return_value); \ 26 | } 27 | 28 | #define ORT_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error) \ 29 | ORT_OBJC_API_IMPL_CATCH(error, NO) 30 | 31 | #define ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error) \ 32 | ORT_OBJC_API_IMPL_CATCH(error, nil) 33 | 34 | NS_ASSUME_NONNULL_END 35 | -------------------------------------------------------------------------------- /objectivec/cxx_api.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | 4 | // wrapper for ORT C/C++ API headers 5 | 6 | #if defined(__clang__) 7 | #pragma clang diagnostic push 8 | // ignore clang documentation-related warnings 9 | // instead, we will rely on Doxygen warnings for the C/C++ API headers 10 | #pragma clang diagnostic ignored "-Wdocumentation" 11 | #endif // defined(__clang__) 12 | 13 | // paths are different when building the Swift Package Manager package as the headers come from the iOS pod archive 14 | // clang-format off 15 | #define STRINGIFY(x) #x 16 | #ifdef SPM_BUILD 17 | #define ORT_C_CXX_HEADER_FILE_PATH(x) STRINGIFY(onnxruntime/x) 18 | #else 19 | #define ORT_C_CXX_HEADER_FILE_PATH(x) STRINGIFY(x) 20 | #endif 21 | // clang-format on 22 | 23 | #if __has_include(ORT_C_CXX_HEADER_FILE_PATH(onnxruntime_training_c_api.h)) 24 | #include ORT_C_CXX_HEADER_FILE_PATH(onnxruntime_training_c_api.h) 25 | #include ORT_C_CXX_HEADER_FILE_PATH(onnxruntime_training_cxx_api.h) 26 | #else 27 | #include ORT_C_CXX_HEADER_FILE_PATH(onnxruntime_c_api.h) 28 | #include ORT_C_CXX_HEADER_FILE_PATH(onnxruntime_cxx_api.h) 29 | #endif 30 | 31 | #if __has_include(ORT_C_CXX_HEADER_FILE_PATH(coreml_provider_factory.h)) 32 | #define ORT_OBJC_API_COREML_EP_AVAILABLE 1 33 | #include ORT_C_CXX_HEADER_FILE_PATH(coreml_provider_factory.h) 34 | #else 35 | #define ORT_OBJC_API_COREML_EP_AVAILABLE 0 36 | #endif 37 | 38 | #if defined(__clang__) 39 | #pragma clang diagnostic pop 40 | #endif // defined(__clang__) 41 | -------------------------------------------------------------------------------- /objectivec/include/ort_enums.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | 4 | #import 5 | 6 | NS_ASSUME_NONNULL_BEGIN 7 | 8 | /** 9 | * The ORT logging verbosity levels. 10 | */ 11 | typedef NS_ENUM(int32_t, ORTLoggingLevel) { 12 | ORTLoggingLevelVerbose, 13 | ORTLoggingLevelInfo, 14 | ORTLoggingLevelWarning, 15 | ORTLoggingLevelError, 16 | ORTLoggingLevelFatal, 17 | }; 18 | 19 | /** 20 | * The ORT value types. 21 | * Currently, a subset of all types is supported. 22 | */ 23 | typedef NS_ENUM(int32_t, ORTValueType) { 24 | ORTValueTypeUnknown, 25 | ORTValueTypeTensor, 26 | }; 27 | 28 | /** 29 | * The ORT tensor element data types. 30 | * Currently, a subset of all types is supported. 31 | */ 32 | typedef NS_ENUM(int32_t, ORTTensorElementDataType) { 33 | ORTTensorElementDataTypeUndefined, 34 | ORTTensorElementDataTypeFloat, 35 | ORTTensorElementDataTypeInt8, 36 | ORTTensorElementDataTypeUInt8, 37 | ORTTensorElementDataTypeInt32, 38 | ORTTensorElementDataTypeUInt32, 39 | ORTTensorElementDataTypeInt64, 40 | ORTTensorElementDataTypeUInt64, 41 | ORTTensorElementDataTypeString, 42 | }; 43 | 44 | /** 45 | * The ORT graph optimization levels. 46 | * See here for more details: 47 | * https://onnxruntime.ai/docs/performance/model-optimizations/graph-optimizations.html 48 | */ 49 | typedef NS_ENUM(int32_t, ORTGraphOptimizationLevel) { 50 | ORTGraphOptimizationLevelNone, 51 | ORTGraphOptimizationLevelBasic, 52 | ORTGraphOptimizationLevelExtended, 53 | ORTGraphOptimizationLevelAll, 54 | }; 55 | 56 | NS_ASSUME_NONNULL_END 57 | -------------------------------------------------------------------------------- /objectivec/test/test_utils.mm: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | 4 | #import "test_utils.h" 5 | 6 | NS_ASSUME_NONNULL_BEGIN 7 | 8 | namespace test_utils { 9 | 10 | NSString* createTemporaryDirectory(XCTestCase* testCase) { 11 | NSString* temporaryDirectory = NSTemporaryDirectory(); 12 | NSString* directoryPath = [temporaryDirectory stringByAppendingPathComponent:@"ort-objective-c-test"]; 13 | 14 | NSError* error = nil; 15 | [[NSFileManager defaultManager] createDirectoryAtPath:directoryPath 16 | withIntermediateDirectories:YES 17 | attributes:nil 18 | error:&error]; 19 | 20 | XCTAssertNil(error, @"Error creating temporary directory: %@", error.localizedDescription); 21 | 22 | // add teardown block to delete the temporary directory 23 | [testCase addTeardownBlock:^{ 24 | NSError* error = nil; 25 | [[NSFileManager defaultManager] removeItemAtPath:directoryPath error:&error]; 26 | XCTAssertNil(error, @"Error removing temporary directory: %@", error.localizedDescription); 27 | }]; 28 | 29 | return directoryPath; 30 | } 31 | 32 | NSArray* getFloatArrayFromData(NSData* data) { 33 | NSMutableArray* array = [NSMutableArray array]; 34 | float value; 35 | for (size_t i = 0; i < data.length / sizeof(float); ++i) { 36 | [data getBytes:&value range:NSMakeRange(i * sizeof(float), sizeof(float))]; 37 | [array addObject:[NSNumber numberWithFloat:value]]; 38 | } 39 | return array; 40 | } 41 | 42 | } // namespace test_utils 43 | 44 | NS_ASSUME_NONNULL_END 45 | -------------------------------------------------------------------------------- /objectivec/ort_coreml_execution_provider.mm: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | 4 | #import "ort_coreml_execution_provider.h" 5 | 6 | #import "cxx_api.h" 7 | #import "error_utils.h" 8 | #import "ort_session_internal.h" 9 | 10 | NS_ASSUME_NONNULL_BEGIN 11 | 12 | BOOL ORTIsCoreMLExecutionProviderAvailable() { 13 | return ORT_OBJC_API_COREML_EP_AVAILABLE ? YES : NO; 14 | } 15 | 16 | @implementation ORTCoreMLExecutionProviderOptions 17 | 18 | @end 19 | 20 | @implementation ORTSessionOptions (ORTSessionOptionsCoreMLEP) 21 | 22 | - (BOOL)appendCoreMLExecutionProviderWithOptions:(ORTCoreMLExecutionProviderOptions*)options 23 | error:(NSError**)error { 24 | #if ORT_OBJC_API_COREML_EP_AVAILABLE 25 | try { 26 | const uint32_t flags = 27 | (options.useCPUOnly ? COREML_FLAG_USE_CPU_ONLY : 0) | 28 | (options.useCPUAndGPU ? COREML_FLAG_USE_CPU_AND_GPU : 0) | 29 | (options.enableOnSubgraphs ? COREML_FLAG_ENABLE_ON_SUBGRAPH : 0) | 30 | (options.onlyEnableForDevicesWithANE ? COREML_FLAG_ONLY_ENABLE_DEVICE_WITH_ANE : 0) | 31 | (options.onlyAllowStaticInputShapes ? COREML_FLAG_ONLY_ALLOW_STATIC_INPUT_SHAPES : 0) | 32 | (options.createMLProgram ? COREML_FLAG_CREATE_MLPROGRAM : 0); 33 | 34 | Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CoreML( 35 | [self CXXAPIOrtSessionOptions], flags)); 36 | return YES; 37 | } 38 | ORT_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error); 39 | #else // !ORT_OBJC_API_COREML_EP_AVAILABLE 40 | static_cast(options); 41 | ORTSaveCodeAndDescriptionToError(ORT_FAIL, "CoreML execution provider is not enabled.", error); 42 | return NO; 43 | #endif 44 | } 45 | 46 | @end 47 | 48 | NS_ASSUME_NONNULL_END 49 | -------------------------------------------------------------------------------- /.pipelines/templates/main.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | - name: artifactName 3 | type: string 4 | default: 'ios_packaging_artifacts_full' 5 | 6 | jobs: 7 | - job: main 8 | templateContext: 9 | isProduction: false 10 | inputs: 11 | - input: pipelineArtifact 12 | pipeline: 'pod' 13 | artifactName: '${{ parameters.artifactName }}' 14 | targetPath: '$(Build.ArtifactStagingDirectory)/${{ parameters.artifactName }}' 15 | steps: 16 | - script: | 17 | set -e -x 18 | POD_ARCHIVE=$(find . -name "pod-archive-onnxruntime-objc*.zip") 19 | unzip ${POD_ARCHIVE} -d unzipped 20 | cp -rf unzipped/objectivec/ $(Build.SourcesDirectory)/objectivec/ 21 | # Corrected variable name to match the parameter 22 | workingDirectory: '$(Build.ArtifactStagingDirectory)/${{ parameters.artifactName }}' 23 | displayName: Copy latest dev version ORT objectivec/ source files 24 | 25 | # copy the pod archive to a path relative to Package.swift and set the env var required by Package.swift to use that. 26 | # xcodebuild will implicitly use Package.swift and build/run the .testTarget (tests in swift/onnxTests). 27 | # once that's done cleanup the copy of the pod zip file 28 | - script: | 29 | set -e -x 30 | # Corrected variable name to match the parameter 31 | cd "$(Build.ArtifactStagingDirectory)/${{ parameters.artifactName }}" 32 | POD_ARCHIVE=$(find . -name "pod-archive-onnxruntime-c*.zip") 33 | 34 | shasum -a 256 "$(Build.ArtifactStagingDirectory)/${{ parameters.artifactName }}/${POD_ARCHIVE}" 35 | 36 | cd "$(Build.SourcesDirectory)" 37 | cp "$(Build.ArtifactStagingDirectory)/${{ parameters.artifactName }}/${POD_ARCHIVE}" swift/ 38 | export ORT_POD_LOCAL_PATH="swift/${POD_ARCHIVE}" 39 | xcodebuild test -scheme onnxruntime-Package -destination 'platform=iOS Simulator,name=iPhone 16' 40 | 41 | xcodebuild test -scheme onnxruntime-Package -destination 'platform=macosx' 42 | 43 | rm swift/pod-archive-onnxruntime-c-*.zip 44 | workingDirectory: "$(Build.SourcesDirectory)" 45 | displayName: "Print ORT iOS Pod checksum and Test Package.swift usage" 46 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Swift Package Manager for ONNX Runtime 2 | 3 | A light-weight repository for providing [Swift Package Manager (SPM)](https://www.swift.org/package-manager/) support for [ONNXRuntime](https://github.com/microsoft/onnxruntime). The ONNX Runtime native package is included as a binary dependency of the SPM package. 4 | 5 | 6 | SPM is the alternative to CocoaPods when desired platform to consume is mobile iOS. 7 | 8 | ## Note 9 | 10 | The `objectivec/` directory is copied from the [ORT repo](https://github.com/microsoft/onnxruntime/tree/main/objectivec) and it's expected to match. It will be updated periodically/before release to merge new changes. 11 | 12 | ## Contributing 13 | 14 | This project welcomes contributions and suggestions. Most contributions require you to agree to a 15 | Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us 16 | the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com. 17 | 18 | When you submit a pull request, a CLA bot will automatically determine whether you need to provide 19 | a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions 20 | provided by the bot. You will only need to do this once across all repos using our CLA. 21 | 22 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 23 | For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or 24 | contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments. 25 | 26 | ## Trademarks 27 | 28 | This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft 29 | trademarks or logos is subject to and must follow 30 | [Microsoft's Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks/usage/general). 31 | Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship. 32 | Any use of third-party trademarks or logos are subject to those third-party's policies. 33 | -------------------------------------------------------------------------------- /objectivec/include/ort_coreml_execution_provider.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | 4 | #import 5 | 6 | #import "ort_session.h" 7 | 8 | #ifdef __cplusplus 9 | extern "C" { 10 | #endif 11 | 12 | /** 13 | * Gets whether the CoreML execution provider is available. 14 | */ 15 | BOOL ORTIsCoreMLExecutionProviderAvailable(void); 16 | 17 | #ifdef __cplusplus 18 | } 19 | #endif 20 | 21 | NS_ASSUME_NONNULL_BEGIN 22 | 23 | /** 24 | * Options for configuring the CoreML execution provider. 25 | */ 26 | @interface ORTCoreMLExecutionProviderOptions : NSObject 27 | 28 | /** 29 | * Whether the CoreML execution provider should run on CPU only. 30 | */ 31 | @property BOOL useCPUOnly; 32 | /** 33 | * exclude ANE in CoreML. 34 | */ 35 | @property BOOL useCPUAndGPU; 36 | /** 37 | * Whether the CoreML execution provider is enabled on subgraphs. 38 | */ 39 | @property BOOL enableOnSubgraphs; 40 | 41 | /** 42 | * Whether the CoreML execution provider is only enabled for devices with Apple 43 | * Neural Engine (ANE). 44 | */ 45 | @property BOOL onlyEnableForDevicesWithANE; 46 | 47 | /** 48 | * Only allow CoreML EP to take nodes with inputs with static shapes. By default it will also allow inputs with 49 | * dynamic shapes. However, the performance may be negatively impacted if inputs have dynamic shapes. 50 | */ 51 | @property BOOL onlyAllowStaticInputShapes; 52 | 53 | /** 54 | * Create an MLProgram. By default it will create a NeuralNetwork model. Requires Core ML 5 or later. 55 | */ 56 | @property BOOL createMLProgram; 57 | 58 | @end 59 | 60 | @interface ORTSessionOptions (ORTSessionOptionsCoreMLEP) 61 | 62 | /** 63 | * Enables the CoreML execution provider in the session configuration options. 64 | * It is appended to the execution provider list which is ordered by 65 | * decreasing priority. 66 | * 67 | * @param options The CoreML execution provider configuration options. 68 | * @param error Optional error information set if an error occurs. 69 | * @return Whether the provider was enabled successfully. 70 | */ 71 | - (BOOL)appendCoreMLExecutionProviderWithOptions:(ORTCoreMLExecutionProviderOptions*)options 72 | error:(NSError**)error; 73 | 74 | @end 75 | 76 | NS_ASSUME_NONNULL_END 77 | -------------------------------------------------------------------------------- /objectivec/test/assertion_utils.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | 4 | #import 5 | 6 | NS_ASSUME_NONNULL_BEGIN 7 | 8 | #define ORTAssertNullableResultSuccessful(result, error) \ 9 | do { \ 10 | XCTAssertNotNil(result, @"Expected non-nil result but got nil. Error: %@", error); \ 11 | XCTAssertNil(error); \ 12 | } while (0) 13 | 14 | #define ORTAssertBoolResultSuccessful(result, error) \ 15 | do { \ 16 | XCTAssertTrue(result, @"Expected true result but got false. Error: %@", error); \ 17 | XCTAssertNil(error); \ 18 | } while (0) 19 | 20 | #define ORTAssertNullableResultUnsuccessful(result, error) \ 21 | do { \ 22 | XCTAssertNil(result); \ 23 | XCTAssertNotNil(error); \ 24 | } while (0) 25 | 26 | #define ORTAssertBoolResultUnsuccessful(result, error) \ 27 | do { \ 28 | XCTAssertFalse(result); \ 29 | XCTAssertNotNil(error); \ 30 | } while (0) 31 | 32 | #define ORTAssertEqualFloatAndNoError(expected, result, error) \ 33 | do { \ 34 | XCTAssertEqualWithAccuracy(expected, result, 1e-3f, @"Expected %f but got %f. Error:%@", expected, result, error); \ 35 | XCTAssertNil(error); \ 36 | } while (0) 37 | 38 | #define ORTAssertEqualFloatArrays(expected, result) \ 39 | do { \ 40 | XCTAssertEqual(expected.count, result.count); \ 41 | for (size_t i = 0; i < expected.count; ++i) { \ 42 | XCTAssertEqualWithAccuracy([expected[i] floatValue], [result[i] floatValue], 1e-3f); \ 43 | } \ 44 | } while (0) 45 | 46 | NS_ASSUME_NONNULL_END 47 | -------------------------------------------------------------------------------- /swift/OnnxRuntimeBindingsTests/SwiftOnnxRuntimeBindingsTests.swift: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | 4 | import XCTest 5 | import Foundation 6 | @testable import OnnxRuntimeBindings 7 | 8 | final class SwiftOnnxRuntimeBindingsTests: XCTestCase { 9 | let modelPath: String = Bundle.module.url(forResource: "single_add.basic", withExtension: "ort")!.path 10 | 11 | func testGetVersionString() throws { 12 | do { 13 | let version = ORTVersion() 14 | XCTAssertNotNil(version) 15 | } catch let error { 16 | XCTFail(error.localizedDescription) 17 | } 18 | } 19 | 20 | func testCreateSession() throws { 21 | do { 22 | let env = try ORTEnv(loggingLevel: ORTLoggingLevel.verbose) 23 | let options = try ORTSessionOptions() 24 | try options.setLogSeverityLevel(ORTLoggingLevel.verbose) 25 | try options.setIntraOpNumThreads(1) 26 | // Create the ORTSession 27 | _ = try ORTSession(env: env, modelPath: modelPath, sessionOptions: options) 28 | } catch let error { 29 | XCTFail(error.localizedDescription) 30 | } 31 | } 32 | 33 | func testAppendCoreMLEP() throws { 34 | do { 35 | let env = try ORTEnv(loggingLevel: ORTLoggingLevel.verbose) 36 | let sessionOptions: ORTSessionOptions = try ORTSessionOptions() 37 | let coreMLOptions: ORTCoreMLExecutionProviderOptions = ORTCoreMLExecutionProviderOptions() 38 | coreMLOptions.enableOnSubgraphs = true 39 | try sessionOptions.appendCoreMLExecutionProvider(with: coreMLOptions) 40 | 41 | XCTAssertTrue(ORTIsCoreMLExecutionProviderAvailable()) 42 | _ = try ORTSession(env: env, modelPath: modelPath, sessionOptions: sessionOptions) 43 | } catch let error { 44 | XCTFail(error.localizedDescription) 45 | } 46 | } 47 | 48 | func testAppendXnnpackEP() throws { 49 | do { 50 | let env = try ORTEnv(loggingLevel: ORTLoggingLevel.verbose) 51 | let sessionOptions: ORTSessionOptions = try ORTSessionOptions() 52 | let XnnpackOptions: ORTXnnpackExecutionProviderOptions = ORTXnnpackExecutionProviderOptions() 53 | XnnpackOptions.intra_op_num_threads = 2 54 | try sessionOptions.appendXnnpackExecutionProvider(with: XnnpackOptions) 55 | 56 | XCTAssertTrue(ORTIsCoreMLExecutionProviderAvailable()) 57 | _ = try ORTSession(env: env, modelPath: modelPath, sessionOptions: sessionOptions) 58 | } catch let error { 59 | XCTFail(error.localizedDescription) 60 | } 61 | } 62 | } 63 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Security 4 | 5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/). 6 | 7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/opensource/security/definition), please report it to us as described below. 8 | 9 | ## Reporting Security Issues 10 | 11 | **Please do not report security vulnerabilities through public GitHub issues.** 12 | 13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/opensource/security/create-report). 14 | 15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/opensource/security/pgpkey). 16 | 17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://aka.ms/opensource/security/msrc). 18 | 19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: 20 | 21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 22 | * Full paths of source file(s) related to the manifestation of the issue 23 | * The location of the affected source code (tag/branch/commit or direct URL) 24 | * Any special configuration required to reproduce the issue 25 | * Step-by-step instructions to reproduce the issue 26 | * Proof-of-concept or exploit code (if possible) 27 | * Impact of the issue, including how an attacker might exploit the issue 28 | 29 | This information will help us triage your report more quickly. 30 | 31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/opensource/security/bounty) page for more details about our active programs. 32 | 33 | ## Preferred Languages 34 | 35 | We prefer all communications to be in English. 36 | 37 | ## Policy 38 | 39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/opensource/security/cvd). 40 | 41 | 42 | -------------------------------------------------------------------------------- /release_instructions.md: -------------------------------------------------------------------------------- 1 | # Release Instructions 2 | 3 | These are the steps to create a new onnxruntime-swift-package-manager release. 4 | Typically, an onnxruntime-swift-package-manager release will follow soon after an onnxruntime release. 5 | 6 | ## 1. Sync onnxruntime Objective-C source files 7 | 8 | Check out the corresponding onnxruntime release version, i.e., a release tag. 9 | ``` 10 | cd 11 | git checkout 12 | ``` 13 | 14 | Sync the onnxruntime Objective-C source files. 15 | ``` 16 | # Replace objectivec directory. 17 | rm -r /objectivec 18 | cp -r /objectivec /objectivec 19 | ``` 20 | 21 | 22 | ## 2. Update dependency versions in Package.swift 23 | 24 | Find the onnxruntime and onnxruntime-extensions binary target configuration in the Package.swift file. For example: 25 | 26 | https://github.com/microsoft/onnxruntime-swift-package-manager/blob/bbc428e168a0374eb7d0503cdb7c73fdc1d99751/Package.swift#L98-L104 27 | 28 | https://github.com/microsoft/onnxruntime-swift-package-manager/blob/bbc428e168a0374eb7d0503cdb7c73fdc1d99751/Package.swift#L110-L116 29 | 30 | ```swift 31 | // ORT release 32 | package.targets.append( 33 | Target.binaryTarget(name: "onnxruntime", 34 | url: "https://download.onnxruntime.ai/pod-archive-onnxruntime-c-1.19.2.zip", 35 | // ^^^^^^ Update version 36 | // SHA256 checksum 37 | checksum: "28787ee2f966a2c47eb293322c733c5dc4b5e3327cec321c1fe31a7c698edf68") 38 | // ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Update checksum 39 | ) 40 | ``` 41 | 42 | Update the version in the URL and checksum for both onnxruntime and onnxruntime-extensions targets. 43 | To get the checksum value, download the file from the updated URL and compute its SHA256 checksum. 44 | 45 | To compute a SHA256 checksum on Linux or MacOS: 46 | ``` 47 | sha256sum pod-archive-onnxruntime-c-x.y.z.zip 48 | ``` 49 | 50 | To compute a SHA256 checksum on Windows with Powershell: 51 | ``` 52 | (Get-FileHash -Algorithm SHA256 pod-archive-onnxruntime-c-x.y.z.zip).Hash.ToLower() 53 | ``` 54 | 55 | 56 | ## 3. Check in updates and create a release 57 | 58 | Note: This repo is updated relatively infrequently, so we currently do not have a release process that uses a separate release branch. 59 | 60 | Check in the changes made in the previous steps to the `main` branch. 61 | 62 | Create a release tag from the `main` branch. 63 | 64 | The tag should match the tag of the corresponding onnxruntime release excluding the leading `v` in order to make it a valid semantic version string. 65 | E.g., for onnxruntime tag `v1.20.0`, the onnxruntime-swift-package-manager tag should be `1.20.0`. 66 | -------------------------------------------------------------------------------- /objectivec/cxx_utils.mm: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | 4 | #import "cxx_utils.h" 5 | 6 | #include 7 | #include 8 | #include 9 | 10 | #import "error_utils.h" 11 | 12 | #import "ort_value_internal.h" 13 | 14 | NS_ASSUME_NONNULL_BEGIN 15 | 16 | namespace utils { 17 | 18 | NSString* toNSString(const std::string& str) { 19 | NSString* nsStr = [NSString stringWithUTF8String:str.c_str()]; 20 | if (!nsStr) { 21 | ORT_CXX_API_THROW("Failed to convert std::string to NSString", ORT_INVALID_ARGUMENT); 22 | } 23 | 24 | return nsStr; 25 | } 26 | 27 | NSString* _Nullable toNullableNSString(const std::optional& str) { 28 | if (str.has_value()) { 29 | return toNSString(*str); 30 | } 31 | return nil; 32 | } 33 | 34 | std::string toStdString(NSString* str) { 35 | return std::string([str UTF8String]); 36 | } 37 | 38 | std::optional toStdOptionalString(NSString* _Nullable str) { 39 | if (str) { 40 | return std::optional([str UTF8String]); 41 | } 42 | return std::nullopt; 43 | } 44 | 45 | std::vector toStdStringVector(NSArray* strs) { 46 | std::vector result; 47 | result.reserve(strs.count); 48 | for (NSString* str in strs) { 49 | result.push_back([str UTF8String]); 50 | } 51 | return result; 52 | } 53 | 54 | NSArray* toNSStringNSArray(const std::vector& strs) { 55 | NSMutableArray* result = [NSMutableArray arrayWithCapacity:strs.size()]; 56 | for (const std::string& str : strs) { 57 | [result addObject:toNSString(str)]; 58 | } 59 | return result; 60 | } 61 | 62 | NSArray* _Nullable wrapUnownedCAPIOrtValues(const std::vector& CAPIValues, NSError** error) { 63 | NSMutableArray* result = [NSMutableArray arrayWithCapacity:CAPIValues.size()]; 64 | for (size_t i = 0; i < CAPIValues.size(); ++i) { 65 | // Wrap the C OrtValue in a C++ Ort::Value to automatically handle its release. 66 | // Then, transfer that C++ Ort::Value to a new ORTValue. 67 | Ort::Value CXXAPIValue{CAPIValues[i]}; 68 | ORTValue* val = [[ORTValue alloc] initWithCXXAPIOrtValue:std::move(CXXAPIValue) 69 | externalTensorData:nil 70 | error:error]; 71 | if (!val) { 72 | // clean up remaining C OrtValues which haven't been wrapped by a C++ Ort::Value yet 73 | for (size_t j = i + 1; j < CAPIValues.size(); ++j) { 74 | Ort::GetApi().ReleaseValue(CAPIValues[j]); 75 | } 76 | return nil; 77 | } 78 | [result addObject:val]; 79 | } 80 | return result; 81 | } 82 | 83 | std::vector getWrappedCAPIOrtValues(NSArray* values) { 84 | std::vector result; 85 | result.reserve(values.count); 86 | for (ORTValue* val in values) { 87 | result.push_back(static_cast([val CXXAPIOrtValue])); 88 | } 89 | return result; 90 | } 91 | 92 | } // namespace utils 93 | 94 | NS_ASSUME_NONNULL_END 95 | -------------------------------------------------------------------------------- /objectivec/test/ort_value_test.mm: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | 4 | #import 5 | 6 | #import "ort_value.h" 7 | 8 | #include 9 | 10 | #import "test/assertion_utils.h" 11 | 12 | NS_ASSUME_NONNULL_BEGIN 13 | 14 | @interface ORTValueTest : XCTestCase 15 | @end 16 | 17 | @implementation ORTValueTest 18 | 19 | - (void)setUp { 20 | [super setUp]; 21 | 22 | self.continueAfterFailure = NO; 23 | } 24 | 25 | - (void)testInitTensorOk { 26 | int32_t value = 42; 27 | NSMutableData* data = [[NSMutableData alloc] initWithBytes:&value 28 | length:sizeof(int32_t)]; 29 | NSArray* shape = @[ @1 ]; 30 | 31 | const ORTTensorElementDataType elementType = ORTTensorElementDataTypeInt32; 32 | 33 | NSError* err = nil; 34 | ORTValue* ortValue = [[ORTValue alloc] initWithTensorData:data 35 | elementType:elementType 36 | shape:shape 37 | error:&err]; 38 | ORTAssertNullableResultSuccessful(ortValue, err); 39 | 40 | auto checkTensorInfo = [&](ORTTensorTypeAndShapeInfo* tensorInfo) { 41 | XCTAssertEqual(tensorInfo.elementType, elementType); 42 | XCTAssertEqualObjects(tensorInfo.shape, shape); 43 | }; 44 | 45 | ORTValueTypeInfo* typeInfo = [ortValue typeInfoWithError:&err]; 46 | ORTAssertNullableResultSuccessful(typeInfo, err); 47 | XCTAssertEqual(typeInfo.type, ORTValueTypeTensor); 48 | XCTAssertNotNil(typeInfo.tensorTypeAndShapeInfo); 49 | checkTensorInfo(typeInfo.tensorTypeAndShapeInfo); 50 | 51 | ORTTensorTypeAndShapeInfo* tensorInfo = [ortValue tensorTypeAndShapeInfoWithError:&err]; 52 | ORTAssertNullableResultSuccessful(tensorInfo, err); 53 | checkTensorInfo(tensorInfo); 54 | 55 | NSData* actualData = [ortValue tensorDataWithError:&err]; 56 | ORTAssertNullableResultSuccessful(actualData, err); 57 | XCTAssertEqual(actualData.length, sizeof(int32_t)); 58 | int32_t actualValue; 59 | memcpy(&actualValue, actualData.bytes, sizeof(int32_t)); 60 | XCTAssertEqual(actualValue, value); 61 | } 62 | 63 | - (void)testInitTensorFailsWithDataSmallerThanShape { 64 | std::vector values{1, 2, 3, 4}; 65 | NSMutableData* data = [[NSMutableData alloc] initWithBytes:values.data() 66 | length:values.size() * sizeof(int32_t)]; 67 | NSArray* shape = @[ @2, @3 ]; // too large 68 | 69 | NSError* err = nil; 70 | ORTValue* ortValue = [[ORTValue alloc] initWithTensorData:data 71 | elementType:ORTTensorElementDataTypeInt32 72 | shape:shape 73 | error:&err]; 74 | ORTAssertNullableResultUnsuccessful(ortValue, err); 75 | } 76 | 77 | - (void)testInitTensorWithStringDataSucceeds { 78 | NSArray* stringData = @[ @"ONNX Runtime", @"is", @"the", @"best", @"AI", @"Framework" ]; 79 | NSError* err = nil; 80 | ORTValue* stringValue = [[ORTValue alloc] initWithTensorStringData:stringData shape:@[ @3, @2 ] error:&err]; 81 | ORTAssertNullableResultSuccessful(stringValue, err); 82 | 83 | NSArray* returnedStringData = [stringValue tensorStringDataWithError:&err]; 84 | ORTAssertNullableResultSuccessful(returnedStringData, err); 85 | 86 | XCTAssertEqual([stringData count], [returnedStringData count]); 87 | XCTAssertTrue([stringData isEqualToArray:returnedStringData]); 88 | } 89 | 90 | @end 91 | 92 | NS_ASSUME_NONNULL_END 93 | -------------------------------------------------------------------------------- /objectivec/ort_checkpoint.mm: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | 4 | #import "ort_checkpoint_internal.h" 5 | 6 | #include 7 | #include 8 | #include 9 | #import "cxx_api.h" 10 | 11 | #import "cxx_utils.h" 12 | #import "error_utils.h" 13 | 14 | NS_ASSUME_NONNULL_BEGIN 15 | 16 | @implementation ORTCheckpoint { 17 | std::optional _checkpoint; 18 | } 19 | 20 | - (nullable instancetype)initWithPath:(NSString*)path 21 | error:(NSError**)error { 22 | if ((self = [super init]) == nil) { 23 | return nil; 24 | } 25 | 26 | try { 27 | _checkpoint = Ort::CheckpointState::LoadCheckpoint(path.UTF8String); 28 | return self; 29 | } 30 | ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error) 31 | } 32 | 33 | - (BOOL)saveCheckpointToPath:(NSString*)path 34 | withOptimizerState:(BOOL)includeOptimizerState 35 | error:(NSError**)error { 36 | try { 37 | Ort::CheckpointState::SaveCheckpoint([self CXXAPIOrtCheckpoint], path.UTF8String, includeOptimizerState); 38 | return YES; 39 | } 40 | ORT_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error) 41 | } 42 | 43 | - (BOOL)addIntPropertyWithName:(NSString*)name 44 | value:(int64_t)value 45 | error:(NSError**)error { 46 | try { 47 | [self CXXAPIOrtCheckpoint].AddProperty(name.UTF8String, value); 48 | return YES; 49 | } 50 | ORT_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error) 51 | } 52 | 53 | - (BOOL)addFloatPropertyWithName:(NSString*)name 54 | value:(float)value 55 | error:(NSError**)error { 56 | try { 57 | [self CXXAPIOrtCheckpoint].AddProperty(name.UTF8String, value); 58 | return YES; 59 | } 60 | ORT_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error) 61 | } 62 | 63 | - (BOOL)addStringPropertyWithName:(NSString*)name 64 | value:(NSString*)value 65 | error:(NSError**)error { 66 | try { 67 | [self CXXAPIOrtCheckpoint].AddProperty(name.UTF8String, value.UTF8String); 68 | return YES; 69 | } 70 | ORT_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error) 71 | } 72 | 73 | - (nullable NSString*)getStringPropertyWithName:(NSString*)name error:(NSError**)error { 74 | try { 75 | Ort::Property value = [self CXXAPIOrtCheckpoint].GetProperty(name.UTF8String); 76 | if (std::string* str = std::get_if(&value)) { 77 | return utils::toNSString(str->c_str()); 78 | } 79 | ORT_CXX_API_THROW("Property is not a string.", ORT_INVALID_ARGUMENT); 80 | } 81 | ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error) 82 | } 83 | 84 | - (int64_t)getIntPropertyWithName:(NSString*)name error:(NSError**)error { 85 | try { 86 | Ort::Property value = [self CXXAPIOrtCheckpoint].GetProperty(name.UTF8String); 87 | if (int64_t* i = std::get_if(&value)) { 88 | return *i; 89 | } 90 | ORT_CXX_API_THROW("Property is not an integer.", ORT_INVALID_ARGUMENT); 91 | } 92 | ORT_OBJC_API_IMPL_CATCH(error, 0) 93 | } 94 | 95 | - (float)getFloatPropertyWithName:(NSString*)name error:(NSError**)error { 96 | try { 97 | Ort::Property value = [self CXXAPIOrtCheckpoint].GetProperty(name.UTF8String); 98 | if (float* f = std::get_if(&value)) { 99 | return *f; 100 | } 101 | ORT_CXX_API_THROW("Property is not a float.", ORT_INVALID_ARGUMENT); 102 | } 103 | ORT_OBJC_API_IMPL_CATCH(error, 0.0f) 104 | } 105 | 106 | - (Ort::CheckpointState&)CXXAPIOrtCheckpoint { 107 | return *_checkpoint; 108 | } 109 | 110 | @end 111 | 112 | NS_ASSUME_NONNULL_END 113 | -------------------------------------------------------------------------------- /objectivec/test/ort_checkpoint_test.mm: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | 4 | #import 5 | 6 | #import "ort_checkpoint.h" 7 | #import "ort_training_session.h" 8 | #import "ort_env.h" 9 | #import "ort_session.h" 10 | 11 | #import "test/test_utils.h" 12 | #import "test/assertion_utils.h" 13 | 14 | NS_ASSUME_NONNULL_BEGIN 15 | 16 | @interface ORTCheckpointTest : XCTestCase 17 | @property(readonly, nullable) ORTEnv* ortEnv; 18 | @end 19 | 20 | @implementation ORTCheckpointTest 21 | 22 | - (void)setUp { 23 | [super setUp]; 24 | 25 | self.continueAfterFailure = NO; 26 | 27 | NSError* err = nil; 28 | _ortEnv = [[ORTEnv alloc] initWithLoggingLevel:ORTLoggingLevelWarning 29 | error:&err]; 30 | ORTAssertNullableResultSuccessful(_ortEnv, err); 31 | } 32 | 33 | + (NSString*)getCheckpointPath { 34 | NSBundle* bundle = [NSBundle bundleForClass:[ORTCheckpointTest class]]; 35 | NSString* path = [[bundle resourcePath] stringByAppendingPathComponent:@"checkpoint.ckpt"]; 36 | return path; 37 | } 38 | 39 | + (NSString*)getTrainingModelPath { 40 | NSBundle* bundle = [NSBundle bundleForClass:[ORTCheckpointTest class]]; 41 | NSString* path = [[bundle resourcePath] stringByAppendingPathComponent:@"training_model.onnx"]; 42 | return path; 43 | } 44 | 45 | - (void)testSaveCheckpoint { 46 | NSError* error = nil; 47 | ORTCheckpoint* checkpoint = [[ORTCheckpoint alloc] initWithPath:[ORTCheckpointTest getCheckpointPath] error:&error]; 48 | ORTAssertNullableResultSuccessful(checkpoint, error); 49 | 50 | // save checkpoint 51 | NSString* path = [test_utils::createTemporaryDirectory(self) stringByAppendingPathComponent:@"save_checkpoint.ckpt"]; 52 | XCTAssertNotNil(path); 53 | BOOL result = [checkpoint saveCheckpointToPath:path withOptimizerState:NO error:&error]; 54 | 55 | ORTAssertBoolResultSuccessful(result, error); 56 | } 57 | 58 | - (void)testInitCheckpoint { 59 | NSError* error = nil; 60 | ORTCheckpoint* checkpoint = [[ORTCheckpoint alloc] initWithPath:[ORTCheckpointTest getCheckpointPath] error:&error]; 61 | ORTAssertNullableResultSuccessful(checkpoint, error); 62 | } 63 | 64 | - (void)testIntProperty { 65 | NSError* error = nil; 66 | // Load checkpoint 67 | ORTCheckpoint* checkpoint = [[ORTCheckpoint alloc] initWithPath:[ORTCheckpointTest getCheckpointPath] error:&error]; 68 | ORTAssertNullableResultSuccessful(checkpoint, error); 69 | 70 | // Add property 71 | BOOL result = [checkpoint addIntPropertyWithName:@"test" value:314 error:&error]; 72 | ORTAssertBoolResultSuccessful(result, error); 73 | 74 | // Get property 75 | int64_t value = [checkpoint getIntPropertyWithName:@"test" error:&error]; 76 | XCTAssertEqual(value, 314); 77 | } 78 | 79 | - (void)testFloatProperty { 80 | NSError* error = nil; 81 | // Load checkpoint 82 | ORTCheckpoint* checkpoint = [[ORTCheckpoint alloc] initWithPath:[ORTCheckpointTest getCheckpointPath] error:&error]; 83 | ORTAssertNullableResultSuccessful(checkpoint, error); 84 | 85 | // Add property 86 | BOOL result = [checkpoint addFloatPropertyWithName:@"test" value:3.14f error:&error]; 87 | ORTAssertBoolResultSuccessful(result, error); 88 | 89 | // Get property 90 | float value = [checkpoint getFloatPropertyWithName:@"test" error:&error]; 91 | XCTAssertEqual(value, 3.14f); 92 | } 93 | 94 | - (void)testStringProperty { 95 | NSError* error = nil; 96 | // Load checkpoint 97 | ORTCheckpoint* checkpoint = [[ORTCheckpoint alloc] initWithPath:[ORTCheckpointTest getCheckpointPath] error:&error]; 98 | ORTAssertNullableResultSuccessful(checkpoint, error); 99 | 100 | // Add property 101 | BOOL result = [checkpoint addStringPropertyWithName:@"test" value:@"hello" error:&error]; 102 | ORTAssertBoolResultSuccessful(result, error); 103 | 104 | // Get property 105 | NSString* value = [checkpoint getStringPropertyWithName:@"test" error:&error]; 106 | XCTAssertEqualObjects(value, @"hello"); 107 | } 108 | 109 | - (void)tearDown { 110 | _ortEnv = nil; 111 | 112 | [super tearDown]; 113 | } 114 | 115 | @end 116 | 117 | NS_ASSUME_NONNULL_END 118 | -------------------------------------------------------------------------------- /objectivec/include/ort_value.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | 4 | #import 5 | 6 | #import "ort_enums.h" 7 | 8 | NS_ASSUME_NONNULL_BEGIN 9 | 10 | @class ORTValueTypeInfo; 11 | @class ORTTensorTypeAndShapeInfo; 12 | 13 | /** 14 | * An ORT value encapsulates data used as an input or output to a model at runtime. 15 | */ 16 | @interface ORTValue : NSObject 17 | 18 | - (instancetype)init NS_UNAVAILABLE; 19 | 20 | /** 21 | * Creates a value that is a tensor. 22 | * The tensor data is allocated by the caller. 23 | * 24 | * @param tensorData The tensor data. 25 | * @param elementType The tensor element data type. 26 | * @param shape The tensor shape. 27 | * @param error Optional error information set if an error occurs. 28 | * @return The instance, or nil if an error occurs. 29 | */ 30 | - (nullable instancetype)initWithTensorData:(NSMutableData*)tensorData 31 | elementType:(ORTTensorElementDataType)elementType 32 | shape:(NSArray*)shape 33 | error:(NSError**)error; 34 | 35 | /** 36 | * Creates a value that is a string tensor. 37 | * The string data will be copied into a buffer owned by this ORTValue instance. 38 | * 39 | * Available since 1.16. 40 | * 41 | * @param tensorStringData The tensor string data. 42 | * @param shape The tensor shape. 43 | * @param error Optional error information set if an error occurs. 44 | * @return The instance, or nil if an error occurs. 45 | */ 46 | - (nullable instancetype)initWithTensorStringData:(NSArray*)tensorStringData 47 | shape:(NSArray*)shape 48 | error:(NSError**)error; 49 | 50 | /** 51 | * Gets the type information. 52 | * 53 | * @param error Optional error information set if an error occurs. 54 | * @return The type information, or nil if an error occurs. 55 | */ 56 | - (nullable ORTValueTypeInfo*)typeInfoWithError:(NSError**)error; 57 | 58 | /** 59 | * Gets the tensor type and shape information. 60 | * This assumes that the value is a tensor. 61 | * 62 | * @param error Optional error information set if an error occurs. 63 | * @return The tensor type and shape information, or nil if an error occurs. 64 | */ 65 | - (nullable ORTTensorTypeAndShapeInfo*)tensorTypeAndShapeInfoWithError:(NSError**)error; 66 | 67 | /** 68 | * Gets the tensor data. 69 | * This assumes that the value is a tensor. 70 | * 71 | * This returns the value's underlying data directly, not a copy of it. 72 | * The memory's lifetime may be tied to this value, i.e., if it was allocated 73 | * by ORT. On the other hand, the memory's lifetime is independent of the value 74 | * if the value was created with user-provided data. 75 | * 76 | * @param error Optional error information set if an error occurs. 77 | * @return The tensor data, or nil if an error occurs. 78 | */ 79 | - (nullable NSMutableData*)tensorDataWithError:(NSError**)error; 80 | 81 | /** 82 | * Gets the tensor string data. 83 | * This assumes that the value is a string tensor. 84 | * 85 | * This returns a copy of the value's underlying string data. 86 | * 87 | * Available since 1.16. 88 | * 89 | * @param error Optional error information set if an error occurs. 90 | * @return The copy of the tensor string data, or nil if an error occurs. 91 | */ 92 | - (nullable NSArray*)tensorStringDataWithError:(NSError**)error; 93 | 94 | @end 95 | 96 | /** 97 | * A value's type information. 98 | */ 99 | @interface ORTValueTypeInfo : NSObject 100 | 101 | /** The value type. */ 102 | @property(nonatomic) ORTValueType type; 103 | 104 | /** The tensor type and shape information, if the value is a tensor. */ 105 | @property(nonatomic, nullable) ORTTensorTypeAndShapeInfo* tensorTypeAndShapeInfo; 106 | 107 | @end 108 | 109 | /** 110 | * A tensor's type and shape information. 111 | */ 112 | @interface ORTTensorTypeAndShapeInfo : NSObject 113 | 114 | /** The tensor element data type. */ 115 | @property(nonatomic) ORTTensorElementDataType elementType; 116 | 117 | /** The tensor shape. */ 118 | @property(nonatomic) NSArray* shape; 119 | 120 | @end 121 | 122 | NS_ASSUME_NONNULL_END 123 | -------------------------------------------------------------------------------- /objectivec/include/ort_checkpoint.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | 4 | #import 5 | #include 6 | 7 | NS_ASSUME_NONNULL_BEGIN 8 | 9 | /** 10 | * An ORT checkpoint is a snapshot of the state of a model at a given point in time. 11 | * 12 | * This class holds the entire training session state that includes model parameters, 13 | * their gradients, optimizer parameters, and user properties. The `ORTTrainingSession` leverages the 14 | * `ORTCheckpoint` by accessing and updating the contained training state. 15 | * 16 | * Available since 1.16. 17 | * 18 | * @note This class is only available when the training APIs are enabled. 19 | */ 20 | @interface ORTCheckpoint : NSObject 21 | 22 | - (instancetype)init NS_UNAVAILABLE; 23 | 24 | /** 25 | * Creates a checkpoint from directory on disk. 26 | * 27 | * @param path The path to the checkpoint directory. 28 | * @param error Optional error information set if an error occurs. 29 | * @return The instance, or nil if an error occurs. 30 | * 31 | * @warning The construction of the checkpoint state requires instantiation of `ORTEnv`. 32 | * The intialization will fail if the `ORTEnv` is not properly initialized. 33 | */ 34 | - (nullable instancetype)initWithPath:(NSString*)path 35 | error:(NSError**)error NS_DESIGNATED_INITIALIZER; 36 | 37 | /** 38 | * Saves a checkpoint to directory on disk. 39 | * 40 | * @param path The path to the checkpoint directory. 41 | * @param includeOptimizerState Flag to indicate whether to save the optimizer state or not. 42 | * @param error Optional error information set if an error occurs. 43 | * @return Whether the checkpoint was saved successfully. 44 | */ 45 | - (BOOL)saveCheckpointToPath:(NSString*)path 46 | withOptimizerState:(BOOL)includeOptimizerState 47 | error:(NSError**)error; 48 | 49 | /** 50 | * Adds an int property to this checkpoint. 51 | * 52 | * @param name The name of the property. 53 | * @param value The value of the property. 54 | * @param error Optional error information set if an error occurs. 55 | * @return Whether the property was added successfully. 56 | */ 57 | - (BOOL)addIntPropertyWithName:(NSString*)name 58 | value:(int64_t)value 59 | error:(NSError**)error; 60 | 61 | /** 62 | * Adds a float property to this checkpoint. 63 | * 64 | * @param name The name of the property. 65 | * @param value The value of the property. 66 | * @param error Optional error information set if an error occurs. 67 | * @return Whether the property was added successfully. 68 | */ 69 | - (BOOL)addFloatPropertyWithName:(NSString*)name 70 | value:(float)value 71 | error:(NSError**)error; 72 | 73 | /** 74 | * Adds a string property to this checkpoint. 75 | * 76 | * @param name The name of the property. 77 | * @param value The value of the property. 78 | * @param error Optional error information set if an error occurs. 79 | * @return Whether the property was added successfully. 80 | */ 81 | 82 | - (BOOL)addStringPropertyWithName:(NSString*)name 83 | value:(NSString*)value 84 | error:(NSError**)error; 85 | 86 | /** 87 | * Gets an int property from this checkpoint. 88 | * 89 | * @param name The name of the property. 90 | * @param error Optional error information set if an error occurs. 91 | * @return The value of the property or 0 if an error occurs. 92 | */ 93 | - (int64_t)getIntPropertyWithName:(NSString*)name 94 | error:(NSError**)error __attribute__((swift_error(nonnull_error))); 95 | 96 | /** 97 | * Gets a float property from this checkpoint. 98 | * 99 | * @param name The name of the property. 100 | * @param error Optional error information set if an error occurs. 101 | * @return The value of the property or 0.0f if an error occurs. 102 | */ 103 | - (float)getFloatPropertyWithName:(NSString*)name 104 | error:(NSError**)error __attribute__((swift_error(nonnull_error))); 105 | 106 | /** 107 | * 108 | * Gets a string property from this checkpoint. 109 | * 110 | * @param name The name of the property. 111 | * @param error Optional error information set if an error occurs. 112 | * @return The value of the property. 113 | */ 114 | - (nullable NSString*)getStringPropertyWithName:(NSString*)name 115 | error:(NSError**)error __attribute__((swift_error(nonnull_error))); 116 | 117 | @end 118 | 119 | NS_ASSUME_NONNULL_END 120 | -------------------------------------------------------------------------------- /.pipelines/mac-ios-spm-extensions-validation-pipeline.yml: -------------------------------------------------------------------------------- 1 | jobs: 2 | - job: j 3 | displayName: "Test with ORT Extensions native pod" 4 | 5 | pool: 6 | vmImage: "macOS-13" 7 | 8 | variables: 9 | artifactsName: "ios_packaging_artifacts_full" 10 | artifactsNameForExt: "ios_packaging_artifacts" 11 | 12 | timeoutInMinutes: 60 13 | 14 | steps: 15 | - template: templates/use-xcode-version.yml 16 | 17 | - script: | 18 | mkdir tmp 19 | cd tmp 20 | git clone -n --depth=1 --filter=tree:0 https://github.com/microsoft/onnxruntime.git 21 | cd onnxruntime 22 | git sparse-checkout set --no-cone objectivec 23 | git checkout 24 | workingDirectory: "$(Build.SourcesDirectory)" 25 | displayName: "Sparse checkout objectivec/ folders from latest ORT main repository" 26 | 27 | - script: | 28 | ls -R "$(Build.SourcesDirectory)/tmp/onnxruntime" 29 | workingDirectory: "$(Build.SourcesDirectory)/tmp" 30 | displayName: "List sparse checkout repo contents" 31 | 32 | # Download artifacts for ORT C pod from iOS packaging pipeline 33 | - task: DownloadPipelineArtifact@2 34 | inputs: 35 | buildType: 'specific' 36 | project: 'Lotus' 37 | definition: 995 #'definitionid' is obtained from `System.DefinitionId` of ORT CI: onnxruntime-ios-packaging-pipeline 38 | buildVersionToDownload: 'latestFromBranch' 39 | branchName: 'refs/heads/main' 40 | targetPath: '$(Build.ArtifactStagingDirectory)' 41 | 42 | - script: | 43 | set -e -x 44 | ls 45 | workingDirectory: '$(Build.ArtifactStagingDirectory)/$(artifactsName)' 46 | displayName: "List staged artifacts for ORT C Pod" 47 | 48 | # Download artifacts for ORT Ext C pod from extensions iOS packaging pipeline 49 | - task: DownloadPipelineArtifact@2 50 | inputs: 51 | buildType: 'specific' 52 | project: 'Lotus' 53 | definition: 1206 #'definitionid' is obtained from `System.DefinitionId` of extensions CI: extensions.ios_packaging 54 | buildVersionToDownload: 'latestFromBranch' 55 | branchName: 'refs/heads/main' 56 | targetPath: '$(Build.ArtifactStagingDirectory)' 57 | 58 | - script: | 59 | set -e -x 60 | ls 61 | workingDirectory: '$(Build.ArtifactStagingDirectory)/$(artifactsNameForExt)' 62 | displayName: "List staged artifacts for ORT Ext C Pod" 63 | 64 | # Note: Running xcodebuild test on `onnxruntime-Package` scheme will perform swift unit tests for both OnnxRuntimeBindings 65 | # and OnnxRuntimeExtensions targets. 66 | - script: | 67 | set -e -x 68 | cd "$(Build.ArtifactStagingDirectory)/$(artifactsName)" 69 | POD_ARCHIVE=$(find . -name "pod-archive-onnxruntime-c-*.zip") 70 | 71 | echo "Printing the checksum here to make it easier when updating for actual release pod archive file in Package.swift" 72 | shasum -a 256 "$(Build.ArtifactStagingDirectory)/$(artifactsName)/${POD_ARCHIVE}" 73 | 74 | cd "$(Build.ArtifactStagingDirectory)/$(artifactsNameForExt)" 75 | POD_ARCHIVE_EXT=$(find . -name "pod-archive-onnxruntime-extensions-c-*.zip") 76 | 77 | echo "Printing the checksum here to make it easier when updating for actual release extensions pod archive file in Package.swift" 78 | shasum -a 256 "$(Build.ArtifactStagingDirectory)/$(artifactsNameForExt)/${POD_ARCHIVE_EXT}" 79 | 80 | cd "$(Build.SourcesDirectory)/tmp/onnxruntime" 81 | cp -r "$(Build.SourcesDirectory)/swift" . 82 | cp "$(Build.SourcesDirectory)/Package.swift" . 83 | cp -r "$(Build.SourcesDirectory)/extensions" . 84 | 85 | cp "$(Build.ArtifactStagingDirectory)/$(artifactsName)/${POD_ARCHIVE}" swift/ 86 | export ORT_POD_LOCAL_PATH="swift/${POD_ARCHIVE}" 87 | cp "$(Build.ArtifactStagingDirectory)/$(artifactsNameForExt)/${POD_ARCHIVE_EXT}" swift/ 88 | export ORT_EXTENSIONS_POD_LOCAL_PATH="swift/${POD_ARCHIVE_EXT}" 89 | 90 | xcodebuild test -scheme onnxruntime-Package -destination 'platform=iOS Simulator,name=iPhone 14' 91 | 92 | xcodebuild test -scheme onnxruntime-Package -destination 'platform=macosx' 93 | 94 | rm swift/pod-archive-onnxruntime-c-*.zip 95 | rm swift/pod-archive-onnxruntime-extensions-c-*.zip 96 | workingDirectory: "$(Build.SourcesDirectory)/tmp" 97 | displayName: "Test Package.swift usage for ORT and ORT Extensions" 98 | 99 | - template: templates/component-governance-component-detection-steps.yml 100 | parameters: 101 | condition: 'succeeded' -------------------------------------------------------------------------------- /.github/workflows/codeql.yml: -------------------------------------------------------------------------------- 1 | # For most projects, this workflow file will not need changing; you simply need 2 | # to commit it to your repository. 3 | # 4 | # You may wish to alter this file to override the set of languages analyzed, 5 | # or to provide custom queries or build logic. 6 | # 7 | # ******** NOTE ******** 8 | # We have attempted to detect the languages in your repository. Please check 9 | # the `language` matrix defined below to confirm you have the correct set of 10 | # supported CodeQL languages. 11 | # 12 | name: "CodeQL Advanced" 13 | 14 | on: 15 | push: 16 | branches: [ "main" ] 17 | pull_request: 18 | branches: [ "main" ] 19 | schedule: 20 | - cron: '37 1 * * 5' 21 | 22 | jobs: 23 | analyze: 24 | name: Analyze (${{ matrix.language }}) 25 | # Runner size impacts CodeQL analysis time. To learn more, please see: 26 | # - https://gh.io/recommended-hardware-resources-for-running-codeql 27 | # - https://gh.io/supported-runners-and-hardware-resources 28 | # - https://gh.io/using-larger-runners (GitHub.com only) 29 | # Consider using larger runners or machines with greater resources for possible analysis time improvements. 30 | runs-on: ${{ (matrix.language == 'swift' && 'macos-latest') || 'ubuntu-latest' }} 31 | permissions: 32 | # required for all workflows 33 | security-events: write 34 | 35 | # required to fetch internal or private CodeQL packs 36 | packages: read 37 | 38 | # only required for workflows in private repositories 39 | actions: read 40 | contents: read 41 | 42 | strategy: 43 | fail-fast: false 44 | matrix: 45 | include: 46 | - language: swift 47 | build-mode: autobuild 48 | # CodeQL supports the following values keywords for 'language': 'actions', 'c-cpp', 'csharp', 'go', 'java-kotlin', 'javascript-typescript', 'python', 'ruby', 'rust', 'swift' 49 | # Use `c-cpp` to analyze code written in C, C++ or both 50 | # Use 'java-kotlin' to analyze code written in Java, Kotlin or both 51 | # Use 'javascript-typescript' to analyze code written in JavaScript, TypeScript or both 52 | # To learn more about changing the languages that are analyzed or customizing the build mode for your analysis, 53 | # see https://docs.github.com/en/code-security/code-scanning/creating-an-advanced-setup-for-code-scanning/customizing-your-advanced-setup-for-code-scanning. 54 | # If you are analyzing a compiled language, you can modify the 'build-mode' for that language to customize how 55 | # your codebase is analyzed, see https://docs.github.com/en/code-security/code-scanning/creating-an-advanced-setup-for-code-scanning/codeql-code-scanning-for-compiled-languages 56 | steps: 57 | - name: Checkout repository 58 | uses: actions/checkout@v4 59 | 60 | # Add any setup steps before running the `github/codeql-action/init` action. 61 | # This includes steps like installing compilers or runtimes (`actions/setup-node` 62 | # or others). This is typically only required for manual builds. 63 | # - name: Setup runtime (example) 64 | # uses: actions/setup-example@v1 65 | 66 | # Initializes the CodeQL tools for scanning. 67 | - name: Initialize CodeQL 68 | uses: github/codeql-action/init@v3 69 | with: 70 | languages: ${{ matrix.language }} 71 | build-mode: ${{ matrix.build-mode }} 72 | # If you wish to specify custom queries, you can do so here or in a config file. 73 | # By default, queries listed here will override any specified in a config file. 74 | # Prefix the list here with "+" to use these queries and those in the config file. 75 | 76 | # For more details on CodeQL's query packs, refer to: https://docs.github.com/en/code-security/code-scanning/automatically-scanning-your-code-for-vulnerabilities-and-errors/configuring-code-scanning#using-queries-in-ql-packs 77 | # queries: security-extended,security-and-quality 78 | 79 | # If the analyze step fails for one of the languages you are analyzing with 80 | # "We were unable to automatically build your code", modify the matrix above 81 | # to set the build mode to "manual" for that language. Then modify this step 82 | # to build your code. 83 | # ℹ️ Command-line programs to run using the OS shell. 84 | # 📚 See https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#jobsjob_idstepsrun 85 | - if: matrix.build-mode == 'manual' 86 | shell: bash 87 | run: | 88 | echo 'If you are using a "manual" build mode for one or more of the' \ 89 | 'languages you are analyzing, replace this with the commands to build' \ 90 | 'your code, for example:' 91 | echo ' make bootstrap' 92 | echo ' make release' 93 | exit 1 94 | 95 | - name: Perform CodeQL Analysis 96 | uses: github/codeql-action/analyze@v3 97 | with: 98 | category: "/language:${{matrix.language}}" 99 | -------------------------------------------------------------------------------- /Package.swift: -------------------------------------------------------------------------------- 1 | // swift-tools-version: 5.6 2 | // The swift-tools-version declares the minimum version of Swift required to build this package and MUST be the first 3 | // line of this file. 5.6 is required to support zip files for the pod archive binaryTarget. 4 | // 5 | // Copyright (c) Microsoft Corporation. All rights reserved. 6 | // Licensed under the MIT License. 7 | // 8 | // A user of the Swift Package Manager (SPM) package will consume this file directly from the ORT SPM github repository. 9 | // For example, the end user's config will look something like: 10 | // 11 | // dependencies: [ 12 | // .package(url: "https://github.com/microsoft/onnxruntime-swift-package-manager", from: "1.16.0"), 13 | // ... 14 | // ], 15 | // 16 | // NOTE: For valid version numbers, please refer to this page: 17 | // https://github.com/microsoft/onnxruntime-swift-package-manager/releases 18 | 19 | import PackageDescription 20 | import class Foundation.ProcessInfo 21 | 22 | let package = Package( 23 | name: "onnxruntime", 24 | platforms: [.iOS(.v13), 25 | .macOS(.v11)], 26 | products: [ 27 | .library(name: "onnxruntime", 28 | type: .static, 29 | targets: ["OnnxRuntimeBindings"]), 30 | .library(name: "onnxruntime_extensions", 31 | type: .static, 32 | targets: ["OnnxRuntimeExtensions"]), 33 | ], 34 | dependencies: [], 35 | targets: [ 36 | .target(name: "OnnxRuntimeBindings", 37 | dependencies: ["onnxruntime"], 38 | path: "objectivec", 39 | exclude: ["ReadMe.md", "format_objc.sh", "test", "docs", 40 | "ort_checkpoint.mm", 41 | "ort_checkpoint_internal.h", 42 | "ort_training_session_internal.h", 43 | "ort_training_session.mm", 44 | "include/ort_checkpoint.h", 45 | "include/ort_training_session.h", 46 | "include/onnxruntime_training.h"], 47 | cxxSettings: [ 48 | .define("SPM_BUILD"), 49 | ]), 50 | .testTarget(name: "OnnxRuntimeBindingsTests", 51 | dependencies: ["OnnxRuntimeBindings"], 52 | path: "swift/OnnxRuntimeBindingsTests", 53 | resources: [ 54 | .copy("Resources/single_add.basic.ort") 55 | ]), 56 | .target(name: "OnnxRuntimeExtensions", 57 | dependencies: ["onnxruntime_extensions", "onnxruntime"], 58 | path: "extensions", 59 | cxxSettings: [ 60 | .define("ORT_SWIFT_PACKAGE_MANAGER_BUILD"), 61 | ]), 62 | .testTarget(name: "OnnxRuntimeExtensionsTests", 63 | dependencies: ["OnnxRuntimeExtensions", "OnnxRuntimeBindings"], 64 | path: "swift/OnnxRuntimeExtensionsTests", 65 | resources: [ 66 | .copy("Resources/decode_image.onnx") 67 | ]), 68 | ], 69 | cxxLanguageStandard: .cxx17 70 | ) 71 | 72 | // Add the ORT CocoaPods C/C++ pod archive as a binary target. 73 | // 74 | // There are 2 scenarios: 75 | // - Target will be set to a released pod archive and its checksum. 76 | // 77 | // - Target will be set to a local pod archive. 78 | // This can be used to test with the latest (not yet released) ORT Objective-C source code. 79 | 80 | // CI or local testing where you have built/obtained the pod archive matching the current source code. 81 | // Requires the ORT_POD_LOCAL_PATH environment variable to be set to specify the location of the pod. 82 | if let pod_archive_path = ProcessInfo.processInfo.environment["ORT_POD_LOCAL_PATH"] { 83 | // ORT_POD_LOCAL_PATH MUST be a path that is relative to Package.swift. 84 | // 85 | // To build locally, tools/ci_build/github/apple/build_and_assemble_apple_pods.py can be used 86 | // See https://onnxruntime.ai/docs/build/custom.html#ios 87 | // Example command: 88 | // python3 tools/ci_build/github/apple/build_and_assemble_apple_pods.py \ 89 | // --variant Full \ 90 | // --build-settings-file tools/ci_build/github/apple/default_full_apple_framework_build_settings.json 91 | // 92 | // This should produce the pod archive in build/apple_pod_staging, and ORT_POD_LOCAL_PATH can be set to 93 | // "build/apple_pod_staging/pod-archive-onnxruntime-c-???.zip" where '???' is replaced by the version info in the 94 | // actual filename. 95 | package.targets.append(Target.binaryTarget(name: "onnxruntime", path: pod_archive_path)) 96 | 97 | } else { 98 | // ORT release 99 | package.targets.append( 100 | Target.binaryTarget(name: "onnxruntime", 101 | url: "https://download.onnxruntime.ai/pod-archive-onnxruntime-c-1.20.0.zip", 102 | // SHA256 checksum 103 | checksum: "50891a8aadd17d4811acb05ed151ba6c394129bb3ab14e843b0fc83a48d450ff") 104 | ) 105 | } 106 | 107 | if let ext_pod_archive_path = ProcessInfo.processInfo.environment["ORT_EXTENSIONS_POD_LOCAL_PATH"] { 108 | package.targets.append(Target.binaryTarget(name: "onnxruntime_extensions", path: ext_pod_archive_path)) 109 | } else { 110 | // ORT Extensions release 111 | package.targets.append( 112 | Target.binaryTarget(name: "onnxruntime_extensions", 113 | url: "https://download.onnxruntime.ai/pod-archive-onnxruntime-extensions-c-0.13.0.zip", 114 | // SHA256 checksum 115 | checksum: "346522d1171d4c99cb0908fa8e4e9330a4a6aad39cd83ce36eb654437b33e6b5") 116 | ) 117 | } 118 | -------------------------------------------------------------------------------- /objectivec/ort_enums.mm: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | 4 | #import "ort_enums_internal.h" 5 | 6 | #include 7 | #include 8 | 9 | #import "cxx_api.h" 10 | 11 | namespace { 12 | 13 | struct LoggingLevelInfo { 14 | ORTLoggingLevel logging_level; 15 | OrtLoggingLevel capi_logging_level; 16 | }; 17 | 18 | // supported ORT logging levels 19 | // define the mapping from ORTLoggingLevel to C API OrtLoggingLevel here 20 | constexpr LoggingLevelInfo kLoggingLevelInfos[]{ 21 | {ORTLoggingLevelVerbose, ORT_LOGGING_LEVEL_VERBOSE}, 22 | {ORTLoggingLevelInfo, ORT_LOGGING_LEVEL_INFO}, 23 | {ORTLoggingLevelWarning, ORT_LOGGING_LEVEL_WARNING}, 24 | {ORTLoggingLevelError, ORT_LOGGING_LEVEL_ERROR}, 25 | {ORTLoggingLevelFatal, ORT_LOGGING_LEVEL_FATAL}, 26 | }; 27 | 28 | struct ValueTypeInfo { 29 | ORTValueType type; 30 | ONNXType capi_type; 31 | }; 32 | 33 | // supported ORT value types 34 | // define the mapping from ORTValueType to C API ONNXType here 35 | constexpr ValueTypeInfo kValueTypeInfos[]{ 36 | {ORTValueTypeUnknown, ONNX_TYPE_UNKNOWN}, 37 | {ORTValueTypeTensor, ONNX_TYPE_TENSOR}, 38 | }; 39 | 40 | struct TensorElementTypeInfo { 41 | ORTTensorElementDataType type; 42 | ONNXTensorElementDataType capi_type; 43 | std::optional element_size; 44 | }; 45 | 46 | // supported ORT tensor element data types 47 | // define the mapping from ORTTensorElementDataType to C API ONNXTensorElementDataType here 48 | constexpr TensorElementTypeInfo kElementTypeInfos[]{ 49 | {ORTTensorElementDataTypeUndefined, ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED, std::nullopt}, 50 | {ORTTensorElementDataTypeFloat, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, sizeof(float)}, 51 | {ORTTensorElementDataTypeInt8, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, sizeof(int8_t)}, 52 | {ORTTensorElementDataTypeUInt8, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, sizeof(uint8_t)}, 53 | {ORTTensorElementDataTypeInt32, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, sizeof(int32_t)}, 54 | {ORTTensorElementDataTypeUInt32, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32, sizeof(uint32_t)}, 55 | {ORTTensorElementDataTypeInt64, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, sizeof(int64_t)}, 56 | {ORTTensorElementDataTypeUInt64, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64, sizeof(uint64_t)}, 57 | {ORTTensorElementDataTypeString, ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING, std::nullopt}, 58 | }; 59 | 60 | struct GraphOptimizationLevelInfo { 61 | ORTGraphOptimizationLevel opt_level; 62 | GraphOptimizationLevel capi_opt_level; 63 | }; 64 | 65 | // ORT graph optimization levels 66 | // define the mapping from ORTGraphOptimizationLevel to C API GraphOptimizationLevel here 67 | constexpr GraphOptimizationLevelInfo kGraphOptimizationLevelInfos[]{ 68 | {ORTGraphOptimizationLevelNone, ORT_DISABLE_ALL}, 69 | {ORTGraphOptimizationLevelBasic, ORT_ENABLE_BASIC}, 70 | {ORTGraphOptimizationLevelExtended, ORT_ENABLE_EXTENDED}, 71 | {ORTGraphOptimizationLevelAll, ORT_ENABLE_ALL}, 72 | }; 73 | 74 | template 75 | auto SelectAndTransform( 76 | const Container& container, SelectFn select_fn, TransformFn transform_fn, 77 | const char* not_found_msg) 78 | -> decltype(transform_fn(*std::begin(container))) { 79 | const auto it = std::find_if( 80 | std::begin(container), std::end(container), select_fn); 81 | if (it == std::end(container)) { 82 | ORT_CXX_API_THROW(not_found_msg, ORT_NOT_IMPLEMENTED); 83 | } 84 | return transform_fn(*it); 85 | } 86 | 87 | } // namespace 88 | 89 | OrtLoggingLevel PublicToCAPILoggingLevel(ORTLoggingLevel logging_level) { 90 | return SelectAndTransform( 91 | kLoggingLevelInfos, 92 | [logging_level](const auto& logging_level_info) { return logging_level_info.logging_level == logging_level; }, 93 | [](const auto& logging_level_info) { return logging_level_info.capi_logging_level; }, 94 | "unsupported logging level"); 95 | } 96 | 97 | ORTValueType CAPIToPublicValueType(ONNXType capi_type) { 98 | return SelectAndTransform( 99 | kValueTypeInfos, 100 | [capi_type](const auto& type_info) { return type_info.capi_type == capi_type; }, 101 | [](const auto& type_info) { return type_info.type; }, 102 | "unsupported value type"); 103 | } 104 | 105 | ONNXTensorElementDataType PublicToCAPITensorElementType(ORTTensorElementDataType type) { 106 | return SelectAndTransform( 107 | kElementTypeInfos, 108 | [type](const auto& type_info) { return type_info.type == type; }, 109 | [](const auto& type_info) { return type_info.capi_type; }, 110 | "unsupported tensor element type"); 111 | } 112 | 113 | ORTTensorElementDataType CAPIToPublicTensorElementType(ONNXTensorElementDataType capi_type) { 114 | return SelectAndTransform( 115 | kElementTypeInfos, 116 | [capi_type](const auto& type_info) { return type_info.capi_type == capi_type; }, 117 | [](const auto& type_info) { return type_info.type; }, 118 | "unsupported tensor element type"); 119 | } 120 | 121 | size_t SizeOfCAPITensorElementType(ONNXTensorElementDataType capi_type) { 122 | return SelectAndTransform( 123 | kElementTypeInfos, 124 | [capi_type](const auto& type_info) { 125 | return type_info.element_size.has_value() && type_info.capi_type == capi_type; 126 | }, 127 | [](const auto& type_info) { return *type_info.element_size; }, 128 | "unsupported tensor element type or tensor element type does not have a known size"); 129 | } 130 | 131 | GraphOptimizationLevel PublicToCAPIGraphOptimizationLevel(ORTGraphOptimizationLevel opt_level) { 132 | return SelectAndTransform( 133 | kGraphOptimizationLevelInfos, 134 | [opt_level](const auto& opt_level_info) { return opt_level_info.opt_level == opt_level; }, 135 | [](const auto& opt_level_info) { return opt_level_info.capi_opt_level; }, 136 | "unsupported graph optimization level"); 137 | } 138 | -------------------------------------------------------------------------------- /objectivec/ort_training_session.mm: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | 4 | #import "ort_training_session_internal.h" 5 | 6 | #import 7 | #import 8 | #import 9 | 10 | #import "cxx_api.h" 11 | #import "cxx_utils.h" 12 | #import "error_utils.h" 13 | #import "ort_checkpoint_internal.h" 14 | #import "ort_session_internal.h" 15 | #import "ort_enums_internal.h" 16 | #import "ort_env_internal.h" 17 | #import "ort_value_internal.h" 18 | 19 | NS_ASSUME_NONNULL_BEGIN 20 | 21 | @implementation ORTTrainingSession { 22 | ORTEnv* _env; // keep a strong reference so the ORTEnv doesn't get destroyed before this does 23 | ORTCheckpoint* _checkpoint; 24 | std::optional _session; 25 | } 26 | 27 | - (Ort::TrainingSession&)CXXAPIOrtTrainingSession { 28 | return *_session; 29 | } 30 | 31 | - (nullable instancetype)initWithEnv:(ORTEnv*)env 32 | sessionOptions:(nullable ORTSessionOptions*)sessionOptions 33 | checkpoint:(ORTCheckpoint*)checkpoint 34 | trainModelPath:(NSString*)trainModelPath 35 | evalModelPath:(nullable NSString*)evalModelPath 36 | optimizerModelPath:(nullable NSString*)optimizerModelPath 37 | error:(NSError**)error { 38 | if ((self = [super init]) == nil) { 39 | return nil; 40 | } 41 | 42 | try { 43 | if (!sessionOptions) { 44 | sessionOptions = [[ORTSessionOptions alloc] initWithError:error]; 45 | if (!sessionOptions) { 46 | return nil; 47 | } 48 | } 49 | 50 | std::optional evalPath = utils::toStdOptionalString(evalModelPath); 51 | std::optional optimizerPath = utils::toStdOptionalString(optimizerModelPath); 52 | 53 | _env = env; 54 | _checkpoint = checkpoint; 55 | _session = Ort::TrainingSession{ 56 | [env CXXAPIOrtEnv], 57 | [sessionOptions CXXAPIOrtSessionOptions], 58 | [checkpoint CXXAPIOrtCheckpoint], 59 | trainModelPath.UTF8String, 60 | evalPath, 61 | optimizerPath}; 62 | 63 | return self; 64 | } 65 | ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error) 66 | } 67 | 68 | - (nullable NSArray*)trainStepWithInputValues:(NSArray*)inputs 69 | error:(NSError**)error { 70 | try { 71 | std::vector inputValues = utils::getWrappedCAPIOrtValues(inputs); 72 | 73 | size_t outputCount; 74 | Ort::ThrowOnError(Ort::GetTrainingApi().TrainingSessionGetTrainingModelOutputCount(*_session, &outputCount)); 75 | std::vector outputValues(outputCount, nullptr); 76 | 77 | Ort::RunOptions runOptions; 78 | Ort::ThrowOnError(Ort::GetTrainingApi().TrainStep( 79 | *_session, 80 | runOptions, 81 | inputValues.size(), 82 | inputValues.data(), 83 | outputValues.size(), 84 | outputValues.data())); 85 | 86 | return utils::wrapUnownedCAPIOrtValues(outputValues, error); 87 | } 88 | ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error) 89 | } 90 | - (nullable NSArray*)evalStepWithInputValues:(NSArray*)inputs 91 | error:(NSError**)error { 92 | try { 93 | // create vector of OrtValue from NSArray with same size as inputValues 94 | std::vector inputValues = utils::getWrappedCAPIOrtValues(inputs); 95 | 96 | size_t outputCount; 97 | Ort::ThrowOnError(Ort::GetTrainingApi().TrainingSessionGetEvalModelOutputCount(*_session, &outputCount)); 98 | std::vector outputValues(outputCount, nullptr); 99 | 100 | Ort::RunOptions runOptions; 101 | Ort::ThrowOnError(Ort::GetTrainingApi().EvalStep( 102 | *_session, 103 | runOptions, 104 | inputValues.size(), 105 | inputValues.data(), 106 | outputValues.size(), 107 | outputValues.data())); 108 | 109 | return utils::wrapUnownedCAPIOrtValues(outputValues, error); 110 | } 111 | ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error) 112 | } 113 | 114 | - (BOOL)lazyResetGradWithError:(NSError**)error { 115 | try { 116 | [self CXXAPIOrtTrainingSession].LazyResetGrad(); 117 | return YES; 118 | } 119 | ORT_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error) 120 | } 121 | 122 | - (BOOL)optimizerStepWithError:(NSError**)error { 123 | try { 124 | [self CXXAPIOrtTrainingSession].OptimizerStep(); 125 | return YES; 126 | } 127 | ORT_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error) 128 | } 129 | 130 | - (nullable NSArray*)getTrainInputNamesWithError:(NSError**)error { 131 | try { 132 | std::vector inputNames = [self CXXAPIOrtTrainingSession].InputNames(true); 133 | return utils::toNSStringNSArray(inputNames); 134 | } 135 | ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error) 136 | } 137 | 138 | - (nullable NSArray*)getTrainOutputNamesWithError:(NSError**)error { 139 | try { 140 | std::vector outputNames = [self CXXAPIOrtTrainingSession].OutputNames(true); 141 | return utils::toNSStringNSArray(outputNames); 142 | } 143 | ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error) 144 | } 145 | 146 | - (nullable NSArray*)getEvalInputNamesWithError:(NSError**)error { 147 | try { 148 | std::vector inputNames = [self CXXAPIOrtTrainingSession].InputNames(false); 149 | return utils::toNSStringNSArray(inputNames); 150 | } 151 | ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error) 152 | } 153 | 154 | - (nullable NSArray*)getEvalOutputNamesWithError:(NSError**)error { 155 | try { 156 | std::vector outputNames = [self CXXAPIOrtTrainingSession].OutputNames(false); 157 | return utils::toNSStringNSArray(outputNames); 158 | } 159 | ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error) 160 | } 161 | 162 | - (BOOL)registerLinearLRSchedulerWithWarmupStepCount:(int64_t)warmupStepCount 163 | totalStepCount:(int64_t)totalStepCount 164 | initialLr:(float)initialLr 165 | error:(NSError**)error { 166 | try { 167 | [self CXXAPIOrtTrainingSession].RegisterLinearLRScheduler(warmupStepCount, totalStepCount, initialLr); 168 | return YES; 169 | } 170 | ORT_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error) 171 | } 172 | 173 | - (BOOL)schedulerStepWithError:(NSError**)error { 174 | try { 175 | [self CXXAPIOrtTrainingSession].SchedulerStep(); 176 | return YES; 177 | } 178 | ORT_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error) 179 | } 180 | 181 | - (float)getLearningRateWithError:(NSError**)error { 182 | try { 183 | return [self CXXAPIOrtTrainingSession].GetLearningRate(); 184 | } 185 | ORT_OBJC_API_IMPL_CATCH(error, 0.0f); 186 | } 187 | 188 | - (BOOL)setLearningRate:(float)lr 189 | error:(NSError**)error { 190 | try { 191 | [self CXXAPIOrtTrainingSession].SetLearningRate(lr); 192 | return YES; 193 | } 194 | ORT_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error) 195 | } 196 | 197 | - (BOOL)fromBufferWithValue:(ORTValue*)buffer 198 | error:(NSError**)error { 199 | try { 200 | [self CXXAPIOrtTrainingSession].FromBuffer([buffer CXXAPIOrtValue]); 201 | return YES; 202 | } 203 | ORT_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error) 204 | } 205 | 206 | - (nullable ORTValue*)toBufferWithTrainable:(BOOL)onlyTrainable 207 | error:(NSError**)error { 208 | try { 209 | Ort::Value val = [self CXXAPIOrtTrainingSession].ToBuffer(onlyTrainable); 210 | return [[ORTValue alloc] initWithCXXAPIOrtValue:std::move(val) 211 | externalTensorData:nil 212 | error:error]; 213 | } 214 | ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error) 215 | } 216 | 217 | - (BOOL)exportModelForInferenceWithOutputPath:(NSString*)inferenceModelPath 218 | graphOutputNames:(NSArray*)graphOutputNames 219 | error:(NSError**)error { 220 | try { 221 | [self CXXAPIOrtTrainingSession].ExportModelForInferencing(utils::toStdString(inferenceModelPath), 222 | utils::toStdStringVector(graphOutputNames)); 223 | return YES; 224 | } 225 | ORT_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error) 226 | } 227 | 228 | @end 229 | 230 | void ORTSetSeed(int64_t seed) { 231 | Ort::SetSeed(seed); 232 | } 233 | 234 | NS_ASSUME_NONNULL_END 235 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | ## Ignore Visual Studio temporary files, build results, and 2 | ## files generated by popular Visual Studio add-ons. 3 | ## 4 | ## Get latest from https://github.com/github/gitignore/blob/main/VisualStudio.gitignore 5 | 6 | # User-specific files 7 | *.rsuser 8 | *.suo 9 | *.user 10 | *.userosscache 11 | *.sln.docstates 12 | 13 | # User-specific files (MonoDevelop/Xamarin Studio) 14 | *.userprefs 15 | 16 | # Mono auto generated files 17 | mono_crash.* 18 | 19 | # Build results 20 | [Dd]ebug/ 21 | [Dd]ebugPublic/ 22 | [Rr]elease/ 23 | [Rr]eleases/ 24 | x64/ 25 | x86/ 26 | [Ww][Ii][Nn]32/ 27 | [Aa][Rr][Mm]/ 28 | [Aa][Rr][Mm]64/ 29 | bld/ 30 | [Bb]in/ 31 | [Oo]bj/ 32 | [Ll]og/ 33 | [Ll]ogs/ 34 | 35 | # Visual Studio 2015/2017 cache/options directory 36 | .vs/ 37 | # Uncomment if you have tasks that create the project's static files in wwwroot 38 | #wwwroot/ 39 | 40 | # Visual Studio 2017 auto generated files 41 | Generated\ Files/ 42 | 43 | # MSTest test Results 44 | [Tt]est[Rr]esult*/ 45 | [Bb]uild[Ll]og.* 46 | 47 | # NUnit 48 | *.VisualState.xml 49 | TestResult.xml 50 | nunit-*.xml 51 | 52 | # Build Results of an ATL Project 53 | [Dd]ebugPS/ 54 | [Rr]eleasePS/ 55 | dlldata.c 56 | 57 | # Benchmark Results 58 | BenchmarkDotNet.Artifacts/ 59 | 60 | # .NET Core 61 | project.lock.json 62 | project.fragment.lock.json 63 | artifacts/ 64 | 65 | # ASP.NET Scaffolding 66 | ScaffoldingReadMe.txt 67 | 68 | # StyleCop 69 | StyleCopReport.xml 70 | 71 | # Files built by Visual Studio 72 | *_i.c 73 | *_p.c 74 | *_h.h 75 | *.ilk 76 | *.meta 77 | *.obj 78 | *.iobj 79 | *.pch 80 | *.pdb 81 | *.ipdb 82 | *.pgc 83 | *.pgd 84 | *.rsp 85 | *.sbr 86 | *.tlb 87 | *.tli 88 | *.tlh 89 | *.tmp 90 | *.tmp_proj 91 | *_wpftmp.csproj 92 | *.log 93 | *.tlog 94 | *.vspscc 95 | *.vssscc 96 | .builds 97 | *.pidb 98 | *.svclog 99 | *.scc 100 | 101 | # Chutzpah Test files 102 | _Chutzpah* 103 | 104 | # Visual C++ cache files 105 | ipch/ 106 | *.aps 107 | *.ncb 108 | *.opendb 109 | *.opensdf 110 | *.sdf 111 | *.cachefile 112 | *.VC.db 113 | *.VC.VC.opendb 114 | 115 | # Visual Studio profiler 116 | *.psess 117 | *.vsp 118 | *.vspx 119 | *.sap 120 | 121 | # Visual Studio Trace Files 122 | *.e2e 123 | 124 | # TFS 2012 Local Workspace 125 | $tf/ 126 | 127 | # Guidance Automation Toolkit 128 | *.gpState 129 | 130 | # ReSharper is a .NET coding add-in 131 | _ReSharper*/ 132 | *.[Rr]e[Ss]harper 133 | *.DotSettings.user 134 | 135 | # TeamCity is a build add-in 136 | _TeamCity* 137 | 138 | # DotCover is a Code Coverage Tool 139 | *.dotCover 140 | 141 | # AxoCover is a Code Coverage Tool 142 | .axoCover/* 143 | !.axoCover/settings.json 144 | 145 | # Coverlet is a free, cross platform Code Coverage Tool 146 | coverage*.json 147 | coverage*.xml 148 | coverage*.info 149 | 150 | # Visual Studio code coverage results 151 | *.coverage 152 | *.coveragexml 153 | 154 | # NCrunch 155 | _NCrunch_* 156 | .*crunch*.local.xml 157 | nCrunchTemp_* 158 | 159 | # MightyMoose 160 | *.mm.* 161 | AutoTest.Net/ 162 | 163 | # Web workbench (sass) 164 | .sass-cache/ 165 | 166 | # Installshield output folder 167 | [Ee]xpress/ 168 | 169 | # DocProject is a documentation generator add-in 170 | DocProject/buildhelp/ 171 | DocProject/Help/*.HxT 172 | DocProject/Help/*.HxC 173 | DocProject/Help/*.hhc 174 | DocProject/Help/*.hhk 175 | DocProject/Help/*.hhp 176 | DocProject/Help/Html2 177 | DocProject/Help/html 178 | 179 | # Click-Once directory 180 | publish/ 181 | 182 | # Publish Web Output 183 | *.[Pp]ublish.xml 184 | *.azurePubxml 185 | # Note: Comment the next line if you want to checkin your web deploy settings, 186 | # but database connection strings (with potential passwords) will be unencrypted 187 | *.pubxml 188 | *.publishproj 189 | 190 | # Microsoft Azure Web App publish settings. Comment the next line if you want to 191 | # checkin your Azure Web App publish settings, but sensitive information contained 192 | # in these scripts will be unencrypted 193 | PublishScripts/ 194 | 195 | # NuGet Packages 196 | *.nupkg 197 | # NuGet Symbol Packages 198 | *.snupkg 199 | # The packages folder can be ignored because of Package Restore 200 | **/[Pp]ackages/* 201 | # except build/, which is used as an MSBuild target. 202 | !**/[Pp]ackages/build/ 203 | # Uncomment if necessary however generally it will be regenerated when needed 204 | #!**/[Pp]ackages/repositories.config 205 | # NuGet v3's project.json files produces more ignorable files 206 | *.nuget.props 207 | *.nuget.targets 208 | 209 | # Microsoft Azure Build Output 210 | csx/ 211 | *.build.csdef 212 | 213 | # Microsoft Azure Emulator 214 | ecf/ 215 | rcf/ 216 | 217 | # Windows Store app package directories and files 218 | AppPackages/ 219 | BundleArtifacts/ 220 | Package.StoreAssociation.xml 221 | _pkginfo.txt 222 | *.appx 223 | *.appxbundle 224 | *.appxupload 225 | 226 | # Visual Studio cache files 227 | # files ending in .cache can be ignored 228 | *.[Cc]ache 229 | # but keep track of directories ending in .cache 230 | !?*.[Cc]ache/ 231 | 232 | # Others 233 | ClientBin/ 234 | ~$* 235 | *~ 236 | *.dbmdl 237 | *.dbproj.schemaview 238 | *.jfm 239 | *.pfx 240 | *.publishsettings 241 | orleans.codegen.cs 242 | 243 | # Including strong name files can present a security risk 244 | # (https://github.com/github/gitignore/pull/2483#issue-259490424) 245 | #*.snk 246 | 247 | # Since there are multiple workflows, uncomment next line to ignore bower_components 248 | # (https://github.com/github/gitignore/pull/1529#issuecomment-104372622) 249 | #bower_components/ 250 | 251 | # RIA/Silverlight projects 252 | Generated_Code/ 253 | 254 | # Backup & report files from converting an old project file 255 | # to a newer Visual Studio version. Backup files are not needed, 256 | # because we have git ;-) 257 | _UpgradeReport_Files/ 258 | Backup*/ 259 | UpgradeLog*.XML 260 | UpgradeLog*.htm 261 | ServiceFabricBackup/ 262 | *.rptproj.bak 263 | 264 | # SQL Server files 265 | *.mdf 266 | *.ldf 267 | *.ndf 268 | 269 | # Business Intelligence projects 270 | *.rdl.data 271 | *.bim.layout 272 | *.bim_*.settings 273 | *.rptproj.rsuser 274 | *- [Bb]ackup.rdl 275 | *- [Bb]ackup ([0-9]).rdl 276 | *- [Bb]ackup ([0-9][0-9]).rdl 277 | 278 | # Microsoft Fakes 279 | FakesAssemblies/ 280 | 281 | # GhostDoc plugin setting file 282 | *.GhostDoc.xml 283 | 284 | # Node.js Tools for Visual Studio 285 | .ntvs_analysis.dat 286 | node_modules/ 287 | 288 | # Visual Studio 6 build log 289 | *.plg 290 | 291 | # Visual Studio 6 workspace options file 292 | *.opt 293 | 294 | # Visual Studio 6 auto-generated workspace file (contains which files were open etc.) 295 | *.vbw 296 | 297 | # Visual Studio 6 auto-generated project file (contains which files were open etc.) 298 | *.vbp 299 | 300 | # Visual Studio 6 workspace and project file (working project files containing files to include in project) 301 | *.dsw 302 | *.dsp 303 | 304 | # Visual Studio 6 technical files 305 | *.ncb 306 | *.aps 307 | 308 | # Visual Studio LightSwitch build output 309 | **/*.HTMLClient/GeneratedArtifacts 310 | **/*.DesktopClient/GeneratedArtifacts 311 | **/*.DesktopClient/ModelManifest.xml 312 | **/*.Server/GeneratedArtifacts 313 | **/*.Server/ModelManifest.xml 314 | _Pvt_Extensions 315 | 316 | # Paket dependency manager 317 | .paket/paket.exe 318 | paket-files/ 319 | 320 | # FAKE - F# Make 321 | .fake/ 322 | 323 | # CodeRush personal settings 324 | .cr/personal 325 | 326 | # Python Tools for Visual Studio (PTVS) 327 | __pycache__/ 328 | *.pyc 329 | 330 | # Cake - Uncomment if you are using it 331 | # tools/** 332 | # !tools/packages.config 333 | 334 | # Tabs Studio 335 | *.tss 336 | 337 | # Telerik's JustMock configuration file 338 | *.jmconfig 339 | 340 | # BizTalk build output 341 | *.btp.cs 342 | *.btm.cs 343 | *.odx.cs 344 | *.xsd.cs 345 | 346 | # OpenCover UI analysis results 347 | OpenCover/ 348 | 349 | # Azure Stream Analytics local run output 350 | ASALocalRun/ 351 | 352 | # MSBuild Binary and Structured Log 353 | *.binlog 354 | 355 | # NVidia Nsight GPU debugger configuration file 356 | *.nvuser 357 | 358 | # MFractors (Xamarin productivity tool) working folder 359 | .mfractor/ 360 | 361 | # Local History for Visual Studio 362 | .localhistory/ 363 | 364 | # Visual Studio History (VSHistory) files 365 | .vshistory/ 366 | 367 | # BeatPulse healthcheck temp database 368 | healthchecksdb 369 | 370 | # Backup folder for Package Reference Convert tool in Visual Studio 2017 371 | MigrationBackup/ 372 | 373 | # Ionide (cross platform F# VS Code tools) working folder 374 | .ionide/ 375 | 376 | # Fody - auto-generated XML schema 377 | FodyWeavers.xsd 378 | 379 | # VS Code files for those working on multiple tools 380 | .vscode/* 381 | !.vscode/settings.json 382 | !.vscode/tasks.json 383 | !.vscode/launch.json 384 | !.vscode/extensions.json 385 | *.code-workspace 386 | 387 | # Local History for Visual Studio Code 388 | .history/ 389 | 390 | # Windows Installer files from build outputs 391 | *.cab 392 | *.msi 393 | *.msix 394 | *.msm 395 | *.msp 396 | 397 | # JetBrains Rider 398 | *.sln.iml 399 | 400 | .DS_Store 401 | *.DS_Store 402 | 403 | .build/ 404 | .swiftpm -------------------------------------------------------------------------------- /objectivec/ort_value.mm: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | 4 | #import "ort_value_internal.h" 5 | 6 | #include 7 | 8 | #import "cxx_api.h" 9 | #import "error_utils.h" 10 | #import "ort_enums_internal.h" 11 | 12 | NS_ASSUME_NONNULL_BEGIN 13 | 14 | namespace { 15 | 16 | ORTTensorTypeAndShapeInfo* CXXAPIToPublicTensorTypeAndShapeInfo( 17 | const Ort::ConstTensorTypeAndShapeInfo& CXXAPITensorTypeAndShapeInfo) { 18 | auto* result = [[ORTTensorTypeAndShapeInfo alloc] init]; 19 | const auto elementType = CXXAPITensorTypeAndShapeInfo.GetElementType(); 20 | const std::vector shape = CXXAPITensorTypeAndShapeInfo.GetShape(); 21 | 22 | result.elementType = CAPIToPublicTensorElementType(elementType); 23 | auto* shapeArray = [[NSMutableArray alloc] initWithCapacity:shape.size()]; 24 | for (size_t i = 0; i < shape.size(); ++i) { 25 | shapeArray[i] = @(shape[i]); 26 | } 27 | result.shape = shapeArray; 28 | 29 | return result; 30 | } 31 | 32 | ORTValueTypeInfo* CXXAPIToPublicValueTypeInfo( 33 | const Ort::TypeInfo& CXXAPITypeInfo) { 34 | auto* result = [[ORTValueTypeInfo alloc] init]; 35 | const auto valueType = CXXAPITypeInfo.GetONNXType(); 36 | 37 | result.type = CAPIToPublicValueType(valueType); 38 | 39 | if (valueType == ONNX_TYPE_TENSOR) { 40 | const auto tensorTypeAndShapeInfo = CXXAPITypeInfo.GetTensorTypeAndShapeInfo(); 41 | result.tensorTypeAndShapeInfo = CXXAPIToPublicTensorTypeAndShapeInfo(tensorTypeAndShapeInfo); 42 | } 43 | 44 | return result; 45 | } 46 | 47 | // out = a * b 48 | // returns true iff the result does not overflow 49 | bool SafeMultiply(size_t a, size_t b, size_t& out) { 50 | return !__builtin_mul_overflow(a, b, &out); 51 | } 52 | 53 | } // namespace 54 | 55 | @interface ORTValue () 56 | 57 | // pointer to any external tensor data to keep alive for the lifetime of the ORTValue 58 | @property(nonatomic, nullable) NSMutableData* externalTensorData; 59 | 60 | @end 61 | 62 | @implementation ORTValue { 63 | std::optional _value; 64 | std::optional _typeInfo; 65 | } 66 | 67 | #pragma mark - Public 68 | 69 | - (nullable instancetype)initWithTensorData:(NSMutableData*)tensorData 70 | elementType:(ORTTensorElementDataType)elementType 71 | shape:(NSArray*)shape 72 | error:(NSError**)error { 73 | try { 74 | if (elementType == ORTTensorElementDataTypeString) { 75 | ORT_CXX_API_THROW( 76 | "ORTTensorElementDataTypeString element type provided. " 77 | "Please call initWithTensorStringData:shape:error: instead to create an ORTValue with string data.", 78 | ORT_INVALID_ARGUMENT); 79 | } 80 | const auto memoryInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); 81 | const auto ONNXElementType = PublicToCAPITensorElementType(elementType); 82 | const auto shapeVector = [shape]() { 83 | std::vector result{}; 84 | result.reserve(shape.count); 85 | for (NSNumber* dim in shape) { 86 | result.push_back(dim.longLongValue); 87 | } 88 | return result; 89 | }(); 90 | Ort::Value ortValue = Ort::Value::CreateTensor( 91 | memoryInfo, tensorData.mutableBytes, tensorData.length, 92 | shapeVector.data(), shapeVector.size(), ONNXElementType); 93 | 94 | return [self initWithCXXAPIOrtValue:std::move(ortValue) 95 | externalTensorData:tensorData 96 | error:error]; 97 | } 98 | ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error) 99 | } 100 | 101 | - (nullable instancetype)initWithTensorStringData:(NSArray*)tensorStringData 102 | shape:(NSArray*)shape 103 | error:(NSError**)error { 104 | try { 105 | Ort::AllocatorWithDefaultOptions allocator; 106 | size_t tensorSize = 1U; 107 | const auto shapeVector = [&tensorSize, shape]() { 108 | std::vector result{}; 109 | result.reserve(shape.count); 110 | for (NSNumber* dim in shape) { 111 | const auto dimValue = dim.longLongValue; 112 | if (dimValue < 0 || !SafeMultiply(static_cast(dimValue), tensorSize, tensorSize)) { 113 | ORT_CXX_API_THROW("Failed to compute the tensor size.", ORT_RUNTIME_EXCEPTION); 114 | } 115 | result.push_back(dimValue); 116 | } 117 | return result; 118 | }(); 119 | 120 | if (tensorSize != [tensorStringData count]) { 121 | ORT_CXX_API_THROW( 122 | "Computed tensor size does not equal the length of the provided tensor string data.", 123 | ORT_INVALID_ARGUMENT); 124 | } 125 | 126 | Ort::Value ortValue = Ort::Value::CreateTensor( 127 | allocator, shapeVector.data(), shapeVector.size(), ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING); 128 | 129 | size_t index = 0; 130 | for (NSString* stringData in tensorStringData) { 131 | ortValue.FillStringTensorElement([stringData UTF8String], index++); 132 | } 133 | 134 | return [self initWithCXXAPIOrtValue:std::move(ortValue) 135 | externalTensorData:nil 136 | error:error]; 137 | } 138 | ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error) 139 | } 140 | 141 | - (nullable ORTValueTypeInfo*)typeInfoWithError:(NSError**)error { 142 | try { 143 | return CXXAPIToPublicValueTypeInfo(*_typeInfo); 144 | } 145 | ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error) 146 | } 147 | 148 | - (nullable ORTTensorTypeAndShapeInfo*)tensorTypeAndShapeInfoWithError:(NSError**)error { 149 | try { 150 | const auto tensorTypeAndShapeInfo = _typeInfo->GetTensorTypeAndShapeInfo(); 151 | if (!tensorTypeAndShapeInfo) { 152 | ORT_CXX_API_THROW("ORTValue is not a tensor.", ORT_RUNTIME_EXCEPTION); 153 | } 154 | return CXXAPIToPublicTensorTypeAndShapeInfo(tensorTypeAndShapeInfo); 155 | } 156 | ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error) 157 | } 158 | 159 | - (nullable NSMutableData*)tensorDataWithError:(NSError**)error { 160 | try { 161 | const auto tensorTypeAndShapeInfo = _typeInfo->GetTensorTypeAndShapeInfo(); 162 | if (!tensorTypeAndShapeInfo) { 163 | ORT_CXX_API_THROW("ORTValue is not a tensor.", ORT_RUNTIME_EXCEPTION); 164 | } 165 | if (tensorTypeAndShapeInfo.GetElementType() == ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) { 166 | ORT_CXX_API_THROW( 167 | "This ORTValue holds string data. Please call tensorStringDataWithError: " 168 | "instead to retrieve the string data from this ORTValue.", 169 | ORT_RUNTIME_EXCEPTION); 170 | } 171 | const size_t elementCount = tensorTypeAndShapeInfo.GetElementCount(); 172 | const size_t elementSize = SizeOfCAPITensorElementType(tensorTypeAndShapeInfo.GetElementType()); 173 | size_t rawDataLength; 174 | if (!SafeMultiply(elementCount, elementSize, rawDataLength)) { 175 | ORT_CXX_API_THROW("failed to compute tensor data length", ORT_RUNTIME_EXCEPTION); 176 | } 177 | 178 | void* rawData; 179 | Ort::ThrowOnError(Ort::GetApi().GetTensorMutableData(*_value, &rawData)); 180 | 181 | return [NSMutableData dataWithBytesNoCopy:rawData 182 | length:rawDataLength 183 | freeWhenDone:NO]; 184 | } 185 | ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error) 186 | } 187 | 188 | - (nullable NSArray*)tensorStringDataWithError:(NSError**)error { 189 | try { 190 | const auto tensorTypeAndShapeInfo = _typeInfo->GetTensorTypeAndShapeInfo(); 191 | if (!tensorTypeAndShapeInfo) { 192 | ORT_CXX_API_THROW("ORTValue is not a tensor.", ORT_RUNTIME_EXCEPTION); 193 | } 194 | const size_t elementCount = tensorTypeAndShapeInfo.GetElementCount(); 195 | const size_t tensorStringDataLength = _value->GetStringTensorDataLength(); 196 | std::vector tensorStringData(tensorStringDataLength, '\0'); 197 | std::vector offsets(elementCount); 198 | _value->GetStringTensorContent(tensorStringData.data(), tensorStringDataLength, 199 | offsets.data(), offsets.size()); 200 | 201 | NSMutableArray* result = [NSMutableArray arrayWithCapacity:elementCount]; 202 | for (size_t idx = 0; idx < elementCount; ++idx) { 203 | const size_t strLength = (idx == elementCount - 1) ? tensorStringDataLength - offsets[idx] 204 | : offsets[idx + 1] - offsets[idx]; 205 | [result addObject:[[NSString alloc] initWithBytes:tensorStringData.data() + offsets[idx] 206 | length:strLength 207 | encoding:NSUTF8StringEncoding]]; 208 | } 209 | return result; 210 | } 211 | ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error) 212 | } 213 | 214 | #pragma mark - Internal 215 | 216 | - (nullable instancetype)initWithCXXAPIOrtValue:(Ort::Value&&)existingCXXAPIOrtValue 217 | externalTensorData:(nullable NSMutableData*)externalTensorData 218 | error:(NSError**)error { 219 | if ((self = [super init]) == nil) { 220 | return nil; 221 | } 222 | 223 | try { 224 | _typeInfo = existingCXXAPIOrtValue.GetTypeInfo(); 225 | _externalTensorData = externalTensorData; 226 | 227 | // transfer C++ Ort::Value ownership to this instance 228 | _value = std::move(existingCXXAPIOrtValue); 229 | return self; 230 | } 231 | ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error); 232 | } 233 | 234 | - (Ort::Value&)CXXAPIOrtValue { 235 | return *_value; 236 | } 237 | 238 | @end 239 | 240 | @implementation ORTValueTypeInfo 241 | @end 242 | 243 | @implementation ORTTensorTypeAndShapeInfo 244 | @end 245 | 246 | NS_ASSUME_NONNULL_END 247 | -------------------------------------------------------------------------------- /objectivec/include/ort_training_session.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | 4 | #import 5 | #include 6 | 7 | NS_ASSUME_NONNULL_BEGIN 8 | 9 | @class ORTCheckpoint; 10 | @class ORTEnv; 11 | @class ORTValue; 12 | @class ORTSessionOptions; 13 | 14 | /** 15 | * Trainer class that provides methods to train, evaluate and optimize ONNX models. 16 | * 17 | * The training session requires four training artifacts: 18 | * 1. Training onnx model 19 | * 2. Evaluation onnx model (optional) 20 | * 3. Optimizer onnx model 21 | * 4. Checkpoint directory 22 | * 23 | * [onnxruntime-training python utility](https://github.com/microsoft/onnxruntime/blob/main/orttraining/orttraining/python/training/onnxblock/README.md) 24 | * can be used to generate above training artifacts. 25 | * 26 | * Available since 1.16. 27 | * 28 | * @note This class is only available when the training APIs are enabled. 29 | */ 30 | @interface ORTTrainingSession : NSObject 31 | 32 | - (instancetype)init NS_UNAVAILABLE; 33 | 34 | /** 35 | * Creates a training session from the training artifacts that can be used to begin or resume training. 36 | * 37 | * The initializer instantiates the training session based on provided env and session options, which can be used to 38 | * begin or resume training from a given checkpoint state. The checkpoint state represents the parameters of training 39 | * session which will be moved to the device specified in the session option if needed. 40 | * 41 | * @param env The `ORTEnv` instance to use for the training session. 42 | * @param sessionOptions The optional `ORTSessionOptions` to use for the training session. 43 | * @param checkpoint Training states that are used as a starting point for training. 44 | * @param trainModelPath The path to the training onnx model. 45 | * @param evalModelPath The path to the evaluation onnx model. 46 | * @param optimizerModelPath The path to the optimizer onnx model used to perform gradient descent. 47 | * @param error Optional error information set if an error occurs. 48 | * @return The instance, or nil if an error occurs. 49 | * 50 | * @note Note that the training session created with a checkpoint state uses this state to store the entire training 51 | * state (including model parameters, its gradients, the optimizer states and the properties). The training session 52 | * keeps a strong (owning) pointer to the checkpoint state. 53 | */ 54 | - (nullable instancetype)initWithEnv:(ORTEnv*)env 55 | sessionOptions:(nullable ORTSessionOptions*)sessionOptions 56 | checkpoint:(ORTCheckpoint*)checkpoint 57 | trainModelPath:(NSString*)trainModelPath 58 | evalModelPath:(nullable NSString*)evalModelPath 59 | optimizerModelPath:(nullable NSString*)optimizerModelPath 60 | error:(NSError**)error NS_DESIGNATED_INITIALIZER; 61 | 62 | /** 63 | * Performs a training step, which is equivalent to a forward and backward propagation in a single step. 64 | * 65 | * The training step computes the outputs of the training model and the gradients of the trainable parameters 66 | * for the given input values. The train step is performed based on the training model that was provided to the training session. 67 | * It is equivalent to running forward and backward propagation in a single step. The computed gradients are stored inside 68 | * the training session state so they can be later consumed by `optimizerStep`. The gradients can be lazily reset by 69 | * calling `lazyResetGrad` method. 70 | * 71 | * @param inputs The input values to the training model. 72 | * @param error Optional error information set if an error occurs. 73 | * @return The output values of the training model. 74 | */ 75 | - (nullable NSArray*)trainStepWithInputValues:(NSArray*)inputs 76 | error:(NSError**)error; 77 | 78 | /** 79 | * Performs a evaluation step that computes the outputs of the evaluation model for the given inputs. 80 | * The eval step is performed based on the evaluation model that was provided to the training session. 81 | * 82 | * @param inputs The input values to the eval model. 83 | * @param error Optional error information set if an error occurs. 84 | * @return The output values of the eval model. 85 | * 86 | */ 87 | - (nullable NSArray*)evalStepWithInputValues:(NSArray*)inputs 88 | error:(NSError**)error; 89 | 90 | /** 91 | * Reset the gradients of all trainable parameters to zero lazily. 92 | * 93 | * Calling this method sets the internal state of the training session such that the gradients of the trainable parameters 94 | * in the ORTCheckpoint will be scheduled to be reset just before the new gradients are computed on the next 95 | * invocation of the `trainStep` method. 96 | * 97 | * @param error Optional error information set if an error occurs. 98 | * @return YES if the gradients are set to reset successfully, NO otherwise. 99 | */ 100 | - (BOOL)lazyResetGradWithError:(NSError**)error; 101 | 102 | /** 103 | * Performs the weight updates for the trainable parameters using the optimizer model. The optimizer step is performed 104 | * based on the optimizer model that was provided to the training session. The updated parameters are stored inside the 105 | * training state so that they can be used by the next `trainStep` method call. 106 | * 107 | * @param error Optional error information set if an error occurs. 108 | * @return YES if the optimizer step was performed successfully, NO otherwise. 109 | */ 110 | - (BOOL)optimizerStepWithError:(NSError**)error; 111 | 112 | /** 113 | * Returns the names of the user inputs for the training model that can be associated with 114 | * the `ORTValue` provided to the `trainStep`. 115 | * 116 | * @param error Optional error information set if an error occurs. 117 | * @return The names of the user inputs for the training model. 118 | */ 119 | - (nullable NSArray*)getTrainInputNamesWithError:(NSError**)error; 120 | 121 | /** 122 | * Returns the names of the user inputs for the evaluation model that can be associated with 123 | * the `ORTValue` provided to the `evalStep`. 124 | * 125 | * @param error Optional error information set if an error occurs. 126 | * @return The names of the user inputs for the evaluation model. 127 | */ 128 | - (nullable NSArray*)getEvalInputNamesWithError:(NSError**)error; 129 | 130 | /** 131 | * Returns the names of the user outputs for the training model that can be associated with 132 | * the `ORTValue` returned by the `trainStep`. 133 | * 134 | * @param error Optional error information set if an error occurs. 135 | * @return The names of the user outputs for the training model. 136 | */ 137 | - (nullable NSArray*)getTrainOutputNamesWithError:(NSError**)error; 138 | 139 | /** 140 | * Returns the names of the user outputs for the evaluation model that can be associated with 141 | * the `ORTValue` returned by the `evalStep`. 142 | * 143 | * @param error Optional error information set if an error occurs. 144 | * @return The names of the user outputs for the evaluation model. 145 | */ 146 | - (nullable NSArray*)getEvalOutputNamesWithError:(NSError**)error; 147 | 148 | /** 149 | * Registers a linear learning rate scheduler for the training session. 150 | * 151 | * The scheduler gradually decreases the learning rate from the initial value to zero over the course of the training. 152 | * The decrease is performed by multiplying the current learning rate by a linearly updated factor. 153 | * Before the decrease, the learning rate is gradually increased from zero to the initial value during a warmup phase. 154 | * 155 | * @param warmupStepCount The number of steps to perform the linear warmup. 156 | * @param totalStepCount The total number of steps to perform the linear decay. 157 | * @param initialLr The initial learning rate. 158 | * @param error Optional error information set if an error occurs. 159 | * @return YES if the scheduler was registered successfully, NO otherwise. 160 | */ 161 | - (BOOL)registerLinearLRSchedulerWithWarmupStepCount:(int64_t)warmupStepCount 162 | totalStepCount:(int64_t)totalStepCount 163 | initialLr:(float)initialLr 164 | error:(NSError**)error; 165 | 166 | /** 167 | * Update the learning rate based on the registered learning rate scheduler. 168 | * 169 | * Performs a scheduler step that updates the learning rate that is being used by the training session. 170 | * This function should typically be called before invoking the optimizer step for each round, or as necessary 171 | * to update the learning rate being used by the training session. 172 | * 173 | * @note A valid predefined learning rate scheduler must be first registered to invoke this method. 174 | * 175 | * @param error Optional error information set if an error occurs. 176 | * @return YES if the scheduler step was performed successfully, NO otherwise. 177 | */ 178 | - (BOOL)schedulerStepWithError:(NSError**)error; 179 | 180 | /** 181 | * Returns the current learning rate being used by the training session. 182 | * 183 | * @param error Optional error information set if an error occurs. 184 | * @return The current learning rate or 0.0f if an error occurs. 185 | */ 186 | - (float)getLearningRateWithError:(NSError**)error __attribute__((swift_error(nonnull_error))); 187 | 188 | /** 189 | * Sets the learning rate being used by the training session. 190 | * 191 | * The current learning rate is maintained by the training session and can be overwritten by invoking this method 192 | * with the desired learning rate. This function should not be used when a valid learning rate scheduler is registered. 193 | * It should be used either to set the learning rate derived from a custom learning rate scheduler or to set a constant 194 | * learning rate to be used throughout the training session. 195 | * 196 | * @note It does not set the initial learning rate that may be needed by the predefined learning rate schedulers. 197 | * To set the initial learning rate for learning rate schedulers, use the `registerLinearLRScheduler` method. 198 | * 199 | * @param lr The learning rate to be used by the training session. 200 | * @param error Optional error information set if an error occurs. 201 | * @return YES if the learning rate was set successfully, NO otherwise. 202 | */ 203 | - (BOOL)setLearningRate:(float)lr 204 | error:(NSError**)error; 205 | 206 | /** 207 | * Loads the training session model parameters from a contiguous buffer. 208 | * 209 | * @param buffer Contiguous buffer to load the parameters from. 210 | * @param error Optional error information set if an error occurs. 211 | * @return YES if the parameters were loaded successfully, NO otherwise. 212 | */ 213 | - (BOOL)fromBufferWithValue:(ORTValue*)buffer 214 | error:(NSError**)error; 215 | 216 | /** 217 | * Returns a contiguous buffer that holds a copy of all training state parameters. 218 | * 219 | * @param onlyTrainable If YES, returns a buffer that holds only the trainable parameters, otherwise returns a buffer 220 | * that holds all the parameters. 221 | * @param error Optional error information set if an error occurs. 222 | * @return A contiguous buffer that holds a copy of all training state parameters. 223 | */ 224 | - (nullable ORTValue*)toBufferWithTrainable:(BOOL)onlyTrainable 225 | error:(NSError**)error; 226 | 227 | /** 228 | * Exports the training session model that can be used for inference. 229 | * 230 | * If the training session was provided with an eval model, the training session can generate an inference model if it 231 | * knows the inference graph outputs. The input inference graph outputs are used to prune the eval model so that the 232 | * inference model's outputs align with the provided outputs. The exported model is saved at the path provided and 233 | * can be used for inferencing with `ORTSession`. 234 | * 235 | * @note The method reloads the eval model from the path provided to the initializer and expects this path to be valid. 236 | * 237 | * @param inferenceModelPath The path to the serialized the inference model. 238 | * @param graphOutputNames The names of the outputs that are needed in the inference model. 239 | * @param error Optional error information set if an error occurs. 240 | * @return YES if the inference model was exported successfully, NO otherwise. 241 | */ 242 | - (BOOL)exportModelForInferenceWithOutputPath:(NSString*)inferenceModelPath 243 | graphOutputNames:(NSArray*)graphOutputNames 244 | error:(NSError**)error; 245 | @end 246 | 247 | #ifdef __cplusplus 248 | extern "C" { 249 | #endif 250 | 251 | /** 252 | * This function sets the seed for generating random numbers. 253 | * Use this function to generate reproducible results. It should be noted that completely reproducible results are not guaranteed. 254 | * 255 | * @param seed Manually set seed to use for random number generation. 256 | */ 257 | void ORTSetSeed(int64_t seed); 258 | 259 | #ifdef __cplusplus 260 | } 261 | #endif 262 | 263 | NS_ASSUME_NONNULL_END 264 | -------------------------------------------------------------------------------- /objectivec/include/ort_session.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | 4 | #import 5 | 6 | #import "ort_custom_op_registration.h" 7 | #import "ort_enums.h" 8 | 9 | NS_ASSUME_NONNULL_BEGIN 10 | 11 | @class ORTEnv; 12 | @class ORTRunOptions; 13 | @class ORTSessionOptions; 14 | @class ORTValue; 15 | 16 | /** 17 | * An ORT session loads and runs a model. 18 | */ 19 | @interface ORTSession : NSObject 20 | 21 | - (instancetype)init NS_UNAVAILABLE; 22 | 23 | /** 24 | * Creates a session. 25 | * 26 | * @param env The ORT Environment instance. 27 | * @param path The path to the ONNX model. 28 | * @param sessionOptions Optional session configuration options. 29 | * @param error Optional error information set if an error occurs. 30 | * @return The instance, or nil if an error occurs. 31 | */ 32 | - (nullable instancetype)initWithEnv:(ORTEnv*)env 33 | modelPath:(NSString*)path 34 | sessionOptions:(nullable ORTSessionOptions*)sessionOptions 35 | error:(NSError**)error NS_DESIGNATED_INITIALIZER; 36 | 37 | /** 38 | * Runs the model. 39 | * The inputs and outputs are pre-allocated. 40 | * 41 | * @param inputs Dictionary of input names to input ORT values. 42 | * @param outputs Dictionary of output names to output ORT values. 43 | * @param runOptions Optional run configuration options. 44 | * @param error Optional error information set if an error occurs. 45 | * @return Whether the model was run successfully. 46 | */ 47 | - (BOOL)runWithInputs:(NSDictionary*)inputs 48 | outputs:(NSDictionary*)outputs 49 | runOptions:(nullable ORTRunOptions*)runOptions 50 | error:(NSError**)error; 51 | 52 | /** 53 | * Runs the model. 54 | * The inputs are pre-allocated and the outputs are allocated by ORT. 55 | * 56 | * @param inputs Dictionary of input names to input ORT values. 57 | * @param outputNames Set of output names. 58 | * @param runOptions Optional run configuration options. 59 | * @param error Optional error information set if an error occurs. 60 | * @return A dictionary of output names to output ORT values with the outputs 61 | * requested in `outputNames`, or nil if an error occurs. 62 | */ 63 | - (nullable NSDictionary*)runWithInputs:(NSDictionary*)inputs 64 | outputNames:(NSSet*)outputNames 65 | runOptions:(nullable ORTRunOptions*)runOptions 66 | error:(NSError**)error; 67 | 68 | /** 69 | * Gets the model's input names. 70 | * 71 | * @param error Optional error information set if an error occurs. 72 | * @return An array of input names, or nil if an error occurs. 73 | */ 74 | - (nullable NSArray*)inputNamesWithError:(NSError**)error; 75 | 76 | /** 77 | * Gets the model's overridable initializer names. 78 | * 79 | * @param error Optional error information set if an error occurs. 80 | * @return An array of overridable initializer names, or nil if an error occurs. 81 | */ 82 | - (nullable NSArray*)overridableInitializerNamesWithError:(NSError**)error; 83 | 84 | /** 85 | * Gets the model's output names. 86 | * 87 | * @param error Optional error information set if an error occurs. 88 | * @return An array of output names, or nil if an error occurs. 89 | */ 90 | - (nullable NSArray*)outputNamesWithError:(NSError**)error; 91 | 92 | @end 93 | 94 | /** 95 | * Options for configuring a session. 96 | */ 97 | @interface ORTSessionOptions : NSObject 98 | 99 | - (instancetype)init NS_UNAVAILABLE; 100 | 101 | /** 102 | * Creates session configuration options. 103 | * 104 | * @param error Optional error information set if an error occurs. 105 | * @return The instance, or nil if an error occurs. 106 | */ 107 | - (nullable instancetype)initWithError:(NSError**)error NS_SWIFT_NAME(init()); 108 | 109 | /** 110 | * Appends an execution provider to the session options to enable the execution provider to be used when running 111 | * the model. 112 | * 113 | * Available since 1.14. 114 | * 115 | * The execution provider list is ordered by decreasing priority. 116 | * i.e. the first provider registered has the highest priority. 117 | * 118 | * @param providerName Provider name. For example, "xnnpack". 119 | * @param providerOptions Provider-specific options. For example, for provider "xnnpack", {"intra_op_num_threads": "2"}. 120 | * @param error Optional error information set if an error occurs. 121 | * @return Whether the execution provider was appended successfully 122 | */ 123 | - (BOOL)appendExecutionProvider:(NSString*)providerName 124 | providerOptions:(NSDictionary*)providerOptions 125 | error:(NSError**)error; 126 | /** 127 | * Sets the number of threads used to parallelize the execution within nodes. 128 | * A value of 0 means ORT will pick a default value. 129 | * 130 | * @param intraOpNumThreads The number of threads. 131 | * @param error Optional error information set if an error occurs. 132 | * @return Whether the option was set successfully. 133 | */ 134 | - (BOOL)setIntraOpNumThreads:(int)intraOpNumThreads 135 | error:(NSError**)error; 136 | 137 | /** 138 | * Sets the graph optimization level. 139 | * 140 | * @param graphOptimizationLevel The graph optimization level. 141 | * @param error Optional error information set if an error occurs. 142 | * @return Whether the option was set successfully. 143 | */ 144 | - (BOOL)setGraphOptimizationLevel:(ORTGraphOptimizationLevel)graphOptimizationLevel 145 | error:(NSError**)error; 146 | 147 | /** 148 | * Sets the path to which the optimized model file will be saved. 149 | * 150 | * @param optimizedModelFilePath The optimized model file path. 151 | * @param error Optional error information set if an error occurs. 152 | * @return Whether the option was set successfully. 153 | */ 154 | - (BOOL)setOptimizedModelFilePath:(NSString*)optimizedModelFilePath 155 | error:(NSError**)error; 156 | 157 | /** 158 | * Sets the session log ID. 159 | * 160 | * @param logID The log ID. 161 | * @param error Optional error information set if an error occurs. 162 | * @return Whether the option was set successfully. 163 | */ 164 | - (BOOL)setLogID:(NSString*)logID 165 | error:(NSError**)error; 166 | 167 | /** 168 | * Sets the session log severity level. 169 | * 170 | * @param loggingLevel The log severity level. 171 | * @param error Optional error information set if an error occurs. 172 | * @return Whether the option was set successfully. 173 | */ 174 | - (BOOL)setLogSeverityLevel:(ORTLoggingLevel)loggingLevel 175 | error:(NSError**)error; 176 | 177 | /** 178 | * Sets a session configuration key-value pair. 179 | * Any value for a previously set key will be overwritten. 180 | * The session configuration keys and values are documented here: 181 | * https://github.com/microsoft/onnxruntime/blob/main/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h 182 | * 183 | * @param key The key. 184 | * @param value The value. 185 | * @param error Optional error information set if an error occurs. 186 | * @return Whether the option was set successfully. 187 | */ 188 | - (BOOL)addConfigEntryWithKey:(NSString*)key 189 | value:(NSString*)value 190 | error:(NSError**)error; 191 | 192 | /** 193 | * Registers custom ops for use with `ORTSession`s using this SessionOptions by calling the specified 194 | * native function name. The custom ops library must either be linked against, or have previously been loaded 195 | * by the user. 196 | * 197 | * Available since 1.14. 198 | * 199 | * The registration function must have the signature: 200 | * `OrtStatus* (*fn)(OrtSessionOptions* options, const OrtApiBase* api);` 201 | * 202 | * The signature is defined in the ONNX Runtime C API: 203 | * https://github.com/microsoft/onnxruntime/blob/67f4cd54fab321d83e4a75a40efeee95a6a17079/include/onnxruntime/core/session/onnxruntime_c_api.h#L697 204 | * 205 | * See https://onnxruntime.ai/docs/reference/operators/add-custom-op.html for more information on custom ops. 206 | * See https://github.com/microsoft/onnxruntime/blob/342a5bf2b756d1a1fc6fdc582cfeac15182632fe/onnxruntime/test/testdata/custom_op_library/custom_op_library.cc#L115 207 | * for an example of a custom op library registration function. 208 | * 209 | * @note The caller must ensure that `registrationFuncName` names a valid function that is visible to the native ONNX 210 | * Runtime code and has the correct signature. 211 | * They must ensure that the function does what they expect it to do because this method will just call it. 212 | * 213 | * @param registrationFuncName The name of the registration function to call. 214 | * @param error Optional error information set if an error occurs. 215 | * @return Whether the registration function was successfully called. 216 | */ 217 | - (BOOL)registerCustomOpsUsingFunction:(NSString*)registrationFuncName 218 | error:(NSError**)error; 219 | 220 | /** 221 | * Registers custom ops for use with `ORTSession`s using this SessionOptions by calling the specified function 222 | * pointed to by `registerCustomOpsFn`. 223 | * 224 | * Available since 1.16. 225 | * 226 | * The registration function must have the signature: 227 | * `OrtStatus* (*fn)(OrtSessionOptions* options, const OrtApiBase* api);` 228 | * 229 | * The signature is defined in the ONNX Runtime C API: 230 | * https://github.com/microsoft/onnxruntime/blob/67f4cd54fab321d83e4a75a40efeee95a6a17079/include/onnxruntime/core/session/onnxruntime_c_api.h#L697 231 | * 232 | * See https://onnxruntime.ai/docs/reference/operators/add-custom-op.html for more information on custom ops. 233 | * See https://github.com/microsoft/onnxruntime/blob/342a5bf2b756d1a1fc6fdc582cfeac15182632fe/onnxruntime/test/testdata/custom_op_library/custom_op_library.cc#L115 234 | * for an example of a custom op library registration function. 235 | * 236 | * @note The caller must ensure that `registerCustomOpsFn` is a valid function pointer and has the correct signature. 237 | * They must ensure that the function does what they expect it to do because this method will just call it. 238 | * 239 | * @param registerCustomOpsFn A pointer to the registration function to call. 240 | * @param error Optional error information set if an error occurs. 241 | * @return Whether the registration function was successfully called. 242 | */ 243 | - (BOOL)registerCustomOpsUsingFunctionPointer:(ORTCAPIRegisterCustomOpsFnPtr)registerCustomOpsFn 244 | error:(NSError**)error; 245 | 246 | /** 247 | * Registers ONNX Runtime Extensions custom ops that have been built in to ONNX Runtime. 248 | * 249 | * Available since 1.16. 250 | * 251 | * @note ONNX Runtime must have been built with the `--use_extensions` flag for the ONNX Runtime Extensions custom ops 252 | * to be able to be registered with this method. When using a separate ONNX Runtime Extensions library, use 253 | * `registerCustomOpsUsingFunctionPointer:error:` instead. 254 | * 255 | * @param error Optional error information set if an error occurs. 256 | * @return Whether the ONNX Runtime Extensions custom ops were successfully registered. 257 | */ 258 | - (BOOL)enableOrtExtensionsCustomOpsWithError:(NSError**)error; 259 | 260 | @end 261 | 262 | /** 263 | * Options for configuring a run. 264 | */ 265 | @interface ORTRunOptions : NSObject 266 | 267 | - (instancetype)init NS_UNAVAILABLE; 268 | 269 | /** 270 | * Creates run configuration options. 271 | * 272 | * @param error Optional error information set if an error occurs. 273 | * @return The instance, or nil if an error occurs. 274 | */ 275 | - (nullable instancetype)initWithError:(NSError**)error NS_SWIFT_NAME(init()); 276 | 277 | /** 278 | * Sets the run log tag. 279 | * 280 | * @param logTag The log tag. 281 | * @param error Optional error information set if an error occurs. 282 | * @return Whether the option was set successfully. 283 | */ 284 | - (BOOL)setLogTag:(NSString*)logTag 285 | error:(NSError**)error; 286 | 287 | /** 288 | * Sets the run log severity level. 289 | * 290 | * @param loggingLevel The log severity level. 291 | * @param error Optional error information set if an error occurs. 292 | * @return Whether the option was set successfully. 293 | */ 294 | - (BOOL)setLogSeverityLevel:(ORTLoggingLevel)loggingLevel 295 | error:(NSError**)error; 296 | 297 | /** 298 | * Sets a run configuration key-value pair. 299 | * Any value for a previously set key will be overwritten. 300 | * The run configuration keys and values are documented here: 301 | * https://github.com/microsoft/onnxruntime/blob/main/include/onnxruntime/core/session/onnxruntime_run_options_config_keys.h 302 | * 303 | * @param key The key. 304 | * @param value The value. 305 | * @param error Optional error information set if an error occurs. 306 | * @return Whether the option was set successfully. 307 | */ 308 | - (BOOL)addConfigEntryWithKey:(NSString*)key 309 | value:(NSString*)value 310 | error:(NSError**)error; 311 | 312 | @end 313 | 314 | NS_ASSUME_NONNULL_END 315 | -------------------------------------------------------------------------------- /objectivec/ort_session.mm: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | 4 | #import "ort_session_internal.h" 5 | 6 | #include 7 | #include 8 | 9 | #import "cxx_api.h" 10 | #import "cxx_utils.h" 11 | #import "error_utils.h" 12 | #import "ort_enums_internal.h" 13 | #import "ort_env_internal.h" 14 | #import "ort_value_internal.h" 15 | 16 | namespace { 17 | enum class NamedValueType { 18 | Input, 19 | OverridableInitializer, 20 | Output, 21 | }; 22 | } // namespace 23 | 24 | NS_ASSUME_NONNULL_BEGIN 25 | 26 | @implementation ORTSession { 27 | ORTEnv* _env; // keep a strong reference so the ORTEnv doesn't get destroyed before this does 28 | std::optional _session; 29 | } 30 | 31 | #pragma mark - Public 32 | 33 | - (nullable instancetype)initWithEnv:(ORTEnv*)env 34 | modelPath:(NSString*)path 35 | sessionOptions:(nullable ORTSessionOptions*)sessionOptions 36 | error:(NSError**)error { 37 | if ((self = [super init]) == nil) { 38 | return nil; 39 | } 40 | 41 | try { 42 | if (!sessionOptions) { 43 | sessionOptions = [[ORTSessionOptions alloc] initWithError:error]; 44 | if (!sessionOptions) { 45 | return nil; 46 | } 47 | } 48 | 49 | _env = env; 50 | _session = Ort::Session{[env CXXAPIOrtEnv], 51 | path.UTF8String, 52 | [sessionOptions CXXAPIOrtSessionOptions]}; 53 | 54 | return self; 55 | } 56 | ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error) 57 | } 58 | 59 | - (BOOL)runWithInputs:(NSDictionary*)inputs 60 | outputs:(NSDictionary*)outputs 61 | runOptions:(nullable ORTRunOptions*)runOptions 62 | error:(NSError**)error { 63 | try { 64 | if (!runOptions) { 65 | runOptions = [[ORTRunOptions alloc] initWithError:error]; 66 | if (!runOptions) { 67 | return NO; 68 | } 69 | } 70 | 71 | std::vector inputNames, outputNames; 72 | std::vector inputCAPIValues; 73 | std::vector outputCAPIValues; 74 | 75 | inputNames.reserve(inputs.count); 76 | inputCAPIValues.reserve(inputs.count); 77 | for (NSString* inputName in inputs) { 78 | inputNames.push_back(inputName.UTF8String); 79 | inputCAPIValues.push_back(static_cast([inputs[inputName] CXXAPIOrtValue])); 80 | } 81 | 82 | outputNames.reserve(outputs.count); 83 | outputCAPIValues.reserve(outputs.count); 84 | for (NSString* outputName in outputs) { 85 | outputNames.push_back(outputName.UTF8String); 86 | outputCAPIValues.push_back(static_cast([outputs[outputName] CXXAPIOrtValue])); 87 | } 88 | 89 | Ort::ThrowOnError(Ort::GetApi().Run(*_session, [runOptions CXXAPIOrtRunOptions], 90 | inputNames.data(), inputCAPIValues.data(), inputNames.size(), 91 | outputNames.data(), outputNames.size(), outputCAPIValues.data())); 92 | 93 | return YES; 94 | } 95 | ORT_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error) 96 | } 97 | 98 | - (nullable NSDictionary*)runWithInputs:(NSDictionary*)inputs 99 | outputNames:(NSSet*)outputNameSet 100 | runOptions:(nullable ORTRunOptions*)runOptions 101 | error:(NSError**)error { 102 | try { 103 | if (!runOptions) { 104 | runOptions = [[ORTRunOptions alloc] initWithError:error]; 105 | if (!runOptions) { 106 | return nil; 107 | } 108 | } 109 | 110 | NSArray* outputNameArray = outputNameSet.allObjects; 111 | 112 | std::vector inputNames, outputNames; 113 | std::vector inputCAPIValues; 114 | std::vector outputCAPIValues; 115 | 116 | inputNames.reserve(inputs.count); 117 | inputCAPIValues.reserve(inputs.count); 118 | for (NSString* inputName in inputs) { 119 | inputNames.push_back(inputName.UTF8String); 120 | inputCAPIValues.push_back(static_cast([inputs[inputName] CXXAPIOrtValue])); 121 | } 122 | 123 | outputNames.reserve(outputNameArray.count); 124 | outputCAPIValues.reserve(outputNameArray.count); 125 | for (NSString* outputName in outputNameArray) { 126 | outputNames.push_back(outputName.UTF8String); 127 | outputCAPIValues.push_back(nullptr); 128 | } 129 | 130 | Ort::ThrowOnError(Ort::GetApi().Run(*_session, [runOptions CXXAPIOrtRunOptions], 131 | inputNames.data(), inputCAPIValues.data(), inputNames.size(), 132 | outputNames.data(), outputNames.size(), outputCAPIValues.data())); 133 | 134 | NSMutableDictionary* outputs = [[NSMutableDictionary alloc] init]; 135 | for (NSUInteger i = 0; i < outputNameArray.count; ++i) { 136 | // Wrap the C OrtValue in a C++ Ort::Value to automatically handle its release. 137 | // Then, transfer that C++ Ort::Value to a new ORTValue. 138 | Ort::Value outputCXXAPIValue{outputCAPIValues[i]}; 139 | ORTValue* outputValue = [[ORTValue alloc] initWithCXXAPIOrtValue:std::move(outputCXXAPIValue) 140 | externalTensorData:nil 141 | error:error]; 142 | if (!outputValue) { 143 | // clean up remaining C OrtValues which haven't been wrapped by a C++ Ort::Value yet 144 | for (NSUInteger j = i + 1; j < outputNameArray.count; ++j) { 145 | Ort::GetApi().ReleaseValue(outputCAPIValues[j]); 146 | } 147 | return nil; 148 | } 149 | 150 | outputs[outputNameArray[i]] = outputValue; 151 | } 152 | 153 | return outputs; 154 | } 155 | ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error) 156 | } 157 | 158 | - (nullable NSArray*)inputNamesWithError:(NSError**)error { 159 | return [self namesWithType:NamedValueType::Input error:error]; 160 | } 161 | 162 | - (nullable NSArray*)overridableInitializerNamesWithError:(NSError**)error { 163 | return [self namesWithType:NamedValueType::OverridableInitializer error:error]; 164 | } 165 | 166 | - (nullable NSArray*)outputNamesWithError:(NSError**)error { 167 | return [self namesWithType:NamedValueType::Output error:error]; 168 | } 169 | 170 | #pragma mark - Private 171 | 172 | - (nullable NSArray*)namesWithType:(NamedValueType)namedValueType 173 | error:(NSError**)error { 174 | try { 175 | auto getCount = [&session = *_session, namedValueType]() { 176 | if (namedValueType == NamedValueType::Input) { 177 | return session.GetInputCount(); 178 | } else if (namedValueType == NamedValueType::OverridableInitializer) { 179 | return session.GetOverridableInitializerCount(); 180 | } else { 181 | return session.GetOutputCount(); 182 | } 183 | }; 184 | 185 | auto getName = [&session = *_session, namedValueType](size_t i, OrtAllocator* allocator) { 186 | if (namedValueType == NamedValueType::Input) { 187 | return session.GetInputNameAllocated(i, allocator); 188 | } else if (namedValueType == NamedValueType::OverridableInitializer) { 189 | return session.GetOverridableInitializerNameAllocated(i, allocator); 190 | } else { 191 | return session.GetOutputNameAllocated(i, allocator); 192 | } 193 | }; 194 | 195 | const size_t nameCount = getCount(); 196 | 197 | Ort::AllocatorWithDefaultOptions allocator; 198 | NSMutableArray* result = [NSMutableArray arrayWithCapacity:nameCount]; 199 | 200 | for (size_t i = 0; i < nameCount; ++i) { 201 | auto name = getName(i, allocator); 202 | NSString* nameNsstr = utils::toNSString(name.get()); 203 | [result addObject:nameNsstr]; 204 | } 205 | 206 | return result; 207 | } 208 | ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error) 209 | } 210 | 211 | @end 212 | 213 | @implementation ORTSessionOptions { 214 | std::optional _sessionOptions; 215 | } 216 | 217 | #pragma mark - Public 218 | 219 | - (nullable instancetype)initWithError:(NSError**)error { 220 | if ((self = [super init]) == nil) { 221 | return nil; 222 | } 223 | 224 | try { 225 | _sessionOptions = Ort::SessionOptions{}; 226 | return self; 227 | } 228 | ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error) 229 | } 230 | 231 | - (BOOL)appendExecutionProvider:(NSString*)providerName 232 | providerOptions:(NSDictionary*)providerOptions 233 | error:(NSError**)error { 234 | try { 235 | std::unordered_map options; 236 | NSArray* keys = [providerOptions allKeys]; 237 | 238 | for (NSString* key in keys) { 239 | NSString* value = [providerOptions objectForKey:key]; 240 | options.emplace(key.UTF8String, value.UTF8String); 241 | } 242 | 243 | _sessionOptions->AppendExecutionProvider(providerName.UTF8String, options); 244 | return YES; 245 | } 246 | ORT_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error); 247 | } 248 | 249 | - (BOOL)setIntraOpNumThreads:(int)intraOpNumThreads 250 | error:(NSError**)error { 251 | try { 252 | _sessionOptions->SetIntraOpNumThreads(intraOpNumThreads); 253 | return YES; 254 | } 255 | ORT_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error) 256 | } 257 | 258 | - (BOOL)setGraphOptimizationLevel:(ORTGraphOptimizationLevel)graphOptimizationLevel 259 | error:(NSError**)error { 260 | try { 261 | _sessionOptions->SetGraphOptimizationLevel( 262 | PublicToCAPIGraphOptimizationLevel(graphOptimizationLevel)); 263 | return YES; 264 | } 265 | ORT_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error) 266 | } 267 | 268 | - (BOOL)setOptimizedModelFilePath:(NSString*)optimizedModelFilePath 269 | error:(NSError**)error { 270 | try { 271 | _sessionOptions->SetOptimizedModelFilePath(optimizedModelFilePath.UTF8String); 272 | return YES; 273 | } 274 | ORT_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error) 275 | } 276 | 277 | - (BOOL)setLogID:(NSString*)logID 278 | error:(NSError**)error { 279 | try { 280 | _sessionOptions->SetLogId(logID.UTF8String); 281 | return YES; 282 | } 283 | ORT_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error) 284 | } 285 | 286 | - (BOOL)setLogSeverityLevel:(ORTLoggingLevel)loggingLevel 287 | error:(NSError**)error { 288 | try { 289 | _sessionOptions->SetLogSeverityLevel(PublicToCAPILoggingLevel(loggingLevel)); 290 | return YES; 291 | } 292 | ORT_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error) 293 | } 294 | 295 | - (BOOL)addConfigEntryWithKey:(NSString*)key 296 | value:(NSString*)value 297 | error:(NSError**)error { 298 | try { 299 | _sessionOptions->AddConfigEntry(key.UTF8String, value.UTF8String); 300 | return YES; 301 | } 302 | ORT_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error) 303 | } 304 | 305 | - (BOOL)registerCustomOpsUsingFunction:(NSString*)registrationFuncName 306 | error:(NSError**)error { 307 | try { 308 | _sessionOptions->RegisterCustomOpsUsingFunction(registrationFuncName.UTF8String); 309 | return YES; 310 | } 311 | ORT_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error) 312 | } 313 | 314 | - (BOOL)registerCustomOpsUsingFunctionPointer:(ORTCAPIRegisterCustomOpsFnPtr)registerCustomOpsFn 315 | error:(NSError**)error { 316 | try { 317 | if (!registerCustomOpsFn) { 318 | ORT_CXX_API_THROW("registerCustomOpsFn must not be null", ORT_INVALID_ARGUMENT); 319 | } 320 | Ort::ThrowOnError((*registerCustomOpsFn)(static_cast(*_sessionOptions), 321 | OrtGetApiBase())); 322 | return YES; 323 | } 324 | ORT_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error) 325 | } 326 | 327 | - (BOOL)enableOrtExtensionsCustomOpsWithError:(NSError**)error { 328 | try { 329 | _sessionOptions->EnableOrtCustomOps(); 330 | return YES; 331 | } 332 | ORT_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error) 333 | } 334 | 335 | #pragma mark - Internal 336 | 337 | - (Ort::SessionOptions&)CXXAPIOrtSessionOptions { 338 | return *_sessionOptions; 339 | } 340 | 341 | @end 342 | 343 | @implementation ORTRunOptions { 344 | std::optional _runOptions; 345 | } 346 | 347 | #pragma mark - Public 348 | 349 | - (nullable instancetype)initWithError:(NSError**)error { 350 | if ((self = [super init]) == nil) { 351 | return nil; 352 | } 353 | 354 | try { 355 | _runOptions = Ort::RunOptions{}; 356 | return self; 357 | } 358 | ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error) 359 | } 360 | 361 | - (BOOL)setLogTag:(NSString*)logTag 362 | error:(NSError**)error { 363 | try { 364 | _runOptions->SetRunTag(logTag.UTF8String); 365 | return YES; 366 | } 367 | ORT_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error) 368 | } 369 | 370 | - (BOOL)setLogSeverityLevel:(ORTLoggingLevel)loggingLevel 371 | error:(NSError**)error { 372 | try { 373 | _runOptions->SetRunLogSeverityLevel(PublicToCAPILoggingLevel(loggingLevel)); 374 | return YES; 375 | } 376 | ORT_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error) 377 | } 378 | 379 | - (BOOL)addConfigEntryWithKey:(NSString*)key 380 | value:(NSString*)value 381 | error:(NSError**)error { 382 | try { 383 | _runOptions->AddConfigEntry(key.UTF8String, value.UTF8String); 384 | return YES; 385 | } 386 | ORT_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error) 387 | } 388 | 389 | #pragma mark - Internal 390 | 391 | - (Ort::RunOptions&)CXXAPIOrtRunOptions { 392 | return *_runOptions; 393 | } 394 | 395 | @end 396 | 397 | NS_ASSUME_NONNULL_END 398 | -------------------------------------------------------------------------------- /objectivec/test/ort_session_test.mm: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | 4 | #import 5 | 6 | #import "ort_coreml_execution_provider.h" 7 | #import "ort_xnnpack_execution_provider.h" 8 | #import "ort_env.h" 9 | #import "ort_session.h" 10 | #import "ort_value.h" 11 | 12 | #import "test/assertion_utils.h" 13 | 14 | #include 15 | 16 | NS_ASSUME_NONNULL_BEGIN 17 | 18 | @interface ORTSessionTest : XCTestCase 19 | 20 | @property(readonly, nullable) ORTEnv* ortEnv; 21 | 22 | @end 23 | 24 | @implementation ORTSessionTest 25 | 26 | - (void)setUp { 27 | [super setUp]; 28 | 29 | self.continueAfterFailure = NO; 30 | 31 | NSError* err = nil; 32 | _ortEnv = [[ORTEnv alloc] initWithLoggingLevel:ORTLoggingLevelWarning 33 | error:&err]; 34 | ORTAssertNullableResultSuccessful(_ortEnv, err); 35 | } 36 | 37 | - (void)tearDown { 38 | _ortEnv = nil; 39 | 40 | [super tearDown]; 41 | } 42 | 43 | // model with an Add op 44 | // inputs: A, B 45 | // output: C = A + B 46 | + (NSString*)getAddModelPath { 47 | NSBundle* bundle = [NSBundle bundleForClass:[ORTSessionTest class]]; 48 | NSString* path = [bundle pathForResource:@"single_add.basic" 49 | ofType:@"ort"]; 50 | return path; 51 | } 52 | 53 | + (NSString*)getStringModelPath { 54 | NSBundle* bundle = [NSBundle bundleForClass:[ORTSessionTest class]]; 55 | NSString* path = [bundle pathForResource:@"identity_string" 56 | ofType:@"ort"]; 57 | return path; 58 | } 59 | 60 | + (NSMutableData*)dataWithScalarFloat:(float)value { 61 | NSMutableData* data = [[NSMutableData alloc] initWithBytes:&value length:sizeof(value)]; 62 | return data; 63 | } 64 | 65 | + (ORTValue*)ortValueWithScalarFloatData:(NSMutableData*)data { 66 | NSArray* shape = @[ @1 ]; 67 | NSError* err = nil; 68 | ORTValue* ortValue = [[ORTValue alloc] initWithTensorData:data 69 | elementType:ORTTensorElementDataTypeFloat 70 | shape:shape 71 | error:&err]; 72 | ORTAssertNullableResultSuccessful(ortValue, err); 73 | return ortValue; 74 | } 75 | 76 | + (ORTSessionOptions*)makeSessionOptions { 77 | NSError* err = nil; 78 | ORTSessionOptions* sessionOptions = [[ORTSessionOptions alloc] initWithError:&err]; 79 | ORTAssertNullableResultSuccessful(sessionOptions, err); 80 | return sessionOptions; 81 | } 82 | 83 | + (ORTRunOptions*)makeRunOptions { 84 | NSError* err = nil; 85 | ORTRunOptions* runOptions = [[ORTRunOptions alloc] initWithError:&err]; 86 | ORTAssertNullableResultSuccessful(runOptions, err); 87 | return runOptions; 88 | } 89 | 90 | - (void)testInitAndRunWithPreallocatedOutputOk { 91 | NSMutableData* aData = [ORTSessionTest dataWithScalarFloat:1.0f]; 92 | NSMutableData* bData = [ORTSessionTest dataWithScalarFloat:2.0f]; 93 | NSMutableData* cData = [ORTSessionTest dataWithScalarFloat:0.0f]; 94 | 95 | ORTValue* a = [ORTSessionTest ortValueWithScalarFloatData:aData]; 96 | ORTValue* b = [ORTSessionTest ortValueWithScalarFloatData:bData]; 97 | ORTValue* c = [ORTSessionTest ortValueWithScalarFloatData:cData]; 98 | 99 | NSError* err = nil; 100 | ORTSession* session = [[ORTSession alloc] initWithEnv:self.ortEnv 101 | modelPath:[ORTSessionTest getAddModelPath] 102 | sessionOptions:[ORTSessionTest makeSessionOptions] 103 | error:&err]; 104 | ORTAssertNullableResultSuccessful(session, err); 105 | 106 | BOOL runResult = [session runWithInputs:@{@"A" : a, @"B" : b} 107 | outputs:@{@"C" : c} 108 | runOptions:[ORTSessionTest makeRunOptions] 109 | error:&err]; 110 | ORTAssertBoolResultSuccessful(runResult, err); 111 | 112 | const float cExpected = 3.0f; 113 | float cActual; 114 | memcpy(&cActual, cData.bytes, sizeof(float)); 115 | XCTAssertEqual(cActual, cExpected); 116 | } 117 | 118 | - (void)testInitAndRunOk { 119 | NSMutableData* aData = [ORTSessionTest dataWithScalarFloat:1.0f]; 120 | NSMutableData* bData = [ORTSessionTest dataWithScalarFloat:2.0f]; 121 | 122 | ORTValue* a = [ORTSessionTest ortValueWithScalarFloatData:aData]; 123 | ORTValue* b = [ORTSessionTest ortValueWithScalarFloatData:bData]; 124 | 125 | NSError* err = nil; 126 | ORTSession* session = [[ORTSession alloc] initWithEnv:self.ortEnv 127 | modelPath:[ORTSessionTest getAddModelPath] 128 | sessionOptions:[ORTSessionTest makeSessionOptions] 129 | error:&err]; 130 | ORTAssertNullableResultSuccessful(session, err); 131 | 132 | NSDictionary* outputs = 133 | [session runWithInputs:@{@"A" : a, @"B" : b} 134 | outputNames:[NSSet setWithArray:@[ @"C" ]] 135 | runOptions:[ORTSessionTest makeRunOptions] 136 | error:&err]; 137 | ORTAssertNullableResultSuccessful(outputs, err); 138 | 139 | ORTValue* cOutput = outputs[@"C"]; 140 | XCTAssertNotNil(cOutput); 141 | 142 | NSData* cData = [cOutput tensorDataWithError:&err]; 143 | ORTAssertNullableResultSuccessful(cData, err); 144 | 145 | const float cExpected = 3.0f; 146 | float cActual; 147 | memcpy(&cActual, cData.bytes, sizeof(float)); 148 | XCTAssertEqual(cActual, cExpected); 149 | } 150 | 151 | - (void)testGetNamesOk { 152 | NSError* err = nil; 153 | ORTSession* session = [[ORTSession alloc] initWithEnv:self.ortEnv 154 | modelPath:[ORTSessionTest getAddModelPath] 155 | sessionOptions:[ORTSessionTest makeSessionOptions] 156 | error:&err]; 157 | ORTAssertNullableResultSuccessful(session, err); 158 | 159 | NSArray* inputNames = [session inputNamesWithError:&err]; 160 | ORTAssertNullableResultSuccessful(inputNames, err); 161 | XCTAssertEqualObjects(inputNames, (@[ @"A", @"B" ])); 162 | 163 | NSArray* overridableInitializerNames = [session overridableInitializerNamesWithError:&err]; 164 | ORTAssertNullableResultSuccessful(overridableInitializerNames, err); 165 | XCTAssertEqualObjects(overridableInitializerNames, (@[])); 166 | 167 | NSArray* outputNames = [session outputNamesWithError:&err]; 168 | ORTAssertNullableResultSuccessful(outputNames, err); 169 | XCTAssertEqualObjects(outputNames, (@[ @"C" ])); 170 | } 171 | 172 | - (void)testInitFailsWithInvalidPath { 173 | NSString* invalidModelPath = @"invalid/path/to/model.ort"; 174 | NSError* err = nil; 175 | ORTSession* session = [[ORTSession alloc] initWithEnv:self.ortEnv 176 | modelPath:invalidModelPath 177 | sessionOptions:[ORTSessionTest makeSessionOptions] 178 | error:&err]; 179 | ORTAssertNullableResultUnsuccessful(session, err); 180 | } 181 | 182 | - (void)testRunFailsWithInvalidInput { 183 | NSMutableData* dData = [ORTSessionTest dataWithScalarFloat:1.0f]; 184 | NSMutableData* cData = [ORTSessionTest dataWithScalarFloat:0.0f]; 185 | 186 | ORTValue* d = [ORTSessionTest ortValueWithScalarFloatData:dData]; 187 | ORTValue* c = [ORTSessionTest ortValueWithScalarFloatData:cData]; 188 | 189 | NSError* err = nil; 190 | ORTSession* session = [[ORTSession alloc] initWithEnv:self.ortEnv 191 | modelPath:[ORTSessionTest getAddModelPath] 192 | sessionOptions:[ORTSessionTest makeSessionOptions] 193 | error:&err]; 194 | ORTAssertNullableResultSuccessful(session, err); 195 | 196 | BOOL runResult = [session runWithInputs:@{@"D" : d} 197 | outputs:@{@"C" : c} 198 | runOptions:[ORTSessionTest makeRunOptions] 199 | error:&err]; 200 | ORTAssertBoolResultUnsuccessful(runResult, err); 201 | } 202 | 203 | - (void)testAppendCoreMLEP { 204 | NSError* err = nil; 205 | ORTSessionOptions* sessionOptions = [ORTSessionTest makeSessionOptions]; 206 | ORTCoreMLExecutionProviderOptions* coreMLOptions = [[ORTCoreMLExecutionProviderOptions alloc] init]; 207 | coreMLOptions.enableOnSubgraphs = YES; // set an arbitrary option 208 | 209 | BOOL appendResult = [sessionOptions appendCoreMLExecutionProviderWithOptions:coreMLOptions 210 | error:&err]; 211 | 212 | if (!ORTIsCoreMLExecutionProviderAvailable()) { 213 | ORTAssertBoolResultUnsuccessful(appendResult, err); 214 | return; 215 | } 216 | 217 | ORTAssertBoolResultSuccessful(appendResult, err); 218 | 219 | ORTSession* session = [[ORTSession alloc] initWithEnv:self.ortEnv 220 | modelPath:[ORTSessionTest getAddModelPath] 221 | sessionOptions:sessionOptions 222 | error:&err]; 223 | ORTAssertNullableResultSuccessful(session, err); 224 | } 225 | 226 | - (void)testAppendXnnpackEP { 227 | NSError* err = nil; 228 | ORTSessionOptions* sessionOptions = [ORTSessionTest makeSessionOptions]; 229 | ORTXnnpackExecutionProviderOptions* XnnpackOptions = [[ORTXnnpackExecutionProviderOptions alloc] init]; 230 | XnnpackOptions.intra_op_num_threads = 2; 231 | 232 | BOOL appendResult = [sessionOptions appendXnnpackExecutionProviderWithOptions:XnnpackOptions 233 | error:&err]; 234 | // Without xnnpack EP in building also can pass the test 235 | NSString* err_msg = [err localizedDescription]; 236 | if (!appendResult && [err_msg containsString:@"XNNPACK execution provider is not supported in this build. "]) { 237 | return; 238 | } 239 | 240 | ORTAssertBoolResultSuccessful(appendResult, err); 241 | 242 | ORTSession* session = [[ORTSession alloc] initWithEnv:self.ortEnv 243 | modelPath:[ORTSessionTest getAddModelPath] 244 | sessionOptions:sessionOptions 245 | error:&err]; 246 | ORTAssertNullableResultSuccessful(session, err); 247 | } 248 | 249 | static bool gDummyRegisterCustomOpsFnCalled = false; 250 | 251 | static OrtStatus* _Nullable DummyRegisterCustomOpsFn(OrtSessionOptions* /*session_options*/, 252 | const OrtApiBase* /*api*/) { 253 | gDummyRegisterCustomOpsFnCalled = true; 254 | return nullptr; 255 | } 256 | 257 | - (void)testRegisterCustomOpsUsingFunctionPointer { 258 | NSError* err = nil; 259 | ORTSessionOptions* sessionOptions = [ORTSessionTest makeSessionOptions]; 260 | 261 | gDummyRegisterCustomOpsFnCalled = false; 262 | BOOL registerResult = [sessionOptions registerCustomOpsUsingFunctionPointer:&DummyRegisterCustomOpsFn 263 | error:&err]; 264 | ORTAssertBoolResultSuccessful(registerResult, err); 265 | 266 | XCTAssertEqual(gDummyRegisterCustomOpsFnCalled, true); 267 | } 268 | 269 | - (void)testStringInputs { 270 | NSError* err = nil; 271 | NSArray* stringData = @[ @"ONNX Runtime", @"is the", @"best", @"AI Framework" ]; 272 | ORTValue* stringValue = [[ORTValue alloc] initWithTensorStringData:stringData shape:@[ @2, @2 ] error:&err]; 273 | ORTAssertNullableResultSuccessful(stringValue, err); 274 | 275 | ORTSession* session = [[ORTSession alloc] initWithEnv:self.ortEnv 276 | modelPath:[ORTSessionTest getStringModelPath] 277 | sessionOptions:[ORTSessionTest makeSessionOptions] 278 | error:&err]; 279 | ORTAssertNullableResultSuccessful(session, err); 280 | 281 | NSDictionary* outputs = 282 | [session runWithInputs:@{@"input:0" : stringValue} 283 | outputNames:[NSSet setWithArray:@[ @"output:0" ]] 284 | runOptions:[ORTSessionTest makeRunOptions] 285 | error:&err]; 286 | ORTAssertNullableResultSuccessful(outputs, err); 287 | 288 | ORTValue* outputStringValue = outputs[@"output:0"]; 289 | XCTAssertNotNil(outputStringValue); 290 | 291 | NSArray* outputStringData = [outputStringValue tensorStringDataWithError:&err]; 292 | ORTAssertNullableResultSuccessful(outputStringData, err); 293 | 294 | XCTAssertEqual([stringData count], [outputStringData count]); 295 | XCTAssertTrue([stringData isEqualToArray:outputStringData]); 296 | } 297 | 298 | - (void)testKeepORTEnvReference { 299 | ORTEnv* __weak envWeak = _ortEnv; 300 | // Remove sole strong reference to the ORTEnv created in setUp. 301 | _ortEnv = nil; 302 | // There should be no more strong references to it. 303 | XCTAssertNil(envWeak); 304 | 305 | // Create a new ORTEnv. 306 | NSError* err = nil; 307 | ORTEnv* env = [[ORTEnv alloc] initWithLoggingLevel:ORTLoggingLevelWarning 308 | error:&err]; 309 | ORTAssertNullableResultSuccessful(env, err); 310 | 311 | ORTSession* session = [[ORTSession alloc] initWithEnv:env 312 | modelPath:[ORTSessionTest getAddModelPath] 313 | sessionOptions:[ORTSessionTest makeSessionOptions] 314 | error:&err]; 315 | ORTAssertNullableResultSuccessful(session, err); 316 | 317 | envWeak = env; 318 | // Remove strong reference to the ORTEnv passed to the ORTSession initializer. 319 | env = nil; 320 | // ORTSession should keep a strong reference to it. 321 | XCTAssertNotNil(envWeak); 322 | } 323 | 324 | @end 325 | 326 | NS_ASSUME_NONNULL_END 327 | -------------------------------------------------------------------------------- /objectivec/test/ort_training_session_test.mm: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | 4 | #import 5 | 6 | #import "ort_checkpoint.h" 7 | #import "ort_training_session.h" 8 | #import "ort_env.h" 9 | #import "ort_session.h" 10 | #import "ort_value.h" 11 | 12 | #import "test/test_utils.h" 13 | #import "test/assertion_utils.h" 14 | 15 | NS_ASSUME_NONNULL_BEGIN 16 | 17 | @interface ORTTrainingSessionTest : XCTestCase 18 | @property(readonly, nullable) ORTEnv* ortEnv; 19 | @property(readonly, nullable) ORTCheckpoint* checkpoint; 20 | @property(readonly, nullable) ORTTrainingSession* session; 21 | @end 22 | 23 | @implementation ORTTrainingSessionTest 24 | 25 | - (void)setUp { 26 | [super setUp]; 27 | 28 | self.continueAfterFailure = NO; 29 | 30 | NSError* err = nil; 31 | _ortEnv = [[ORTEnv alloc] initWithLoggingLevel:ORTLoggingLevelWarning 32 | error:&err]; 33 | ORTAssertNullableResultSuccessful(_ortEnv, err); 34 | _checkpoint = [[ORTCheckpoint alloc] initWithPath:[ORTTrainingSessionTest 35 | getFilePathFromName:@"checkpoint.ckpt"] 36 | error:&err]; 37 | ORTAssertNullableResultSuccessful(_checkpoint, err); 38 | _session = [self makeTrainingSessionWithCheckpoint:_checkpoint]; 39 | } 40 | 41 | + (NSString*)getFilePathFromName:(NSString*)name { 42 | NSBundle* bundle = [NSBundle bundleForClass:[ORTTrainingSessionTest class]]; 43 | NSString* path = [[bundle resourcePath] stringByAppendingPathComponent:name]; 44 | return path; 45 | } 46 | 47 | + (NSMutableData*)loadTensorDataFromFile:(NSString*)filePath skipHeader:(BOOL)skipHeader { 48 | NSError* error = nil; 49 | NSString* fileContents = [NSString stringWithContentsOfFile:filePath 50 | encoding:NSUTF8StringEncoding 51 | error:&error]; 52 | ORTAssertNullableResultSuccessful(fileContents, error); 53 | 54 | NSArray* lines = [fileContents componentsSeparatedByCharactersInSet:[NSCharacterSet newlineCharacterSet]]; 55 | 56 | if (skipHeader) { 57 | lines = [lines subarrayWithRange:NSMakeRange(1, lines.count - 1)]; 58 | } 59 | 60 | NSArray* dataArray = [lines[0] componentsSeparatedByCharactersInSet: 61 | [NSCharacterSet characterSetWithCharactersInString:@",[] "]]; 62 | NSMutableData* tensorData = [NSMutableData data]; 63 | 64 | for (NSString* str in dataArray) { 65 | if (str.length > 0) { 66 | float value = [str floatValue]; 67 | [tensorData appendBytes:&value length:sizeof(float)]; 68 | } 69 | } 70 | 71 | return tensorData; 72 | } 73 | 74 | - (ORTTrainingSession*)makeTrainingSessionWithCheckpoint:(ORTCheckpoint*)checkpoint { 75 | NSError* error = nil; 76 | ORTSessionOptions* sessionOptions = [[ORTSessionOptions alloc] initWithError:&error]; 77 | ORTAssertNullableResultSuccessful(sessionOptions, error); 78 | 79 | ORTTrainingSession* session = [[ORTTrainingSession alloc] 80 | initWithEnv:self.ortEnv 81 | sessionOptions:sessionOptions 82 | checkpoint:checkpoint 83 | trainModelPath:[ORTTrainingSessionTest getFilePathFromName:@"training_model.onnx"] 84 | evalModelPath:[ORTTrainingSessionTest getFilePathFromName:@"eval_model.onnx"] 85 | optimizerModelPath:[ORTTrainingSessionTest getFilePathFromName:@"adamw.onnx"] 86 | error:&error]; 87 | 88 | ORTAssertNullableResultSuccessful(session, error); 89 | return session; 90 | } 91 | 92 | - (void)testInitTrainingSession { 93 | NSError* error = nil; 94 | 95 | // check that inputNames contains input-0 96 | NSArray* inputNames = [self.session getTrainInputNamesWithError:&error]; 97 | ORTAssertNullableResultSuccessful(inputNames, error); 98 | 99 | XCTAssertTrue(inputNames.count > 0); 100 | XCTAssertTrue([inputNames containsObject:@"input-0"]); 101 | 102 | // check that outNames contains onnx::loss::21273 103 | NSArray* outputNames = [self.session getTrainOutputNamesWithError:&error]; 104 | ORTAssertNullableResultSuccessful(outputNames, error); 105 | 106 | XCTAssertTrue(outputNames.count > 0); 107 | XCTAssertTrue([outputNames containsObject:@"onnx::loss::21273"]); 108 | } 109 | 110 | - (void)testInitTrainingSessionWithEval { 111 | NSError* error = nil; 112 | 113 | // check that inputNames contains input-0 114 | NSArray* inputNames = [self.session getEvalInputNamesWithError:&error]; 115 | ORTAssertNullableResultSuccessful(inputNames, error); 116 | 117 | XCTAssertTrue(inputNames.count > 0); 118 | XCTAssertTrue([inputNames containsObject:@"input-0"]); 119 | 120 | // check that outNames contains onnx::loss::21273 121 | NSArray* outputNames = [self.session getEvalOutputNamesWithError:&error]; 122 | ORTAssertNullableResultSuccessful(outputNames, error); 123 | 124 | XCTAssertTrue(outputNames.count > 0); 125 | XCTAssertTrue([outputNames containsObject:@"onnx::loss::21273"]); 126 | } 127 | 128 | - (void)runTrainStep { 129 | // load input and expected output 130 | NSError* error = nil; 131 | NSMutableData* expectedOutput = [ORTTrainingSessionTest loadTensorDataFromFile:[ORTTrainingSessionTest 132 | getFilePathFromName:@"loss_1.out"] 133 | skipHeader:YES]; 134 | 135 | NSMutableData* input = [ORTTrainingSessionTest loadTensorDataFromFile:[ORTTrainingSessionTest 136 | getFilePathFromName:@"input-0.in"] 137 | skipHeader:YES]; 138 | 139 | int32_t labels[] = {1, 1}; 140 | 141 | // create ORTValue array for input and labels 142 | NSMutableArray* inputValues = [NSMutableArray array]; 143 | 144 | ORTValue* inputTensor = [[ORTValue alloc] initWithTensorData:input 145 | elementType:ORTTensorElementDataTypeFloat 146 | shape:@[ @2, @784 ] 147 | error:&error]; 148 | ORTAssertNullableResultSuccessful(inputTensor, error); 149 | [inputValues addObject:inputTensor]; 150 | 151 | ORTValue* labelTensor = [[ORTValue alloc] initWithTensorData:[NSMutableData dataWithBytes:labels 152 | length:sizeof(labels)] 153 | elementType:ORTTensorElementDataTypeInt32 154 | shape:@[ @2 ] 155 | error:&error]; 156 | 157 | ORTAssertNullableResultSuccessful(labelTensor, error); 158 | [inputValues addObject:labelTensor]; 159 | 160 | NSArray* outputs = [self.session trainStepWithInputValues:inputValues error:&error]; 161 | ORTAssertNullableResultSuccessful(outputs, error); 162 | XCTAssertTrue(outputs.count > 0); 163 | 164 | BOOL result = [self.session lazyResetGradWithError:&error]; 165 | ORTAssertBoolResultSuccessful(result, error); 166 | 167 | outputs = [self.session trainStepWithInputValues:inputValues error:&error]; 168 | ORTAssertNullableResultSuccessful(outputs, error); 169 | XCTAssertTrue(outputs.count > 0); 170 | 171 | ORTValue* outputValue = outputs[0]; 172 | ORTValueTypeInfo* typeInfo = [outputValue typeInfoWithError:&error]; 173 | ORTAssertNullableResultSuccessful(typeInfo, error); 174 | XCTAssertEqual(typeInfo.type, ORTValueTypeTensor); 175 | XCTAssertNotNil(typeInfo.tensorTypeAndShapeInfo); 176 | 177 | ORTTensorTypeAndShapeInfo* tensorInfo = [outputValue tensorTypeAndShapeInfoWithError:&error]; 178 | ORTAssertNullableResultSuccessful(tensorInfo, error); 179 | XCTAssertEqual(tensorInfo.elementType, ORTTensorElementDataTypeFloat); 180 | 181 | NSMutableData* tensorData = [outputValue tensorDataWithError:&error]; 182 | ORTAssertNullableResultSuccessful(tensorData, error); 183 | ORTAssertEqualFloatArrays(test_utils::getFloatArrayFromData(tensorData), 184 | test_utils::getFloatArrayFromData(expectedOutput)); 185 | } 186 | 187 | - (void)testTrainStepOutput { 188 | [self runTrainStep]; 189 | } 190 | 191 | - (void)testOptimizerStep { 192 | // load input and expected output 193 | NSError* error = nil; 194 | NSMutableData* expectedOutput1 = [ORTTrainingSessionTest loadTensorDataFromFile:[ORTTrainingSessionTest 195 | getFilePathFromName:@"loss_1.out"] 196 | skipHeader:YES]; 197 | 198 | NSMutableData* expectedOutput2 = [ORTTrainingSessionTest loadTensorDataFromFile:[ORTTrainingSessionTest 199 | getFilePathFromName:@"loss_2.out"] 200 | skipHeader:YES]; 201 | 202 | NSMutableData* input = [ORTTrainingSessionTest loadTensorDataFromFile:[ORTTrainingSessionTest 203 | getFilePathFromName:@"input-0.in"] 204 | skipHeader:YES]; 205 | 206 | int32_t labels[] = {1, 1}; 207 | 208 | // create ORTValue array for input and labels 209 | NSMutableArray* inputValues = [NSMutableArray array]; 210 | 211 | ORTValue* inputTensor = [[ORTValue alloc] initWithTensorData:input 212 | elementType:ORTTensorElementDataTypeFloat 213 | shape:@[ @2, @784 ] 214 | error:&error]; 215 | ORTAssertNullableResultSuccessful(inputTensor, error); 216 | [inputValues addObject:inputTensor]; 217 | 218 | ORTValue* labelTensor = [[ORTValue alloc] initWithTensorData:[NSMutableData dataWithBytes:labels 219 | length:sizeof(labels)] 220 | elementType:ORTTensorElementDataTypeInt32 221 | shape:@[ @2 ] 222 | error:&error]; 223 | ORTAssertNullableResultSuccessful(labelTensor, error); 224 | [inputValues addObject:labelTensor]; 225 | 226 | // run train step, optimizer steps and check loss 227 | NSArray* outputs = [self.session trainStepWithInputValues:inputValues error:&error]; 228 | ORTAssertNullableResultSuccessful(outputs, error); 229 | 230 | NSMutableData* loss = [outputs[0] tensorDataWithError:&error]; 231 | ORTAssertNullableResultSuccessful(loss, error); 232 | ORTAssertEqualFloatArrays(test_utils::getFloatArrayFromData(loss), 233 | test_utils::getFloatArrayFromData(expectedOutput1)); 234 | 235 | BOOL result = [self.session lazyResetGradWithError:&error]; 236 | ORTAssertBoolResultSuccessful(result, error); 237 | 238 | outputs = [self.session trainStepWithInputValues:inputValues error:&error]; 239 | ORTAssertNullableResultSuccessful(outputs, error); 240 | 241 | loss = [outputs[0] tensorDataWithError:&error]; 242 | ORTAssertNullableResultSuccessful(loss, error); 243 | ORTAssertEqualFloatArrays(test_utils::getFloatArrayFromData(loss), 244 | test_utils::getFloatArrayFromData(expectedOutput1)); 245 | 246 | result = [self.session optimizerStepWithError:&error]; 247 | ORTAssertBoolResultSuccessful(result, error); 248 | 249 | outputs = [self.session trainStepWithInputValues:inputValues error:&error]; 250 | ORTAssertNullableResultSuccessful(outputs, error); 251 | loss = [outputs[0] tensorDataWithError:&error]; 252 | ORTAssertNullableResultSuccessful(loss, error); 253 | ORTAssertEqualFloatArrays(test_utils::getFloatArrayFromData(loss), 254 | test_utils::getFloatArrayFromData(expectedOutput2)); 255 | } 256 | 257 | - (void)testSetLearningRate { 258 | NSError* error = nil; 259 | 260 | float learningRate = 0.1f; 261 | BOOL result = [self.session setLearningRate:learningRate error:&error]; 262 | ORTAssertBoolResultSuccessful(result, error); 263 | 264 | float actualLearningRate = [self.session getLearningRateWithError:&error]; 265 | ORTAssertEqualFloatAndNoError(learningRate, actualLearningRate, error); 266 | } 267 | 268 | - (void)testLinearLRScheduler { 269 | NSError* error = nil; 270 | 271 | float learningRate = 0.1f; 272 | BOOL result = [self.session registerLinearLRSchedulerWithWarmupStepCount:2 273 | totalStepCount:4 274 | initialLr:learningRate 275 | error:&error]; 276 | 277 | ORTAssertBoolResultSuccessful(result, error); 278 | 279 | [self runTrainStep]; 280 | 281 | result = [self.session optimizerStepWithError:&error]; 282 | ORTAssertBoolResultSuccessful(result, error); 283 | result = [self.session schedulerStepWithError:&error]; 284 | ORTAssertBoolResultSuccessful(result, error); 285 | ORTAssertEqualFloatAndNoError(0.05f, [self.session getLearningRateWithError:&error], error); 286 | 287 | result = [self.session optimizerStepWithError:&error]; 288 | ORTAssertBoolResultSuccessful(result, error); 289 | result = [self.session schedulerStepWithError:&error]; 290 | ORTAssertBoolResultSuccessful(result, error); 291 | ORTAssertEqualFloatAndNoError(0.1f, [self.session getLearningRateWithError:&error], error); 292 | 293 | result = [self.session optimizerStepWithError:&error]; 294 | ORTAssertBoolResultSuccessful(result, error); 295 | result = [self.session schedulerStepWithError:&error]; 296 | ORTAssertBoolResultSuccessful(result, error); 297 | ORTAssertEqualFloatAndNoError(0.05f, [self.session getLearningRateWithError:&error], error); 298 | 299 | result = [self.session optimizerStepWithError:&error]; 300 | ORTAssertBoolResultSuccessful(result, error); 301 | result = [self.session schedulerStepWithError:&error]; 302 | ORTAssertBoolResultSuccessful(result, error); 303 | ORTAssertEqualFloatAndNoError(0.0f, [self.session getLearningRateWithError:&error], error); 304 | } 305 | 306 | - (void)testExportModelForInference { 307 | NSError* error = nil; 308 | 309 | NSString* inferenceModelPath = [test_utils::createTemporaryDirectory(self) 310 | stringByAppendingPathComponent:@"inference_model.onnx"]; 311 | XCTAssertNotNil(inferenceModelPath); 312 | 313 | NSArray* graphOutputNames = [NSArray arrayWithObjects:@"output-0", nil]; 314 | 315 | BOOL result = [self.session exportModelForInferenceWithOutputPath:inferenceModelPath 316 | graphOutputNames:graphOutputNames 317 | error:&error]; 318 | 319 | ORTAssertBoolResultSuccessful(result, error); 320 | XCTAssertTrue([[NSFileManager defaultManager] fileExistsAtPath:inferenceModelPath]); 321 | 322 | [self addTeardownBlock:^{ 323 | NSError* error = nil; 324 | [[NSFileManager defaultManager] removeItemAtPath:inferenceModelPath error:&error]; 325 | }]; 326 | } 327 | 328 | - (void)testToBuffer { 329 | NSError* error = nil; 330 | ORTValue* buffer = [self.session toBufferWithTrainable:YES error:&error]; 331 | ORTAssertNullableResultSuccessful(buffer, error); 332 | 333 | ORTValueTypeInfo* typeInfo = [buffer typeInfoWithError:&error]; 334 | ORTAssertNullableResultSuccessful(typeInfo, error); 335 | XCTAssertEqual(typeInfo.type, ORTValueTypeTensor); 336 | XCTAssertNotNil(typeInfo.tensorTypeAndShapeInfo); 337 | } 338 | 339 | - (void)testFromBuffer { 340 | NSError* error = nil; 341 | 342 | ORTValue* buffer = [self.session toBufferWithTrainable:YES error:&error]; 343 | ORTAssertNullableResultSuccessful(buffer, error); 344 | 345 | BOOL result = [self.session fromBufferWithValue:buffer error:&error]; 346 | ORTAssertBoolResultSuccessful(result, error); 347 | } 348 | 349 | - (void)tearDown { 350 | _session = nil; 351 | _checkpoint = nil; 352 | _ortEnv = nil; 353 | 354 | [super tearDown]; 355 | } 356 | 357 | @end 358 | 359 | NS_ASSUME_NONNULL_END 360 | --------------------------------------------------------------------------------