├── .gitattributes ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── go.mod ├── legacy_code.go ├── onnxruntime_c_api.h ├── onnxruntime_go.go ├── onnxruntime_test.go ├── onnxruntime_wrapper.c ├── onnxruntime_wrapper.h ├── setup_env.go ├── setup_env_windows.go ├── tensor_type_constraints.go └── test_data ├── example ż 大 김.onnx ├── example_0_dim_output.onnx ├── example_big_compute.onnx ├── example_big_fanout.onnx ├── example_dynamic_axes.onnx ├── example_float16.onnx ├── example_multitype.onnx ├── example_several_inputs_and_outputs.onnx ├── generate_0_dimension_output.py ├── generate_big_fanout.py ├── generate_dynamic_axes_network.py ├── generate_float16_network.py ├── generate_network_big_compute.py ├── generate_network_different_types.py ├── generate_odd_name_onnx.py ├── generate_several_inputs_and_outputs.py ├── generate_sklearn_network.py ├── modify_metadata.py ├── onnxruntime.dll ├── onnxruntime_amd64.dylib ├── onnxruntime_arm64.dylib ├── onnxruntime_arm64.so └── sklearn_randomforest.onnx /.gitattributes: -------------------------------------------------------------------------------- 1 | onnxruntime_c_api.h linguist-vendored 2 | onnxruntime_training_c_api.h linguist-vendored 3 | 4 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | Contribution Guidelines 2 | ======================= 3 | 4 | This library began as a personal project, and is primarily still maintained as 5 | such. The following list of guidelines is not necessarily exhaustive, and, 6 | ultimately, any contribution is subject to the maintainer's discretion. That 7 | being said, contributions are welcome, and most recent new features have been 8 | added by users who need them! 9 | 10 | Coding Style 11 | ------------ 12 | 13 | - Go code must be formatted using the official `gofmt` tool. 14 | 15 | - C code should adhere to the portions of Google's C++ style guide that 16 | apply to C. 17 | 18 | - If at all possible, any Go or C code should have at most 80 character lines. 19 | (This may not be enforced very strictly.) 20 | 21 | - Purely stylistic changes are unlikely to be accepted. Instead, the 22 | maintainer or other contributers may make small stylistic adjustments to 23 | surrounding code as part of other contributions. 24 | 25 | - Attempt to mimic the existing style of the surrounding code. 26 | 27 | 28 | Documentation 29 | ------------- 30 | 31 | - All Go types, public-facing functions, and nontrivial internal functions 32 | must include a comment on their intended usage, to be parsed by godoc. 33 | 34 | - As per the google C++ style guide, all C functions must be documented with a 35 | comment as well. If a C function is defined in a header file, the comment 36 | should appear with the definition in the header. If it's a static function 37 | in a `.c` file, the comment should appear with the function definition. 38 | 39 | 40 | Tests 41 | ----- 42 | 43 | - All new features and bugfixes must include a basic unit test (in 44 | `onnxruntime_test.go`) to serve as a sanity check. 45 | 46 | - If a test is for a platform-dependent or execution-provider-dependent 47 | feature, the test must be skipped if run on an unsupported system. 48 | 49 | - No tests should panic. Always check errors and fail rather than allowing 50 | tests to panic. 51 | 52 | - Every change must ensure that `go test -v -bench=.` passes on every 53 | supported platform. 54 | 55 | - Every test failure should be accompanied by a message containing the reason, 56 | either using `t.Logf()`, `t.Errorf()`, or `t.Fatalf()`. 57 | 58 | 59 | Adding New Files 60 | ---------------- 61 | 62 | - Apart from testing data, try not to add new source files. 63 | 64 | - Do not add third-party code or headers. The only exception for now is 65 | `onnxruntime_c_api.h`. 66 | 67 | - No C++ at all. Developing Go-to-C wrappers is annoying enough as it is. 68 | 69 | - Do not add any new `onnxruntime` shared libraries under `test_data`. I know 70 | there are additional platforms that would be nice to include (such as 71 | `x86_64` Linux), but I do not want this project turning into an unofficial 72 | distribution channel for onnxruntime libraries. It also clogs up the git 73 | repo with large files, and increases the size of the history every time 74 | these files are updated. The libraries that are included were only intended 75 | to allow a majority of users to run `go test -v -bench=.` without further 76 | setup or modification. Currently: amd64 Windows, arm64 Linux (I wish I 77 | hadn't included this!), arm64 osx, and amd64 osx. All other users must set 78 | the `ONNXRUNTIME_SHARED_LIBRARY_PATH` environment variable to a valid path 79 | to the correct `onnxruntime` shared library file prior to running tests. 80 | 81 | - If you need to add a .onnx file for a test, place both the .onnx file 82 | _and_ the script used to generate it into `test_data/`. 83 | 84 | - Keep any testing .onnx files as small as possible. 85 | 86 | - Without a good reason (i.e., implementing an entire class of APIs such as 87 | training), avoid adding new Go files---just add to `onnxruntime_go.go`. 88 | 89 | 90 | Dependencies 91 | ------------ 92 | 93 | - Avoid Go or C dependencies outside of the language's standard libraries. 94 | This package currently does not depend on any third-party Go modules, and 95 | it would be great to keep it this way. 96 | 97 | - Python scripts within `test_data/` can use whatever dependencies they need, 98 | because end users should not be required to run the python files, and the 99 | `.onnx` file they produce should already be included. 100 | 101 | 102 | C-Specific Stuff 103 | ---------------- 104 | 105 | - Minimize Go management of C-allocated memory as much as possible. For 106 | example, see the `convertORTString` function on `onnxruntime_go.go`, which 107 | copies a C-allocated string into a garbage-collected go `string`. 108 | 109 | - If you need to use a `OrtAllocator` in onnxruntime's C API, always use the 110 | default `OrtAllocator` returned by 111 | `ort_api->GetAllocatorWithDefaultOptions()`. 112 | 113 | - ONNXRuntime APIs requiring file paths typically use `ORTCHAR_T*` 114 | strings. On Linux/OSX/etc, these should be UTF-8, but on Windows they will 115 | be wide-character strings. (Our tricks with `#include` to make them look 116 | like `char*` to C code even on Windows, but the DLL still expects a 117 | `wchar_t*`.) The important takeaway: when passing `ORTCHAR_T*` 118 | values to the onnxruntime C API, use the `createOrtCharString(...)` 119 | function. It converts a Go string to a C string, but unlike `C.CString`, it 120 | will do UTF8 to UTF16 conversion on Windows. (On Linux, it simply wraps 121 | `C.CString`.) 122 | 123 | 124 | A Few Notes on Organization 125 | --------------------------- 126 | 127 | - The `onnxruntime` C API uses a struct containing function pointers. Cgo 128 | can't directly invoke functions via pointers, so `onnxruntime_wrapper.c` 129 | (along with the associated header file) are used to provide top-level C 130 | functions that call the function pointers within the `OrtApi` struct. 131 | 132 | - Linux and OSX use `dlopen` to load the onnxruntime shared library, but this 133 | isn't possible on Windows, which instead can use the `syscall.LoadLibrary()` 134 | function from Go's standard library. This different behavior is locked 135 | behind build constraints in `setup_env.go` and `setup_env_windows.go`, 136 | respectively. 137 | 138 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2023 Nathan Otterness 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in all 11 | copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | SOFTWARE. 20 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Cross-Platform `onnxruntime` Wrapper for Go 2 | =========================================== 3 | 4 | About 5 | ----- 6 | 7 | This library seeks to provide an interface for loading and executing neural 8 | networks from Go(lang) code, while remaining as simple to use as possible. 9 | 10 | A few example applications using this library can be found in the 11 | [`onnxruntime_go_examples` repository](https://github.com/yalue/onnxruntime_go_examples). 12 | 13 | The [onnxruntime](https://github.com/microsoft/onnxruntime) library provides a 14 | way to load and execute ONNX-format neural networks, though the library 15 | primarily supports C and C++ APIs. Several efforts exist to have written 16 | Go(lang) wrappers for the `onnxruntime` library, but as far as I can tell, none 17 | of these existing Go wrappers support Windows. This is due to the fact that 18 | Microsoft's `onnxruntime` library assumes the user will be using the MSVC 19 | compiler on Windows systems, while CGo on Windows requires using Mingw. 20 | 21 | This wrapper works around the issues by manually loading the `onnxruntime` 22 | shared library, removing any dependency on the `onnxruntime` source code beyond 23 | the header files. Naturally, this approach works equally well on non-Windows 24 | systems. 25 | 26 | Additionally, this library uses Go's recent addition of generics to support 27 | multiple Tensor data types; see the `NewTensor` or `NewEmptyTensor` functions. 28 | 29 | **IMPORTANT:** As of onnxruntime_go v1.12.0 or above, for CUDA acceleration we 30 | now require the use of CUDA 12.x and CuDNN 9.x (as required by onnxruntime 31 | v1.19.0+). Those wishing to stay on CUDA 11.8 should remain on onnxruntime_go 32 | v1.11.0 or below. 33 | 34 | Note on onnxruntime Library Versions 35 | ------------------------------------ 36 | 37 | At the time of writing, this library uses version 1.22.0 of the onnxruntime 38 | C API headers. So, it will probably only work with version 1.22.0 of the 39 | onnxruntime shared libraries, as well. If you need to use a different version, 40 | or if I get behind on updating this repository, updating or changing the 41 | onnxruntime version should be fairly easy: 42 | 43 | 1. Replace the `onnxruntime_c_api.h` file with the version corresponding to 44 | the onnxruntime version you wish to use. 45 | 46 | 2. Replace the `test_data/onnxruntime.dll` (or `test_data/onnxruntime*.so`) 47 | file with the version corresponding to the onnxruntime version you wish to 48 | use. 49 | 50 | 3. (If you care about DirectML support) Verify that the entries in the 51 | `DummyOrtDMLAPI` struct in `onnxruntime_wrapper.c` match the order in which 52 | they appear in the `OrtDmlApi` struct from the `dml_provider_factory.h` 53 | header in the official repo. See the comment on this struct in 54 | `onnxruntime_wrapper.c` for more information. 55 | 56 | Note that both the C API header and the shared library files are available to 57 | download from the releases page in the 58 | [official repo](https://github.com/microsoft/onnxruntime). Download the archive 59 | for the release you want to use, and extract it. The header file is located in 60 | the "include" subdirectory, and the shared library will be located in the "lib" 61 | subdirectory. (On Linux systems, you'll need the version of the .so with the 62 | appended version numbers, e.g., `libonnxruntime.so.1.22.0`, and _not_ the 63 | `libonnxruntime.so`, which is just a symbolic link.) The archive will contain 64 | several other files containing C++ headers, debug symbols, and so on, but you 65 | shouldn't need anything other than the single onnxruntime shared library and 66 | `onnxruntime_c_api.h`. (The exception is if you're wanting to enable GPU 67 | support, where you may need other shared-library files, such as 68 | `execution_providers_cuda.dll` and `execution_providers_shared.dll` on Windows.) 69 | 70 | 71 | Requirements 72 | ------------ 73 | 74 | To use this library, you'll need a version of Go with cgo support. If you are 75 | not using an amd64 version of Windows or Linux (or if you want to provide your 76 | own library for some other reason), you simply need to provide the correct path 77 | to the shared library when initializing the wrapper. This is seen in the first 78 | few lines of the following example. 79 | 80 | Note that if you want to use CUDA, you'll need to be using a version of the 81 | onnxruntime shared library with CUDA support, as well as be using a CUDA 82 | version supported by the underlying version of your onnxruntime library. For 83 | example, version 1.22.0 of the onnxruntime library only supports CUDA versions 84 | 12.x. See 85 | [the onnxruntime CUDA support documentation](https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html) 86 | for more specifics. 87 | 88 | 89 | Example Usage 90 | ------------- 91 | 92 | The full documentation can be found at [pkg.go.dev](https://pkg.go.dev/github.com/yalue/onnxruntime_go). 93 | 94 | Additionally, several example command-line applications complete with necessary 95 | networks and data can be found in the 96 | [`onnxruntime_go_examples` repository](https://github.com/yalue/onnxruntime_go_examples). 97 | 98 | The following example illustrates how this library can be used to load and run 99 | an ONNX network taking a single input tensor and producing a single output 100 | tensor, both of which contain 32-bit floating point values. Note that error 101 | handling is omitted; each of the functions returns an err value, which will be 102 | non-nil in the case of failure. 103 | 104 | ```go 105 | import ( 106 | "fmt" 107 | ort "github.com/yalue/onnxruntime_go" 108 | "os" 109 | ) 110 | 111 | func main() { 112 | // This line _may_ be optional; by default the library will try to load 113 | // "onnxruntime.dll" on Windows, and "onnxruntime.so" on any other system. 114 | // For stability, it is probably a good idea to always set this explicitly. 115 | ort.SetSharedLibraryPath("path/to/onnxruntime.so") 116 | 117 | err := ort.InitializeEnvironment() 118 | if err != nil { 119 | panic(err) 120 | } 121 | defer ort.DestroyEnvironment() 122 | 123 | // For a slight performance boost and convenience when re-using existing 124 | // tensors, this library expects the user to create all input and output 125 | // tensors prior to creating the session. If this isn't ideal for your use 126 | // case, see the DynamicAdvancedSession type in the documnentation, which 127 | // allows input and output tensors to be specified when calling Run() 128 | // rather than when initializing a session. 129 | inputData := []float32{0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9} 130 | inputShape := ort.NewShape(2, 5) 131 | inputTensor, err := ort.NewTensor(inputShape, inputData) 132 | defer inputTensor.Destroy() 133 | // This hypothetical network maps a 2x5 input -> 2x3x4 output. 134 | outputShape := ort.NewShape(2, 3, 4) 135 | outputTensor, err := ort.NewEmptyTensor[float32](outputShape) 136 | defer outputTensor.Destroy() 137 | 138 | session, err := ort.NewAdvancedSession("path/to/network.onnx", 139 | []string{"Input 1 Name"}, []string{"Output 1 Name"}, 140 | []ort.Value{inputTensor}, []ort.Value{outputTensor}, nil) 141 | defer session.Destroy() 142 | 143 | // Calling Run() will run the network, reading the current contents of the 144 | // input tensors and modifying the contents of the output tensors. 145 | err = session.Run() 146 | 147 | // Get a slice view of the output tensor's data. 148 | outputData := outputTensor.GetData() 149 | 150 | // If you want to run the network on a different input, all you need to do 151 | // is modify the input tensor data (available via inputTensor.GetData()) 152 | // and call Run() again. 153 | 154 | // ... 155 | } 156 | ``` 157 | 158 | 159 | Deprecated APIs 160 | --------------- 161 | 162 | Older versions of this library used a typed `Session[T]` struct to keep track 163 | of sessions. In retrospect, associating type parameters with Sessions was 164 | unnecessary, and the `AdvancedSession` type, along with its associated APIs, 165 | was added to rectify this mistake. For backwards compatibility, the old typed 166 | `Session[T]` and `DynamicSession[T]` types are still included and unlikely to 167 | be removed. However, they now delegate their functionality to 168 | `AdvancedSession` internally. New code should always favor using 169 | `AdvancedSession` directly. 170 | 171 | 172 | Running Tests and System Compatibility for Testing 173 | -------------------------------------------------- 174 | 175 | Navigate to this directory and run `go test -v`, or optionally 176 | `go test -v -bench=.`. All tests should pass; tests relating to CUDA or other 177 | accelerator support will be skipped on systems or onnxruntime builds that don't 178 | support them. 179 | 180 | Currently, this repository includes a copy of the onnxruntime shared libraries 181 | for a few systems, including AMD64 windows, ARM64 Linux, and ARM64 darwin. 182 | These should allow tests to pass on those systems without users needing to copy 183 | additional libraries beyond cloning this repository. In the future, however, 184 | this may change if support for more systems are added or removed. 185 | 186 | You may want to use a different version of the `onnxruntime` shared library for 187 | a couple reasons. In particular: 188 | 189 | 1. The included shared library copies do not include support for CUDA or other 190 | accelerated execution providers, so CUDA-related tests will always be 191 | skipped if you use the default libraries in this repo. 192 | 193 | 2. Many systems, including AMD64 and i386 Linux, and x86 osx, do not currently 194 | have shared libraries included in `test_data/` in the first place. (I would 195 | like to keep this directory, and the overall repo, smaller by keeping the 196 | number of shared libraries small.) 197 | 198 | If these or other reasons apply to you, the test code will check the 199 | `ONNXRUNTIME_SHARED_LIBRARY_PATH` environment variable before attempting to 200 | load a library from `test_data/`. So, if you are using one of these systems or 201 | want accelerator-related tests to run, you should set the environment variable 202 | to the path to the onnxruntime shared library. Afterwards, `go test -v` should 203 | run and pass. 204 | 205 | 206 | Training API Support 207 | -------------------- 208 | 209 | The training API has been deprecated as of onnxruntime version 1.20. Rather 210 | than continuing to maintain wrappers for a deprecated API, `onnxruntime_go` has 211 | replaced the wrapper functions for the training API with stubs that return an 212 | error. Users who need to continue to use the training API will need to use an 213 | older version. For example the following versions should be compatible with 214 | training: 215 | 216 | - Version `v1.12.1` of `onnxruntime_go`, and 217 | - Version 1.19.2 of `onnxruntime`. 218 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/yalue/onnxruntime_go 2 | 3 | go 1.19 4 | -------------------------------------------------------------------------------- /legacy_code.go: -------------------------------------------------------------------------------- 1 | package onnxruntime_go 2 | 3 | // This file contains code and types that we maintain for compatibility 4 | // purposes, but is not expected to be regularly maintained or udpated. 5 | 6 | import ( 7 | "fmt" 8 | "os" 9 | ) 10 | 11 | // #include "onnxruntime_wrapper.h" 12 | import "C" 13 | 14 | // This type of session is for ONNX networks with the same input and output 15 | // data types. 16 | // 17 | // NOTE: This type was written with a type parameter despite the fact that a 18 | // type parameter is not necessary for any of its underlying implementation, 19 | // which is a mistake in retrospect. It is preserved only for compatibility 20 | // with older code, and new users should almost certainly be using an 21 | // AdvancedSession instead. 22 | // 23 | // Using an AdvancedSession struct should be easier, and supports arbitrary 24 | // combination of input and output tensor data types as well as more options. 25 | type Session[T TensorData] struct { 26 | // We now delegate all of the implementation to an AdvancedSession here. 27 | s *AdvancedSession 28 | } 29 | 30 | // Similar to Session, but does not require the specification of the input 31 | // and output shapes at session creation time, and allows for input and output 32 | // tensors to have different types. This allows for fully dynamic input to the 33 | // onnx model. 34 | // 35 | // NOTE: As with Session[T], new users should probably be using 36 | // DynamicAdvancedSession in the future. 37 | type DynamicSession[In TensorData, Out TensorData] struct { 38 | s *DynamicAdvancedSession 39 | } 40 | 41 | // The same as NewSession, but takes a slice of bytes containing the .onnx 42 | // network rather than a file path. 43 | func NewSessionWithONNXData[T TensorData](onnxData []byte, inputNames, 44 | outputNames []string, inputs, outputs []*Tensor[T]) (*Session[T], error) { 45 | // Unfortunately, a slice of pointers that satisfy an interface don't count 46 | // as a slice of interfaces (at least, as I write this), so we'll make the 47 | // conversion here. 48 | tmpInputs := make([]Value, len(inputs)) 49 | tmpOutputs := make([]Value, len(outputs)) 50 | for i, t := range inputs { 51 | tmpInputs[i] = t 52 | } 53 | for i, t := range outputs { 54 | tmpOutputs[i] = t 55 | } 56 | s, e := NewAdvancedSessionWithONNXData(onnxData, inputNames, outputNames, 57 | tmpInputs, tmpOutputs, nil) 58 | if e != nil { 59 | return nil, e 60 | } 61 | return &Session[T]{ 62 | s: s, 63 | }, nil 64 | } 65 | 66 | // Similar to NewSessionWithOnnxData, but for dynamic sessions. 67 | func NewDynamicSessionWithONNXData[in TensorData, out TensorData](onnxData []byte, 68 | inputNames, outputNames []string) (*DynamicSession[in, out], error) { 69 | s, e := NewDynamicAdvancedSessionWithONNXData(onnxData, inputNames, 70 | outputNames, nil) 71 | if e != nil { 72 | return nil, e 73 | } 74 | return &DynamicSession[in, out]{ 75 | s: s, 76 | }, nil 77 | } 78 | 79 | // Loads the ONNX network at the given path, and initializes a Session 80 | // instance. If this returns successfully, the caller must call Destroy() on 81 | // the returned session when it is no longer needed. We require the user to 82 | // provide the input and output tensors and names at this point, in order to 83 | // not need to re-allocate them every time Run() is called. The user instead 84 | // can just update or access the input/output tensor data after calling Run(). 85 | // The input and output tensors MUST outlive this session, and calling 86 | // session.Destroy() will not destroy the input or output tensors. 87 | func NewSession[T TensorData](onnxFilePath string, inputNames, 88 | outputNames []string, inputs, outputs []*Tensor[T]) (*Session[T], error) { 89 | fileContent, e := os.ReadFile(onnxFilePath) 90 | if e != nil { 91 | return nil, fmt.Errorf("Error reading %s: %w", onnxFilePath, e) 92 | } 93 | 94 | toReturn, e := NewSessionWithONNXData[T](fileContent, inputNames, 95 | outputNames, inputs, outputs) 96 | if e != nil { 97 | return nil, fmt.Errorf("Error creating session from %s: %w", 98 | onnxFilePath, e) 99 | } 100 | return toReturn, nil 101 | } 102 | 103 | // Same as NewSession, but for dynamic sessions. 104 | func NewDynamicSession[in TensorData, out TensorData](onnxFilePath string, 105 | inputNames, outputNames []string) (*DynamicSession[in, out], error) { 106 | fileContent, e := os.ReadFile(onnxFilePath) 107 | if e != nil { 108 | return nil, fmt.Errorf("Error reading %s: %w", onnxFilePath, e) 109 | } 110 | 111 | toReturn, e := NewDynamicSessionWithONNXData[in, out](fileContent, 112 | inputNames, outputNames) 113 | if e != nil { 114 | return nil, fmt.Errorf("Error creating session from %s: %w", 115 | onnxFilePath, e) 116 | } 117 | return toReturn, nil 118 | } 119 | 120 | func (s *Session[_]) Destroy() error { 121 | return s.s.Destroy() 122 | } 123 | 124 | func (s *DynamicSession[_, _]) Destroy() error { 125 | return s.s.Destroy() 126 | } 127 | 128 | func (s *Session[T]) Run() error { 129 | return s.s.Run() 130 | } 131 | 132 | // Unlike the non-dynamic equivalents, the DynamicSession's Run() function 133 | // takes a list of input and output tensors rather than requiring the tensors 134 | // to be specified at Session creation time. It is still the caller's 135 | // responsibility to create and Destroy all tensors passed to this function. 136 | func (s *DynamicSession[in, out]) Run(inputs []*Tensor[in], 137 | outputs []*Tensor[out]) error { 138 | if len(inputs) != len(s.s.s.inputNames) { 139 | return fmt.Errorf("The session specified %d input names, but Run() "+ 140 | "was called with %d input tensors", len(s.s.s.inputNames), 141 | len(inputs)) 142 | } 143 | if len(outputs) != len(s.s.s.outputNames) { 144 | return fmt.Errorf("The session specified %d output names, but Run() "+ 145 | "was called with %d output tensors", len(s.s.s.outputNames), 146 | len(outputs)) 147 | } 148 | inputValues := make([]*C.OrtValue, len(inputs)) 149 | for i, v := range inputs { 150 | inputValues[i] = v.GetInternals().ortValue 151 | } 152 | outputValues := make([]*C.OrtValue, len(outputs)) 153 | for i, v := range outputs { 154 | outputValues[i] = v.GetInternals().ortValue 155 | } 156 | 157 | status := C.RunOrtSession(s.s.s.ortSession, &inputValues[0], 158 | &s.s.s.inputNames[0], C.int(len(inputs)), &outputValues[0], 159 | &s.s.s.outputNames[0], C.int(len(outputs))) 160 | if status != nil { 161 | return fmt.Errorf("Error running network: %w", statusToError(status)) 162 | } 163 | return nil 164 | } 165 | 166 | // This type alias is included to avoid breaking older code, where the inputs 167 | // and outputs to session.Run() were ArbitraryTensors rather than Values. 168 | type ArbitraryTensor = Value 169 | 170 | // As with the ArbitraryTensor type, this type alias only exists to facilitate 171 | // renaming an old type without breaking existing code. 172 | type TensorInternalData = ValueInternalData 173 | 174 | var TrainingAPIRemovedError error = fmt.Errorf("Support for the training " + 175 | "API has been removed from onnxruntime_go following its deprecation in " + 176 | "onnxruntime versions 1.19.2 and later. The last revision of " + 177 | "onnxruntime_go supporting the training API is version v1.12.1") 178 | 179 | // Support for TrainingSessions has been removed from onnxruntime_go following 180 | // the deprecation of the training API in onnxruntime 1.20.0. 181 | type TrainingSession struct{} 182 | 183 | // Always returns TrainingAPIRemovedError. 184 | func (s *TrainingSession) ExportModel(path string, outputNames []string) error { 185 | return TrainingAPIRemovedError 186 | } 187 | 188 | // Always returns TrainingAPIRemovedError. 189 | func (s *TrainingSession) SaveCheckpoint(path string, 190 | saveOptimizerState bool) error { 191 | return TrainingAPIRemovedError 192 | } 193 | 194 | // Always returns TrainingAPIRemovedError. 195 | func (s *TrainingSession) Destroy() error { 196 | return TrainingAPIRemovedError 197 | } 198 | 199 | // Always returns TrainingAPIRemovedError. 200 | func (s *TrainingSession) TrainStep() error { 201 | return TrainingAPIRemovedError 202 | } 203 | 204 | // Always returns TrainingAPIRemovedError. 205 | func (s *TrainingSession) OptimizerStep() error { 206 | return TrainingAPIRemovedError 207 | } 208 | 209 | // Always returns TrainingAPIRemovedError. 210 | func (s *TrainingSession) LazyResetGrad() error { 211 | return TrainingAPIRemovedError 212 | } 213 | 214 | // Support for TrainingInputOutputNames has been removed from onnxruntime_go 215 | // following the deprecation of the training API in onnxruntime 1.20.0. 216 | type TrainingInputOutputNames struct { 217 | TrainingInputNames []string 218 | EvalInputNames []string 219 | TrainingOutputNames []string 220 | EvalOutputNames []string 221 | } 222 | 223 | // Always returns (nil, TrainingAPIRemovedError). 224 | func GetInputOutputNames(checkpointStatePath string, trainingModelPath string, 225 | evalModelPath string) (*TrainingInputOutputNames, error) { 226 | return nil, TrainingAPIRemovedError 227 | } 228 | 229 | // Always returns false. 230 | func IsTrainingSupported() bool { 231 | return false 232 | } 233 | 234 | // Always returns (nil, TrainingAPIRemovedError). 235 | func NewTrainingSessionWithOnnxData(checkpointData, trainingData, evalData, 236 | optimizerData []byte, inputs, outputs []Value, 237 | options *SessionOptions) (*TrainingSession, error) { 238 | return nil, TrainingAPIRemovedError 239 | } 240 | 241 | // Always returns (nil, TrainingAPIRemovedError). 242 | func NewTrainingSession(checkpointStatePath, trainingModelPath, evalModelPath, 243 | optimizerModelPath string, inputs, outputs []Value, 244 | options *SessionOptions) (*TrainingSession, error) { 245 | return nil, TrainingAPIRemovedError 246 | } 247 | -------------------------------------------------------------------------------- /onnxruntime_test.go: -------------------------------------------------------------------------------- 1 | package onnxruntime_go 2 | 3 | import ( 4 | "fmt" 5 | "math" 6 | "math/rand" 7 | "os" 8 | "runtime" 9 | "testing" 10 | ) 11 | 12 | // Always use the same RNG seed for benchmarks, so we can compare the 13 | // performance on the same random input data. 14 | const benchmarkRNGSeed = 12345678 15 | 16 | // If the ONNXRUNTIME_SHARED_LIBRARY_PATH environment variable is set, then 17 | // we'll try to use its contents as the location of the shared library for 18 | // these tests. Otherwise, we'll fall back to trying the shared library copies 19 | // in the test_data directory. 20 | func getTestSharedLibraryPath(t testing.TB) string { 21 | toReturn := os.Getenv("ONNXRUNTIME_SHARED_LIBRARY_PATH") 22 | if toReturn != "" { 23 | return toReturn 24 | } 25 | if runtime.GOOS == "windows" { 26 | return "test_data/onnxruntime.dll" 27 | } 28 | if runtime.GOARCH == "arm64" { 29 | if runtime.GOOS == "darwin" { 30 | return "test_data/onnxruntime_arm64.dylib" 31 | } 32 | return "test_data/onnxruntime_arm64.so" 33 | } 34 | if runtime.GOARCH == "amd64" && runtime.GOOS == "darwin" { 35 | return "test_data/onnxruntime_amd64.dylib" 36 | } 37 | return "test_data/onnxruntime.so" 38 | } 39 | 40 | // This must be called prior to running each test. 41 | func InitializeRuntime(t testing.TB) { 42 | if IsInitialized() { 43 | return 44 | } 45 | SetSharedLibraryPath(getTestSharedLibraryPath(t)) 46 | e := InitializeEnvironment() 47 | if e != nil { 48 | t.Fatalf("Failed setting up onnxruntime environment: %s\n", e) 49 | } 50 | } 51 | 52 | // Should be called at the end of each test to de-initialize the runtime. 53 | func CleanupRuntime(t testing.TB) { 54 | e := DestroyEnvironment() 55 | if e != nil { 56 | t.Fatalf("Error cleaning up environment: %s\n", e) 57 | } 58 | } 59 | 60 | // Returns nil if a and b are within a small delta of one another, otherwise 61 | // returns an error indicating their values. 62 | func floatsEqual(a, b float32) error { 63 | diff := a - b 64 | if diff < 0 { 65 | diff = -diff 66 | } 67 | // Arbitrarily chosen precision. (Unfortunately, going higher than this may 68 | // cause test failures, since the Sum operator doesn't have the same 69 | // results as doing sums purely in Go.) 70 | if diff >= 0.000001 { 71 | return fmt.Errorf("Values differ by too much: %f vs %f", a, b) 72 | } 73 | return nil 74 | } 75 | 76 | // Returns an error if any element between a and b don't match. 77 | func allFloatsEqual(a, b []float32) error { 78 | if len(a) != len(b) { 79 | return fmt.Errorf("Length mismatch: %d vs %d", len(a), len(b)) 80 | } 81 | for i := range a { 82 | e := floatsEqual(a[i], b[i]) 83 | if e != nil { 84 | return fmt.Errorf("Data element %d doesn't match: %s", i, e) 85 | } 86 | } 87 | return nil 88 | } 89 | 90 | // Returns an empty tensor with the given type and shape, or fails the test on 91 | // error. 92 | func newTestTensor[T TensorData](t testing.TB, s Shape) *Tensor[T] { 93 | toReturn, e := NewEmptyTensor[T](s) 94 | if e != nil { 95 | t.Fatalf("Failed creating empty tensor with shape %s: %s\n", s, e) 96 | } 97 | return toReturn 98 | } 99 | 100 | func TestGetVersion(t *testing.T) { 101 | InitializeRuntime(t) 102 | defer CleanupRuntime(t) 103 | version := GetVersion() 104 | if version == "" { 105 | t.Fatalf("Unable to get the onnxruntime library version\n") 106 | } 107 | t.Logf("Found onnxruntime library version: %s\n", version) 108 | } 109 | 110 | func TestInitializationOptions(t *testing.T) { 111 | e := InitializeEnvironment(WithLogLevelVerbose()) 112 | if e != nil { 113 | t.Fatalf("Error initializing with 'verbose' log level: %s\n", e) 114 | } 115 | e = DestroyEnvironment() 116 | if e != nil { 117 | t.Fatalf("Error cleaning up 'verbose' log level environment: %s\n", e) 118 | } 119 | e = InitializeEnvironment(WithLogLevelError()) 120 | if e != nil { 121 | t.Fatalf("Error initializing with 'error' log level: %s\n", e) 122 | } 123 | e = DestroyEnvironment() 124 | if e != nil { 125 | t.Fatalf("Error cleaning up 'error' log level environment: %s\n", e) 126 | } 127 | } 128 | 129 | func TestTensorTypes(t *testing.T) { 130 | type myFloat float64 131 | dataType := TensorElementDataType(GetTensorElementDataType[myFloat]()) 132 | expected := TensorElementDataType(TensorElementDataTypeDouble) 133 | if dataType != expected { 134 | t.Fatalf("Expected float64 data type to be %d (%s), got %d (%s)\n", 135 | expected, expected, dataType, dataType) 136 | } 137 | t.Logf("Got data type for float64-based double: %d (%s)\n", 138 | dataType, dataType) 139 | } 140 | 141 | func TestCreateTensor(t *testing.T) { 142 | InitializeRuntime(t) 143 | defer CleanupRuntime(t) 144 | s := NewShape(1, 2, 3) 145 | tensor1, e := NewEmptyTensor[uint8](s) 146 | if e != nil { 147 | t.Fatalf("Failed creating %s uint8 tensor: %s\n", s, e) 148 | } 149 | defer tensor1.Destroy() 150 | if len(tensor1.GetData()) != 6 { 151 | t.Fatalf("Incorrect data length for tensor1: %d\n", 152 | len(tensor1.GetData())) 153 | } 154 | // Make sure that the underlying tensor created a copy of the shape we 155 | // passed to NewEmptyTensor. 156 | s[1] = 3 157 | if tensor1.GetShape()[1] == s[1] { 158 | t.Fatalf("Modifying the original shape incorrectly changed the " + 159 | "tensor's shape.\n") 160 | } 161 | 162 | // Try making a tensor with a different data type. 163 | s = NewShape(2, 5) 164 | data := []float32{1.0} 165 | _, e = NewTensor(s, data) 166 | if e == nil { 167 | t.Fatalf("Didn't get error when creating a tensor with too little " + 168 | "data.\n") 169 | } 170 | t.Logf("Got expected error when creating a tensor without enough data: "+ 171 | "%s\n", e) 172 | 173 | // It shouldn't be an error to create a tensor with too *much* underlying 174 | // data; we'll just use the first portion of it. 175 | data = []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14} 176 | tensor2, e := NewTensor(s, data) 177 | if e != nil { 178 | t.Fatalf("Error creating tensor with data: %s\n", e) 179 | } 180 | defer tensor2.Destroy() 181 | // Make sure the tensor's internal slice only refers to the part we care 182 | // about, and not the entire slice. 183 | if len(tensor2.GetData()) != 10 { 184 | t.Fatalf("New tensor data contains %d elements, when it should "+ 185 | "contain 10.\n", len(tensor2.GetData())) 186 | } 187 | 188 | // It shouldn't be an error to create a tensor with a 0 dimension; it just 189 | // shouldn't have any underlying data. 190 | tensor3, e := NewEmptyTensor[uint8](NewShape(4, 5, 6, 0, 8)) 191 | if e != nil { 192 | t.Fatalf("Error creating a tensor with a dimension of 0: %s\n", e) 193 | } 194 | defer tensor3.Destroy() 195 | if len(tensor3.GetData()) != 0 { 196 | t.Fatalf("Tensor 3's data length is %d, expected 0.\n", 197 | len(tensor3.GetData())) 198 | } 199 | } 200 | 201 | func TestBoolTensor(t *testing.T) { 202 | InitializeRuntime(t) 203 | defer CleanupRuntime(t) 204 | data := []bool{true, false, true, true} 205 | tensor1, e := NewTensor(NewShape(4), data) 206 | if e != nil { 207 | t.Fatalf("Failed creating bool tensor: %s\n", e) 208 | } 209 | defer tensor1.Destroy() 210 | if len(tensor1.GetData()) != 4 { 211 | t.Fatalf("Incorrect data length for bool tensor: %d\n", 212 | len(tensor1.GetData())) 213 | } 214 | if !tensor1.GetData()[3] { 215 | t.Errorf("Expected tensor1[3] to be true; it wasn't.\n") 216 | } 217 | if tensor1.GetData()[1] { 218 | t.Errorf("Expected tensor1[1] to be false; it wasn't.\n") 219 | } 220 | tensor1.ZeroContents() 221 | for _, b := range data { 222 | if b { 223 | t.Errorf("Zeroing tensor1 didn't set its data slice to false.\n") 224 | } 225 | } 226 | tensor2, e := NewEmptyTensor[bool](tensor1.GetShape()) 227 | if e != nil { 228 | t.Fatalf("Error creating empty bool tensor: %s\n", e) 229 | } 230 | defer tensor2.Destroy() 231 | for _, b := range tensor2.GetData() { 232 | if b { 233 | t.Errorf("A new empty bool tensor wasn't initialized to false.\n") 234 | } 235 | } 236 | } 237 | 238 | func TestBadTensorShapes(t *testing.T) { 239 | InitializeRuntime(t) 240 | defer CleanupRuntime(t) 241 | s := NewShape() 242 | _, e := NewEmptyTensor[float64](s) 243 | if e == nil { 244 | t.Fatalf("Didn't get an error when creating a tensor with an empty " + 245 | "shape.\n") 246 | } 247 | t.Logf("Got expected error when creating a tensor with an empty shape: "+ 248 | "%s\n", e) 249 | s = NewShape(10, 10, -10) 250 | _, e = NewEmptyTensor[int32](s) 251 | if e == nil { 252 | t.Fatalf("Didn't get an error when creating a tensor with a negative" + 253 | " dimension.\n") 254 | } 255 | t.Logf("Got expected error when creating a tensor with a negative "+ 256 | "dimension: %s\n", e) 257 | s = NewShape(10, -10, -10) 258 | _, e = NewEmptyTensor[uint64](s) 259 | if e == nil { 260 | t.Fatalf("Didn't get an error when creating a tensor with two " + 261 | "negative dimensions.\n") 262 | } 263 | t.Logf("Got expected error when creating a tensor with two negative "+ 264 | "dimensions: %s\n", e) 265 | s = NewShape(int64(1)<<62, 1, int64(1)<<62) 266 | _, e = NewEmptyTensor[float32](s) 267 | if e == nil { 268 | t.Fatalf("Didn't get an error when creating a tensor with an " + 269 | "overflowing shape.\n") 270 | } 271 | t.Logf("Got expected error when creating a tensor with an overflowing "+ 272 | "shape: %s\n", e) 273 | } 274 | 275 | func TestCloneTensor(t *testing.T) { 276 | InitializeRuntime(t) 277 | defer CleanupRuntime(t) 278 | originalData := []float32{1, 2, 3, 4} 279 | originalTensor, e := NewTensor(NewShape(2, 2), originalData) 280 | if e != nil { 281 | t.Fatalf("Error creating tensor: %s\n", e) 282 | } 283 | clone, e := originalTensor.Clone() 284 | if e != nil { 285 | t.Fatalf("Error cloning tensor: %s\n", e) 286 | } 287 | if !clone.GetShape().Equals(originalTensor.GetShape()) { 288 | t.Fatalf("Clone shape (%s) doesn't match original shape (%s)\n", 289 | clone.GetShape(), originalTensor.GetShape()) 290 | } 291 | cloneData := clone.GetData() 292 | for i := range originalData { 293 | if cloneData[i] != originalData[i] { 294 | t.Fatalf("Clone data incorrect at index %d: %f (expected %f)\n", 295 | i, cloneData[i], originalData[i]) 296 | } 297 | } 298 | cloneData[2] = 1337 299 | if originalData[2] != 3 { 300 | t.Fatalf("Modifying clone data effected the original.\n") 301 | } 302 | } 303 | 304 | func TestZeroTensorContents(t *testing.T) { 305 | InitializeRuntime(t) 306 | defer CleanupRuntime(t) 307 | a := newTestTensor[float64](t, NewShape(3, 4, 5)) 308 | defer a.Destroy() 309 | data := a.GetData() 310 | for i := range data { 311 | data[i] = float64(i) 312 | } 313 | t.Logf("Before zeroing: a[%d] = %f\n", len(data)-1, data[len(data)-1]) 314 | a.ZeroContents() 315 | for i, v := range data { 316 | if v != 0.0 { 317 | t.Fatalf("a[%d] = %f, expected it to be set to 0.\n", i, v) 318 | } 319 | } 320 | 321 | // Do the same basic test with a CustomDataTensor 322 | shape := NewShape(2, 3, 4, 5) 323 | customData := randomBytes(123, 2*shape.FlattenedSize()) 324 | b, e := NewCustomDataTensor(shape, customData, TensorElementDataTypeUint16) 325 | if e != nil { 326 | t.Fatalf("Error creating custom data tensor: %s\n", e) 327 | } 328 | defer b.Destroy() 329 | for i := range customData { 330 | // This will wrap around, but doesn't matter. We just need arbitrary 331 | // nonzero data for the test. 332 | customData[i] = uint8(i) 333 | } 334 | t.Logf("Start of custom data before zeroing: % x\n", customData[0:10]) 335 | b.ZeroContents() 336 | for i, v := range customData { 337 | if v != 0 { 338 | t.Fatalf("b[%d] = %d, expected it to be set to 0.\n", i, v) 339 | } 340 | } 341 | } 342 | 343 | // This test makes sure that functions taking .onnx data don't crash when 344 | // passed an empty slice. (This used to be a bug.) 345 | func TestEmptyONNXFiles(t *testing.T) { 346 | InitializeRuntime(t) 347 | defer CleanupRuntime(t) 348 | inputNames := []string{"whatever"} 349 | outputNames := []string{"whatever_out"} 350 | dummyIn := newTestTensor[float32](t, NewShape(1)) 351 | defer dummyIn.Destroy() 352 | dummyOut := newTestTensor[float32](t, NewShape(1)) 353 | defer dummyOut.Destroy() 354 | inputTensors := []Value{dummyIn} 355 | outputTensors := []Value{dummyOut} 356 | _, e := NewAdvancedSessionWithONNXData([]byte{}, inputNames, outputNames, 357 | inputTensors, outputTensors, nil) 358 | if e == nil { 359 | // Really we're checking for a panic due to the empty slice, rather 360 | // than a nil error. 361 | t.Fatalf("Didn't get expected error when creating session.\n") 362 | } 363 | t.Logf("Got expected error creating session with no ONNX content: %s\n", e) 364 | _, e = NewDynamicAdvancedSessionWithONNXData([]byte{}, inputNames, 365 | outputNames, nil) 366 | if e == nil { 367 | t.Fatalf("Didn't get expected error when creating dynamic advanced " + 368 | "session.\n") 369 | } 370 | t.Logf("Got expected error when creating dynamic session with no ONNX "+ 371 | "content: %s\n", e) 372 | _, _, e = GetInputOutputInfoWithONNXData([]byte{}) 373 | if e == nil { 374 | t.Fatalf("Didn't get expected error when getting input/output info " + 375 | "with no ONNX content.\n") 376 | } 377 | t.Logf("Got expected error when getting input/output info with no "+ 378 | "ONNX content: %s\n", e) 379 | _, e = GetModelMetadataWithONNXData([]byte{}) 380 | if e == nil { 381 | t.Fatalf("Didn't get expected error when getting metadata with no " + 382 | "ONNX content.\n") 383 | } 384 | t.Logf("Got expected error when getting metadata with no ONNX "+ 385 | "content: %s\n", e) 386 | } 387 | 388 | func TestLegacyAPI(t *testing.T) { 389 | InitializeRuntime(t) 390 | defer CleanupRuntime(t) 391 | 392 | // We'll use this network simply due to its simple input and output format, 393 | // as well as it using the same data type for inputs and outputs. See 394 | // TestNonAsciiPath for more comments. 395 | filePath := "test_data/example ż 大 김.onnx" 396 | inputData := []int32{12, 21} 397 | input, e := NewTensor(NewShape(1, 2), inputData) 398 | if e != nil { 399 | t.Fatalf("Error creating input tensor: %s\n", e) 400 | } 401 | defer input.Destroy() 402 | output := newTestTensor[int32](t, NewShape(1)) 403 | defer output.Destroy() 404 | 405 | session, e := NewSession[int32](filePath, []string{"in"}, []string{"out"}, 406 | []*Tensor[int32]{input}, []*Tensor[int32]{output}) 407 | if e != nil { 408 | t.Fatalf("Error creating sesion via legacy API: %s\n", e) 409 | } 410 | e = session.Run() 411 | if e != nil { 412 | t.Fatalf("Error running session: %s\n", e) 413 | } 414 | expected := inputData[0] + inputData[1] 415 | result := output.GetData()[0] 416 | if result != expected { 417 | t.Errorf("Incorrect result. Expected %d, got %d.\n", expected, result) 418 | } 419 | } 420 | 421 | func TestLegacyAPIDynamic(t *testing.T) { 422 | InitializeRuntime(t) 423 | defer CleanupRuntime(t) 424 | filePath := "test_data/example ż 大 김.onnx" 425 | inputData := []int32{12, 21} 426 | input, e := NewTensor(NewShape(1, 2), inputData) 427 | if e != nil { 428 | t.Fatalf("Error creating input tensor: %s\n", e) 429 | } 430 | defer input.Destroy() 431 | output := newTestTensor[int32](t, NewShape(1)) 432 | defer output.Destroy() 433 | 434 | session, e := NewDynamicSession[int32, int32](filePath, 435 | []string{"in"}, []string{"out"}) 436 | if e != nil { 437 | t.Fatalf("Error creating sesion via legacy API: %s\n", e) 438 | } 439 | e = session.Run([]*Tensor[int32]{input}, []*Tensor[int32]{output}) 440 | if e != nil { 441 | t.Fatalf("Error running session: %s\n", e) 442 | } 443 | expected := inputData[0] + inputData[1] 444 | result := output.GetData()[0] 445 | if result != expected { 446 | t.Errorf("Incorrect result. Expected %d, got %d.\n", expected, result) 447 | } 448 | } 449 | 450 | func TestEnableDisableTelemetry(t *testing.T) { 451 | InitializeRuntime(t) 452 | defer CleanupRuntime(t) 453 | 454 | e := EnableTelemetry() 455 | if e != nil { 456 | t.Errorf("Error enabling onnxruntime telemetry: %s\n", e) 457 | } 458 | e = DisableTelemetry() 459 | if e != nil { 460 | t.Errorf("Error disabling onnxruntime telemetry: %s\n", e) 461 | } 462 | e = EnableTelemetry() 463 | if e != nil { 464 | t.Errorf("Error re-enabling onnxruntime telemetry after "+ 465 | "disabling: %s\n", e) 466 | } 467 | } 468 | 469 | func TestArbitraryTensors(t *testing.T) { 470 | InitializeRuntime(t) 471 | defer CleanupRuntime(t) 472 | 473 | tensorShape := NewShape(2, 2) 474 | tensorA, e := NewTensor(tensorShape, []uint8{1, 2, 3, 4}) 475 | if e != nil { 476 | t.Fatalf("Error creating uint8 tensor: %s\n", e) 477 | } 478 | defer tensorA.Destroy() 479 | tensorB, e := NewTensor(tensorShape, []float64{5, 6, 7, 8}) 480 | if e != nil { 481 | t.Fatalf("Error creating float64 tensor: %s\n", e) 482 | } 483 | defer tensorB.Destroy() 484 | tensorC, e := NewTensor(tensorShape, []int16{9, 10, 11, 12}) 485 | if e != nil { 486 | t.Fatalf("Error creating int16 tensor: %s\n", e) 487 | } 488 | defer tensorC.Destroy() 489 | tensorList := []ArbitraryTensor{tensorA, tensorB, tensorC} 490 | for i, v := range tensorList { 491 | ortValue := v.GetInternals().ortValue 492 | t.Logf("ArbitraryTensor %d: Data type %d, shape %s, OrtValue %p\n", 493 | i, v.DataType(), v.GetShape(), ortValue) 494 | } 495 | } 496 | 497 | // Used for testing the operation of test_data/example_multitype.onnx 498 | func randomMultitypeInputs(t *testing.T, seed int64) (*Tensor[uint8], 499 | *Tensor[float64]) { 500 | rng := rand.New(rand.NewSource(seed)) 501 | inputA := newTestTensor[uint8](t, NewShape(1, 1, 1)) 502 | // We won't use newTestTensor here, otherwise we won't have a chance to 503 | // destroy inputA on failure. 504 | inputB, e := NewEmptyTensor[float64](NewShape(1, 2, 2)) 505 | if e != nil { 506 | inputA.Destroy() 507 | t.Fatalf("Failed creating input B: %s\n", e) 508 | } 509 | inputA.GetData()[0] = uint8(rng.Intn(256)) 510 | for i := 0; i < 4; i++ { 511 | inputB.GetData()[i] = rng.Float64() 512 | } 513 | return inputA, inputB 514 | } 515 | 516 | // Used when checking the output produced by test_data/example_multitype.onnx 517 | func getExpectedMultitypeOutputs(inputA *Tensor[uint8], 518 | inputB *Tensor[float64]) ([]int16, []int64) { 519 | outputA := make([]int16, 4) 520 | dataA := inputA.GetData()[0] 521 | dataB := inputB.GetData() 522 | for i := 0; i < len(outputA); i++ { 523 | outputA[i] = int16((dataB[i] * float64(dataA)) - 512) 524 | } 525 | return outputA, []int64{int64(dataA) * 1234} 526 | } 527 | 528 | // Verifies that the given tensor's data matches the expected content. Prints 529 | // an error and fails the test if anything doesn't match. 530 | func verifyTensorData[T TensorData](t *testing.T, tensor *Tensor[T], 531 | expectedContent []T) { 532 | data := tensor.GetData() 533 | if len(data) != len(expectedContent) { 534 | t.Fatalf("Expected tensor to contain %d elements, got %d elements.\n", 535 | len(expectedContent), len(data)) 536 | } 537 | for i, v := range expectedContent { 538 | if v != data[i] { 539 | t.Fatalf("Data mismatch at index %d: expected %v, got %v\n", i, v, 540 | data[i]) 541 | } 542 | } 543 | } 544 | 545 | // Tests a session taking multiple input tensors of different types and 546 | // producing multiple output tensors of different types. 547 | func TestDifferentInputOutputTypes(t *testing.T) { 548 | InitializeRuntime(t) 549 | defer CleanupRuntime(t) 550 | 551 | inputA, inputB := randomMultitypeInputs(t, 9999) 552 | defer inputA.Destroy() 553 | defer inputB.Destroy() 554 | outputA := newTestTensor[int16](t, NewShape(1, 2, 2)) 555 | defer outputA.Destroy() 556 | outputB := newTestTensor[int64](t, NewShape(1, 1, 1)) 557 | defer outputB.Destroy() 558 | 559 | // Decided to toss in an "ArbitraryTensor" here to ensure that it remains 560 | // compatible with Value in the future. 561 | session, e := NewAdvancedSession("test_data/example_multitype.onnx", 562 | []string{"InputA", "InputB"}, []string{"OutputA", "OutputB"}, 563 | []Value{inputA, inputB}, []ArbitraryTensor{outputA, outputB}, nil) 564 | if e != nil { 565 | t.Fatalf("Failed creating session: %s\n", e) 566 | } 567 | defer session.Destroy() 568 | e = session.Run() 569 | if e != nil { 570 | t.Fatalf("Error running session: %s\n", e) 571 | } 572 | expectedA, expectedB := getExpectedMultitypeOutputs(inputA, inputB) 573 | verifyTensorData(t, outputA, expectedA) 574 | verifyTensorData(t, outputB, expectedB) 575 | } 576 | 577 | func TestDynamicDifferentInputOutputTypes(t *testing.T) { 578 | InitializeRuntime(t) 579 | defer CleanupRuntime(t) 580 | 581 | session, e := NewDynamicAdvancedSession("test_data/example_multitype.onnx", 582 | []string{"InputA", "InputB"}, []string{"OutputA", "OutputB"}, nil) 583 | defer session.Destroy() 584 | 585 | numTests := 100 586 | aInputs := make([]*Tensor[uint8], numTests) 587 | bInputs := make([]*Tensor[float64], numTests) 588 | aOutputs := make([]*Tensor[int16], numTests) 589 | bOutputs := make([]*Tensor[int64], numTests) 590 | 591 | // Make sure we clean up all the tensors created for this test, even if we 592 | // somehow fail before we've created them all. 593 | defer func() { 594 | for i := 0; i < numTests; i++ { 595 | if aInputs[i] != nil { 596 | aInputs[i].Destroy() 597 | } 598 | if bInputs[i] != nil { 599 | bInputs[i].Destroy() 600 | } 601 | if aOutputs[i] != nil { 602 | aOutputs[i].Destroy() 603 | } 604 | if bOutputs[i] != nil { 605 | bOutputs[i].Destroy() 606 | } 607 | } 608 | }() 609 | 610 | // Actually create the inputs and run the tests. 611 | for i := 0; i < numTests; i++ { 612 | aInputs[i], bInputs[i] = randomMultitypeInputs(t, 999+int64(i)) 613 | aOutputs[i] = newTestTensor[int16](t, NewShape(1, 2, 2)) 614 | bOutputs[i] = newTestTensor[int64](t, NewShape(1, 1, 1)) 615 | e = session.Run([]Value{aInputs[i], bInputs[i]}, 616 | []Value{aOutputs[i], bOutputs[i]}) 617 | if e != nil { 618 | t.Fatalf("Failed running session for test %d: %s\n", i, e) 619 | } 620 | } 621 | 622 | // Now that all the tests ran, check the outputs. If the 623 | // DynamicAdvancedSession worked properly, each run should have only 624 | // modified its given outputs. 625 | for i := 0; i < numTests; i++ { 626 | expectedA, expectedB := getExpectedMultitypeOutputs(aInputs[i], 627 | bInputs[i]) 628 | verifyTensorData(t, aOutputs[i], expectedA) 629 | verifyTensorData(t, bOutputs[i], expectedB) 630 | } 631 | } 632 | 633 | func TestDynamicAllocatedOutputTensor(t *testing.T) { 634 | InitializeRuntime(t) 635 | defer CleanupRuntime(t) 636 | 637 | session, e := NewDynamicAdvancedSession("test_data/example_multitype.onnx", 638 | []string{"InputA", "InputB"}, []string{"OutputA", "OutputB"}, nil) 639 | if e != nil { 640 | t.Fatalf("Error creating session: %s\n", e) 641 | } 642 | defer session.Destroy() 643 | 644 | // Actually create the inputs and run the tests. 645 | aInput, bInput := randomMultitypeInputs(t, 999) 646 | var outputs [2]Value 647 | e = session.Run([]Value{aInput, bInput}, outputs[:]) 648 | if e != nil { 649 | t.Fatalf("Failed running session: %s\n", e) 650 | } 651 | defer func() { 652 | for _, output := range outputs { 653 | output.Destroy() 654 | } 655 | }() 656 | 657 | expectedA, expectedB := getExpectedMultitypeOutputs(aInput, bInput) 658 | expectedShape := NewShape(1, 2, 2) 659 | outputA, ok := outputs[0].(*Tensor[int16]) 660 | if !ok { 661 | t.Fatalf("Expected outputA to be of type %T, got of type %T\n", 662 | outputA, outputs[0]) 663 | } 664 | if !outputA.shape.Equals(expectedShape) { 665 | t.Fatalf("Expected outputA to be of shape %s, got of shape %s\n", 666 | expectedShape, outputA.shape) 667 | } 668 | verifyTensorData(t, outputA, expectedA) 669 | 670 | outputB, ok := outputs[1].(*Tensor[int64]) 671 | expectedShape = NewShape(1, 1, 1) 672 | if !ok { 673 | t.Fatalf("Expected outputB to be of type %T, got of type %T\n", 674 | outputB, outputs[1]) 675 | } 676 | if !outputB.shape.Equals(expectedShape) { 677 | t.Fatalf("Expected outputB to be of shape %s, got of shape %s\n", 678 | expectedShape, outputB.shape) 679 | } 680 | verifyTensorData(t, outputB, expectedB) 681 | } 682 | 683 | func TestZeroDimensionOutput(t *testing.T) { 684 | InitializeRuntime(t) 685 | defer CleanupRuntime(t) 686 | 687 | filename := "test_data/example_0_dim_output.onnx" 688 | session, e := NewDynamicAdvancedSession(filename, []string{"x"}, 689 | []string{"y"}, nil) 690 | if e != nil { 691 | t.Fatalf("Error creating session for %s: %s\n", filename, e) 692 | } 693 | defer session.Destroy() 694 | 695 | input, e := NewEmptyTensor[float32](NewShape(2, 8)) 696 | if e != nil { 697 | t.Fatalf("Error creating input tensor: %s\n", e) 698 | } 699 | defer input.Destroy() 700 | outputs := []Value{nil} 701 | 702 | e = session.Run([]Value{input}, outputs) 703 | if e != nil { 704 | t.Fatalf("Error running session: %s\n", e) 705 | } 706 | if outputs[0] == nil { 707 | t.Fatalf("Session did allocate output automatically\n") 708 | } 709 | defer outputs[0].Destroy() 710 | 711 | outputShape := outputs[0].GetShape() 712 | expectedShape := NewShape(2, 0, 8) 713 | if !outputShape.Equals(expectedShape) { 714 | t.Fatalf("Output shape %s does not equal expected shape %s\n", 715 | outputShape, expectedShape) 716 | } 717 | } 718 | 719 | // Makes sure that the sum of each vector in the input tensor matches the 720 | // corresponding scalar in the output tensor. Used when testing tensors with 721 | // unknown batch dimensions. 722 | // NOTE: Destroys the input and output tensors before returning, regardless of 723 | // test success. 724 | func checkVectorSum(input *Tensor[float32], output *Tensor[float32], 725 | t testing.TB) { 726 | defer input.Destroy() 727 | defer output.Destroy() 728 | // Make sure the sizes are what we expect. 729 | inputShape := input.GetShape() 730 | outputShape := output.GetShape() 731 | if len(inputShape) != 2 { 732 | t.Fatalf("Expected a 2-dimensional input shape, got %v\n", inputShape) 733 | } 734 | if len(outputShape) != 1 { 735 | t.Fatalf("Expected 1-dimensional output shape, got %v\n", outputShape) 736 | } 737 | if inputShape[0] != outputShape[0] { 738 | t.Fatalf("Input and output batch dimensions don't match (%d vs %d)\n", 739 | inputShape[0], outputShape[0]) 740 | } 741 | 742 | // Compute the sums in Go 743 | batchSize := inputShape[0] 744 | vectorLength := inputShape[1] 745 | expectedSums := make([]float32, batchSize) 746 | for i := int64(0); i < batchSize; i++ { 747 | inputVector := input.GetData()[i*vectorLength : (i+1)*vectorLength] 748 | sum := float32(0.0) 749 | for _, v := range inputVector { 750 | sum += v 751 | } 752 | expectedSums[i] = sum 753 | } 754 | 755 | e := allFloatsEqual(expectedSums, output.GetData()) 756 | if e != nil { 757 | t.Fatalf("ONNX-produced sums don't match CPU-produced sums: %s\n", e) 758 | } 759 | } 760 | 761 | func TestDynamicInputOutputAxes(t *testing.T) { 762 | InitializeRuntime(t) 763 | defer CleanupRuntime(t) 764 | 765 | netPath := "test_data/example_dynamic_axes.onnx" 766 | session, e := NewDynamicAdvancedSession(netPath, 767 | []string{"input_vectors"}, []string{"output_scalars"}, nil) 768 | if e != nil { 769 | t.Fatalf("Error loading %s: %s\n", netPath, e) 770 | } 771 | defer session.Destroy() 772 | maxBatchSize := 99 773 | // The example network takes a dynamic batch size of vectors containing 10 774 | // elements each. 775 | dataBuffer := make([]float32, maxBatchSize*10) 776 | 777 | // Try running the session with many different batch sizes 778 | for i := 11; i <= maxBatchSize; i += 11 { 779 | // Create an input with the new batch size. 780 | inputShape := NewShape(int64(i), 10) 781 | input, e := NewTensor(inputShape, dataBuffer) 782 | if e != nil { 783 | t.Fatalf("Error creating input tensor with shape %v: %s\n", 784 | inputShape, e) 785 | } 786 | 787 | // Populate the input with new random floats. 788 | fillRandomFloats(input.GetData(), 1234) 789 | 790 | // Run the session; make onnxruntime allocate the output tensor for us. 791 | outputs := []Value{nil} 792 | e = session.Run([]Value{input}, outputs) 793 | if e != nil { 794 | input.Destroy() 795 | t.Fatalf("Error running the session with batch size %d: %s\n", 796 | i, e) 797 | } 798 | 799 | // The checkVectorSum function will destroy the input and output tensor 800 | // regardless of their correctness. 801 | checkVectorSum(input, outputs[0].(*Tensor[float32]), t) 802 | input.Destroy() 803 | t.Logf("Batch size %d seems OK!\n", i) 804 | } 805 | } 806 | 807 | func TestWrongInputs(t *testing.T) { 808 | InitializeRuntime(t) 809 | defer CleanupRuntime(t) 810 | 811 | session, e := NewDynamicAdvancedSession("test_data/example_multitype.onnx", 812 | []string{"InputA", "InputB"}, []string{"OutputA", "OutputB"}, nil) 813 | defer session.Destroy() 814 | 815 | inputA, inputB := randomMultitypeInputs(t, 123456) 816 | defer inputA.Destroy() 817 | defer inputB.Destroy() 818 | outputA := newTestTensor[int16](t, NewShape(1, 2, 2)) 819 | defer outputA.Destroy() 820 | outputB := newTestTensor[int64](t, NewShape(1, 1, 1)) 821 | defer outputB.Destroy() 822 | 823 | // Make sure that passing a tensor with the wrong type but correct shape 824 | // will correctly cause an error rather than a crash, whether used as an 825 | // input or output. 826 | wrongTypeTensor := newTestTensor[float32](t, NewShape(1, 2, 2)) 827 | defer wrongTypeTensor.Destroy() 828 | e = session.Run([]Value{inputA, inputB}, []Value{wrongTypeTensor, outputB}) 829 | if e == nil { 830 | t.Fatalf("Didn't get expected error when passing a float32 tensor in" + 831 | " place of an int16 output tensor.\n") 832 | } 833 | t.Logf("Got expected error when passing a float32 tensor in place of an "+ 834 | "int16 output tensor: %s\n", e) 835 | e = session.Run([]Value{inputA, wrongTypeTensor}, 836 | []Value{outputA, outputB}) 837 | if e == nil { 838 | t.Fatalf("Didn't get expected error when passing a float32 tensor in" + 839 | " place of a float64 input tensor.\n") 840 | } 841 | t.Logf("Got expected error when passing a float32 tensor in place of a "+ 842 | "float64 input tensor: %s\n", e) 843 | 844 | // Make sure that passing a tensor with the wrong shape but correct type 845 | // will cause an error rather than a crash, when using as an input or an 846 | // output. 847 | wrongShapeInput := newTestTensor[uint8](t, NewShape(22)) 848 | defer wrongShapeInput.Destroy() 849 | e = session.Run([]Value{wrongShapeInput, inputB}, 850 | []Value{outputA, outputB}) 851 | if e == nil { 852 | t.Fatalf("Didn't get expected error when running with an incorrectly" + 853 | " shaped input.\n") 854 | } 855 | t.Logf("Got expected error when running with an incorrectly shaped "+ 856 | "input: %s\n", e) 857 | wrongShapeOutput := newTestTensor[int64](t, NewShape(1, 1, 1, 1, 1, 1)) 858 | defer wrongShapeOutput.Destroy() 859 | e = session.Run([]Value{inputA, inputB}, 860 | []Value{outputA, wrongShapeOutput}) 861 | if e == nil { 862 | t.Fatalf("Didn't get expected error when running with an incorrectly" + 863 | " shaped output.\n") 864 | } 865 | t.Logf("Got expected error when running with an incorrectly shaped "+ 866 | "output: %s\n", e) 867 | 868 | e = session.Run([]Value{inputA, inputB}, []Value{outputA, outputB}) 869 | if e != nil { 870 | t.Fatalf("Got error attempting to (correctly) Run a session after "+ 871 | "attempting to use incorrect inputs or outputs: %s\n", e) 872 | } 873 | } 874 | 875 | func TestGetInputOutputInfo(t *testing.T) { 876 | InitializeRuntime(t) 877 | defer CleanupRuntime(t) 878 | file := "test_data/example_several_inputs_and_outputs.onnx" 879 | inputs, outputs, e := GetInputOutputInfo(file) 880 | if e != nil { 881 | t.Fatalf("Error getting input and output info for %s: %s\n", file, e) 882 | } 883 | if len(inputs) != 3 { 884 | t.Fatalf("Expected 3 inputs, got %d\n", len(inputs)) 885 | } 886 | if len(outputs) != 2 { 887 | t.Fatalf("Expected 2 outputs, got %d\n", len(outputs)) 888 | } 889 | for i, v := range inputs { 890 | t.Logf("Input %d: %s\n", i, &v) 891 | } 892 | for i, v := range outputs { 893 | t.Logf("Output %d: %s\n", i, &v) 894 | } 895 | 896 | if outputs[1].Name != "output 2" { 897 | t.Errorf("Incorrect output 1 name: %s, expected \"output 2\"\n", 898 | outputs[1].Name) 899 | } 900 | expectedShape := NewShape(1, 2, 3, 4, 5) 901 | if !outputs[1].Dimensions.Equals(expectedShape) { 902 | t.Errorf("Incorrect output 1 shape: %s, expected %s\n", 903 | outputs[1].Dimensions, expectedShape) 904 | } 905 | var expectedType TensorElementDataType = TensorElementDataTypeDouble 906 | if outputs[1].DataType != expectedType { 907 | t.Errorf("Incorrect output 1 data type: %s, expected %s\n", 908 | outputs[1].DataType, expectedType) 909 | } 910 | if inputs[0].Name != "input 1" { 911 | t.Errorf("Incorrect input 0 name: %s, expected \"input 1\"\n", 912 | inputs[0].Name) 913 | } 914 | expectedShape = NewShape(2, 5, 2, 5) 915 | if !inputs[0].Dimensions.Equals(expectedShape) { 916 | t.Errorf("Incorrect input 0 shape: %s, expected %s\n", 917 | inputs[0].Dimensions, expectedShape) 918 | } 919 | expectedType = TensorElementDataTypeInt32 920 | if inputs[0].DataType != expectedType { 921 | t.Errorf("Incorrect input 0 data type: %s, expected %s\n", 922 | inputs[0].DataType, expectedType) 923 | } 924 | } 925 | 926 | func TestModelMetadata(t *testing.T) { 927 | InitializeRuntime(t) 928 | defer CleanupRuntime(t) 929 | file := "test_data/example_big_compute.onnx" 930 | metadata, e := GetModelMetadata(file) 931 | if e != nil { 932 | t.Fatalf("Error getting metadata for %s: %s\n", file, e) 933 | } 934 | // We'll just test Destroy once; after this we won't check its return value 935 | e = metadata.Destroy() 936 | if e != nil { 937 | t.Fatalf("Error destroying metadata: %s\n", e) 938 | } 939 | 940 | // Try getting the metadata from a session instead of from a file. 941 | // NOTE: All of the expected values here were manually set using the 942 | // test_data/modify_metadata.py script after generating the network. See 943 | // that script for the expected values of each of the metadata accesors. 944 | session, e := NewDynamicAdvancedSession(file, []string{"Input"}, 945 | []string{"Output"}, nil) 946 | if e != nil { 947 | t.Fatalf("Error creating session: %s\n", e) 948 | } 949 | defer session.Destroy() 950 | metadata, e = session.GetModelMetadata() 951 | if e != nil { 952 | t.Fatalf("Error getting metadata from DynamicAdvancedSession: %s\n", e) 953 | } 954 | defer metadata.Destroy() 955 | producerName, e := metadata.GetProducerName() 956 | if e != nil { 957 | t.Errorf("Error getting producer name: %s\n", e) 958 | } else { 959 | t.Logf("Got producer name: %s\n", producerName) 960 | } 961 | graphName, e := metadata.GetGraphName() 962 | if e != nil { 963 | t.Errorf("Error getting graph name: %s\n", e) 964 | } else { 965 | t.Logf("Got graph name: %s\n", graphName) 966 | } 967 | domainStr, e := metadata.GetDomain() 968 | if e != nil { 969 | t.Errorf("Error getting domain: %s\n", e) 970 | } else { 971 | t.Logf("Got domain: %s\n", domainStr) 972 | if domainStr != "test domain" { 973 | t.Errorf("Incorrect domain string, expected \"test domain\"\n") 974 | } 975 | } 976 | description, e := metadata.GetDescription() 977 | if e != nil { 978 | t.Errorf("Error getting description: %s\n", e) 979 | } else { 980 | t.Logf("Got description: %s\n", description) 981 | } 982 | version, e := metadata.GetVersion() 983 | if e != nil { 984 | t.Errorf("Error getting version: %s\n", e) 985 | } else { 986 | t.Logf("Got version: %d\n", version) 987 | if version != 1337 { 988 | t.Errorf("Incorrect version number, expected 1337\n") 989 | } 990 | } 991 | mapKeys, e := metadata.GetCustomMetadataMapKeys() 992 | if e != nil { 993 | t.Fatalf("Error getting custom metadata keys: %s\n", e) 994 | } 995 | t.Logf("Got %d custom metadata map keys.\n", len(mapKeys)) 996 | if len(mapKeys) != 2 { 997 | t.Errorf("Incorrect number of custom metadata keys, expected 2") 998 | } 999 | for _, k := range mapKeys { 1000 | value, present, e := metadata.LookupCustomMetadataMap(k) 1001 | if e != nil { 1002 | t.Errorf("Error looking up key %s in custom metadata: %s\n", k, e) 1003 | } else { 1004 | if !present { 1005 | t.Errorf("LookupCustomMetadataMap didn't return true for a " + 1006 | "key that should be present in the map\n") 1007 | } 1008 | t.Logf(" Metadata key \"%s\" = \"%s\"\n", k, value) 1009 | } 1010 | } 1011 | badValue, present, e := metadata.LookupCustomMetadataMap("invalid key") 1012 | if len(badValue) != 0 { 1013 | t.Fatalf("Didn't get an empty string when looking up an invalid "+ 1014 | "metadata key, got \"%s\" instead\n", badValue) 1015 | } 1016 | if present { 1017 | t.Errorf("LookupCustomMetadataMap didn't return false for a key that" + 1018 | " isn't in the map\n") 1019 | } 1020 | // Tossing in this check, since the docs aren't clear on this topic. (The 1021 | // docs specify returning an empty string, but do not mention a non-NULL 1022 | // OrtStatus.) At the time of writing, it does _not_ return an error. 1023 | if e == nil { 1024 | t.Logf("Informational: looking up an invalid metadata key doesn't " + 1025 | "return an error\n") 1026 | } else { 1027 | t.Logf("Informational: got error when looking up an invalid "+ 1028 | "metadata key: %s\n", e) 1029 | } 1030 | } 1031 | 1032 | func randomBytes(seed, n int64) []byte { 1033 | toReturn := make([]byte, n) 1034 | rng := rand.New(rand.NewSource(seed)) 1035 | rng.Read(toReturn) 1036 | return toReturn 1037 | } 1038 | 1039 | func fillRandomFloats(dst []float32, seed int64) { 1040 | rng := rand.New(rand.NewSource(seed)) 1041 | for i := range dst { 1042 | dst[i] = rng.Float32() 1043 | } 1044 | } 1045 | 1046 | func TestCustomDataTensors(t *testing.T) { 1047 | InitializeRuntime(t) 1048 | defer CleanupRuntime(t) 1049 | shape := NewShape(2, 3, 4, 5) 1050 | tensorData := randomBytes(123, 2*shape.FlattenedSize()) 1051 | // This could have been created using a Tensor[uint16], but we'll make sure 1052 | // it works this way, too. 1053 | v, e := NewCustomDataTensor(shape, tensorData, TensorElementDataTypeUint16) 1054 | if e != nil { 1055 | t.Fatalf("Error creating uint16 CustomDataTensor: %s\n", e) 1056 | } 1057 | shape[0] = 6 1058 | if v.GetShape().Equals(shape) { 1059 | t.Fatalf("CustomDataTensor didn't properly clone its shape") 1060 | } 1061 | e = v.Destroy() 1062 | if e != nil { 1063 | t.Fatalf("Error destroying CustomDataTensor: %s\n", e) 1064 | } 1065 | tensorData = randomBytes(1234, 2*shape.FlattenedSize()) 1066 | v, e = NewCustomDataTensor(shape, tensorData, TensorElementDataTypeFloat16) 1067 | if e != nil { 1068 | t.Fatalf("Error creating float16 tensor: %s\n", e) 1069 | } 1070 | e = v.Destroy() 1071 | if e != nil { 1072 | t.Fatalf("Error destroying float16 tensor: %s\n", e) 1073 | } 1074 | // Make sure we don't fail if providing more data than necessary 1075 | shape[0] = 1 1076 | v, e = NewCustomDataTensor(shape, tensorData, 1077 | TensorElementDataTypeBFloat16) 1078 | if e != nil { 1079 | t.Fatalf("Got error when creating a tensor with more data than "+ 1080 | "necessary: %s\n", e) 1081 | } 1082 | v.Destroy() 1083 | 1084 | // Make sure we fail when using a bad shape 1085 | shape = NewShape(0, -1, -2) 1086 | v, e = NewCustomDataTensor(shape, tensorData, TensorElementDataTypeFloat16) 1087 | if e == nil { 1088 | v.Destroy() 1089 | t.Fatalf("Didn't get error when creating custom tensor with an " + 1090 | "invalid shape\n") 1091 | } 1092 | t.Logf("Got expected error creating tensor with invalid shape: %s\n", e) 1093 | shape = NewShape(1, 2, 3, 4, 5) 1094 | tensorData = []byte{1, 2, 3, 4} 1095 | v, e = NewCustomDataTensor(shape, tensorData, TensorElementDataTypeUint8) 1096 | if e == nil { 1097 | v.Destroy() 1098 | t.Fatalf("Didn't get error when creating custom tensor with too " + 1099 | "little data\n") 1100 | } 1101 | t.Logf("Got expected error when creating custom data tensor with "+ 1102 | "too little data: %s\n", e) 1103 | 1104 | // Make sure we fail when using a bad type 1105 | tensorData = []byte{1, 2, 3, 4, 5, 6, 7, 8} 1106 | badType := TensorElementDataType(0xffffff) 1107 | v, e = NewCustomDataTensor(NewShape(2), tensorData, badType) 1108 | if e == nil { 1109 | v.Destroy() 1110 | t.Fatalf("Didn't get error when creating tensor with bad type\n") 1111 | } 1112 | t.Logf("Got expected error when creating custom data tensor with bad "+ 1113 | "type: %s\n", e) 1114 | } 1115 | 1116 | // Converts a slice of floats to their representation as bfloat16 bytes. 1117 | func floatsToBfloat16(f []float32) []byte { 1118 | toReturn := make([]byte, 2*len(f)) 1119 | // bfloat16 is just a truncated version of a float32 1120 | for i := range f { 1121 | bf16Bits := uint16(math.Float32bits(f[i]) >> 16) 1122 | toReturn[i*2] = uint8(bf16Bits) 1123 | toReturn[i*2+1] = uint8(bf16Bits >> 8) 1124 | } 1125 | return toReturn 1126 | } 1127 | 1128 | func TestFloat16Network(t *testing.T) { 1129 | InitializeRuntime(t) 1130 | defer CleanupRuntime(t) 1131 | 1132 | // The network takes a 1x2x2x2 float16 input 1133 | inputData := []byte{ 1134 | // 0.0, 1.0, 2.0, 3.0 1135 | 0x00, 0x00, 0x00, 0x3c, 0x00, 0x40, 0x00, 0x42, 1136 | // 4.0, 5.0, 6.0, 7.0 1137 | 0x00, 0x44, 0x00, 0x45, 0x00, 0x46, 0x00, 0x47, 1138 | } 1139 | // The network produces a 1x2x2x2 bfloat16 output: the input multiplied 1140 | // by 3 1141 | expectedOutput := floatsToBfloat16([]float32{0, 3, 6, 9, 12, 15, 18, 21}) 1142 | outputData := make([]byte, len(expectedOutput)) 1143 | inputTensor, e := NewCustomDataTensor(NewShape(1, 2, 2, 2), inputData, 1144 | TensorElementDataTypeFloat16) 1145 | if e != nil { 1146 | t.Fatalf("Error creating input tensor: %s\n", e) 1147 | } 1148 | defer inputTensor.Destroy() 1149 | outputTensor, e := NewCustomDataTensor(NewShape(1, 2, 2, 2), outputData, 1150 | TensorElementDataTypeBFloat16) 1151 | if e != nil { 1152 | t.Fatalf("Error creating output tensor: %s\n", e) 1153 | } 1154 | defer outputTensor.Destroy() 1155 | 1156 | session, e := NewAdvancedSession("test_data/example_float16.onnx", 1157 | []string{"InputA"}, []string{"OutputA"}, 1158 | []Value{inputTensor}, []Value{outputTensor}, nil) 1159 | if e != nil { 1160 | t.Fatalf("Error creating session: %s\n", e) 1161 | } 1162 | defer session.Destroy() 1163 | e = session.Run() 1164 | if e != nil { 1165 | t.Fatalf("Error running session: %s\n", e) 1166 | } 1167 | for i := range outputData { 1168 | if outputData[i] != expectedOutput[i] { 1169 | t.Fatalf("Incorrect output byte at index %d: 0x%02x (expected "+ 1170 | "0x%02x)\n", i, outputData[i], expectedOutput[i]) 1171 | } 1172 | } 1173 | } 1174 | 1175 | // Returns a 10-element tensor randomly filled values using the given rng seed. 1176 | func randomSmallTensor(seed int64, t testing.TB) *Tensor[float32] { 1177 | toReturn, e := NewEmptyTensor[float32](NewShape(10)) 1178 | if e != nil { 1179 | t.Fatalf("Error creating small tensor: %s\n", e) 1180 | } 1181 | fillRandomFloats(toReturn.GetData(), seed) 1182 | return toReturn 1183 | } 1184 | 1185 | func TestONNXSequence(t *testing.T) { 1186 | InitializeRuntime(t) 1187 | defer CleanupRuntime(t) 1188 | sequenceLength := int64(123) 1189 | 1190 | values := make([]Value, sequenceLength) 1191 | for i := range values { 1192 | values[i] = randomSmallTensor(int64(i)+123, t) 1193 | } 1194 | defer func() { 1195 | for _, v := range values { 1196 | v.Destroy() 1197 | } 1198 | }() 1199 | sequence, e := NewSequence(values) 1200 | if e != nil { 1201 | t.Fatalf("Error creating sequence: %s\n", e) 1202 | } 1203 | defer sequence.Destroy() 1204 | sequenceContents, e := sequence.GetValues() 1205 | if e != nil { 1206 | t.Fatalf("Error getting sequence contents: %s\n", e) 1207 | } 1208 | if int64(len(sequenceContents)) != sequenceLength { 1209 | t.Fatalf("Got %d values in sequence, expected %d\n", 1210 | len(sequenceContents), sequenceLength) 1211 | } 1212 | if sequence.GetONNXType() != ONNXTypeSequence { 1213 | t.Fatalf("Got incorrect ONNX type for sequence: %s\n", 1214 | sequence.GetONNXType()) 1215 | } 1216 | // Make sure we adhere to what I wrote in the docs 1217 | if !sequence.GetShape().Equals(NewShape(sequenceLength)) { 1218 | t.Fatalf("Sequence.GetShape() returned incorrect shape: %s\n", 1219 | sequence.GetShape()) 1220 | } 1221 | 1222 | selectedIndex := 44 1223 | selectedValue := sequenceContents[selectedIndex] 1224 | if selectedValue.GetONNXType() != ONNXTypeTensor { 1225 | t.Fatalf("Got incorrect ONNXType for value at index %d: "+ 1226 | "expected %s, got %s\n", selectedIndex, ONNXType(ONNXTypeTensor), 1227 | selectedValue.GetONNXType()) 1228 | } 1229 | } 1230 | 1231 | func TestBadSequences(t *testing.T) { 1232 | InitializeRuntime(t) 1233 | defer CleanupRuntime(t) 1234 | 1235 | // Sequences containing no elements or nil entries shouldn't be allowed 1236 | _, e := NewSequence([]Value{}) 1237 | if e == nil { 1238 | t.Fatalf("Didn't get expected error when creating an empty sequence\n") 1239 | } 1240 | t.Logf("Got expected error when creating an empty sequence: %s\n", e) 1241 | _, e = NewSequence([]Value{nil}) 1242 | if e == nil { 1243 | t.Fatalf("Didn't get expected error when creating sequence with a " + 1244 | "nil entry.\n") 1245 | } 1246 | t.Logf("Got expected error when creating sequence with nil entry: %s\n", e) 1247 | 1248 | // Sequences containing mixed data types shouldn't be allowed 1249 | tensor := randomSmallTensor(1337, t) 1250 | defer tensor.Destroy() 1251 | innerSequence, e := NewSequence([]Value{tensor}) 1252 | if e != nil { 1253 | t.Fatalf("Error creating 1-element sequence: %s\n", e) 1254 | } 1255 | defer innerSequence.Destroy() 1256 | _, e = NewSequence([]Value{tensor, innerSequence}) 1257 | if e == nil { 1258 | t.Fatalf("Didn't get expected error when attempting to create a "+ 1259 | "mixed sequence: %s\n", e) 1260 | } 1261 | t.Logf("Got expected error when attempting a mixed sequence: %s\n", e) 1262 | 1263 | // Nested sequences also aren't allowed; the C API docs don't seem to 1264 | // mention this either. 1265 | _, e = NewSequence([]Value{innerSequence, innerSequence}) 1266 | if e == nil { 1267 | t.Fatalf("Didn't get an error creating a sequence with nested " + 1268 | "sequences.\n") 1269 | } 1270 | t.Logf("Got expected error when creating a sequence with nested "+ 1271 | "sequences: %s\n", e) 1272 | } 1273 | 1274 | func TestMap(t *testing.T) { 1275 | InitializeRuntime(t) 1276 | defer CleanupRuntime(t) 1277 | 1278 | testGoMap := map[int64]float64{ 1279 | 123: 456.7, 1280 | 789: 123.4, 1281 | } 1282 | m, e := NewMapFromGoMap(testGoMap) 1283 | if e != nil { 1284 | t.Fatalf("Error creating onnx map from Go map: %s\n", e) 1285 | } 1286 | defer m.Destroy() 1287 | keys, values, e := m.GetKeysAndValues() 1288 | if e != nil { 1289 | t.Fatalf("Error getting map keys and values: %s\n", e) 1290 | } 1291 | 1292 | // In real code I almost certainly would do these type assertions without 1293 | // the checks, and just panic if it was wrong. But it makes sense in a test 1294 | keysTensor, ok := keys.(*Tensor[int64]) 1295 | if !ok { 1296 | t.Fatalf("Keys weren't a uint32 tensor, but %s\n", 1297 | TensorElementDataType(keysTensor.DataType())) 1298 | } 1299 | valuesTensor, ok := values.(*Tensor[float64]) 1300 | if !ok { 1301 | t.Fatalf("Values weren't a float64 tensor, but %s\n", 1302 | TensorElementDataType(valuesTensor.DataType())) 1303 | } 1304 | 1305 | if !keysTensor.GetShape().Equals(valuesTensor.GetShape()) { 1306 | t.Fatalf("Key and value tensor shapes don't match: %s vs %s\n", 1307 | keysTensor.GetShape(), valuesTensor.GetShape()) 1308 | } 1309 | 1310 | for i, k := range keysTensor.GetData() { 1311 | v := valuesTensor.GetData()[i] 1312 | e = floatsEqual(float32(v), float32(testGoMap[k])) 1313 | if e != nil { 1314 | t.Errorf("Value for key %d doesn't match: %s\n", k, e) 1315 | } 1316 | } 1317 | } 1318 | 1319 | func TestBadMaps(t *testing.T) { 1320 | InitializeRuntime(t) 1321 | defer CleanupRuntime(t) 1322 | 1323 | // There are many, many ways I've found to create a bad map. This test only 1324 | // checks a few of them. 1325 | 1326 | // Floats aren't supported as keys. 1327 | floatKeysTensor := newTestTensor[float32](t, NewShape(10)) 1328 | defer floatKeysTensor.Destroy() 1329 | floatValuesTensor := newTestTensor[float32](t, NewShape(10)) 1330 | defer floatValuesTensor.Destroy() 1331 | _, e := NewMap(floatKeysTensor, floatValuesTensor) 1332 | if e == nil { 1333 | t.Fatalf("Didn't get expected error when using float map keys.\n") 1334 | } 1335 | t.Logf("Got expected error when using float map keys: %s\n", e) 1336 | 1337 | // The length of keys and values must match. 1338 | tooManyKeysTensor := newTestTensor[int64](t, NewShape(16)) 1339 | for i := range tooManyKeysTensor.GetData() { 1340 | tooManyKeysTensor.GetData()[i] = int64(i) 1341 | } 1342 | defer tooManyKeysTensor.Destroy() 1343 | _, e = NewMap(tooManyKeysTensor, floatValuesTensor) 1344 | if e == nil { 1345 | t.Fatalf("Didn't get expected error when map keys and values are " + 1346 | "different sizes.\n") 1347 | } 1348 | t.Logf("Got expected error when keys and values lengths mismatch: %s\n", e) 1349 | } 1350 | 1351 | func TestSklearnNetwork(t *testing.T) { 1352 | InitializeRuntime(t) 1353 | defer CleanupRuntime(t) 1354 | 1355 | // These inputs and outputs were taken from the information printed by 1356 | // test_data/generate_sklearn_network.py 1357 | inputShape := NewShape(6, 4) 1358 | inputValues := []float32{ 1359 | 5.9, 3.0, 5.1, 1.8, 1360 | 6.8, 2.8, 4.8, 1.4, 1361 | 6.3, 2.3, 4.4, 1.3, 1362 | 6.5, 3.0, 5.5, 1.8, 1363 | 7.7, 2.8, 6.7, 2.0, 1364 | 5.5, 2.5, 4.0, 1.3, 1365 | } 1366 | 1367 | // "output_label": A tensor of an int64 label per set of 4 inputs 1368 | expectedPredictions := []int64{2, 1, 1, 2, 2, 1} 1369 | 1370 | // "output_probability": A sequence of maps, mapping each int64 label to a 1371 | // float64 output. We'll just store them in order here. 1372 | outputProbabilities := []map[int64]float32{ 1373 | {0: 0.0, 1: 0.12999998033046722, 2: 0.8699994683265686}, 1374 | {0: 0.0, 1: 0.7699995636940002, 2: 0.23000003397464752}, 1375 | {0: 0.0, 1: 0.969999372959137, 2: 0.029999999329447746}, 1376 | {0: 0.0, 1: 0.0, 2: 0.9999993443489075}, 1377 | {0: 0.0, 1: 0.0, 2: 0.9999993443489075}, 1378 | {0: 0.0, 1: 0.9999993443489075, 2: 0.0}, 1379 | } 1380 | 1381 | modelPath := "test_data/sklearn_randomforest.onnx" 1382 | session, e := NewDynamicAdvancedSession(modelPath, []string{"X"}, 1383 | []string{"output_label", "output_probability"}, nil) 1384 | if e != nil { 1385 | t.Fatalf("Error loading %s: %s\n", modelPath, e) 1386 | } 1387 | defer session.Destroy() 1388 | 1389 | // The point of this test is to make sure we get the correct types and 1390 | // results when the network allocates the output values. 1391 | outputs := []Value{nil, nil} 1392 | inputTensor, e := NewTensor(inputShape, inputValues) 1393 | if e != nil { 1394 | t.Fatalf("Error creating input tensor: %s\n", e) 1395 | } 1396 | defer inputTensor.Destroy() 1397 | e = session.Run([]Value{inputTensor}, outputs) 1398 | if e != nil { 1399 | t.Fatalf("Error running %s: %s\n", modelPath, e) 1400 | } 1401 | defer func() { 1402 | for _, v := range outputs { 1403 | v.Destroy() 1404 | } 1405 | }() 1406 | 1407 | // First, check the easy part: the int64 output tensor 1408 | tensorDataType := TensorElementDataType(outputs[0].DataType()) 1409 | if tensorDataType != TensorElementDataTypeInt64 { 1410 | t.Fatalf("Expected int64 output tensor, got %s\n", tensorDataType) 1411 | } 1412 | predictionTensor := outputs[0].(*Tensor[int64]) 1413 | predictions := predictionTensor.GetData() 1414 | if len(predictions) != len(expectedPredictions) { 1415 | t.Fatalf("Expected %d predictions, got %d\n", len(expectedPredictions), 1416 | len(predictions)) 1417 | } 1418 | for i, v := range expectedPredictions { 1419 | actualPrediction := predictions[i] 1420 | if v != actualPrediction { 1421 | t.Errorf("Incorrect prediction at index %d: %d (expected %d)\n", 1422 | i, actualPrediction, v) 1423 | } 1424 | } 1425 | 1426 | // Next, check the sequence of maps. There is one map giving the fine- 1427 | // grained probabilities for each label. (Predictions is just the entry 1428 | // of each map with the highest probability.) 1429 | sequence, ok := outputs[1].(*Sequence) 1430 | if !ok { 1431 | t.Fatalf("Expected a sequence for the probabilities output, got %s\n", 1432 | outputs[1].GetONNXType()) 1433 | } 1434 | probabilityMaps, e := sequence.GetValues() 1435 | if e != nil { 1436 | t.Fatalf("Error getting contents of sequence of maps: %s\n", e) 1437 | } 1438 | if len(probabilityMaps) != len(expectedPredictions) { 1439 | t.Fatalf("Expected a %d-element sequence, got %d\n", 1440 | len(expectedPredictions), len(probabilityMaps)) 1441 | } 1442 | for i := range probabilityMaps { 1443 | m, isMap := probabilityMaps[i].(*Map) 1444 | if !isMap { 1445 | t.Fatalf("Output sequence index %d wasn't a map, but a %s\n", i, 1446 | probabilityMaps[i].GetONNXType()) 1447 | } 1448 | keys, values, e := m.GetKeysAndValues() 1449 | if e != nil { 1450 | t.Fatalf("Error getting keys and values for map at index %d: %s\n", 1451 | i, e) 1452 | } 1453 | if !keys.GetShape().Equals(values.GetShape()) { 1454 | t.Fatalf("Key and value tensors don't match in shape: %s vs %s\n", 1455 | keys.GetShape(), values.GetShape()) 1456 | } 1457 | keysTensor, ok := keys.(*Tensor[int64]) 1458 | if !ok { 1459 | t.Fatalf("Keys were not an int64 tensor\n") 1460 | } 1461 | valuesTensor, ok := values.(*Tensor[float32]) 1462 | if !ok { 1463 | t.Fatalf("Values were not a float32 tensor\n") 1464 | } 1465 | expectedProbabilities := outputProbabilities[i] 1466 | for j, key := range keysTensor.GetData() { 1467 | v := valuesTensor.GetData()[j] 1468 | e = floatsEqual(expectedProbabilities[key], v) 1469 | if e != nil { 1470 | t.Errorf("Expected values don't match for key %d in map "+ 1471 | "index %d: %s\n", key, i, e) 1472 | } 1473 | } 1474 | } 1475 | } 1476 | 1477 | // This tests that we're able to read a file containing multi-byte characters 1478 | // in the path. 1479 | func TestNonAsciiPath(t *testing.T) { 1480 | InitializeRuntime(t) 1481 | defer CleanupRuntime(t) 1482 | 1483 | // The test network just adds two integers and returns the result. 1484 | inputData := []int32{12, 21} 1485 | input, e := NewTensor(NewShape(1, 2), inputData) 1486 | if e != nil { 1487 | t.Fatalf("Error creating input tensor: %s\n", e) 1488 | } 1489 | defer input.Destroy() 1490 | output := newTestTensor[int32](t, NewShape(1)) 1491 | defer output.Destroy() 1492 | 1493 | filePath := "test_data/example ż 大 김.onnx" 1494 | session, e := NewAdvancedSession(filePath, []string{"in"}, []string{"out"}, 1495 | []Value{input}, []Value{output}, nil) 1496 | if e != nil { 1497 | t.Fatalf("Failed creating session for %s: %s\n", filePath, e) 1498 | } 1499 | 1500 | e = session.Run() 1501 | if e != nil { 1502 | t.Fatalf("Error running %s: %s\n", filePath, e) 1503 | } 1504 | expected := inputData[0] + inputData[1] 1505 | result := output.GetData()[0] 1506 | if result != expected { 1507 | t.Errorf("Running %s gave the wrong result. Expected %d, got %d.\n", 1508 | filePath, expected, result) 1509 | } 1510 | } 1511 | 1512 | // This tests that the *WithONNXData method works for loading a session. 1513 | // Hopefully this covers most other *WithONNXData variants, since all use the 1514 | // same code internally when creating an OrtSession in C. 1515 | func TestSessionFromDataBuffer(t *testing.T) { 1516 | // This test is almost a copy of TestNonAsciiPath, since it was fairly 1517 | // simple. 1518 | InitializeRuntime(t) 1519 | defer CleanupRuntime(t) 1520 | 1521 | inputData := []int32{12, 21} 1522 | input, e := NewTensor(NewShape(1, 2), inputData) 1523 | if e != nil { 1524 | t.Fatalf("Error creating input tensor: %s\n", e) 1525 | } 1526 | defer input.Destroy() 1527 | output := newTestTensor[int32](t, NewShape(1)) 1528 | defer output.Destroy() 1529 | 1530 | filePath := "test_data/example ż 大 김.onnx" 1531 | fileData, e := os.ReadFile(filePath) 1532 | if e != nil { 1533 | t.Fatalf("Error buffering content of %s: %s\n", filePath, e) 1534 | } 1535 | 1536 | session, e := NewAdvancedSessionWithONNXData(fileData, []string{"in"}, 1537 | []string{"out"}, []Value{input}, []Value{output}, nil) 1538 | if e != nil { 1539 | t.Fatalf("Failed creating session: %s\n", e) 1540 | } 1541 | e = session.Run() 1542 | if e != nil { 1543 | t.Fatalf("Error running session: %s\n", e) 1544 | } 1545 | expected := inputData[0] + inputData[1] 1546 | result := output.GetData()[0] 1547 | if result != expected { 1548 | t.Errorf("Incorrect result. Expected %d, got %d.\n", expected, result) 1549 | } 1550 | } 1551 | 1552 | func TestScalar(t *testing.T) { 1553 | InitializeRuntime(t) 1554 | defer CleanupRuntime(t) 1555 | s, e := NewEmptyScalar[float32]() 1556 | if e != nil { 1557 | t.Fatalf("Error creating empty scalar: %s\n", e) 1558 | } 1559 | if s.GetData() != 0.0 { 1560 | t.Fatalf("Empty scalar not initialized to 0: %s\n", e) 1561 | } 1562 | e = s.Destroy() 1563 | if e != nil { 1564 | t.Fatalf("Failed destroying scalar: %s\n", e) 1565 | } 1566 | s2, e := NewScalar(int64(1337)) 1567 | if e != nil { 1568 | t.Fatalf("Failed creating int64 scalar: %s\n", e) 1569 | } 1570 | defer s2.Destroy() 1571 | contents := s2.GetData() 1572 | if contents != 1337 { 1573 | t.Fatalf("Incorrect initial contents of s2: %d\n", contents) 1574 | } 1575 | s2.ZeroContents() 1576 | contents = s2.GetData() 1577 | if contents != 0 { 1578 | t.Fatalf("Incorrect value of s2 after zeroing: %d\n", contents) 1579 | } 1580 | s2.Set(1234) 1581 | contents = s2.GetData() 1582 | if contents != 1234 { 1583 | t.Fatalf("Incorrect value of s2: %d (expected 1234)\n", contents) 1584 | } 1585 | } 1586 | 1587 | func TestIoBinding(t *testing.T) { 1588 | InitializeRuntime(t) 1589 | defer CleanupRuntime(t) 1590 | // TODO: Create a slightly better test for I/O bindings: 1591 | // - Maybe make it something with a fixed input that will perform better 1592 | // if bound and unchanged. 1593 | // - Have multiple outputs, including ones with multibyte chars in their 1594 | // names (if this is supported). This would better exercise the 1595 | // GetBoundOutputNames function. 1596 | filePath := "test_data/example ż 大 김.onnx" 1597 | 1598 | session, e := NewDynamicAdvancedSession(filePath, nil, nil, nil) 1599 | if e != nil { 1600 | t.Fatalf("Error creating session without specifying input names: %s\n", 1601 | e) 1602 | } 1603 | defer session.Destroy() 1604 | binding, e := session.CreateIoBinding() 1605 | if e != nil { 1606 | t.Fatalf("Error creating I/O binding object: %s\n", e) 1607 | } 1608 | defer func() { 1609 | e := binding.Destroy() 1610 | if e != nil { 1611 | t.Logf("Error destroying I/O binding: %s\n", e) 1612 | t.Fail() 1613 | } 1614 | }() 1615 | 1616 | // Run the session with the bindings and allow it to allocate the outputs 1617 | // automatically. 1618 | inputData := []int32{1000, 337} 1619 | input, e := NewTensor(NewShape(1, 2), inputData) 1620 | if e != nil { 1621 | t.Fatalf("Error creating input tensor: %s\n", e) 1622 | } 1623 | defer input.Destroy() 1624 | e = binding.BindInput("in", input) 1625 | if e != nil { 1626 | t.Fatalf("Error binding input: %s\n", e) 1627 | } 1628 | output, e := NewEmptyTensor[int32](NewShape(1)) 1629 | if e != nil { 1630 | t.Fatalf("Error creating output tensor: %s\n", e) 1631 | } 1632 | defer output.Destroy() 1633 | e = binding.BindOutput("out", output) 1634 | if e != nil { 1635 | t.Fatalf("Error binding output: %s\n", e) 1636 | } 1637 | 1638 | e = session.RunWithBinding(binding) 1639 | if e != nil { 1640 | t.Fatalf("Error running session with I/O binding: %s\n", e) 1641 | } 1642 | outputNames, e := binding.GetBoundOutputNames() 1643 | if e != nil { 1644 | t.Fatalf("Error getting bound output names: %s\n", e) 1645 | } 1646 | if len(outputNames) != 1 { 1647 | t.Fatalf("Expected 1 output name, got %d\n", len(outputNames)) 1648 | } 1649 | if outputNames[0] != "out" { 1650 | t.Fatalf("Incorrect bound output name, expected \"out\", got \"%s\"\n", 1651 | outputNames[0]) 1652 | } 1653 | 1654 | // Ensure the actual output is what we expect 1655 | outputValues, e := binding.GetBoundOutputValues() 1656 | if e != nil { 1657 | t.Fatalf("Error getting bound output values: %s\n", e) 1658 | } 1659 | defer func() { 1660 | for _, v := range outputValues { 1661 | v.Destroy() 1662 | } 1663 | }() 1664 | if len(outputValues) != 1 { 1665 | t.Fatalf("Expected 1 output value, got %d\n", len(outputValues)) 1666 | } 1667 | v := outputValues[0] 1668 | if !v.GetShape().Equals(NewShape(1)) { 1669 | t.Fatalf("Incorrect output shape, expected [1], got %s\n", v.GetShape()) 1670 | } 1671 | outputTensor, ok := v.(*Tensor[int32]) 1672 | if !ok { 1673 | t.Fatalf("Bad output type\n") 1674 | } 1675 | if outputTensor.GetData()[0] != 1337 { 1676 | t.Fatalf("Bad output value, expected 1337, got %d\n", 1677 | outputTensor.GetData()[0]) 1678 | } 1679 | } 1680 | 1681 | // See the comment in generate_network_big_compute.py for information about 1682 | // the inputs and outputs used for testing or benchmarking session options. 1683 | func prepareBenchmarkTensors(t testing.TB, seed int64) (*Tensor[float32], 1684 | *Tensor[float32]) { 1685 | vectorLength := int64(1024 * 1024 * 50) 1686 | inputData := make([]float32, vectorLength) 1687 | rng := rand.New(rand.NewSource(seed)) 1688 | for i := range inputData { 1689 | inputData[i] = rng.Float32() 1690 | } 1691 | input, e := NewTensor(NewShape(1, vectorLength), inputData) 1692 | if e != nil { 1693 | t.Fatalf("Error creating input tensor: %s\n", e) 1694 | } 1695 | output, e := NewEmptyTensor[float32](NewShape(1, vectorLength)) 1696 | if e != nil { 1697 | input.Destroy() 1698 | t.Fatalf("Error creating output tensor: %s\n", e) 1699 | } 1700 | return input, output 1701 | } 1702 | 1703 | // Used mostly when testing different execution providers. Runs the 1704 | // example_big_compute.onnx network on a session created with the given 1705 | // options. May fail or skip the test on error. The runtime must have already 1706 | // been initialized when calling this. 1707 | func testBigSessionWithOptions(t *testing.T, options *SessionOptions) { 1708 | input, output := prepareBenchmarkTensors(t, 1337) 1709 | defer input.Destroy() 1710 | defer output.Destroy() 1711 | session, e := NewAdvancedSession("test_data/example_big_compute.onnx", 1712 | []string{"Input"}, []string{"Output"}, 1713 | []Value{input}, []Value{output}, options) 1714 | if e != nil { 1715 | t.Fatalf("Error creating session: %s\n", e) 1716 | } 1717 | defer session.Destroy() 1718 | e = session.Run() 1719 | if e != nil { 1720 | t.Fatalf("Error running the session: %s\n", e) 1721 | } 1722 | } 1723 | 1724 | func TestSessionOptions(t *testing.T) { 1725 | InitializeRuntime(t) 1726 | defer CleanupRuntime(t) 1727 | options, e := NewSessionOptions() 1728 | if e != nil { 1729 | t.Fatalf("Error creating session options: %s\n", e) 1730 | } 1731 | defer options.Destroy() 1732 | e = options.SetIntraOpNumThreads(3) 1733 | if e != nil { 1734 | t.Fatalf("Error setting intra-op num threads: %s\n", e) 1735 | } 1736 | e = options.SetInterOpNumThreads(1) 1737 | if e != nil { 1738 | t.Fatalf("Error setting inter-op num threads: %s\n", e) 1739 | } 1740 | e = options.SetCpuMemArena(true) 1741 | if e != nil { 1742 | t.Fatalf("Error setting CPU memory arena: %s\n", e) 1743 | } 1744 | e = options.SetMemPattern(true) 1745 | if e != nil { 1746 | t.Fatalf("Error setting memory pattern: %s\n", e) 1747 | } 1748 | e = options.SetLogSeverityLevel(LoggingLevelWarning) 1749 | if e != nil { 1750 | t.Fatalf("Error setting log severity level: %s\n", e) 1751 | } 1752 | testBigSessionWithOptions(t, options) 1753 | } 1754 | 1755 | func TestSessionOptionsConfig(t *testing.T) { 1756 | InitializeRuntime(t) 1757 | defer CleanupRuntime(t) 1758 | options, e := NewSessionOptions() 1759 | if e != nil { 1760 | t.Fatalf("Error creating session options: %s\n", e) 1761 | } 1762 | defer options.Destroy() 1763 | hasEntry, e := options.HasSessionConfigEntry("whatever lol") 1764 | if e != nil { 1765 | t.Errorf("Got error calling HasSessionConfigEntry for invalid "+ 1766 | "key: %s\n", e) 1767 | } 1768 | if hasEntry { 1769 | t.Errorf("HasSessionConfigEntry didn't return false for bad key\n") 1770 | } 1771 | _, e = options.GetSessionConfigEntry("hahahaha whatever!!") 1772 | if e == nil { 1773 | t.Fatalf("Didn't get error calling GetSessionConfigEntry for " + 1774 | "invalid key.\n") 1775 | } 1776 | t.Logf("Got expected error when getting session config entry for "+ 1777 | "invalid key: %s\n", e) 1778 | key := "beans factor" 1779 | expectedValue := "extremely high" 1780 | e = options.AddSessionConfigEntry(key, expectedValue) 1781 | if e != nil { 1782 | t.Fatalf("Got error setting arbitrary config entry: %s\n", e) 1783 | } 1784 | value, e := options.GetSessionConfigEntry(key) 1785 | if e != nil { 1786 | t.Fatalf("Error getting config entry %s: %s\n", key, e) 1787 | } 1788 | if value != expectedValue { 1789 | t.Errorf("Got incorrect value for config entry %s: expected %s, "+ 1790 | "got %s\n", key, expectedValue, value) 1791 | } 1792 | } 1793 | 1794 | func TestGraphOptimizationLevel(t *testing.T) { 1795 | InitializeRuntime(t) 1796 | defer CleanupRuntime(t) 1797 | options, e := NewSessionOptions() 1798 | if e != nil { 1799 | t.Fatalf("Error creating session options: %s\n", e) 1800 | } 1801 | defer options.Destroy() 1802 | e = options.SetGraphOptimizationLevel(GraphOptimizationLevel(-24)) 1803 | if e == nil { 1804 | t.Fatalf("Didn't get expected error setting bad optimization level\n") 1805 | } 1806 | t.Logf("Got expected error when setting bad optimization level: %s\n", e) 1807 | e = options.SetGraphOptimizationLevel(GraphOptimizationLevelDisableAll) 1808 | if e != nil { 1809 | t.Fatalf("Error setting optimization level: %s\n", e) 1810 | } 1811 | 1812 | // Actually make sure the network runs with optimizations disabled. 1813 | inputData := []int32{123, 123} 1814 | input, e := NewTensor(NewShape(1, 2), inputData) 1815 | if e != nil { 1816 | t.Fatalf("Error creating input tensor: %s\n", e) 1817 | } 1818 | defer input.Destroy() 1819 | output := newTestTensor[int32](t, NewShape(1)) 1820 | defer output.Destroy() 1821 | session, e := NewAdvancedSession("test_data/example ż 大 김.onnx", 1822 | []string{"in"}, []string{"out"}, []Value{input}, []Value{output}, 1823 | options) 1824 | if e != nil { 1825 | t.Fatalf("Failed creating session with optimizations off: %s\n", e) 1826 | } 1827 | e = session.Run() 1828 | if e != nil { 1829 | t.Fatalf("Error running session with optimizations off: %s\n", e) 1830 | } 1831 | result := output.GetData()[0] 1832 | if result != 246 { 1833 | t.Fatalf("Got the incorrect session result: %d\n", result) 1834 | } 1835 | } 1836 | 1837 | // Used when benchmarking different execution providers. Otherwise, basically 1838 | // identical in usage to testBigSessionWithOptions. 1839 | func benchmarkBigSessionWithOptions(b *testing.B, options *SessionOptions) { 1840 | // It's also OK for the caller to have already stopped the timer, but we'll 1841 | // make sure it's stopped here. 1842 | b.StopTimer() 1843 | input, output := prepareBenchmarkTensors(b, benchmarkRNGSeed) 1844 | defer input.Destroy() 1845 | defer output.Destroy() 1846 | session, e := NewAdvancedSession("test_data/example_big_compute.onnx", 1847 | []string{"Input"}, []string{"Output"}, 1848 | []Value{input}, []Value{output}, options) 1849 | if e != nil { 1850 | b.Fatalf("Error creating session: %s\n", e) 1851 | } 1852 | defer session.Destroy() 1853 | b.StartTimer() 1854 | for n := 0; n < b.N; n++ { 1855 | e = session.Run() 1856 | if e != nil { 1857 | b.Fatalf("Error running iteration %d/%d: %s\n", n+1, b.N, e) 1858 | } 1859 | } 1860 | } 1861 | 1862 | // Very similar to TestSessionOptions, but structured as a benchmark. 1863 | func runNumThreadsBenchmark(b *testing.B, nThreads int) { 1864 | // Don't run the benchmark timer when doing initialization. 1865 | b.StopTimer() 1866 | InitializeRuntime(b) 1867 | defer CleanupRuntime(b) 1868 | options, e := NewSessionOptions() 1869 | if e != nil { 1870 | b.Fatalf("Error creating options: %s\n", e) 1871 | } 1872 | defer options.Destroy() 1873 | e = options.SetIntraOpNumThreads(nThreads) 1874 | if e != nil { 1875 | b.Fatalf("Error setting intra-op threads to %d: %s\n", nThreads, e) 1876 | } 1877 | e = options.SetInterOpNumThreads(nThreads) 1878 | if e != nil { 1879 | b.Fatalf("Error setting inter-op threads to %d: %s\n", nThreads, e) 1880 | } 1881 | benchmarkBigSessionWithOptions(b, options) 1882 | } 1883 | 1884 | func BenchmarkOpSingleThreaded(b *testing.B) { 1885 | runNumThreadsBenchmark(b, 1) 1886 | } 1887 | 1888 | func BenchmarkOpMultiThreaded(b *testing.B) { 1889 | runNumThreadsBenchmark(b, 0) 1890 | } 1891 | 1892 | // Used for benchmarking the network with big fanout, both in sequential and 1893 | // parallel mode. 1894 | func runParallelismBenchmark(b *testing.B, parallel bool) { 1895 | b.StopTimer() 1896 | InitializeRuntime(b) 1897 | defer CleanupRuntime(b) 1898 | options, e := NewSessionOptions() 1899 | if e != nil { 1900 | b.Fatalf("Error creating options: %s\n", e) 1901 | } 1902 | defer options.Destroy() 1903 | if parallel { 1904 | e = options.SetExecutionMode(ExecutionModeParallel) 1905 | } else { 1906 | e = options.SetExecutionMode(ExecutionModeSequential) 1907 | } 1908 | if e != nil { 1909 | b.Fatalf("Error setting execution mode: %s\n", e) 1910 | } 1911 | input := newTestTensor[float32](b, NewShape(1, 4)) 1912 | defer input.Destroy() 1913 | output := newTestTensor[float32](b, NewShape(4)) 1914 | defer output.Destroy() 1915 | session, e := NewAdvancedSession("test_data/example_big_fanout.onnx", 1916 | []string{"input"}, []string{"output"}, []Value{input}, []Value{output}, 1917 | options) 1918 | if e != nil { 1919 | b.Fatalf("Error creating session: %s\n", e) 1920 | } 1921 | defer session.Destroy() 1922 | b.StartTimer() 1923 | for n := 0; n < b.N; n++ { 1924 | e = session.Run() 1925 | if e != nil { 1926 | b.Fatalf("Error running iteration %d/%d: %s\n", n+1, b.N, e) 1927 | } 1928 | } 1929 | } 1930 | 1931 | func BenchmarkFanoutWithoutParallelism(b *testing.B) { 1932 | runParallelismBenchmark(b, false) 1933 | } 1934 | 1935 | func BenchmarkFanoutWithParallelism(b *testing.B) { 1936 | runParallelismBenchmark(b, true) 1937 | } 1938 | 1939 | // Arbitrarily sets some config options; mostly just to see if they have an 1940 | // impact on performance, but also to just make sure setting the options 1941 | // doesn't cause errors. 1942 | func BenchmarkSessionWithNoSpinningOptions(b *testing.B) { 1943 | b.StopTimer() 1944 | InitializeRuntime(b) 1945 | defer CleanupRuntime(b) 1946 | options, e := NewSessionOptions() 1947 | if e != nil { 1948 | b.Fatalf("Error creating options: %s\n", e) 1949 | } 1950 | defer options.Destroy() 1951 | e = options.AddSessionConfigEntry("session.inter_op.allow_spinning", "0") 1952 | if e != nil { 1953 | b.Fatalf("Error setting inter-op spinning config: %s\n", e) 1954 | } 1955 | e = options.AddSessionConfigEntry("session.intra_op.allow_spinning", "0") 1956 | if e != nil { 1957 | b.Fatalf("Error setting intra-op spinning config: %s\n", e) 1958 | } 1959 | e = options.AddSessionConfigEntry("session.force_spinning_stop", "1") 1960 | if e != nil { 1961 | b.Fatalf("Error setting force spinning stop config: %s\n", e) 1962 | } 1963 | benchmarkBigSessionWithOptions(b, options) 1964 | } 1965 | 1966 | // Creates a SessionOptions struct that's configured to enable CUDA. Skips the 1967 | // test if CUDA isn't supported. If some other error occurs, this will fail the 1968 | // test instead. There may be other possible places for failures to occur due 1969 | // to CUDA not being supported, or incorrectly configured, but this at least 1970 | // checks for the ones I've encountered on my system. 1971 | func getCUDASessionOptions(t testing.TB) *SessionOptions { 1972 | // First, create the CUDA options 1973 | cudaOptions, e := NewCUDAProviderOptions() 1974 | if e != nil { 1975 | // This is where things seem to fail if the onnxruntime library version 1976 | // doesn't support CUDA. 1977 | t.Skipf("Error creating CUDA provider options: %s. "+ 1978 | "Your version of the onnxruntime library may not support CUDA. "+ 1979 | "Skipping the remainder of this test.\n", e) 1980 | } 1981 | defer cudaOptions.Destroy() 1982 | e = cudaOptions.Update(map[string]string{"device_id": "0"}) 1983 | if e != nil { 1984 | // This is where things seem to fail if the system doesn't support CUDA 1985 | // or if CUDA is misconfigured somehow (i.e. a wrong version that isn't 1986 | // supported by onnxruntime, libraries not being located correctly, 1987 | // etc.) 1988 | t.Skipf("Error updating CUDA options to use device ID 0: %s. "+ 1989 | "Your system may not support CUDA, or CUDA may be misconfigured "+ 1990 | "or a version incompatible with this version of onnxruntime. "+ 1991 | "Skipping the remainder of this test.\n", e) 1992 | } 1993 | // Next, provide the CUDA options to the sesison options 1994 | sessionOptions, e := NewSessionOptions() 1995 | if e != nil { 1996 | t.Fatalf("Error creating SessionOptions: %s\n", e) 1997 | } 1998 | e = sessionOptions.AppendExecutionProviderCUDA(cudaOptions) 1999 | if e != nil { 2000 | sessionOptions.Destroy() 2001 | t.Fatalf("Error setting CUDA execution provider options: %s\n", e) 2002 | } 2003 | return sessionOptions 2004 | } 2005 | 2006 | func TestCUDASession(t *testing.T) { 2007 | InitializeRuntime(t) 2008 | defer CleanupRuntime(t) 2009 | sessionOptions := getCUDASessionOptions(t) 2010 | defer sessionOptions.Destroy() 2011 | testBigSessionWithOptions(t, sessionOptions) 2012 | } 2013 | 2014 | func BenchmarkCUDASession(b *testing.B) { 2015 | b.StopTimer() 2016 | InitializeRuntime(b) 2017 | defer CleanupRuntime(b) 2018 | sessionOptions := getCUDASessionOptions(b) 2019 | defer sessionOptions.Destroy() 2020 | benchmarkBigSessionWithOptions(b, sessionOptions) 2021 | } 2022 | 2023 | // Creates a SessionOptions struct that's configured to enable TensorRT. 2024 | // Basically the same as getCUDASessionOptions; see the comments there. 2025 | func getTensorRTSessionOptions(t testing.TB) *SessionOptions { 2026 | trtOptions, e := NewTensorRTProviderOptions() 2027 | if e != nil { 2028 | t.Skipf("Error creating TensorRT provider options; %s. "+ 2029 | "Your version of the onnxruntime library may not include "+ 2030 | "TensorRT support. Skipping the remainder of this test.\n", e) 2031 | } 2032 | defer trtOptions.Destroy() 2033 | // Arbitrarily update an option to test trtOptions.Update() 2034 | e = trtOptions.Update( 2035 | map[string]string{"trt_max_partition_iterations": "60"}) 2036 | if e != nil { 2037 | t.Skipf("Error updating TensorRT options: %s. Your system may not "+ 2038 | "support TensorRT, TensorRT may be misconfigured, or it may be "+ 2039 | "incompatible with this build of onnxruntime. Skipping the "+ 2040 | "remainder of this test.\n", e) 2041 | } 2042 | sessionOptions, e := NewSessionOptions() 2043 | if e != nil { 2044 | t.Fatalf("Error creating SessionOptions: %s\n", e) 2045 | } 2046 | e = sessionOptions.AppendExecutionProviderTensorRT(trtOptions) 2047 | if e != nil { 2048 | sessionOptions.Destroy() 2049 | t.Fatalf("Error setting TensorRT execution provider: %s\n", e) 2050 | } 2051 | return sessionOptions 2052 | } 2053 | 2054 | func TestTensorRTSession(t *testing.T) { 2055 | InitializeRuntime(t) 2056 | defer CleanupRuntime(t) 2057 | sessionOptions := getTensorRTSessionOptions(t) 2058 | defer sessionOptions.Destroy() 2059 | testBigSessionWithOptions(t, sessionOptions) 2060 | } 2061 | 2062 | func BenchmarkTensorRTSession(b *testing.B) { 2063 | b.StopTimer() 2064 | InitializeRuntime(b) 2065 | defer CleanupRuntime(b) 2066 | sessionOptions := getTensorRTSessionOptions(b) 2067 | defer sessionOptions.Destroy() 2068 | benchmarkBigSessionWithOptions(b, sessionOptions) 2069 | } 2070 | 2071 | func getCoreMLSessionOptions(t testing.TB) *SessionOptions { 2072 | options, e := NewSessionOptions() 2073 | if e != nil { 2074 | t.Fatalf("Error creating session options: %s\n", e) 2075 | } 2076 | e = options.AppendExecutionProviderCoreML(0) 2077 | if e != nil { 2078 | options.Destroy() 2079 | t.Skipf("Couldn't enable CoreML: %s. This may be due to your system "+ 2080 | "or onnxruntime library version not supporting CoreML.\n", e) 2081 | } 2082 | return options 2083 | } 2084 | 2085 | func TestCoreMLSession(t *testing.T) { 2086 | InitializeRuntime(t) 2087 | defer CleanupRuntime(t) 2088 | sessionOptions := getCoreMLSessionOptions(t) 2089 | defer sessionOptions.Destroy() 2090 | testBigSessionWithOptions(t, sessionOptions) 2091 | } 2092 | 2093 | func BenchmarkCoreMLSession(b *testing.B) { 2094 | b.StopTimer() 2095 | InitializeRuntime(b) 2096 | defer CleanupRuntime(b) 2097 | sessionOptions := getCoreMLSessionOptions(b) 2098 | defer sessionOptions.Destroy() 2099 | benchmarkBigSessionWithOptions(b, sessionOptions) 2100 | } 2101 | 2102 | func getCoreMLV2SessionOptions(t testing.TB) *SessionOptions { 2103 | options, e := NewSessionOptions() 2104 | if e != nil { 2105 | t.Fatalf("Error creating session options: %s\n", e) 2106 | } 2107 | 2108 | // Use the new API with CoreML provider options 2109 | coreMLOptions := map[string]string{ 2110 | "ModelFormat": "NeuralNetwork", // Use NeuralNetwork format instead of MLProgram 2111 | "MLComputeUnits": "ALL", // Use all available compute units 2112 | "RequireStaticInputShapes": "0", // Allow dynamic shapes 2113 | "EnableOnSubgraphs": "0", // Don't enable on subgraphs 2114 | "ModelCacheDirectory": "/tmp", // Set the cache directory to /tmp 2115 | "AllowLowPrecisionAccumulationOnGPU": "1", // Enable low precision acceleration 2116 | } 2117 | 2118 | // If options can't be empty, catch any potential problems 2119 | if len(coreMLOptions) == 0 { 2120 | options.Destroy() 2121 | t.Skipf("Empty CoreML options map. This should not happen.\n") 2122 | return nil 2123 | } 2124 | 2125 | e = options.AppendExecutionProviderCoreMLV2(coreMLOptions) 2126 | if e != nil { 2127 | options.Destroy() 2128 | t.Skipf("Couldn't enable CoreML: %s. This may be due to your system "+ 2129 | "or onnxruntime library version not supporting CoreML.\n", e) 2130 | } 2131 | return options 2132 | } 2133 | 2134 | func TestCoreMLV2Session(t *testing.T) { 2135 | InitializeRuntime(t) 2136 | defer CleanupRuntime(t) 2137 | sessionOptions := getCoreMLV2SessionOptions(t) 2138 | defer sessionOptions.Destroy() 2139 | testBigSessionWithOptions(t, sessionOptions) 2140 | } 2141 | 2142 | func BenchmarkCoreMLV2Session(b *testing.B) { 2143 | b.StopTimer() 2144 | InitializeRuntime(b) 2145 | defer CleanupRuntime(b) 2146 | sessionOptions := getCoreMLV2SessionOptions(b) 2147 | defer sessionOptions.Destroy() 2148 | benchmarkBigSessionWithOptions(b, sessionOptions) 2149 | } 2150 | 2151 | func getDirectMLSessionOptions(t testing.TB) *SessionOptions { 2152 | options, e := NewSessionOptions() 2153 | if e != nil { 2154 | t.Fatalf("Error creating session options: %s\n", e) 2155 | } 2156 | e = options.AppendExecutionProviderDirectML(0) 2157 | if e != nil { 2158 | options.Destroy() 2159 | t.Skipf("Couldn't enable DirectML: %s. This may be due to your "+ 2160 | "system or onnxruntime library version not supporting DirectML.\n", 2161 | e) 2162 | } 2163 | return options 2164 | } 2165 | 2166 | func TestDirectMLSession(t *testing.T) { 2167 | InitializeRuntime(t) 2168 | defer CleanupRuntime(t) 2169 | sessionOptions := getDirectMLSessionOptions(t) 2170 | defer sessionOptions.Destroy() 2171 | testBigSessionWithOptions(t, sessionOptions) 2172 | } 2173 | 2174 | func BenchmarkDirectMLSession(b *testing.B) { 2175 | b.StopTimer() 2176 | InitializeRuntime(b) 2177 | defer CleanupRuntime(b) 2178 | sessionOptions := getDirectMLSessionOptions(b) 2179 | defer sessionOptions.Destroy() 2180 | benchmarkBigSessionWithOptions(b, sessionOptions) 2181 | } 2182 | 2183 | func getOpenVINOSessionOptions(t testing.TB) *SessionOptions { 2184 | options, e := NewSessionOptions() 2185 | if e != nil { 2186 | t.Fatalf("Error creating session options: %s\n", e) 2187 | } 2188 | // OpenVINO prints messages to stdout simply for missing libraries. We 2189 | // don't want to clutter the test output, so reduce the logging level. 2190 | e = SetEnvironmentLogLevel(LoggingLevelFatal) 2191 | if e != nil { 2192 | t.Fatalf("Error reducing log severity level for OpenVINO: %s\n", e) 2193 | } 2194 | e = options.AppendExecutionProviderOpenVINO(map[string]string{}) 2195 | if e != nil { 2196 | options.Destroy() 2197 | t.Skipf("Couldn't enable OpenVINO: %s. This may be due to your "+ 2198 | "system or onnxruntime library version not supporting OpenVINO.\n", 2199 | e) 2200 | } 2201 | return options 2202 | } 2203 | 2204 | func TestOpenVINOSession(t *testing.T) { 2205 | InitializeRuntime(t) 2206 | defer CleanupRuntime(t) 2207 | sessionOptions := getOpenVINOSessionOptions(t) 2208 | defer sessionOptions.Destroy() 2209 | testBigSessionWithOptions(t, sessionOptions) 2210 | } 2211 | 2212 | func BenchmarkOpenVINOSession(b *testing.B) { 2213 | b.StopTimer() 2214 | InitializeRuntime(b) 2215 | defer CleanupRuntime(b) 2216 | sessionOptions := getOpenVINOSessionOptions(b) 2217 | defer sessionOptions.Destroy() 2218 | benchmarkBigSessionWithOptions(b, sessionOptions) 2219 | } 2220 | -------------------------------------------------------------------------------- /onnxruntime_wrapper.c: -------------------------------------------------------------------------------- 1 | #include "onnxruntime_wrapper.h" 2 | 3 | static const OrtApi *ort_api = NULL; 4 | static const char *ORT_VERSION = NULL; 5 | 6 | static AppendCoreMLProviderFn append_coreml_provider_fn = NULL; 7 | 8 | // The dml_provider_factory.h header for using DirectML is annoying to include 9 | // here for a couple reasons: 10 | // - It contains C++ 11 | // - It includes d3d12.h and DirectML.h, both of which may be hard to set up 12 | // under mingw 13 | // Fortunately, the basic AppendExecutionProvider_DML function from the 14 | // OrtDmlApi struct does not rely on any of these things, but we still need the 15 | // struct definition itself. Obviously, copying it here is not perfect, and 16 | // we'll need to keep an eye on it to make sure it doesn't change between 17 | // updates. Most importantly, we need to make sure that the one function we 18 | // care about remains at the same place in the struct. Since it's first, 19 | // hopefully it's unlikely to change. 20 | typedef OrtStatus* (*AppendDirectMLProviderFn)(OrtSessionOptions*, int); 21 | typedef struct { 22 | AppendDirectMLProviderFn SessionOptionsAppendExecutionProvider_DML; 23 | // All of these functions pointers should be irrelevant (and they depend on 24 | // other definitions from dml_provider_factory.h), but I'll copy them here 25 | // regardless as plain void*s. GetExecutionProviderApi shouldn't write to 26 | // this struct anyway, as it only provides a const pointer to it. 27 | void *SessionOptionsAppendExecutionProvider_DML1; 28 | void *CreateGPUAllocationFromD3DResource; 29 | void *FreeGPUAllocation; 30 | void *GetD3D12ResourceFromAllocation; 31 | void *SessionOptionsAppendExecutionProvider_DML2; 32 | } DummyOrtDMLAPI; 33 | 34 | int SetAPIFromBase(OrtApiBase *api_base) { 35 | if (!api_base) return 1; 36 | ort_api = api_base->GetApi(ORT_API_VERSION); 37 | ORT_VERSION = api_base->GetVersionString(); 38 | if (!ort_api) return 2; 39 | return 0; 40 | } 41 | 42 | const char *GetVersion() { 43 | return ORT_VERSION; 44 | } 45 | 46 | void SetCoreMLProviderFunctionPointer(void *ptr) { 47 | append_coreml_provider_fn = (AppendCoreMLProviderFn) ptr; 48 | } 49 | 50 | void ReleaseOrtStatus(OrtStatus *status) { 51 | ort_api->ReleaseStatus(status); 52 | } 53 | 54 | OrtStatus *CreateOrtEnv(char *name, OrtEnv **env) { 55 | return ort_api->CreateEnv(ORT_LOGGING_LEVEL_ERROR, name, env); 56 | } 57 | 58 | OrtStatus *UpdateEnvWithCustomLogLevel(OrtEnv *ort_env, 59 | OrtLoggingLevel log_severity_level) { 60 | return ort_api->UpdateEnvWithCustomLogLevel(ort_env, log_severity_level); 61 | } 62 | 63 | OrtStatus *DisableTelemetry(OrtEnv *env) { 64 | return ort_api->DisableTelemetryEvents(env); 65 | } 66 | 67 | OrtStatus *EnableTelemetry(OrtEnv *env) { 68 | return ort_api->EnableTelemetryEvents(env); 69 | } 70 | 71 | void ReleaseOrtEnv(OrtEnv *env) { 72 | ort_api->ReleaseEnv(env); 73 | } 74 | 75 | OrtStatus *CreateOrtMemoryInfo(OrtMemoryInfo **mem_info) { 76 | return ort_api->CreateCpuMemoryInfo(OrtArenaAllocator, OrtMemTypeDefault, 77 | mem_info); 78 | } 79 | 80 | void ReleaseOrtMemoryInfo(OrtMemoryInfo *info) { 81 | ort_api->ReleaseMemoryInfo(info); 82 | } 83 | 84 | const char *GetErrorMessage(OrtStatus *status) { 85 | if (!status) return "No error (NULL status)"; 86 | return ort_api->GetErrorMessage(status); 87 | } 88 | 89 | OrtStatus *CreateSessionOptions(OrtSessionOptions **o) { 90 | return ort_api->CreateSessionOptions(o); 91 | } 92 | 93 | void ReleaseSessionOptions(OrtSessionOptions *o) { 94 | ort_api->ReleaseSessionOptions(o); 95 | } 96 | 97 | OrtStatus *SetSessionExecutionMode(OrtSessionOptions *o, int new_mode) { 98 | return ort_api->SetSessionExecutionMode(o, new_mode); 99 | } 100 | 101 | OrtStatus *SetSessionGraphOptimizationLevel(OrtSessionOptions *o, int level) { 102 | return ort_api->SetSessionGraphOptimizationLevel(o, level); 103 | } 104 | 105 | OrtStatus *SetSessionLogSeverityLevel(OrtSessionOptions *o, int level) { 106 | return ort_api->SetSessionLogSeverityLevel(o, level); 107 | } 108 | 109 | OrtStatus *AddSessionConfigEntry(OrtSessionOptions *o, char *key, 110 | char *value) { 111 | return ort_api->AddSessionConfigEntry(o, key, value); 112 | } 113 | 114 | OrtStatus *HasSessionConfigEntry(OrtSessionOptions *o, char *key, 115 | int *result) { 116 | return ort_api->HasSessionConfigEntry(o, key, result); 117 | } 118 | 119 | // Wraps ort_api->GetSessionConfigEntry 120 | OrtStatus *GetSessionConfigEntry(OrtSessionOptions *o, char *key, char *result, 121 | size_t *required_size) { 122 | return ort_api->GetSessionConfigEntry(o, key, result, required_size); 123 | } 124 | 125 | OrtStatus *SetIntraOpNumThreads(OrtSessionOptions *o, int n) { 126 | return ort_api->SetIntraOpNumThreads(o, n); 127 | } 128 | 129 | OrtStatus *SetInterOpNumThreads(OrtSessionOptions *o, int n) { 130 | return ort_api->SetInterOpNumThreads(o, n); 131 | } 132 | 133 | OrtStatus *SetCpuMemArena(OrtSessionOptions *o, int use_arena){ 134 | if (use_arena) 135 | return ort_api->EnableCpuMemArena(o); 136 | return ort_api->DisableCpuMemArena(o); 137 | } 138 | 139 | OrtStatus *SetMemPattern(OrtSessionOptions *o, int use_mem_pattern){ 140 | if (use_mem_pattern) 141 | return ort_api->EnableMemPattern(o); 142 | return ort_api->DisableMemPattern(o); 143 | } 144 | 145 | OrtStatus *AppendExecutionProviderCUDAV2(OrtSessionOptions *o, 146 | OrtCUDAProviderOptionsV2 *cuda_options) { 147 | return ort_api->SessionOptionsAppendExecutionProvider_CUDA_V2(o, 148 | cuda_options); 149 | } 150 | 151 | OrtStatus *CreateCUDAProviderOptions(OrtCUDAProviderOptionsV2 **o) { 152 | return ort_api->CreateCUDAProviderOptions(o); 153 | } 154 | 155 | void ReleaseCUDAProviderOptions(OrtCUDAProviderOptionsV2 *o) { 156 | ort_api->ReleaseCUDAProviderOptions(o); 157 | } 158 | 159 | OrtStatus *UpdateCUDAProviderOptions(OrtCUDAProviderOptionsV2 *o, 160 | const char **keys, const char **values, int num_keys) { 161 | return ort_api->UpdateCUDAProviderOptions(o, keys, values, num_keys); 162 | } 163 | 164 | OrtStatus *CreateTensorRTProviderOptions(OrtTensorRTProviderOptionsV2 **o) { 165 | return ort_api->CreateTensorRTProviderOptions(o); 166 | } 167 | 168 | void ReleaseTensorRTProviderOptions(OrtTensorRTProviderOptionsV2 *o) { 169 | ort_api->ReleaseTensorRTProviderOptions(o); 170 | } 171 | 172 | OrtStatus *UpdateTensorRTProviderOptions(OrtTensorRTProviderOptionsV2 *o, 173 | const char **keys, const char **values, int num_keys) { 174 | return ort_api->UpdateTensorRTProviderOptions(o, keys, values, num_keys); 175 | } 176 | 177 | OrtStatus *AppendExecutionProviderTensorRTV2(OrtSessionOptions *o, 178 | OrtTensorRTProviderOptionsV2 *tensor_rt_options) { 179 | return ort_api->SessionOptionsAppendExecutionProvider_TensorRT_V2(o, 180 | tensor_rt_options); 181 | } 182 | 183 | OrtStatus *AppendExecutionProviderCoreML(OrtSessionOptions *o, 184 | uint32_t flags) { 185 | if (!append_coreml_provider_fn) { 186 | return ort_api->CreateStatus(ORT_NOT_IMPLEMENTED, "Your platform or " 187 | "onnxruntime library does not support CoreML"); 188 | } 189 | return append_coreml_provider_fn(o, flags); 190 | } 191 | 192 | OrtStatus *AppendExecutionProviderCoreMLV2(OrtSessionOptions *o, 193 | const char **keys, const char **values, size_t num_options) { 194 | if (!append_coreml_provider_fn) { 195 | return ort_api->CreateStatus(ORT_NOT_IMPLEMENTED, "Your platform or " 196 | "onnxruntime library does not support CoreML"); 197 | } 198 | return ort_api->SessionOptionsAppendExecutionProvider(o, "CoreML", keys, values, num_options); 199 | } 200 | 201 | OrtStatus *AppendExecutionProviderDirectML(OrtSessionOptions *o, 202 | int device_id) { 203 | DummyOrtDMLAPI *dml_api = NULL; 204 | OrtStatus *status = NULL; 205 | status = ort_api->GetExecutionProviderApi("DML", ORT_API_VERSION, 206 | (const void **) (&dml_api)); 207 | if (status) return status; 208 | status = dml_api->SessionOptionsAppendExecutionProvider_DML(o, device_id); 209 | return status; 210 | } 211 | 212 | OrtStatus *AppendExecutionProviderOpenVINOV2(OrtSessionOptions *o, 213 | const char **keys, const char **values, int num_keys) { 214 | return ort_api->SessionOptionsAppendExecutionProvider_OpenVINO_V2(o, keys, 215 | values, num_keys); 216 | } 217 | 218 | OrtStatus *CreateSession(void *model_data, size_t model_data_length, 219 | OrtEnv *env, OrtSession **out, OrtSessionOptions *options) { 220 | OrtStatus *status = NULL; 221 | int default_options = 0; 222 | if (!options) { 223 | default_options = 1; 224 | status = ort_api->CreateSessionOptions(&options); 225 | if (status) return status; 226 | } 227 | status = ort_api->CreateSessionFromArray(env, model_data, model_data_length, 228 | options, out); 229 | if (default_options) { 230 | // If we created a default, empty, options struct, we don't need to keep it 231 | // after creating the session. 232 | ort_api->ReleaseSessionOptions(options); 233 | } 234 | return status; 235 | } 236 | 237 | OrtStatus *CreateSessionFromFile(char *model_path, OrtEnv *env, 238 | OrtSession **out, OrtSessionOptions *options) { 239 | // Nearly identical to CreateSession, except invokes ort_api->CreateSession 240 | // rather than ort_api->CreateSessionFromArray. 241 | OrtStatus *status = NULL; 242 | int default_options = 0; 243 | if (!options) { 244 | default_options = 1; 245 | status = ort_api->CreateSessionOptions(&options); 246 | if (status) return status; 247 | } 248 | status = ort_api->CreateSession(env, (const ORTCHAR_T*) model_path, options, 249 | out); 250 | if (default_options) ort_api->ReleaseSessionOptions(options); 251 | return status; 252 | } 253 | 254 | OrtStatus *RunOrtSession(OrtSession *session, 255 | OrtValue **inputs, char **input_names, int input_count, 256 | OrtValue **outputs, char **output_names, int output_count) { 257 | OrtStatus *status = NULL; 258 | status = ort_api->Run(session, NULL, (const char* const*) input_names, 259 | (const OrtValue* const*) inputs, input_count, 260 | (const char* const*) output_names, output_count, outputs); 261 | return status; 262 | } 263 | 264 | OrtStatus *RunSessionWithBinding(OrtSession *session, OrtIoBinding *b) { 265 | return ort_api->RunWithBinding(session, NULL, b); 266 | } 267 | 268 | void ReleaseOrtSession(OrtSession *session) { 269 | ort_api->ReleaseSession(session); 270 | } 271 | 272 | OrtStatus *CreateIoBinding(OrtSession *session, OrtIoBinding **out) { 273 | return ort_api->CreateIoBinding(session, out); 274 | } 275 | 276 | void ReleaseIoBinding(OrtIoBinding *b) { 277 | ort_api->ReleaseIoBinding(b); 278 | } 279 | 280 | OrtStatus *BindInput(OrtIoBinding *b, char *name, OrtValue *value) { 281 | return ort_api->BindInput(b, name, value); 282 | } 283 | 284 | OrtStatus *BindOutput(OrtIoBinding *b, char *name, OrtValue *value) { 285 | return ort_api->BindOutput(b, name, value); 286 | } 287 | 288 | OrtStatus *GetBoundOutputNames(OrtIoBinding *b, char **buffer, 289 | size_t **lengths, size_t *count) { 290 | OrtAllocator *allocator = NULL; 291 | OrtStatus *status = NULL; 292 | status = ort_api->GetAllocatorWithDefaultOptions(&allocator); 293 | if (status) return status; 294 | return ort_api->GetBoundOutputNames(b, allocator, buffer, lengths, count); 295 | } 296 | 297 | OrtStatus *SessionGetInputCount(OrtSession *session, size_t *result) { 298 | return ort_api->SessionGetInputCount(session, result); 299 | } 300 | 301 | OrtStatus *GetBoundOutputValues(OrtIoBinding *b, OrtValue ***buffer, 302 | size_t *count) { 303 | OrtAllocator *allocator = NULL; 304 | OrtStatus *status = NULL; 305 | status = ort_api->GetAllocatorWithDefaultOptions(&allocator); 306 | if (status) return status; 307 | return ort_api->GetBoundOutputValues(b, allocator, buffer, count); 308 | } 309 | 310 | void ClearBoundInputs(OrtIoBinding *b) { 311 | ort_api->ClearBoundInputs(b); 312 | } 313 | 314 | void ClearBoundOutputs(OrtIoBinding *b) { 315 | ort_api->ClearBoundOutputs(b); 316 | } 317 | 318 | OrtStatus *SessionGetOutputCount(OrtSession *session, size_t *result) { 319 | return ort_api->SessionGetOutputCount(session, result); 320 | } 321 | 322 | void ReleaseOrtValue(OrtValue *value) { 323 | ort_api->ReleaseValue(value); 324 | } 325 | 326 | OrtStatus *CreateOrtTensorWithShape(void *data, size_t data_size, 327 | int64_t *shape, int64_t shape_size, OrtMemoryInfo *mem_info, 328 | ONNXTensorElementDataType dtype, OrtValue **out) { 329 | OrtStatus *status = NULL; 330 | status = ort_api->CreateTensorWithDataAsOrtValue(mem_info, data, data_size, 331 | shape, shape_size, dtype, out); 332 | return status; 333 | } 334 | 335 | OrtStatus *GetTensorTypeAndShape(const OrtValue *value, OrtTensorTypeAndShapeInfo **out) { 336 | return ort_api->GetTensorTypeAndShape(value, out); 337 | } 338 | 339 | OrtStatus *GetDimensionsCount(const OrtTensorTypeAndShapeInfo *info, size_t *out) { 340 | return ort_api->GetDimensionsCount(info, out); 341 | } 342 | 343 | OrtStatus *GetDimensions(const OrtTensorTypeAndShapeInfo *info, int64_t *dim_values, size_t dim_values_length) { 344 | return ort_api->GetDimensions(info, dim_values, dim_values_length); 345 | } 346 | 347 | OrtStatus *GetTensorElementType(const OrtTensorTypeAndShapeInfo *info, enum ONNXTensorElementDataType *out) { 348 | return ort_api->GetTensorElementType(info, out); 349 | } 350 | 351 | void ReleaseTensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo *input) { 352 | ort_api->ReleaseTensorTypeAndShapeInfo(input); 353 | } 354 | 355 | OrtStatus *GetTensorMutableData(OrtValue *value, void **out) { 356 | return ort_api->GetTensorMutableData(value, out); 357 | } 358 | 359 | OrtStatus *SessionGetInputName(OrtSession *session, size_t i, char **name) { 360 | OrtAllocator *allocator = NULL; 361 | OrtStatus *status = NULL; 362 | status = ort_api->GetAllocatorWithDefaultOptions(&allocator); 363 | if (status) return status; 364 | return ort_api->SessionGetInputName(session, i, allocator, name); 365 | } 366 | 367 | OrtStatus *SessionGetOutputName(OrtSession *session, size_t i, char **name) { 368 | OrtAllocator *allocator = NULL; 369 | OrtStatus *status = NULL; 370 | status = ort_api->GetAllocatorWithDefaultOptions(&allocator); 371 | if (status) return status; 372 | return ort_api->SessionGetOutputName(session, i, allocator, name); 373 | } 374 | 375 | OrtStatus *FreeWithDefaultORTAllocator(void *to_free) { 376 | OrtAllocator *allocator = NULL; 377 | OrtStatus *status = NULL; 378 | status = ort_api->GetAllocatorWithDefaultOptions(&allocator); 379 | if (status) return status; 380 | return ort_api->AllocatorFree(allocator, to_free); 381 | } 382 | 383 | OrtStatus *SessionGetInputTypeInfo(OrtSession *session, size_t i, 384 | OrtTypeInfo **out) { 385 | return ort_api->SessionGetInputTypeInfo(session, i, out); 386 | } 387 | 388 | OrtStatus *SessionGetOutputTypeInfo(OrtSession *session, size_t i, 389 | OrtTypeInfo **out) { 390 | return ort_api->SessionGetOutputTypeInfo(session, i, out); 391 | } 392 | 393 | void ReleaseTypeInfo(OrtTypeInfo *o) { 394 | ort_api->ReleaseTypeInfo(o); 395 | } 396 | 397 | OrtStatus *GetONNXTypeFromTypeInfo(OrtTypeInfo *info, enum ONNXType *out) { 398 | return ort_api->GetOnnxTypeFromTypeInfo(info, out); 399 | } 400 | 401 | OrtStatus *CastTypeInfoToTensorInfo(OrtTypeInfo *type_info, 402 | OrtTensorTypeAndShapeInfo **out) { 403 | return ort_api->CastTypeInfoToTensorInfo(type_info, 404 | (const OrtTensorTypeAndShapeInfo **) out); 405 | } 406 | 407 | OrtStatus *SessionGetModelMetadata(OrtSession *s, OrtModelMetadata **m) { 408 | return ort_api->SessionGetModelMetadata(s, m); 409 | } 410 | 411 | void ReleaseModelMetadata(OrtModelMetadata *m) { 412 | return ort_api->ReleaseModelMetadata(m); 413 | } 414 | 415 | OrtStatus *ModelMetadataGetProducerName(OrtModelMetadata *m, char **name) { 416 | OrtAllocator *allocator = NULL; 417 | OrtStatus *status = NULL; 418 | status = ort_api->GetAllocatorWithDefaultOptions(&allocator); 419 | if (status) return status; 420 | return ort_api->ModelMetadataGetProducerName(m, allocator, name); 421 | } 422 | 423 | OrtStatus *ModelMetadataGetGraphName(OrtModelMetadata *m, char **name) { 424 | OrtAllocator *allocator = NULL; 425 | OrtStatus *status = NULL; 426 | status = ort_api->GetAllocatorWithDefaultOptions(&allocator); 427 | if (status) return status; 428 | return ort_api->ModelMetadataGetGraphName(m, allocator, name); 429 | } 430 | 431 | OrtStatus *ModelMetadataGetDomain(OrtModelMetadata *m, char **domain) { 432 | OrtAllocator *allocator = NULL; 433 | OrtStatus *status = NULL; 434 | status = ort_api->GetAllocatorWithDefaultOptions(&allocator); 435 | if (status) return status; 436 | return ort_api->ModelMetadataGetDomain(m, allocator, domain); 437 | } 438 | 439 | OrtStatus *ModelMetadataGetDescription(OrtModelMetadata *m, char **desc) { 440 | OrtAllocator *allocator = NULL; 441 | OrtStatus *status = NULL; 442 | status = ort_api->GetAllocatorWithDefaultOptions(&allocator); 443 | if (status) return status; 444 | return ort_api->ModelMetadataGetDescription(m, allocator, desc); 445 | } 446 | 447 | OrtStatus *ModelMetadataLookupCustomMetadataMap(OrtModelMetadata *m, char *key, 448 | char **value) { 449 | OrtAllocator *allocator = NULL; 450 | OrtStatus *status = NULL; 451 | status = ort_api->GetAllocatorWithDefaultOptions(&allocator); 452 | if (status) return status; 453 | return ort_api->ModelMetadataLookupCustomMetadataMap(m, allocator, key, 454 | value); 455 | } 456 | 457 | OrtStatus *ModelMetadataGetCustomMetadataMapKeys(OrtModelMetadata *m, 458 | char ***keys, int64_t *num_keys) { 459 | OrtAllocator *allocator = NULL; 460 | OrtStatus *status = NULL; 461 | status = ort_api->GetAllocatorWithDefaultOptions(&allocator); 462 | if (status) return status; 463 | return ort_api->ModelMetadataGetCustomMetadataMapKeys(m, allocator, keys, 464 | num_keys); 465 | } 466 | 467 | OrtStatus *ModelMetadataGetVersion(OrtModelMetadata *m, int64_t *version) { 468 | return ort_api->ModelMetadataGetVersion(m, version); 469 | } 470 | 471 | OrtStatus *GetValue(OrtValue *container, int index, OrtValue **dst) { 472 | OrtAllocator *allocator = NULL; 473 | OrtStatus *status = NULL; 474 | status = ort_api->GetAllocatorWithDefaultOptions(&allocator); 475 | if (status) return status; 476 | return ort_api->GetValue(container, index, allocator, dst); 477 | } 478 | 479 | OrtStatus *GetValueType(OrtValue *v, enum ONNXType *out) { 480 | return ort_api->GetValueType(v, out); 481 | } 482 | 483 | OrtStatus *GetValueCount(OrtValue *v, size_t *out) { 484 | return ort_api->GetValueCount(v, out); 485 | } 486 | 487 | OrtStatus *CreateOrtValue(OrtValue **in, size_t num_values, 488 | enum ONNXType value_type, OrtValue **out) { 489 | return ort_api->CreateValue((const OrtValue* const*) in, num_values, 490 | value_type, out); 491 | } 492 | -------------------------------------------------------------------------------- /onnxruntime_wrapper.h: -------------------------------------------------------------------------------- 1 | #ifndef ONNXRUNTIME_WRAPPER_H 2 | #define ONNXRUNTIME_WRAPPER_H 3 | 4 | // We want to always use the unix-like onnxruntime C APIs, even on Windows, so 5 | // we need to undefine _WIN32 before including onnxruntime_c_api.h. However, 6 | // this requires a careful song-and-dance. 7 | 8 | // First, include these common headers, as they get transitively included by 9 | // onnxruntime_c_api.h. We need to include them ourselves, first, so that the 10 | // preprocessor will skip them while _WIN32 is undefined. 11 | #include 12 | #include 13 | #include 14 | 15 | // Next, we actually include the header. 16 | #undef _WIN32 17 | #include "onnxruntime_c_api.h" 18 | 19 | // ... However, mingw will complain if _WIN32 is *not* defined! So redefine it. 20 | #define _WIN32 21 | 22 | #ifdef __cplusplus 23 | extern "C" { 24 | #endif 25 | 26 | // Used for the OrtSessionOptionsAppendExecutionProvider_CoreML function 27 | // pointer on supported systems. Must match the signature in 28 | // coreml_provider_factory.h provided along with the onnxruntime releases for 29 | // Apple platforms. 30 | typedef OrtStatus* (*AppendCoreMLProviderFn)(OrtSessionOptions*, uint32_t); 31 | 32 | // Takes a pointer to the api_base struct in order to obtain the OrtApi 33 | // pointer. Intended to be called from Go. Returns nonzero on error. 34 | int SetAPIFromBase(OrtApiBase *api_base); 35 | 36 | // Get the version of the Onnxruntime library for logging. 37 | const char *GetVersion(); 38 | 39 | // OrtSessionOptionsAppendExecutionProvider_CoreML is exported directly from 40 | // the Apple .dylib, so we call this function on Apple platforms to set the 41 | // function pointer to the correct address. On other platforms, the function 42 | // pointer should remain NULL. 43 | void SetCoreMLProviderFunctionPointer(void *ptr); 44 | 45 | // Wraps ort_api->ReleaseStatus(status) 46 | void ReleaseOrtStatus(OrtStatus *status); 47 | 48 | // Wraps calling ort_api->CreateEnv. Returns a non-NULL status on error. 49 | OrtStatus *CreateOrtEnv(char *name, OrtEnv **env); 50 | 51 | // Wraps calling ort_api->UpdateEnvWithCustomLogLevel. Return a non-NULL status 52 | // on error. 53 | OrtStatus *UpdateEnvWithCustomLogLevel(OrtEnv *ort_env, 54 | OrtLoggingLevel log_severity_level); 55 | 56 | // Wraps ort_api->DisableTelemetryEvents. Returns a non-NULL status on error. 57 | OrtStatus *DisableTelemetry(OrtEnv *env); 58 | 59 | // Wraps ort_api->EnableTelemetryEvents. Returns a non-NULL status on error. 60 | OrtStatus *EnableTelemetry(OrtEnv *env); 61 | 62 | // Wraps ort_api->ReleaseEnv 63 | void ReleaseOrtEnv(OrtEnv *env); 64 | 65 | // Wraps ort_api->CreateCpuMemoryInfo with some basic, default settings. 66 | OrtStatus *CreateOrtMemoryInfo(OrtMemoryInfo **mem_info); 67 | 68 | // Wraps ort_api->ReleaseMemoryInfo 69 | void ReleaseOrtMemoryInfo(OrtMemoryInfo *info); 70 | 71 | // Returns the message associated with the given ORT status. 72 | const char *GetErrorMessage(OrtStatus *status); 73 | 74 | // Wraps ort_api->CreateSessionOptions 75 | OrtStatus *CreateSessionOptions(OrtSessionOptions **o); 76 | 77 | // Wraps ort_api->ReleaseSessionOptions 78 | void ReleaseSessionOptions(OrtSessionOptions *o); 79 | 80 | // Wraps ort_api->SetSessionExecutionMode 81 | OrtStatus *SetSessionExecutionMode(OrtSessionOptions *o, int new_mode); 82 | 83 | // Wraps ort_api->SetSessionGraphOptimizationLevel 84 | OrtStatus *SetSessionGraphOptimizationLevel(OrtSessionOptions *o, int level); 85 | 86 | // Wraps ort_api->SetSessionLogSeverityLevel 87 | OrtStatus *SetSessionLogSeverityLevel(OrtSessionOptions *o, int level); 88 | 89 | // Wraps ort_api->AddSessionConfigEntry 90 | OrtStatus *AddSessionConfigEntry(OrtSessionOptions *o, char *key, char *value); 91 | 92 | // Wraps ort_api->HasSessionConfigEntry 93 | OrtStatus *HasSessionConfigEntry(OrtSessionOptions *o, char *key, int *result); 94 | 95 | // Wraps ort_api->GetSessionConfigEntry 96 | OrtStatus *GetSessionConfigEntry(OrtSessionOptions *o, char *key, char *result, 97 | size_t *required_size); 98 | 99 | // Wraps ort_api->SetIntraOpNumThreads 100 | OrtStatus *SetIntraOpNumThreads(OrtSessionOptions *o, int n); 101 | 102 | // Wraps ort_api->SetInterOpNumThreads 103 | OrtStatus *SetInterOpNumThreads(OrtSessionOptions *o, int n); 104 | 105 | // Wraps ort_api->EnableCpuMemArena & ort_api->DisableCpuMemArena 106 | OrtStatus *SetCpuMemArena(OrtSessionOptions *o, int use_arena); 107 | 108 | // Wraps ort_api->EnableMemPattern & ort_api->DisableMemPattern 109 | OrtStatus *SetMemPattern(OrtSessionOptions *o, int use_mem_pattern); 110 | 111 | // Wraps ort_api->SessionOptionsAppendExecutionProvider_CUDA_V2 112 | OrtStatus *AppendExecutionProviderCUDAV2(OrtSessionOptions *o, 113 | OrtCUDAProviderOptionsV2 *cuda_options); 114 | 115 | // Wraps ort_api->CreateCUDAProviderOptions 116 | OrtStatus *CreateCUDAProviderOptions(OrtCUDAProviderOptionsV2 **o); 117 | 118 | // Wraps ort_api->ReleaseCUDAProviderOptions 119 | void ReleaseCUDAProviderOptions(OrtCUDAProviderOptionsV2 *o); 120 | 121 | // Wraps ort_api->UpdateCUDAProviderOptions 122 | OrtStatus *UpdateCUDAProviderOptions(OrtCUDAProviderOptionsV2 *o, 123 | const char **keys, const char **values, int num_keys); 124 | 125 | // Wraps ort_api->CreateTensorRTProviderOptions 126 | OrtStatus *CreateTensorRTProviderOptions(OrtTensorRTProviderOptionsV2 **o); 127 | 128 | // Wraps ort_api->ReleaseTensorRTProviderOptions 129 | void ReleaseTensorRTProviderOptions(OrtTensorRTProviderOptionsV2 *o); 130 | 131 | // Wraps ort_api->UpdateTensorRTProviderOptions 132 | OrtStatus *UpdateTensorRTProviderOptions(OrtTensorRTProviderOptionsV2 *o, 133 | const char **keys, const char **values, int num_keys); 134 | 135 | // Wraps ort_api->SessionOptionsAppendExecutionProvider_TensorRT_V2 136 | OrtStatus *AppendExecutionProviderTensorRTV2(OrtSessionOptions *o, 137 | OrtTensorRTProviderOptionsV2 *tensor_rt_options); 138 | 139 | // Wraps OrtSessionOptionsAppendExecutionProvider_CoreML, exported from the 140 | // dylib on Apple devices. Safely returns a non-NULL status on other platforms. 141 | OrtStatus *AppendExecutionProviderCoreML(OrtSessionOptions *o, 142 | uint32_t flags); 143 | 144 | // Wraps OrtApi method SessionOptionsAppendExecutionProvider specifically for CoreML 145 | // with the new options-based API while including the check that coreml is supported. 146 | OrtStatus *AppendExecutionProviderCoreMLV2(OrtSessionOptions *o, 147 | const char **keys, const char **values, size_t num_options); 148 | 149 | // Wraps getting the OrtDmlApi struct and calling 150 | // dml_api->SessionOptionsAppendExecutionProvider_DML. 151 | OrtStatus *AppendExecutionProviderDirectML(OrtSessionOptions *o, 152 | int device_id); 153 | 154 | // Wraps ort_api->AppendExecutionProvider_OpenVINO_V2 155 | OrtStatus *AppendExecutionProviderOpenVINOV2(OrtSessionOptions *o, 156 | const char **keys, const char **values, int num_keys); 157 | 158 | // Creates an ORT session using the given model. The given options pointer may 159 | // be NULL; if it is, then we'll use default options. 160 | OrtStatus *CreateSession(void *model_data, size_t model_data_length, 161 | OrtEnv *env, OrtSession **out, OrtSessionOptions *options); 162 | 163 | // Like the CreateSession function, but takes a path to a model rather than a 164 | // buffer containing it. 165 | OrtStatus *CreateSessionFromFile(char *model_path, OrtEnv *env, 166 | OrtSession **out, OrtSessionOptions *options); 167 | 168 | // Runs an ORT session with the given input and output tensors, along with 169 | // their names. 170 | OrtStatus *RunOrtSession(OrtSession *session, 171 | OrtValue **inputs, char **input_names, int input_count, 172 | OrtValue **outputs, char **output_names, int output_count); 173 | 174 | // Wraps ort_api->RunWithBinding. 175 | OrtStatus *RunSessionWithBinding(OrtSession *session, OrtIoBinding *b); 176 | 177 | // Wraps ort_api->ReleaseSession 178 | void ReleaseOrtSession(OrtSession *session); 179 | 180 | // Wraps ort_api->CreateIoBinding 181 | OrtStatus *CreateIoBinding(OrtSession *session, OrtIoBinding **out); 182 | 183 | // Wraps ort_api->ReleaseIoBinding 184 | void ReleaseIoBinding(OrtIoBinding *b); 185 | 186 | // Wraps ort_api->BindInput 187 | OrtStatus *BindInput(OrtIoBinding *b, char *name, OrtValue *value); 188 | 189 | // Wraps ort_api->BindOutput 190 | OrtStatus *BindOutput(OrtIoBinding *b, char *name, OrtValue *value); 191 | 192 | // Wraps ort_api->GetBoundOutputNames. Uses the default allocator. The caller 193 | // must free the buffer and lengths using FreeWithDefaultOrtAllocator after 194 | // this is done. 195 | OrtStatus *GetBoundOutputNames(OrtIoBinding *b, char **buffer, 196 | size_t **lengths, size_t *count); 197 | 198 | // Wraps ort_api->GetBoundOutputNames. Uses the default allocator, and the 199 | // caller must free the buffer of values using FreeWithDefaultOrtAllocator. 200 | OrtStatus *GetBoundOutputValues(OrtIoBinding *b, OrtValue ***buffer, 201 | size_t *count); 202 | 203 | // Wraps ort_api->ClearBoundInputs. 204 | void ClearBoundInputs(OrtIoBinding *b); 205 | 206 | // Wraps ort_api->ClearBoundOutputs. 207 | void ClearBoundOutputs(OrtIoBinding *b); 208 | 209 | // Wraps ort_api->SessionGetInputCount. 210 | OrtStatus *SessionGetInputCount(OrtSession *session, size_t *result); 211 | 212 | // Wraps ort_api->SessionGetOutputCount. 213 | OrtStatus *SessionGetOutputCount(OrtSession *session, size_t *result); 214 | 215 | // Used to free OrtValue instances, such as tensors. 216 | void ReleaseOrtValue(OrtValue *value); 217 | 218 | // Creates an OrtValue tensor with the given shape, and backed by the user- 219 | // supplied data buffer. 220 | OrtStatus *CreateOrtTensorWithShape(void *data, size_t data_size, 221 | int64_t *shape, int64_t shape_size, OrtMemoryInfo *mem_info, 222 | ONNXTensorElementDataType dtype, OrtValue **out); 223 | 224 | // Wraps ort_api->GetTensorTypeAndShape 225 | OrtStatus *GetTensorTypeAndShape(const OrtValue *value, OrtTensorTypeAndShapeInfo **out); 226 | 227 | // Wraps ort_api->GetDimensionsCount 228 | OrtStatus *GetDimensionsCount(const OrtTensorTypeAndShapeInfo *info, size_t *out); 229 | 230 | // Wraps ort_api->GetDimensions 231 | OrtStatus *GetDimensions(const OrtTensorTypeAndShapeInfo *info, int64_t *dim_values, size_t dim_values_length); 232 | 233 | // Wraps ort_api->GetTensorElementType 234 | OrtStatus *GetTensorElementType(const OrtTensorTypeAndShapeInfo *info, enum ONNXTensorElementDataType *out); 235 | 236 | // Wraps ort_api->ReleaseTensorTypeAndShapeInfo 237 | void ReleaseTensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo *input); 238 | 239 | // Wraps ort_api->GetTensorMutableData 240 | OrtStatus *GetTensorMutableData(OrtValue *value, void **out); 241 | 242 | // Wraps ort_api->SessionGetInputName, using the default allocator. 243 | OrtStatus *SessionGetInputName(OrtSession *session, size_t i, char **name); 244 | 245 | // Wraps ort_api->SessionGetOutputName, using the default allocator. 246 | OrtStatus *SessionGetOutputName(OrtSession *session, size_t i, char **name); 247 | 248 | // Frees anything that was allocated using the default ORT allocator. 249 | OrtStatus *FreeWithDefaultORTAllocator(void *to_free); 250 | 251 | // Wraps ort_api->SessionGetInputTypeInfo. 252 | OrtStatus *SessionGetInputTypeInfo(OrtSession *session, size_t i, 253 | OrtTypeInfo **out); 254 | 255 | // Wraps ort_api->SessionGetOutputTypeInfo. 256 | OrtStatus *SessionGetOutputTypeInfo(OrtSession *session, size_t i, 257 | OrtTypeInfo **out); 258 | 259 | // If the type_info is for a tensor, sets out to the a pointer to the tensor's 260 | // NameAndTypeInfo. Do _not_ free the out pointer; it will be freed when 261 | // type_info is released. 262 | // 263 | // Wraps ort_api->CastTypeInfoToTensorInfo. 264 | OrtStatus *CastTypeInfoToTensorInfo(OrtTypeInfo *type_info, 265 | OrtTensorTypeAndShapeInfo **out); 266 | 267 | // Wraps ort_api->GetOnnxTypeFromTypeInfo. 268 | OrtStatus *GetONNXTypeFromTypeInfo(OrtTypeInfo *info, enum ONNXType *out); 269 | 270 | // Wraps ort_api->FreeTypeInfo. 271 | void ReleaseTypeInfo(OrtTypeInfo *o); 272 | 273 | // Wraps ort_spi->SessionGetModelMetadata. 274 | OrtStatus *SessionGetModelMetadata(OrtSession *s, OrtModelMetadata **out); 275 | 276 | // Wraps ort_api->ReleaseModelMetadata. 277 | void ReleaseModelMetadata(OrtModelMetadata *m); 278 | 279 | // Wraps ort_api->ModelMetadataGetProducerName, using the default allocator. 280 | OrtStatus *ModelMetadataGetProducerName(OrtModelMetadata *m, char **name); 281 | 282 | // Wraps ort_api->ModelMetadataGetGraphName, using the default allocator. 283 | OrtStatus *ModelMetadataGetGraphName(OrtModelMetadata *m, char **name); 284 | 285 | // Wraps ort_api->ModelMetadataGetDomain, using the default allocator. 286 | OrtStatus *ModelMetadataGetDomain(OrtModelMetadata *m, char **domain); 287 | 288 | // Wraps ort_api->ModelMetadataGetDescription, using the default allocator. 289 | OrtStatus *ModelMetadataGetDescription(OrtModelMetadata *m, char **desc); 290 | 291 | // Wraps ort_api->ModelMetadataLookupCustomMetadataMap, using the default 292 | // allocator. 293 | OrtStatus *ModelMetadataLookupCustomMetadataMap(OrtModelMetadata *m, char *key, 294 | char **value); 295 | 296 | // Wraps ort_api->ModelMetadataGetCustomMetadataMapKeys, using the default 297 | // allocator. 298 | OrtStatus *ModelMetadataGetCustomMetadataMapKeys(OrtModelMetadata *m, 299 | char ***keys, int64_t *num_keys); 300 | 301 | // Wraps ort_api->ModelMetadataGetVersion. 302 | OrtStatus *ModelMetadataGetVersion(OrtModelMetadata *m, int64_t *version); 303 | 304 | // Wraps ort_api->GetValue. Uses the default allocator. 305 | OrtStatus *GetValue(OrtValue *container, int index, OrtValue **dst); 306 | 307 | // Wraps ort_api->GetValueType. 308 | OrtStatus *GetValueType(OrtValue *v, enum ONNXType *out); 309 | 310 | // Wraps ort_api->GetValueCount. 311 | OrtStatus *GetValueCount(OrtValue *v, size_t *out); 312 | 313 | // Wraps ort_api->CreateValue to create a map or a sequence. 314 | OrtStatus *CreateOrtValue(OrtValue **in, size_t num_values, 315 | enum ONNXType value_type, OrtValue **out); 316 | 317 | #ifdef __cplusplus 318 | } // extern "C" 319 | #endif 320 | #endif // ONNXRUNTIME_WRAPPER_H 321 | -------------------------------------------------------------------------------- /setup_env.go: -------------------------------------------------------------------------------- 1 | //go:build !windows 2 | 3 | package onnxruntime_go 4 | 5 | import ( 6 | "fmt" 7 | "runtime" 8 | "unsafe" 9 | ) 10 | 11 | /* 12 | #cgo LDFLAGS: -ldl 13 | 14 | #include 15 | #include "onnxruntime_wrapper.h" 16 | 17 | typedef OrtApiBase* (*GetOrtApiBaseFunction)(void); 18 | 19 | // Since Go can't call C function pointers directly, we just use this helper 20 | // when calling GetApiBase 21 | OrtApiBase *CallGetAPIBaseFunction(void *fn) { 22 | OrtApiBase *to_return = ((GetOrtApiBaseFunction) fn)(); 23 | return to_return; 24 | } 25 | */ 26 | import "C" 27 | 28 | // This file includes the code for loading the onnxruntime and setting up the 29 | // environment on non-Windows systems. For now, it has been tested on Linux and 30 | // arm64 OSX. 31 | 32 | // This will contain the handle to the onnxruntime shared library if it has 33 | // been loaded successfully. 34 | var libraryHandle unsafe.Pointer 35 | 36 | func platformCleanup() error { 37 | v, e := C.dlclose(libraryHandle) 38 | if v != 0 { 39 | return fmt.Errorf("Error closing the library: %w", e) 40 | } 41 | return nil 42 | } 43 | 44 | // Should only be called on Apple systems; looks up the CoreML provider 45 | // function which should only be exported on apple onnxruntime dylib files. 46 | func setAppendCoreMLFunctionPointer(libraryHandle unsafe.Pointer) error { 47 | // This function name must match the name in coreml_provider_factory.h, 48 | // which is provided in the onnxruntime release's include/ directory on for 49 | // Apple platforms. 50 | fnName := "OrtSessionOptionsAppendExecutionProvider_CoreML" 51 | cFunctionName := C.CString(fnName) 52 | defer C.free(unsafe.Pointer(cFunctionName)) 53 | appendCoreMLProviderProc := C.dlsym(libraryHandle, cFunctionName) 54 | if appendCoreMLProviderProc == nil { 55 | msg := C.GoString(C.dlerror()) 56 | return fmt.Errorf("Error looking up %s: %s", fnName, msg) 57 | } 58 | C.SetCoreMLProviderFunctionPointer(appendCoreMLProviderProc) 59 | return nil 60 | } 61 | 62 | func platformInitializeEnvironment() error { 63 | if onnxSharedLibraryPath == "" { 64 | onnxSharedLibraryPath = "onnxruntime.so" 65 | } 66 | cName := C.CString(onnxSharedLibraryPath) 67 | defer C.free(unsafe.Pointer(cName)) 68 | handle := C.dlopen(cName, C.RTLD_LAZY) 69 | if handle == nil { 70 | msg := C.GoString(C.dlerror()) 71 | return fmt.Errorf("Error loading ONNX shared library \"%s\": %s", 72 | onnxSharedLibraryPath, msg) 73 | } 74 | cFunctionName := C.CString("OrtGetApiBase") 75 | defer C.free(unsafe.Pointer(cFunctionName)) 76 | getAPIBaseProc := C.dlsym(handle, cFunctionName) 77 | if getAPIBaseProc == nil { 78 | C.dlclose(handle) 79 | msg := C.GoString(C.dlerror()) 80 | return fmt.Errorf("Error looking up OrtGetApiBase in \"%s\": %s", 81 | onnxSharedLibraryPath, msg) 82 | } 83 | ortAPIBase := C.CallGetAPIBaseFunction(getAPIBaseProc) 84 | tmp := C.SetAPIFromBase((*C.OrtApiBase)(unsafe.Pointer(ortAPIBase))) 85 | if tmp != 0 { 86 | C.dlclose(handle) 87 | return fmt.Errorf("Error setting ORT API base: %d", tmp) 88 | } 89 | if (runtime.GOOS == "darwin") || (runtime.GOOS == "ios") { 90 | setAppendCoreMLFunctionPointer(handle) 91 | // We'll silently ignore potential errors returned by 92 | // setAppendCoreMLFunctionPointer (for now at least). Even though we're 93 | // on Apple hardware, it's possible that the user will have compiled 94 | // the onnxruntime library from source without CoreML support. 95 | // A failure here will only leave the coreml function pointer as NULL 96 | // in our C code, which will be detected and result in an error at 97 | // runtime. 98 | } 99 | libraryHandle = handle 100 | return nil 101 | } 102 | 103 | // Converts the given path to an ORTCHAR_T string, pointed to by a *C.char. The 104 | // returned string must be freed using C.free when no longer needed. This 105 | // wrapper is used for source compatibility with onnxruntime API functions 106 | // requiring paths, which must be UTF-16 on Windows but UTF-8 elsewhere. 107 | func createOrtCharString(str string) (*C.char, error) { 108 | return C.CString(str), nil 109 | } 110 | -------------------------------------------------------------------------------- /setup_env_windows.go: -------------------------------------------------------------------------------- 1 | //go:build windows 2 | 3 | package onnxruntime_go 4 | 5 | // This file includes the Windows-specific code for loading the onnxruntime 6 | // library and setting up the environment. 7 | 8 | import ( 9 | "fmt" 10 | "syscall" 11 | "unicode/utf16" 12 | "unicode/utf8" 13 | "unsafe" 14 | ) 15 | 16 | // #include "onnxruntime_wrapper.h" 17 | import "C" 18 | 19 | // This will contain the handle to the onnxruntime dll if it has been loaded 20 | // successfully. 21 | var libraryHandle syscall.Handle 22 | 23 | func platformCleanup() error { 24 | e := syscall.FreeLibrary(libraryHandle) 25 | libraryHandle = 0 26 | return e 27 | } 28 | 29 | func platformInitializeEnvironment() error { 30 | if onnxSharedLibraryPath == "" { 31 | onnxSharedLibraryPath = "onnxruntime.dll" 32 | } 33 | handle, e := syscall.LoadLibrary(onnxSharedLibraryPath) 34 | if e != nil { 35 | return fmt.Errorf("Error loading ONNX shared library \"%s\": %w", 36 | onnxSharedLibraryPath, e) 37 | } 38 | getApiBaseProc, e := syscall.GetProcAddress(handle, "OrtGetApiBase") 39 | if e != nil { 40 | syscall.FreeLibrary(handle) 41 | return fmt.Errorf("Error finding OrtGetApiBase function in %s: %w", 42 | onnxSharedLibraryPath, e) 43 | } 44 | ortApiBase, _, e := syscall.SyscallN(uintptr(getApiBaseProc), 0) 45 | if ortApiBase == 0 { 46 | syscall.FreeLibrary(handle) 47 | if e != nil { 48 | return fmt.Errorf("Error calling OrtGetApiBase: %w", e) 49 | } else { 50 | return fmt.Errorf("Error calling OrtGetApiBase") 51 | } 52 | } 53 | tmp := C.SetAPIFromBase((*C.OrtApiBase)(unsafe.Pointer(ortApiBase))) 54 | if tmp != 0 { 55 | syscall.FreeLibrary(handle) 56 | return fmt.Errorf("Error setting ORT API base: %d", tmp) 57 | } 58 | 59 | libraryHandle = handle 60 | return nil 61 | } 62 | 63 | // Converts the given string to a UTF-16 string, pointed to by a raw 64 | // *C.char. Note that we actually keep ORTCHAR_T defined to char even 65 | // on Windows, so do _not_ index into this string from Cgo code and expect to 66 | // get correct characters! Instead, this should only be used to obtain pointers 67 | // that are passed to onnxruntime windows DLL functions expecting ORTCHAR_T* 68 | // args. This is required because we undefine _WIN32 for cgo compatibility when 69 | // including onnxruntime_c_api.h, but still interact with a DLL that was 70 | // compiled assuming _WIN32 was defined. 71 | // 72 | // The pointer returned by this function must still be freed using C.free when 73 | // no longer needed. This will return an error if the given string contains 74 | // non-UTF8 characters. 75 | func createOrtCharString(str string) (*C.char, error) { 76 | src := []uint8(str) 77 | // Assumed common case: the utf16 buffer contains one uint16 per utf8 byte 78 | // plus one more for the required null terminator in the C buffer. 79 | dst := make([]uint16, 0, len(src)+1) 80 | // Convert UTF-8 to UTF-16 by reading each subsequent rune from src and 81 | // appending it as UTF-16 to dst. 82 | for len(src) > 0 { 83 | r, size := utf8.DecodeRune(src) 84 | if r == utf8.RuneError { 85 | return nil, fmt.Errorf("Invalid UTF-8 rune found in \"%s\"", str) 86 | } 87 | src = src[size:] 88 | dst = utf16.AppendRune(dst, r) 89 | } 90 | // Make sure dst contains the null terminator. Additionally this will cause 91 | // us to return an empty string if the original string was empty. 92 | dst = append(dst, 0) 93 | 94 | // Finally, we need to copy dst into a C array for compatibility with 95 | // C.CString. 96 | toReturn := C.calloc(C.size_t(len(dst)), 2) 97 | if toReturn == nil { 98 | return nil, fmt.Errorf("Error allocating buffer for the utf16 string") 99 | } 100 | C.memcpy(toReturn, unsafe.Pointer(&(dst[0])), C.size_t(len(dst))*2) 101 | 102 | return (*C.char)(toReturn), nil 103 | } 104 | -------------------------------------------------------------------------------- /tensor_type_constraints.go: -------------------------------------------------------------------------------- 1 | package onnxruntime_go 2 | 3 | // This file contains definitions for the generic tensor data types we support. 4 | 5 | // #include "onnxruntime_wrapper.h" 6 | import "C" 7 | 8 | import ( 9 | "reflect" 10 | ) 11 | 12 | type FloatData interface { 13 | ~float32 | ~float64 14 | } 15 | 16 | type IntData interface { 17 | ~int8 | ~uint8 | ~int16 | ~uint16 | ~int32 | ~uint32 | ~int64 | ~uint64 18 | } 19 | 20 | // This is used as a type constraint for the generic Tensor type. 21 | type TensorData interface { 22 | FloatData | IntData | ~bool 23 | } 24 | 25 | // Returns the ONNX enum value used to indicate TensorData type T. 26 | func GetTensorElementDataType[T TensorData]() C.ONNXTensorElementDataType { 27 | // Sadly, we can't do type assertions to get underlying types, so we need 28 | // to use reflect here instead. 29 | var v T 30 | kind := reflect.ValueOf(v).Kind() 31 | switch kind { 32 | case reflect.Float64: 33 | return C.ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE 34 | case reflect.Float32: 35 | return C.ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT 36 | case reflect.Int8: 37 | return C.ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8 38 | case reflect.Uint8: 39 | return C.ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8 40 | case reflect.Int16: 41 | return C.ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16 42 | case reflect.Uint16: 43 | return C.ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16 44 | case reflect.Int32: 45 | return C.ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32 46 | case reflect.Uint32: 47 | return C.ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32 48 | case reflect.Int64: 49 | return C.ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64 50 | case reflect.Uint64: 51 | return C.ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64 52 | case reflect.Bool: 53 | return C.ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL 54 | } 55 | return C.ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED 56 | } 57 | -------------------------------------------------------------------------------- /test_data/example ż 大 김.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yalue/onnxruntime_go/b8230be7b992cd4ecedaa9441974d932d736f42d/test_data/example ż 大 김.onnx -------------------------------------------------------------------------------- /test_data/example_0_dim_output.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yalue/onnxruntime_go/b8230be7b992cd4ecedaa9441974d932d736f42d/test_data/example_0_dim_output.onnx -------------------------------------------------------------------------------- /test_data/example_big_compute.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yalue/onnxruntime_go/b8230be7b992cd4ecedaa9441974d932d736f42d/test_data/example_big_compute.onnx -------------------------------------------------------------------------------- /test_data/example_big_fanout.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yalue/onnxruntime_go/b8230be7b992cd4ecedaa9441974d932d736f42d/test_data/example_big_fanout.onnx -------------------------------------------------------------------------------- /test_data/example_dynamic_axes.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yalue/onnxruntime_go/b8230be7b992cd4ecedaa9441974d932d736f42d/test_data/example_dynamic_axes.onnx -------------------------------------------------------------------------------- /test_data/example_float16.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yalue/onnxruntime_go/b8230be7b992cd4ecedaa9441974d932d736f42d/test_data/example_float16.onnx -------------------------------------------------------------------------------- /test_data/example_multitype.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yalue/onnxruntime_go/b8230be7b992cd4ecedaa9441974d932d736f42d/test_data/example_multitype.onnx -------------------------------------------------------------------------------- /test_data/example_several_inputs_and_outputs.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yalue/onnxruntime_go/b8230be7b992cd4ecedaa9441974d932d736f42d/test_data/example_several_inputs_and_outputs.onnx -------------------------------------------------------------------------------- /test_data/generate_0_dimension_output.py: -------------------------------------------------------------------------------- 1 | # This script creates example_0_dim_output.onnx to use in testing. The idea is 2 | # that the network produces an output with one a dimension of size 0. 3 | import torch 4 | 5 | class ZeroDimOutputModel(torch.nn.Module): 6 | """ Takes a 2x8 input, and produces a 2xNx8 output, where N is the sum of 7 | the first input column, cast to an int. In tests, the input will be all 0s, 8 | so this should result in a 2x0x8 output. """ 9 | def __init__(self): 10 | super().__init__() 11 | 12 | def forward(self, x): 13 | tmp = x.sum(0) 14 | return tmp.unsqueeze(0).expand(2, tmp.int()[0], -1) 15 | 16 | def main(): 17 | model = ZeroDimOutputModel() 18 | model.eval() 19 | x = torch.ones((2, 8), dtype=torch.float32) 20 | out_name = "example_0_dim_output.onnx" 21 | torch.onnx.export(model, (x,), out_name, input_names=["x"], 22 | output_names=["y"]) 23 | print(f"{out_name} saved OK.") 24 | 25 | if __name__ == "__main__": 26 | main() 27 | 28 | -------------------------------------------------------------------------------- /test_data/generate_big_fanout.py: -------------------------------------------------------------------------------- 1 | # This script creates example_big_fanout.onnx to use in testing. The idea is 2 | # to create a newtwork where parallelism makes a big difference. 3 | import torch 4 | 5 | class BigFanoutModel(torch.nn.Module): 6 | """ Maps a 1x4 vector to another 1x4 vector, but goes through a large 7 | number of parallelizable useless FC operations. """ 8 | def __init__(self): 9 | super().__init__() 10 | self.fanout_amount = 100 11 | self.matrices = [torch.rand((4, 4)) for i in range(self.fanout_amount)] 12 | 13 | def forward(self, x): 14 | # Do fanout_amount matrix multiplies, then merge and sum the result. 15 | tmp_results = [ 16 | torch.matmul(x, self.matrices[i]) 17 | for i in range(self.fanout_amount) 18 | ] 19 | combined_tensor = torch.cat(tmp_results) 20 | return combined_tensor.sum(0) 21 | 22 | def main(): 23 | model = BigFanoutModel() 24 | model.float() 25 | model.eval() 26 | test_input = torch.rand((1, 4), dtype=torch.float32) 27 | output_name = "example_big_fanout.onnx" 28 | test_input.requires_grad = False 29 | for param in model.parameters(): 30 | param.requires_grad = False 31 | with torch.no_grad(): 32 | torch.onnx.export(model, (test_input), output_name, 33 | input_names=["input"], output_names=["output"]) 34 | print(f"Saved {output_name} OK.") 35 | 36 | if __name__ == "__main__": 37 | main() 38 | 39 | -------------------------------------------------------------------------------- /test_data/generate_dynamic_axes_network.py: -------------------------------------------------------------------------------- 1 | # This script creates example_dynamic_sizes.py to use in testing. It takes a 2 | # batch of [-1, 10] input vectors and produces [-1] output scalars---the sum of 3 | # each input vector (where -1 is a dynamic batch size). 4 | import torch 5 | 6 | class DynamicSizeModel(torch.nn.Module): 7 | def __init__(self): 8 | super().__init__() 9 | 10 | def forward(self, input_batch): 11 | return input_batch.sum(1) 12 | 13 | def main(): 14 | model = DynamicSizeModel() 15 | model.eval() 16 | test_input = torch.rand((123, 10), dtype=torch.float32) 17 | dynamic_axes = { 18 | "input_vectors": [0], 19 | "output_scalars": [0], 20 | } 21 | output_name = "example_dynamic_axes.onnx" 22 | torch.onnx.export(model, (test_input), output_name, 23 | input_names=["input_vectors"], output_names=["output_scalars"], 24 | dynamic_axes=dynamic_axes) 25 | print(f"Saved {output_name} OK.") 26 | 27 | if __name__ == "__main__": 28 | main() 29 | 30 | -------------------------------------------------------------------------------- /test_data/generate_float16_network.py: -------------------------------------------------------------------------------- 1 | # This script creates example_float16.onnx to use in testing. 2 | # It takes one input: 3 | # - "InputA": A 2x2x2 16-bit float16 tensor 4 | # It produces one output: 5 | # - "OutputA": A 2x2x2 16-bit bfloat16 tensor 6 | # 7 | # The "network" just multiplies each element in the input by 3.0 8 | import torch 9 | 10 | class Float16Model(torch.nn.Module): 11 | def __init__(self): 12 | super().__init__() 13 | 14 | def forward(self, input_a): 15 | output_a = input_a * 3.0 16 | output_a = output_a.type(torch.bfloat16) 17 | return output_a 18 | 19 | def fake_inputs(): 20 | return torch.rand((1, 2, 2, 2), dtype=torch.float16) 21 | 22 | def main(): 23 | model = Float16Model() 24 | model.eval() 25 | input_a = torch.rand((1, 2, 2, 2), dtype=torch.float16) 26 | output_a = model(input_a) 27 | 28 | out_name = "example_float16.onnx" 29 | torch.onnx.export(model, (input_a), out_name, input_names=["InputA"], 30 | output_names=["OutputA"]) 31 | print(f"{out_name} saved OK.") 32 | 33 | if __name__ == "__main__": 34 | main() 35 | 36 | -------------------------------------------------------------------------------- /test_data/generate_network_big_compute.py: -------------------------------------------------------------------------------- 1 | # This script creates example_big_compute.onnx to use in testing. 2 | # The "network" is entirely deterministic; it simply does a large amount of 3 | # hopefully expensive arithmetic operations. 4 | # 5 | # It takes one input: "Input", a one-dimensional vector of 1024*1024*50 32-bit 6 | # floats, and produces one output, named "Output" of the same dimensions. 7 | import torch 8 | 9 | class BigComputeModel(torch.nn.Module): 10 | def __init__(self): 11 | super().__init__() 12 | 13 | def forward(self, x): 14 | for i in range(40): 15 | x = x / 10.0 16 | x = x * 10.0 17 | return x 18 | 19 | def main(): 20 | model = BigComputeModel() 21 | model.eval() 22 | x = torch.zeros((1, 1024 * 1024 * 50), dtype=torch.float32) 23 | 24 | out_name = "example_big_compute.onnx" 25 | torch.onnx.export(model, x, out_name, 26 | input_names=["Input"], output_names=["Output"]) 27 | print(f"{out_name} saved OK.") 28 | 29 | if __name__ == "__main__": 30 | main() 31 | 32 | -------------------------------------------------------------------------------- /test_data/generate_network_different_types.py: -------------------------------------------------------------------------------- 1 | # This script creates example_multitype.onnx to use in testing. 2 | # The "network" doesn't actually do much other than cast around some types and 3 | # perform basic arithmetic. It takes two inputs: 4 | # - "InputA": A 1x1 8-bit unsigned int tensor 5 | # - "InputB": A 2x2 64-bit float tensor 6 | # 7 | # It produces 2 outputs: 8 | # - "OutputA": A 2x2 16-bit signed int tensor, equal to (InputB * InputA) - 512 9 | # - "OutputB": A 1x1 64-bit int tensor, equal to InputA multiplied by 1234 10 | import torch 11 | 12 | class DifferentTypesModel(torch.nn.Module): 13 | def __init__(self): 14 | super().__init__() 15 | 16 | def forward(self, input_a, input_b): 17 | output_a = input_b * input_a[0][0][0] 18 | output_a -= 512 19 | output_a = output_a.type(torch.int16) 20 | output_b = input_a.type(torch.int64) 21 | output_b *= 1234 22 | return output_a, output_b 23 | 24 | def fake_inputs(): 25 | input_a = torch.rand((1, 1, 1)) * 255.0 26 | input_a = input_a.type(torch.uint8) 27 | input_b = torch.rand((1, 2, 2), dtype=torch.float64) 28 | return input_a, input_b 29 | 30 | 31 | def main(): 32 | model = DifferentTypesModel() 33 | model.eval() 34 | input_a, input_b = fake_inputs() 35 | output_a, output_b = model(input_a, input_b) 36 | print(f"Example inputs: A = {input_a!s}, B = {input_b!s}") 37 | print(f"Produced outputs: A = {output_a!s}, B = {output_b!s}") 38 | 39 | out_name = "example_multitype.onnx" 40 | print(f"Saving model as {out_name}") 41 | input_names = ["InputA", "InputB"] 42 | output_names = ["OutputA", "OutputB"] 43 | torch.onnx.export(model, (input_a, input_b), out_name, 44 | input_names=input_names, output_names=output_names) 45 | print(f"{out_name} saved OK.") 46 | 47 | if __name__ == "__main__": 48 | main() 49 | 50 | -------------------------------------------------------------------------------- /test_data/generate_odd_name_onnx.py: -------------------------------------------------------------------------------- 1 | # This script generates the .onnx file with a bunch of different special chars 2 | # in the filename. It takes a 1x2 uint32 tensor and produces a 1x1-element 3 | # uint32 output containing the sum of the 2 inputs. 4 | import torch 5 | 6 | class AddModel(torch.nn.Module): 7 | def __init__(self): 8 | super().__init__() 9 | 10 | def forward(self, inputs): 11 | return inputs.sum(1).int() 12 | 13 | def main(): 14 | model = AddModel() 15 | model.eval() 16 | x = torch.ones((1, 2), dtype=torch.int32) 17 | file_name = "example ż 大 김.onnx" 18 | torch.onnx.export(model, (x,), file_name, input_names=["in"], 19 | output_names=["out"]) 20 | print(f"{file_name} saved OK.") 21 | 22 | if __name__ == "__main__": 23 | main() 24 | 25 | -------------------------------------------------------------------------------- /test_data/generate_several_inputs_and_outputs.py: -------------------------------------------------------------------------------- 1 | # This script creates example_several_inputs_and_outputs.onnx to use in 2 | # testing. The "network" is entirely deterministic, and is intended just to 3 | # illustrate a wide variety of inputs and outputs with varying names, 4 | # dimensions, and types. 5 | # 6 | # Inputs: 7 | # - "input 1": a 2x5x2x5 int32 tensor 8 | # - "input 2": a 2x3x20 float tensor 9 | # - "input 3": a 9-element bfloat16 tensor 10 | # 11 | # Outputs: 12 | # - "output 1": A 10x10 element int64 tensor 13 | # - "output 2": A 1x2x3x4x5 element double tensor 14 | # 15 | # The contents of the inputs and outputs are arbitrary. 16 | import torch 17 | 18 | class ManyInputOutputModel(torch.nn.Module): 19 | def __init__(self): 20 | super().__init__() 21 | 22 | def forward(self, a, b, c): 23 | output_a = a.reshape((10, 10)) 24 | output_a = output_a.type(torch.int64) 25 | output_b = b.reshape((1, 2, 3, 4, 5)) 26 | output_b = output_b.type(torch.double) 27 | # Just to make sure we use input C. 28 | output_a[0][0] += c[0].type(torch.int64) 29 | return output_a, output_b 30 | 31 | def main(): 32 | model = ManyInputOutputModel() 33 | model.eval() 34 | out_name = "example_several_inputs_and_outputs.onnx" 35 | input_a = torch.zeros((2, 5, 2, 5), dtype=torch.int32) 36 | input_b = torch.zeros((2, 3, 20), dtype=torch.float) 37 | input_c = torch.zeros((9), dtype=torch.bfloat16) 38 | torch.onnx.export(model, (input_a, input_b, input_c), out_name, 39 | input_names = ["input 1", "input 2", "input 3"], 40 | output_names = ["output 1", "output 2"]) 41 | print(f"{out_name} saved OK.") 42 | 43 | if __name__ == "__main__": 44 | main() 45 | 46 | -------------------------------------------------------------------------------- /test_data/generate_sklearn_network.py: -------------------------------------------------------------------------------- 1 | # This script is a modified version of the example from 2 | # https://pypi.org/project/skl2onnx/, which we use to produce 3 | # sklearn_randomforest.onnx. sklearn makes heavy use of onnxruntime maps and 4 | # sequences in its networks, so this is used for testing those data types. 5 | 6 | import numpy as np 7 | from sklearn.datasets import load_iris 8 | from sklearn.model_selection import train_test_split 9 | from sklearn.ensemble import RandomForestClassifier 10 | 11 | iris = load_iris() 12 | inputs, outputs = iris.data, iris.target 13 | inputs = inputs.astype(np.float32) 14 | inputs_train, inputs_test, outputs_train, outputs_test = train_test_split(inputs, outputs) 15 | classifier = RandomForestClassifier() 16 | classifier.fit(inputs_train, outputs_train) 17 | 18 | # Convert into ONNX format. 19 | from skl2onnx import to_onnx 20 | output_filename = "sklearn_randomforest.onnx" 21 | onnx_content = to_onnx(classifier, inputs[:1]) 22 | with open(output_filename, "wb") as f: 23 | f.write(onnx_content.SerializeToString()) 24 | 25 | # Compute the prediction with onnxruntime. 26 | import onnxruntime as ort 27 | 28 | def float_formatter(f): 29 | return f"{float(f):.06f}" 30 | 31 | np.set_printoptions(formatter = {'float_kind': float_formatter}) 32 | session = ort.InferenceSession(output_filename) 33 | print(f"Input names: {[n.name for n in session.get_inputs()]!s}") 34 | print(f"Output names: {[o.name for o in session.get_outputs()]!s}") 35 | example_inputs = inputs_test.astype(np.float32)[:6] 36 | print(f"Inputs shape = {example_inputs.shape!s}") 37 | onnx_predictions = session.run(["output_label", "output_probability"], 38 | {"X": example_inputs}) 39 | labels = onnx_predictions[0] 40 | probabilities = onnx_predictions[1] 41 | 42 | print(f"Inputs to network: {example_inputs.astype(np.float32)}") 43 | print(f"ONNX predicted labels: {labels!s}") 44 | print(f"ONNX predicted probabilities: {probabilities!s}") 45 | 46 | -------------------------------------------------------------------------------- /test_data/modify_metadata.py: -------------------------------------------------------------------------------- 1 | # This file is used to modify the metadata of example_big_compute.onnx for the 2 | # sake of testing. The choice of example_big_compute.onnx was arbitrary. 3 | import onnx 4 | 5 | def main(): 6 | file_path = "example_big_compute.onnx" 7 | model = onnx.load(file_path) 8 | model.doc_string = "This is a test description." 9 | model.model_version = 1337 10 | model.domain = "test domain" 11 | custom_1 = model.metadata_props.add() 12 | custom_1.key = "test key 1" 13 | custom_1.value = "" 14 | custom_2 = model.metadata_props.add() 15 | custom_2.key = "test key 2" 16 | custom_2.value = "Test key 2 value" 17 | onnx.save(model, file_path) 18 | 19 | if __name__ == "__main__": 20 | main() 21 | -------------------------------------------------------------------------------- /test_data/onnxruntime.dll: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yalue/onnxruntime_go/b8230be7b992cd4ecedaa9441974d932d736f42d/test_data/onnxruntime.dll -------------------------------------------------------------------------------- /test_data/onnxruntime_amd64.dylib: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yalue/onnxruntime_go/b8230be7b992cd4ecedaa9441974d932d736f42d/test_data/onnxruntime_amd64.dylib -------------------------------------------------------------------------------- /test_data/onnxruntime_arm64.dylib: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yalue/onnxruntime_go/b8230be7b992cd4ecedaa9441974d932d736f42d/test_data/onnxruntime_arm64.dylib -------------------------------------------------------------------------------- /test_data/onnxruntime_arm64.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yalue/onnxruntime_go/b8230be7b992cd4ecedaa9441974d932d736f42d/test_data/onnxruntime_arm64.so -------------------------------------------------------------------------------- /test_data/sklearn_randomforest.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yalue/onnxruntime_go/b8230be7b992cd4ecedaa9441974d932d736f42d/test_data/sklearn_randomforest.onnx --------------------------------------------------------------------------------