├── .gitignore ├── LICENSE ├── README.md ├── cgo_static.go ├── const.go ├── core.cpp ├── core.go ├── core.h ├── dockerfile └── Dockerfile_ubuntu_arm64_example ├── examples └── predict_example.go └── go.mod /.gitignore: -------------------------------------------------------------------------------- 1 | *.DS_Store 2 | *.json 3 | .vscode 4 | vendor 5 | files/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Ivan Suteja 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # go-onnxruntime 2 | Unofficial C binding for Onnxruntime in Golang. 3 | This is used to perform onnx model inference in Go. 4 | 5 | ## Installation 6 | 7 | Download and install go-onnxruntime : 8 | 9 | ``` 10 | go get -v github.com/ivansuteja96/go-onnxruntime 11 | ``` 12 | 13 | The binding requires Onnxruntime C++ and Go 1.14++. 14 | 15 | ### Onnxruntime C++ Library 16 | 17 | The C binding for Onnxruntime in Golang. This repository is built based on Onnxruntime v1.11.0. 18 | 19 | To install Onnxruntime C++ on your system, you can go to [onnxruntime](https://github.com/microsoft/onnxruntime/releases/tag/v1.11.0) and download the assets depends on your system (linux/mac). 20 | 21 | The Onnxruntime C++ libraries are expected to be under `/usr/local/lib`. 22 | 23 | The Onnxruntime C++ header files are expected to be under `/usrl/local/include`. 24 | 25 | 26 | ### Configure Environmental Variables 27 | 28 | Configure the linker environmental variables since the Onnxruntime C++ library is under a non-system directory. Place the following in either your `~/.bashrc` or `~/.zshrc` file : 29 | 30 | Linux (.bashrc) / macOS (.zshrc) 31 | ``` 32 | export LIBRARY_PATH=$LIBRARY_PATH:/usr/local/lib 33 | export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib 34 | ``` 35 | 36 | After that please run either `source ~/.bashrc`(linux) or `source ~/.zshrc`(macOS). 37 | 38 | ## How to Run 39 | 40 | For the quick experience you can run go-onnxruntime using this command : 41 | ``` 42 | docker build --platform linux/arm64/v8 -f dockerfile/Dockerfile_ubuntu_arm64_example -t go-onnxruntime . 43 | docker run --rm -it go-onnxruntime:latest 44 | ``` 45 | 46 | **_Note_** : Currently we only provide Dockerfile for ubuntu arm64 architecture. 47 | 48 | 49 | ## Examples 50 | 51 | Examples of using the Go Onnxruntime binding to do model inference are under [examples](examples). 52 | 53 | ## Credits 54 | 55 | Some of the logic of conversion is referenced from https://github.com/c3sr/go-onnxruntime. 56 | -------------------------------------------------------------------------------- /cgo_static.go: -------------------------------------------------------------------------------- 1 | package onnxruntime 2 | 3 | // Changes here should be mirrored in contrib/cgo_static.go and cuda/cgo_static.go. 4 | 5 | /* 6 | #cgo CXXFLAGS: --std=c++11 7 | #cgo !windows CPPFLAGS: -I/usr/local/include 8 | #cgo !windows LDFLAGS: -L/usr/local/lib -lonnxruntime -lstdc++ 9 | */ 10 | import "C" 11 | -------------------------------------------------------------------------------- /const.go: -------------------------------------------------------------------------------- 1 | package onnxruntime 2 | 3 | const ( 4 | ORT_LOGGING_LEVEL_VERBOSE ORTLoggingLevel = iota // Verbose informational messages (least severe). 5 | ORT_LOGGING_LEVEL_INFO // Informational messages. 6 | ORT_LOGGING_LEVEL_WARNING // Warning messages. 7 | ORT_LOGGING_LEVEL_ERROR // Error messages. 8 | ORT_LOGGING_LEVEL_FATAL // Fatal error messages (most severe). 9 | ) 10 | 11 | const ( 12 | ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED OnnxTensorElementDataType = iota 13 | ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT 14 | ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8 15 | ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8 16 | ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16 17 | ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16 18 | ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32 19 | ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64 20 | ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL 21 | ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE 22 | ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32 23 | ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64 24 | ) 25 | 26 | const ( 27 | OrtCudnnConvAlgoSearchExhaustive CudnnConvAlgoSearch = iota // expensive exhaustive benchmarking using cudnnFindConvolutionForwardAlgorithmEx 28 | OrtCudnnConvAlgoSearchHeuristic // lightweight heuristic based search using cudnnGetConvolutionForwardAlgorithm_v7 29 | OrtCudnnConvAlgoSearchDefault // default algorithm using CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM 30 | ) 31 | -------------------------------------------------------------------------------- /core.cpp: -------------------------------------------------------------------------------- 1 | #include "core.h" 2 | 3 | ORTSessionOptions ORTSessionOptions_New() { 4 | return new Ort::SessionOptions(); 5 | } 6 | 7 | void ORTSessionOptions_AppendExecutionProvider_CUDA(ORTSessionOptions session_options, CudaOptions cuda_options) { 8 | OrtCUDAProviderOptions ort_cuda_options; 9 | ort_cuda_options.device_id = cuda_options.device_id; 10 | ort_cuda_options.cudnn_conv_algo_search = (OrtCudnnConvAlgoSearch)cuda_options.cudnn_conv_algo_search; 11 | ort_cuda_options.gpu_mem_limit = cuda_options.gpu_mem_limit; 12 | ort_cuda_options.arena_extend_strategy = cuda_options.arena_extend_strategy; 13 | ort_cuda_options.do_copy_in_default_stream = cuda_options.do_copy_in_default_stream; 14 | ort_cuda_options.has_user_compute_stream = cuda_options.has_user_compute_stream; 15 | (*session_options).AppendExecutionProvider_CUDA(ort_cuda_options); 16 | } 17 | 18 | ORTEnv ORTEnv_New(int logging_level,char* log_env) { 19 | return new Ort::Env(OrtLoggingLevel(logging_level),log_env); 20 | } 21 | 22 | ORTSession* ORTSession_New(ORTEnv ort_env,char* model_location, ORTSessionOptions session_options){ 23 | char* env_num_threads = std::getenv("ORT_NUM_THREADS"); 24 | if(env_num_threads) { 25 | const int num_threads = std::stoi(env_num_threads); 26 | (*session_options).SetIntraOpNumThreads(num_threads); 27 | } 28 | 29 | auto session = new Ort::Session(*ort_env, model_location, *session_options); 30 | Ort::AllocatorWithDefaultOptions allocator; 31 | size_t num_input_nodes = (*session).GetInputCount(); 32 | char **input_node_names = NULL; 33 | input_node_names = (char**)realloc(input_node_names, num_input_nodes*sizeof(*input_node_names)); 34 | 35 | // iterate over all input nodes 36 | for (int i = 0; i < num_input_nodes; i++) { 37 | char* input_name = (*session).GetInputName(i, allocator); 38 | auto shapes = (*session).GetInputTypeInfo(i).GetTensorTypeAndShapeInfo().GetShape(); 39 | input_node_names[i] = input_name; 40 | printf("Input %d : name=%s shape=", i, input_name); 41 | for (size_t i = 0; i < shapes.size(); ++i) { 42 | printf("%ld", shapes[i]); 43 | if (i < shapes.size() - 1) 44 | printf(","); 45 | } 46 | printf("\n"); 47 | } 48 | 49 | size_t num_output_nodes = (*session).GetOutputCount(); 50 | char **output_node_names = NULL; 51 | output_node_names = (char**)realloc(output_node_names, num_output_nodes*sizeof(*output_node_names)); 52 | 53 | // iterate over all output nodes 54 | for (int i = 0; i < num_output_nodes; i++) { 55 | char* output_name = (*session).GetOutputName(i, allocator); 56 | auto shapes = (*session).GetOutputTypeInfo(i).GetTensorTypeAndShapeInfo().GetShape(); 57 | output_node_names[i] = output_name; 58 | printf("Output %d : name=%s shape=", i, output_name); 59 | for (size_t i = 0; i < shapes.size(); ++i) { 60 | printf("%ld", shapes[i]); 61 | if (i < shapes.size() - 1) 62 | printf(","); 63 | } 64 | printf("\n"); 65 | } 66 | 67 | auto res = new ORTSession{session, input_node_names,num_input_nodes, output_node_names, num_output_nodes}; 68 | return res; 69 | } 70 | 71 | ORTValues* ORTValues_New(){ 72 | return new ORTValues{}; 73 | } 74 | 75 | void ORTValues_AppendTensor(TensorVector tensor_input, ORTValues *ort_values){ 76 | auto memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault); 77 | 78 | switch (tensor_input.data_type) { 79 | case ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED: 80 | throw std::runtime_error(std::string("undefined data type detected in ORTValues_AppendTensor")); 81 | break; 82 | case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: 83 | (*ort_values).emplace_back(Ort::Value::CreateTensor(memory_info, (float*)tensor_input.val, tensor_input.length, (int64_t*)tensor_input.shape.val, (size_t)tensor_input.shape.length)); 84 | break; 85 | case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: 86 | (*ort_values).emplace_back(Ort::Value::CreateTensor(memory_info, (uint8_t*)tensor_input.val, tensor_input.length, (int64_t*)tensor_input.shape.val, (size_t)tensor_input.shape.length)); 87 | break; 88 | case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: 89 | (*ort_values).emplace_back(Ort::Value::CreateTensor(memory_info, (int8_t*)tensor_input.val, tensor_input.length, (int64_t*)tensor_input.shape.val, (size_t)tensor_input.shape.length)); 90 | break; 91 | case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: 92 | (*ort_values).emplace_back(Ort::Value::CreateTensor(memory_info, (uint16_t*)(tensor_input.val), tensor_input.length, (int64_t*)tensor_input.shape.val, (size_t)tensor_input.shape.length)); 93 | break; 94 | case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16: 95 | (*ort_values).emplace_back(Ort::Value::CreateTensor(memory_info, (int16_t*)(tensor_input.val), tensor_input.length, (int64_t*)tensor_input.shape.val, (size_t)tensor_input.shape.length)); 96 | break; 97 | case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: 98 | (*ort_values).emplace_back(Ort::Value::CreateTensor(memory_info, (int32_t*)(tensor_input.val), tensor_input.length, (int64_t*)tensor_input.shape.val, (size_t)tensor_input.shape.length)); 99 | break; 100 | case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: 101 | (*ort_values).emplace_back(Ort::Value::CreateTensor(memory_info, (int64_t*)(tensor_input.val), tensor_input.length, (int64_t*)tensor_input.shape.val, (size_t)tensor_input.shape.length)); 102 | break; 103 | case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: 104 | (*ort_values).emplace_back(Ort::Value::CreateTensor(memory_info, (bool*)(tensor_input.val), tensor_input.length, (int64_t*)tensor_input.shape.val, (size_t)tensor_input.shape.length)); 105 | break; 106 | case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: 107 | (*ort_values).emplace_back(Ort::Value::CreateTensor(memory_info, (double*)(tensor_input.val), tensor_input.length, (int64_t*)tensor_input.shape.val, (size_t)tensor_input.shape.length)); 108 | break; 109 | case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: 110 | (*ort_values).emplace_back(Ort::Value::CreateTensor(memory_info, (uint32_t*)(tensor_input.val), tensor_input.length, (int64_t*)tensor_input.shape.val, (size_t)tensor_input.shape.length)); 111 | break; 112 | case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64: 113 | (*ort_values).emplace_back(Ort::Value::CreateTensor(memory_info, (uint64_t*)(tensor_input.val), tensor_input.length, (int64_t*)tensor_input.shape.val, (size_t)tensor_input.shape.length)); 114 | break; 115 | default: // c++: FLOAT16; onnxruntime: COMPLEX64, COMPLEX128, BFLOAT16; TODO: Implement String method 116 | throw std::runtime_error(std::string("unsupported data type detected in ORTValues_AppendTensor")); 117 | } 118 | return ; 119 | } 120 | 121 | void *ORTValue_GetTensorMutableData(Ort::Value& ort_value, size_t size){ 122 | void *res = NULL; 123 | switch ((ort_value).GetTensorTypeAndShapeInfo().GetElementType()) { 124 | case ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED: 125 | throw std::runtime_error(std::string("undefined data type detected in ORTValue_GetTensorMutableData")); 126 | break; 127 | case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: 128 | res = (void*) malloc(sizeof(float) * size); 129 | memcpy(res, ort_value.GetTensorMutableData(), sizeof(float) * size); 130 | break; 131 | case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: 132 | res = (void*) malloc(sizeof(uint8_t) * size); 133 | memcpy(res, ort_value.GetTensorMutableData(), sizeof(uint8_t) * size); 134 | break; 135 | case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: 136 | res = (void*) malloc(sizeof(int8_t) * size); 137 | memcpy(res, ort_value.GetTensorMutableData(), sizeof(int8_t) * size); 138 | break; 139 | case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: 140 | res = (void*) malloc(sizeof(uint16_t) * size); 141 | memcpy(res, ort_value.GetTensorMutableData(), sizeof(uint16_t) * size); 142 | break; 143 | case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16: 144 | res = (void*) malloc(sizeof(int16_t) * size); 145 | memcpy(res, ort_value.GetTensorMutableData(), sizeof(int16_t) * size); 146 | break; 147 | case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: 148 | res = (void*) malloc(sizeof(int32_t) * size); 149 | memcpy(res, ort_value.GetTensorMutableData(), sizeof(int32_t) * size); 150 | break; 151 | case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: 152 | res = (void*) malloc(sizeof(int64_t) * size); 153 | memcpy(res, ort_value.GetTensorMutableData(), sizeof(int64_t) * size); 154 | break; 155 | case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: 156 | res = (void*) malloc(sizeof(bool) * size); 157 | memcpy(res, ort_value.GetTensorMutableData(), sizeof(bool) * size); 158 | break; 159 | case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: 160 | res = (void*) malloc(sizeof(double) * size); 161 | memcpy(res, ort_value.GetTensorMutableData(), sizeof(double) * size); 162 | break; 163 | case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: 164 | res = (void*) malloc(sizeof(uint32_t) * size); 165 | memcpy(res, ort_value.GetTensorMutableData(), sizeof(uint32_t) * size); 166 | break; 167 | case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64: 168 | res = (void*) malloc(sizeof(uint64_t) * size); 169 | memcpy(res, ort_value.GetTensorMutableData(), sizeof(uint64_t) * size); 170 | default: // c++: FLOAT16; onnxruntime: COMPLEX64, COMPLEX128, BFLOAT16; TODO: Implement String method 171 | throw std::runtime_error(std::string("unsupported data type detected in ORTValue_GetTensorMutableData")); 172 | } 173 | return res; 174 | } 175 | 176 | TensorVectors ORTSession_Predict(ORTSession* session, ORTValues *ort_values_input){ 177 | // score model & input tensor, get back output tensor 178 | auto output_tensors = (*session->session).Run(Ort::RunOptions{nullptr}, session->input_node_names, (*ort_values_input).data(), session->input_node_names_length, session->output_node_names, session->output_node_names_length); 179 | 180 | auto output_tensors_count = output_tensors.size(); 181 | TensorVector* vector_tv = (TensorVector*)realloc(vector_tv, output_tensors_count*sizeof(*vector_tv)); 182 | for (size_t i = 0; i< output_tensors_count;i++){ 183 | auto output_shape_vector = output_tensors[i].GetTensorTypeAndShapeInfo().GetShape(); 184 | auto element_type = output_tensors[i].GetTensorTypeAndShapeInfo().GetElementType(); 185 | 186 | int output_length = 1; 187 | auto output_shape_size = output_shape_vector.size(); 188 | for (int i=0;iinput_node_names); 210 | free(session->output_node_names); 211 | free(session); 212 | } 213 | 214 | void TensorVectors_Clear(TensorVectors tvs){ 215 | for (int i = 0; i < tvs.length; i++) { 216 | free(tvs.arr_vector[i].shape.val); 217 | free(tvs.arr_vector[i].val); 218 | } 219 | free(tvs.arr_vector); 220 | } -------------------------------------------------------------------------------- /core.go: -------------------------------------------------------------------------------- 1 | package onnxruntime 2 | 3 | /* 4 | #include 5 | #include "core.h" 6 | */ 7 | import "C" 8 | import ( 9 | "errors" 10 | "fmt" 11 | "os" 12 | "path/filepath" 13 | "unsafe" 14 | ) 15 | 16 | type ( 17 | ORTSessionOptions struct { 18 | sessOpts C.ORTSessionOptions 19 | } 20 | ORTSession struct { 21 | sess *C.struct_ORTSession 22 | } 23 | ORTEnv struct { 24 | env C.ORTEnv 25 | } 26 | ORTValues struct { 27 | val *C.ORTValues 28 | } 29 | ORTLoggingLevel int 30 | OnnxTensorElementDataType int 31 | CudnnConvAlgoSearch int 32 | TensorValue struct { 33 | Value interface{} 34 | Shape []int64 35 | } 36 | CudaOptions struct { 37 | DeviceID int 38 | CudnnConvAlgoSearch CudnnConvAlgoSearch 39 | GPUMemorylimit int 40 | ArenaExtendStrategy bool 41 | DoCopyInDefaultStream bool 42 | HasUserComputeStream bool 43 | } 44 | ) 45 | 46 | // NewORTEnv Create onnxruntime environment 47 | func NewORTEnv(loggingLevel ORTLoggingLevel, logEnv string) (ortEnv *ORTEnv) { 48 | cLogEnv := C.CString(logEnv) 49 | ortEnv = &ORTEnv{ 50 | env: C.ORTEnv_New(C.int(int(loggingLevel)), cLogEnv), 51 | } 52 | C.free(unsafe.Pointer(cLogEnv)) 53 | return ortEnv 54 | } 55 | 56 | func (o *ORTEnv) Close() error { 57 | C.free(unsafe.Pointer(o.env)) 58 | return nil 59 | } 60 | 61 | // NewORTSessionOptions return empty onnxruntime session options. 62 | func NewORTSessionOptions() *ORTSessionOptions { 63 | return &ORTSessionOptions{sessOpts: C.ORTSessionOptions_New()} 64 | } 65 | 66 | func (so ORTSessionOptions) Close() error { 67 | C.free(unsafe.Pointer(so.sessOpts)) 68 | return nil 69 | } 70 | 71 | // AppendExecutionProviderCUDA append cuda device to the session options. 72 | func (so ORTSessionOptions) AppendExecutionProviderCUDA(cudaOptions CudaOptions) { 73 | var intDoCopyInDefaultStream int 74 | if cudaOptions.DoCopyInDefaultStream { 75 | intDoCopyInDefaultStream = 1 76 | } 77 | 78 | var intHasUserComputeStream int 79 | if cudaOptions.HasUserComputeStream { 80 | intHasUserComputeStream = 1 81 | } 82 | 83 | var intArenaExtendStrategy int 84 | if cudaOptions.ArenaExtendStrategy { 85 | intArenaExtendStrategy = 1 86 | } 87 | C.ORTSessionOptions_AppendExecutionProvider_CUDA(so.sessOpts, C.CudaOptions{ 88 | device_id: C.int(cudaOptions.DeviceID), 89 | cudnn_conv_algo_search: C.int(cudaOptions.CudnnConvAlgoSearch), 90 | gpu_mem_limit: C.int(cudaOptions.GPUMemorylimit), 91 | arena_extend_strategy: C.int(intArenaExtendStrategy), 92 | do_copy_in_default_stream: C.int(intDoCopyInDefaultStream), 93 | has_user_compute_stream: C.int(intHasUserComputeStream), 94 | }) 95 | } 96 | 97 | // NewORTSession return new onnxruntime session 98 | func NewORTSession(ortEnv *ORTEnv, modelLocation string, sessionOptions *ORTSessionOptions) (ortSession *ORTSession, err error) { 99 | if ortEnv == nil { 100 | return ortSession, fmt.Errorf("error nil ort env") 101 | } 102 | if _, err = os.Stat(modelLocation); errors.Is(err, os.ErrNotExist) { 103 | return 104 | } else if fileExtension := filepath.Ext(modelLocation); fileExtension != ".onnx" { 105 | err = errors.New("file isn't an onnx model") 106 | return 107 | } 108 | if sessionOptions == nil { 109 | return ortSession, fmt.Errorf("error nil ort session options") 110 | } 111 | 112 | cModelLocation := C.CString(modelLocation) 113 | ortSession = &ORTSession{sess: C.ORTSession_New(ortEnv.env, cModelLocation, sessionOptions.sessOpts)} 114 | C.free(unsafe.Pointer(cModelLocation)) 115 | 116 | return ortSession, nil 117 | } 118 | 119 | // newTensorVector generate C.TensorVector 120 | func newTensorVector(tv TensorValue) (ctv C.TensorVector, err error) { 121 | switch tv.Value.(type) { 122 | case []float32: 123 | { 124 | val, _ := tv.Value.([]float32) 125 | ctv = C.TensorVector{ 126 | val: unsafe.Pointer(&val[0]), 127 | data_type: C.int(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT), 128 | length: C.int(len(val)), 129 | shape: C.LongVector{ 130 | val: (*C.long)(&tv.Shape[0]), 131 | length: C.int(len(tv.Shape)), 132 | }, 133 | } 134 | } 135 | case []uint8: 136 | { 137 | val := tv.Value.([]uint8) 138 | ctv = C.TensorVector{ 139 | val: unsafe.Pointer(&val[0]), 140 | data_type: C.int(ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8), 141 | length: C.int(len(val)), 142 | shape: C.LongVector{ 143 | val: (*C.long)(&tv.Shape[0]), 144 | length: C.int(len(tv.Shape)), 145 | }, 146 | } 147 | } 148 | case []int8: 149 | { 150 | val := tv.Value.([]int8) 151 | ctv = C.TensorVector{ 152 | val: unsafe.Pointer(&val[0]), 153 | data_type: C.int(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8), 154 | length: C.int(len(val)), 155 | shape: C.LongVector{ 156 | val: (*C.long)(&tv.Shape[0]), 157 | length: C.int(len(tv.Shape)), 158 | }, 159 | } 160 | } 161 | case []uint16: 162 | { 163 | val := tv.Value.([]uint16) 164 | ctv = C.TensorVector{ 165 | val: unsafe.Pointer(&val[0]), 166 | data_type: C.int(ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16), 167 | length: C.int(len(val)), 168 | shape: C.LongVector{ 169 | val: (*C.long)(&tv.Shape[0]), 170 | length: C.int(len(tv.Shape)), 171 | }, 172 | } 173 | } 174 | case []int16: 175 | { 176 | val := tv.Value.([]int16) 177 | ctv = C.TensorVector{ 178 | val: unsafe.Pointer(&val[0]), 179 | data_type: C.int(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16), 180 | length: C.int(len(val)), 181 | shape: C.LongVector{ 182 | val: (*C.long)(&tv.Shape[0]), 183 | length: C.int(len(tv.Shape)), 184 | }, 185 | } 186 | } 187 | case []int32: 188 | { 189 | val := tv.Value.([]int32) 190 | ctv = C.TensorVector{ 191 | val: unsafe.Pointer(&val[0]), 192 | data_type: C.int(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32), 193 | length: C.int(len(val)), 194 | shape: C.LongVector{ 195 | val: (*C.long)(&tv.Shape[0]), 196 | length: C.int(len(tv.Shape)), 197 | }, 198 | } 199 | } 200 | case []int64: 201 | { 202 | val := tv.Value.([]int64) 203 | ctv = C.TensorVector{ 204 | val: unsafe.Pointer(&val[0]), 205 | data_type: C.int(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64), 206 | length: C.int(len(val)), 207 | shape: C.LongVector{ 208 | val: (*C.long)(&tv.Shape[0]), 209 | length: C.int(len(tv.Shape)), 210 | }, 211 | } 212 | } 213 | case []bool: 214 | { 215 | val := tv.Value.([]bool) 216 | ctv = C.TensorVector{ 217 | val: unsafe.Pointer(&val[0]), 218 | data_type: C.int(ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL), 219 | length: C.int(len(val)), 220 | shape: C.LongVector{ 221 | val: (*C.long)(&tv.Shape[0]), 222 | length: C.int(len(tv.Shape)), 223 | }, 224 | } 225 | } 226 | case []float64: 227 | { 228 | val := tv.Value.([]float64) 229 | ctv = C.TensorVector{ 230 | val: unsafe.Pointer(&val[0]), 231 | data_type: C.int(ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE), 232 | length: C.int(len(val)), 233 | shape: C.LongVector{ 234 | val: (*C.long)(&tv.Shape[0]), 235 | length: C.int(len(tv.Shape)), 236 | }, 237 | } 238 | } 239 | case []uint32: 240 | { 241 | val := tv.Value.([]uint32) 242 | ctv = C.TensorVector{ 243 | val: unsafe.Pointer(&val[0]), 244 | data_type: C.int(ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32), 245 | length: C.int(len(val)), 246 | shape: C.LongVector{ 247 | val: (*C.long)(&tv.Shape[0]), 248 | length: C.int(len(tv.Shape)), 249 | }, 250 | } 251 | } 252 | case []uint64: 253 | { 254 | val := tv.Value.([]uint64) 255 | ctv = C.TensorVector{ 256 | val: unsafe.Pointer(&val[0]), 257 | data_type: C.int(ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64), 258 | length: C.int(len(val)), 259 | shape: C.LongVector{ 260 | val: (*C.long)(&tv.Shape[0]), 261 | length: C.int(len(tv.Shape)), 262 | }, 263 | } 264 | } 265 | default: 266 | err = errors.New("invalid data type") 267 | } 268 | return 269 | } 270 | 271 | // cTensorVectorToGo convert C.TensorVector to Go Value 272 | func cTensorVectorToGo(cVal C.TensorVector) (goVal interface{}, shape []int64, err error) { 273 | cShapeValue := unsafe.Pointer(cVal.shape.val) 274 | shape = make([]int64, int64(cVal.shape.length)) 275 | copy(shape, (*[1 << 30]int64)(cShapeValue)[:]) 276 | 277 | switch cVal.data_type { 278 | case C.ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED: 279 | err = errors.New("undefined data type!") 280 | case C.ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: 281 | { 282 | tensorValue := make([]float32, int(cVal.length)) 283 | cTensorValue := unsafe.Pointer(cVal.val) 284 | copy(tensorValue, (*[1 << 30]float32)(cTensorValue)[:]) 285 | goVal = tensorValue 286 | } 287 | case C.ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: 288 | { 289 | tensorValue := make([]uint8, int(cVal.length)) 290 | cTensorValue := unsafe.Pointer(cVal.val) 291 | copy(tensorValue, (*[1 << 30]uint8)(cTensorValue)[:]) 292 | goVal = tensorValue 293 | } 294 | case C.ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: 295 | { 296 | tensorValue := make([]int8, int(cVal.length)) 297 | cTensorValue := unsafe.Pointer(cVal.val) 298 | copy(tensorValue, (*[1 << 30]int8)(cTensorValue)[:]) 299 | goVal = tensorValue 300 | } 301 | case C.ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: 302 | { 303 | tensorValue := make([]uint16, int(cVal.length)) 304 | cTensorValue := unsafe.Pointer(cVal.val) 305 | copy(tensorValue, (*[1 << 30]uint16)(cTensorValue)[:]) 306 | goVal = tensorValue 307 | } 308 | case C.ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16: 309 | { 310 | tensorValue := make([]int16, int(cVal.length)) 311 | cTensorValue := unsafe.Pointer(cVal.val) 312 | copy(tensorValue, (*[1 << 30]int16)(cTensorValue)[:]) 313 | goVal = tensorValue 314 | } 315 | case C.ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: 316 | { 317 | tensorValue := make([]int32, int(cVal.length)) 318 | cTensorValue := unsafe.Pointer(cVal.val) 319 | copy(tensorValue, (*[1 << 30]int32)(cTensorValue)[:]) 320 | goVal = tensorValue 321 | } 322 | case C.ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: 323 | { 324 | tensorValue := make([]int64, int(cVal.length)) 325 | cTensorValue := unsafe.Pointer(cVal.val) 326 | copy(tensorValue, (*[1 << 30]int64)(cTensorValue)[:]) 327 | goVal = tensorValue 328 | } 329 | case C.ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: 330 | { 331 | tensorValue := make([]bool, int(cVal.length)) 332 | cTensorValue := unsafe.Pointer(cVal.val) 333 | copy(tensorValue, (*[1 << 30]bool)(cTensorValue)[:]) 334 | goVal = tensorValue 335 | } 336 | case C.ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: 337 | { 338 | tensorValue := make([]float64, int(cVal.length)) 339 | cTensorValue := unsafe.Pointer(cVal.val) 340 | copy(tensorValue, (*[1 << 30]float64)(cTensorValue)[:]) 341 | goVal = tensorValue 342 | } 343 | case C.ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: 344 | { 345 | tensorValue := make([]uint32, int(cVal.length)) 346 | cTensorValue := unsafe.Pointer(cVal.val) 347 | copy(tensorValue, (*[1 << 30]uint32)(cTensorValue)[:]) 348 | goVal = tensorValue 349 | } 350 | case C.ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64: 351 | { 352 | tensorValue := make([]uint64, int(cVal.length)) 353 | cTensorValue := unsafe.Pointer(cVal.val) 354 | copy(tensorValue, (*[1 << 30]uint64)(cTensorValue)[:]) 355 | goVal = tensorValue 356 | } 357 | default: 358 | err = errors.New("invalid data type!") 359 | } 360 | return 361 | } 362 | 363 | // Predict do prediction from input data 364 | func (ortSession *ORTSession) Predict(inputTensorValues []TensorValue) (result []TensorValue, err error) { 365 | if ortSession == nil { 366 | return result, fmt.Errorf("error nil ortSession") 367 | } 368 | 369 | ortValuesInput := ORTValues{ 370 | val: C.ORTValues_New(), 371 | } 372 | 373 | for _, inputTensorValue := range inputTensorValues { 374 | tensorVector, err := newTensorVector(inputTensorValue) 375 | if err != nil { 376 | return nil, err 377 | } 378 | C.ORTValues_AppendTensor(tensorVector, ortValuesInput.val) 379 | } 380 | 381 | output := C.ORTSession_Predict(ortSession.sess, ortValuesInput.val) 382 | outputSize := int(output.length) 383 | tensorValues := make([]C.TensorVector, outputSize) 384 | arrVector := unsafe.Pointer(output.arr_vector) 385 | copy(tensorValues, (*[1 << 30]C.TensorVector)(arrVector)[:]) 386 | 387 | result = make([]TensorValue, outputSize) 388 | for i := 0; i < outputSize; i++ { 389 | goVal, shape, err := cTensorVectorToGo(tensorValues[i]) 390 | if err != nil { 391 | return nil, err 392 | } 393 | result[i] = TensorValue{ 394 | Value: goVal, 395 | Shape: shape, 396 | } 397 | } 398 | C.TensorVectors_Clear(output) 399 | 400 | return result, nil 401 | } 402 | 403 | func (ortSession *ORTSession) Close() error { 404 | C.ORTSession_Free(ortSession.sess) 405 | return nil 406 | } 407 | -------------------------------------------------------------------------------- /core.h: -------------------------------------------------------------------------------- 1 | #ifdef __cplusplus 2 | #include 3 | #include 4 | using namespace std; 5 | 6 | extern "C" { 7 | #endif 8 | #include 9 | 10 | #ifdef __cplusplus 11 | typedef Ort::SessionOptions* ORTSessionOptions; 12 | typedef struct ORTSession{ 13 | Ort::Session* session; 14 | char** input_node_names; 15 | size_t input_node_names_length; 16 | char** output_node_names; 17 | size_t output_node_names_length; 18 | } ORTSession; 19 | typedef Ort::Env* ORTEnv; 20 | typedef std::vector ORTValues; 21 | #else 22 | typedef void* ORTSessionOptions; 23 | typedef struct ORTSession{ 24 | void* session; 25 | char** input_node_names; 26 | char** output_node_names; 27 | size_t input_node_names_length; 28 | size_t output_node_names_length; 29 | } ORTSession; 30 | typedef void* ORTEnv; 31 | typedef void* ORTMemoryInfo; 32 | typedef void* ORTValues; 33 | #endif 34 | typedef struct LongVector{ 35 | long* val; 36 | int length; 37 | } LongVector; 38 | typedef struct TensorVector{ 39 | void* val; 40 | int data_type; 41 | LongVector shape; 42 | int length; 43 | } TensorVector; 44 | typedef struct TensorVectors{ 45 | TensorVector* arr_vector; 46 | int length; 47 | } TensorVectors; 48 | typedef struct CudaOptions{ 49 | int device_id; 50 | int cudnn_conv_algo_search; 51 | int gpu_mem_limit; 52 | int arena_extend_strategy; 53 | int do_copy_in_default_stream; 54 | int has_user_compute_stream; 55 | } CudaOptions; 56 | 57 | ORTSessionOptions ORTSessionOptions_New(); 58 | ORTSession* ORTSession_New(ORTEnv ort_env,char* model_location, ORTSessionOptions session_options); 59 | void ORTSessionOptions_AppendExecutionProvider_CUDA(ORTSessionOptions session_options, CudaOptions cuda_options); 60 | ORTEnv ORTEnv_New(int logging_level, char* log_env); 61 | TensorVectors ORTSession_Predict(ORTSession* session, ORTValues* ort_values_input); 62 | void ORTSession_Free(ORTSession* session); 63 | ORTValues* ORTValues_New(); 64 | void ORTValues_AppendTensor( TensorVector tensor_input, ORTValues* ort_values); 65 | void TensorVectors_Clear(TensorVectors tvs); 66 | 67 | #ifdef __cplusplus 68 | } 69 | #endif -------------------------------------------------------------------------------- /dockerfile/Dockerfile_ubuntu_arm64_example: -------------------------------------------------------------------------------- 1 | FROM ubuntu:20.04 2 | LABEL maintainer="Ivan Suteja " 3 | 4 | ARG GO_VERSION=1.14.15 5 | ARG ONNXRUNTIME_VERSION=1.11.1 6 | ARG ONNX_VERSION=1.12.0 7 | 8 | # DEFINE ALL ENV 9 | ENV GOROOT=/usr/local/go 10 | ENV GOPATH=/workspace/go 11 | ENV GOBIN=$GOPATH/bin 12 | ENV PATH=$GOPATH/bin:$GOROOT/bin:$PATH 13 | ENV LD_LIBRARY_PATH=/usr/local/lib 14 | ENV LIBRARY_PATH=/usr/local/lib 15 | ENV DYLD_LIBRARY_PATH=/usr/local/lib 16 | ENV DEBIAN_FRONTEND="noninteractive" 17 | ENV TZ=Asia/Jakarta 18 | 19 | # UPDATE AND INSTALL ALL REQUIRED DEPENDECIES 20 | RUN apt-get update -y 21 | RUN apt-get install -y wget build-essential 22 | 23 | # INSTALL GO 24 | RUN cd /tmp && \ 25 | wget https://dl.google.com/go/go${GO_VERSION}.linux-arm64.tar.gz && \ 26 | tar -C /usr/local -xvf go${GO_VERSION}.linux-arm64.tar.gz && \ 27 | rm -rf go${GO_VERSION}.linux-arm64.tar.gz && \ 28 | mkdir -p /workspace/go/src && mkdir -p /workspace/go/bin && mkdir -p /workspace/go/pkg 29 | 30 | # INSTALL ONNXRUNTIME 31 | RUN cd /tmp && \ 32 | wget -O onnxruntime.tgz https://github.com/microsoft/onnxruntime/releases/download/v${ONNXRUNTIME_VERSION}/onnxruntime-linux-aarch64-${ONNXRUNTIME_VERSION}.tgz && \ 33 | tar -C /tmp -xvf onnxruntime.tgz && \ 34 | mv onnxruntime-linux-aarch64-${ONNXRUNTIME_VERSION} onnxruntime && \ 35 | rm -rf onnxruntime.tgz && \ 36 | cp -R onnxruntime/lib /usr/local && \ 37 | cp -R onnxruntime/include /usr/local && \ 38 | rm -rf onnxruntime 39 | 40 | # DOWNLOAD MODEL 41 | RUN cd /tmp && mkdir model && cd model && \ 42 | wget -O model.onnx https://github.com/onnx/onnx/raw/v${ONNX_VERSION}/onnx/backend/test/data/node/test_sigmoid/model.onnx 43 | 44 | # COPY REPO go-onnxruntime 45 | COPY . $GOPATH/src/github.com/ivansuteja96/go-onnxruntime 46 | 47 | # RUN go-onnxruntime 48 | RUN cd $GOPATH/src/github.com/ivansuteja96/go-onnxruntime/examples && \ 49 | go run predict_example.go -------------------------------------------------------------------------------- /examples/predict_example.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | _ "image/jpeg" 6 | "log" 7 | "math/rand" 8 | 9 | "github.com/ivansuteja96/go-onnxruntime" 10 | ) 11 | 12 | func main() { 13 | ortEnvDet := onnxruntime.NewORTEnv(onnxruntime.ORT_LOGGING_LEVEL_VERBOSE, "development") 14 | ortDetSO := onnxruntime.NewORTSessionOptions() 15 | 16 | detModel, err := onnxruntime.NewORTSession(ortEnvDet, "/tmp/model/model.onnx", ortDetSO) 17 | if err != nil { 18 | log.Println(err) 19 | return 20 | } 21 | ortEnvDet.Close() 22 | ortDetSO.Close() 23 | defer detModel.Close() 24 | 25 | shape := []int64{3, 4, 5} 26 | input := randFloats(0, 1, int(shape[0]*shape[1]*shape[2])) 27 | 28 | res, err := detModel.Predict([]onnxruntime.TensorValue{ 29 | { 30 | Value: input, 31 | Shape: shape, 32 | }, 33 | }) 34 | if err != nil { 35 | log.Println(err) 36 | return 37 | } 38 | 39 | if len(res) == 0 { 40 | log.Println("Failed get result") 41 | return 42 | } 43 | fmt.Printf("Success do predict, shape : %+v, result : %+v\n", res[0].Shape, res[0].Value) 44 | } 45 | 46 | func randFloats(min, max float32, n int) []float32 { 47 | res := make([]float32, n) 48 | for i := range res { 49 | res[i] = min + rand.Float32()*(max-min) 50 | } 51 | return res 52 | } 53 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/ivansuteja96/go-onnxruntime 2 | 3 | go 1.14 4 | --------------------------------------------------------------------------------