├── .gitignore ├── LICENSE ├── README.md ├── go.mod ├── go.sum ├── labels.go ├── main.go └── ort ├── allocator.cpp ├── allocator.go ├── allocator.h ├── api.cpp ├── api.go ├── api.h ├── custom-op-domain.go ├── custom-op.go ├── environment.cpp ├── environment.go ├── environment.h ├── memory-info.cpp ├── memory-info.go ├── memory-info.h ├── run-options.cpp ├── run-options.go ├── run-options.h ├── session-options.cpp ├── session-options.go ├── session-options.h ├── session.cpp ├── session.go ├── session.h ├── tensor-type-and-shape-info.cpp ├── tensor-type-and-shape-info.go ├── tensor-type-and-shape-info.h ├── type-info.cpp ├── type-info.go ├── type-info.h ├── value.cpp ├── value.go └── value.h /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | resnet/ 3 | squeezenet/ 4 | models/ 5 | images/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2020 David Daniel 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated 4 | documentation files (the "Software"), to deal in the Software without restriction, including without limitation the 5 | rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit 6 | persons to whom the Software is furnished to do so, subject to the following conditions: 7 | 8 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the 9 | Software. 10 | 11 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE 12 | WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 13 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 14 | OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # go-onnx 2 | 3 | Go language bindings for ONNX runtime 4 | 5 | ## About 6 | I'm a fan of Go and have just started digging a bit deeper in to machine learning. I heard about ONNX runtime and I'm 7 | a fan of standardization, so it seemed like a good place to start. I realized ONNX runtime didn't have Go language 8 | bindings, and I figured, if I can get that going, it'd probably be a great way to get started on my AI/ML journey. 9 | 10 | The initial goal was to replicate the functionality of the C example from the ONNX repository, 11 | [here](https://github.com/microsoft/onnxruntime/blob/master/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests.Capi/C_Api_Sample.cpp). 12 | 13 | At this point, the implemented functionality achieves the same result as the example noted above and I've, additionally, 14 | tested it with ResNet on image classification (example below and in main.go). 15 | 16 | The API is incomplete (compared to the functionality available in the C library), at this time. I may try continue to 17 | build it out, as time permits, but would gladly accept help if anybody else is interested in this sort of thing. 18 | 19 | ## Using this library 20 | **Go-onnx** uses *cgo* and leverages the *onnxruntime* shared library, so to run your program which leverages 21 | **go-onnx**, you'll need to let *cgo* know where that library resides on your local system. To do so, in your `main.go` 22 | (or wherever), include something like the following snippet: 23 | 24 | ```go 25 | /* 26 | #cgo LDFLAGS: -L/path/to/onnx/runtime/lib 27 | */ 28 | import "C" 29 | ``` 30 | 31 | The directory specified should contain the `libonnxruntime.so` (named the same). If your ONNX runtime file is named 32 | something different, you may need to include the additional flag `-l`. 33 | 34 | ## Example 35 | For a new application, first get **go-onnx**: 36 | 37 | `go get github.com/dhdanie/goonnx` 38 | 39 | You'll also need to download the example ResNet model from [here](https://s3.amazonaws.com/onnx-model-zoo/resnet/resnet152v2/resnet152v2.onnx). 40 | 41 | Then, you should be able to run a basic demo application like the following (see main.go for working demo): 42 | 43 | ```go 44 | package main 45 | 46 | /* 47 | #cgo LDFLAGS: -L/usr/local/lib/onnx -lonnxruntime 48 | */ 49 | import "C" 50 | ``` 51 | ... 52 | ```go 53 | func classifyResNet(rgbVals []float32) [][]float32 { 54 | defer timeTrack(time.Now(), "classifyResnet") 55 | 56 | logId := "log0001" 57 | 58 | var myCustomLogger ort.CustomLogger = func(severity ort.LoggingLevel, category string, codeLocation string, message string) { 59 | fmt.Printf("Custom Logger %d/%s/%s - %s\n", severity, category, codeLocation, message) 60 | } 61 | 62 | env, err := ort.NewEnvironmentWithCustomLogger(ort.LoggingLevelVerbose, logId, myCustomLogger) 63 | if err != nil { 64 | errorAndExit(err) 65 | } 66 | defer env.ReleaseEnvironment() 67 | 68 | opts := &ort.SessionOptions{ 69 | IntraOpNumThreads: 1, 70 | GraphOptimizationLevel: ort.GraphOptLevelEnableBasic, 71 | SessionLogID: logId, 72 | LogVerbosityLevel: 0, 73 | } 74 | 75 | session, err := ort.NewSession(env, "models/resnet152v2.onnx", opts) 76 | if err != nil { 77 | errorAndExit(err) 78 | } 79 | defer session.ReleaseSession() 80 | 81 | typeInfo, err := session.GetInputTypeInfo(0) 82 | if err != nil { 83 | errorAndExit(err) 84 | } 85 | tensorInfo, err := typeInfo.ToTensorInfo() 86 | if err != nil { 87 | errorAndExit(err) 88 | } 89 | memoryInfo, err := ort.NewCPUMemoryInfo(ort.AllocatorTypeArena, ort.MemTypeDefault) 90 | if err != nil { 91 | errorAndExit(err) 92 | } 93 | defer memoryInfo.ReleaseMemoryInfo() 94 | value, err := ort.NewTensorWithFloatDataAsValue(memoryInfo, "data", rgbVals, tensorInfo) 95 | if err != nil { 96 | errorAndExit(err) 97 | } 98 | inputValues := []ort.Value{ 99 | value, 100 | } 101 | outs, err := session.Run(&ort.RunOptions{}, inputValues) 102 | if err != nil { 103 | errorAndExit(err) 104 | } 105 | outputs := make([][]float32, len(outs)) 106 | for i, out := range outs { 107 | if out.GetName() != "resnetv27_dense0_fwd" { 108 | continue 109 | } 110 | outFloats, err := out.GetTensorMutableFloatData() 111 | if err != nil { 112 | errorAndExit(err) 113 | } 114 | outputs[i] = make([]float32, len(outFloats)) 115 | for j := range outFloats { 116 | outputs[i][j] = outFloats[j] 117 | } 118 | } 119 | 120 | return outputs 121 | } 122 | ``` 123 | 124 | ## License 125 | [MIT License](LICENSE) -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/dhdanie/goonnx 2 | 3 | go 1.14 4 | 5 | require github.com/disintegration/imaging v1.6.2 6 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/disintegration/imaging v1.6.2 h1:w1LecBlG2Lnp8B3jk5zSuNqd7b4DXhcjwek1ei82L+c= 2 | github.com/disintegration/imaging v1.6.2/go.mod h1:44/5580QXChDfwIclfc/PCwrr44amcmDAg8hxG0Ewe4= 3 | golang.org/x/image v0.0.0-20191009234506-e7c1f5e7dbb8 h1:hVwzHzIUGRjiF7EcUjqNxk3NCfkPxbDKRdnNE1Rpg0U= 4 | golang.org/x/image v0.0.0-20191009234506-e7c1f5e7dbb8/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= 5 | golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= 6 | -------------------------------------------------------------------------------- /labels.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "encoding/json" 5 | "io/ioutil" 6 | ) 7 | 8 | func LoadLabels(classFile string) (map[int]string, error) { 9 | data, err := ioutil.ReadFile(classFile) 10 | if err != nil { 11 | return nil, err 12 | } 13 | var result map[int]string 14 | err = json.Unmarshal(data, &result) 15 | if err != nil { 16 | return nil, err 17 | } 18 | return result, nil 19 | } 20 | -------------------------------------------------------------------------------- /main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | /* 4 | #cgo LDFLAGS: -L/usr/local/lib/onnx -lonnxruntime 5 | */ 6 | import "C" 7 | import ( 8 | "fmt" 9 | "github.com/dhdanie/goonnx/ort" 10 | "github.com/disintegration/imaging" 11 | "log" 12 | "math" 13 | "os" 14 | "sort" 15 | "time" 16 | ) 17 | 18 | func main() { 19 | defer timeTrack(time.Now(), "main") 20 | 21 | //rgbs := preprocessImage("images/kitten.jpg") 22 | //rgbs := preprocessImage("images/dog.jpg") 23 | rgbs := preprocessImage("images/white-dog.jpg") 24 | //rgbs := preprocessImage("images/car.jpg") 25 | 26 | outputs := classifyResNet(rgbs) 27 | for _, output := range outputs { 28 | scores := NewScoresFromResults(output) 29 | scores = Softmax(scores) 30 | sort.Slice(scores, func(i, j int) bool { 31 | return scores[i].Score() > scores[j].Score() 32 | }) 33 | labels, _ := LoadLabels("models/imagenet1000_clsidx_to_labels.txt") 34 | for i := 0; i < 5; i++ { 35 | if labels != nil { 36 | label := labels[scores[i].ClassIndex()] 37 | fmt.Printf("%f: %s\n", scores[i].Score(), label) 38 | } else { 39 | fmt.Printf("%s\n", scores[i]) 40 | } 41 | } 42 | } 43 | } 44 | 45 | func timeTrack(start time.Time, name string) { 46 | elapsed := time.Since(start) 47 | log.Printf("%s took %s", name, elapsed) 48 | } 49 | 50 | func errorAndExit(err error) { 51 | _, _ = fmt.Fprintf(os.Stderr, "Error: %s\n", err.Error()) 52 | os.Exit(1) 53 | } 54 | 55 | func Transpose(rgbs []float32) []float32 { 56 | defer timeTrack(time.Now(), "Transpose") 57 | 58 | out := make([]float32, len(rgbs)) 59 | channelLength := len(rgbs) / 3 60 | for i := 0; i < channelLength; i++ { 61 | out[i] = rgbs[i*3] 62 | out[i+channelLength] = rgbs[i*3+1] 63 | out[i+channelLength*2] = rgbs[i*3+2] 64 | } 65 | return out 66 | } 67 | 68 | func preprocessImage(imageFile string) []float32 { 69 | defer timeTrack(time.Now(), "preprocessImage") 70 | 71 | src, err := imaging.Open(imageFile) 72 | if err != nil { 73 | errorAndExit(err) 74 | } 75 | 76 | rgbs := make([]float32, 224*224*3) 77 | 78 | result := imaging.Resize(src, 256, 256, imaging.Lanczos) 79 | result = imaging.CropAnchor(result, 224, 224, imaging.Center) 80 | j := 0 81 | for i := range result.Pix { 82 | if (i+1)%4 != 0 { 83 | rgbs[j] = float32(result.Pix[i]) 84 | j++ 85 | } 86 | } 87 | 88 | rgbs = Transpose(rgbs) 89 | channelLength := len(rgbs) / 3 90 | for i := 0; i < channelLength; i++ { 91 | rgbs[i] = normalize(rgbs[i]/255, 0.485, 0.229) 92 | rgbs[i+channelLength] = normalize(rgbs[i+channelLength]/255, 0.456, 0.224) 93 | rgbs[i+channelLength*2] = normalize(rgbs[i+channelLength*2]/255, 0.406, 0.225) 94 | } 95 | return rgbs 96 | } 97 | 98 | func normalize(in float32, m float32, s float32) float32 { 99 | return (in - m) / s 100 | } 101 | 102 | func Softmax(in []ClassScore) []ClassScore { 103 | defer timeTrack(time.Now(), "Softmax") 104 | 105 | out := make([]ClassScore, len(in)) 106 | 107 | inMax := max(in) 108 | var sum float32 = 0.0 109 | for i, val := range in { 110 | out[i] = &classScore{ 111 | classIndex: val.ClassIndex(), 112 | score: float32(math.Exp(float64(val.Score() - inMax.Score()))), 113 | } 114 | sum += out[i].Score() 115 | } 116 | 117 | for i, val := range out { 118 | out[i] = &classScore{ 119 | classIndex: val.ClassIndex(), 120 | score: RoundFloat32(val.Score() / sum), 121 | } 122 | } 123 | return out 124 | } 125 | 126 | func RoundFloat32(in float32) float32 { 127 | f64in := float64(in) 128 | return float32(math.Round(f64in*10000000) / 10000000) 129 | } 130 | 131 | func max(in []ClassScore) ClassScore { 132 | defer timeTrack(time.Now(), "max") 133 | 134 | var maxVal float32 = 0.0 135 | maxIndex := -1 136 | for i, val := range in { 137 | if val.Score() > maxVal { 138 | maxVal = val.Score() 139 | maxIndex = i 140 | } 141 | } 142 | return in[maxIndex] 143 | } 144 | 145 | func classifyResNet(rgbVals []float32) [][]float32 { 146 | defer timeTrack(time.Now(), "classifyResnet") 147 | 148 | logId := "log0001" 149 | 150 | var myCustomLogger ort.CustomLogger = func(severity ort.LoggingLevel, category string, codeLocation string, message string) { 151 | fmt.Printf("Custom Logger %d/%s/%s - %s\n", severity, category, codeLocation, message) 152 | } 153 | 154 | env, err := ort.NewEnvironmentWithCustomLogger(ort.LoggingLevelError, logId, myCustomLogger) 155 | //env, err := ort.NewEnvironment(ort.LoggingLevelVerbose, "abcde") 156 | if err != nil { 157 | errorAndExit(err) 158 | } 159 | defer env.ReleaseEnvironment() 160 | 161 | opts := &ort.SessionOptions{ 162 | IntraOpNumThreads: 1, 163 | GraphOptimizationLevel: ort.GraphOptLevelEnableBasic, 164 | SessionLogID: logId, 165 | LogVerbosityLevel: 0, 166 | } 167 | 168 | session, err := ort.NewSession(env, "models/resnet152v2.onnx", opts) 169 | if err != nil { 170 | errorAndExit(err) 171 | } 172 | defer session.ReleaseSession() 173 | 174 | typeInfo, err := session.GetInputTypeInfo(0) 175 | if err != nil { 176 | errorAndExit(err) 177 | } 178 | tensorInfo, err := typeInfo.ToTensorInfo() 179 | if err != nil { 180 | errorAndExit(err) 181 | } 182 | memoryInfo, err := ort.NewCPUMemoryInfo(ort.AllocatorTypeArena, ort.MemTypeDefault) 183 | if err != nil { 184 | errorAndExit(err) 185 | } 186 | defer memoryInfo.ReleaseMemoryInfo() 187 | value, err := ort.NewTensorWithFloatDataAsValue(memoryInfo, "data", rgbVals, tensorInfo) 188 | if err != nil { 189 | errorAndExit(err) 190 | } 191 | inputValues := []ort.Value{ 192 | value, 193 | } 194 | outs, err := session.Run(&ort.RunOptions{}, inputValues) 195 | if err != nil { 196 | errorAndExit(err) 197 | } 198 | outputs := make([][]float32, len(outs)) 199 | for i, out := range outs { 200 | if out.GetName() != "resnetv27_dense0_fwd" { 201 | continue 202 | } 203 | outFloats, err := out.GetTensorMutableFloatData() 204 | if err != nil { 205 | errorAndExit(err) 206 | } 207 | outputs[i] = make([]float32, len(outFloats)) 208 | for j := range outFloats { 209 | outputs[i][j] = outFloats[j] 210 | } 211 | } 212 | 213 | return outputs 214 | } 215 | 216 | type ClassScore interface { 217 | ClassIndex() int 218 | Score() float32 219 | Equals(other ClassScore) bool 220 | } 221 | 222 | type classScore struct { 223 | classIndex int 224 | score float32 225 | } 226 | 227 | func (s *classScore) ClassIndex() int { 228 | return s.classIndex 229 | } 230 | 231 | func (s *classScore) Score() float32 { 232 | return s.score 233 | } 234 | 235 | func (s *classScore) Equals(other ClassScore) bool { 236 | if s.score == other.Score() && s.classIndex == other.ClassIndex() { 237 | return true 238 | } 239 | return false 240 | } 241 | 242 | func (s classScore) String() string { 243 | return fmt.Sprintf("Class: %d, Score %f", s.classIndex, s.score) 244 | } 245 | 246 | func NewScoresFromResults(results []float32) []ClassScore { 247 | var scores []ClassScore 248 | for i, result := range results { 249 | scores = append(scores, &classScore{ 250 | classIndex: i, 251 | score: result, 252 | }) 253 | } 254 | return scores 255 | } 256 | -------------------------------------------------------------------------------- /ort/allocator.cpp: -------------------------------------------------------------------------------- 1 | extern "C" { 2 | #include 3 | #include "allocator.h" 4 | 5 | GetAllocatorResponse getAllocatorWithDefaultOptions(OrtApi *api) { 6 | OrtAllocator *allocator; 7 | OrtStatus *status; 8 | 9 | status = api->GetAllocatorWithDefaultOptions(&allocator); 10 | 11 | GetAllocatorResponse response; 12 | response.allocator = allocator; 13 | response.status = status; 14 | 15 | return response; 16 | } 17 | } -------------------------------------------------------------------------------- /ort/allocator.go: -------------------------------------------------------------------------------- 1 | package ort 2 | 3 | /* 4 | #include 5 | #include "allocator.h" 6 | */ 7 | import "C" 8 | 9 | type allocator struct { 10 | a *C.OrtAllocator 11 | } 12 | 13 | func newAllocatorWithDefaultOptions() (*allocator, error) { 14 | response := C.getAllocatorWithDefaultOptions(ortApi.ort) 15 | err := ortApi.ParseStatus(response.status) 16 | if err != nil { 17 | return nil, err 18 | } 19 | 20 | return &allocator{a: response.allocator}, nil 21 | } 22 | -------------------------------------------------------------------------------- /ort/allocator.h: -------------------------------------------------------------------------------- 1 | #ifndef GOONNX_ORT_ALLOCATOR 2 | #define GOONNX_ORT_ALLOCATOR 3 | #include 4 | 5 | typedef struct GetAllocatorResponse { 6 | OrtAllocator *allocator; 7 | OrtStatus *status; 8 | } GetAllocatorResponse; 9 | 10 | GetAllocatorResponse getAllocatorWithDefaultOptions(OrtApi *api); 11 | 12 | #endif -------------------------------------------------------------------------------- /ort/api.cpp: -------------------------------------------------------------------------------- 1 | extern "C" { 2 | #include 3 | #include 4 | #include "api.h" 5 | 6 | const OrtApi* getApi() { 7 | return OrtGetApiBase()->GetApi(ORT_API_VERSION); 8 | } 9 | 10 | const char* parseStatus(OrtApi* api, OrtStatus* status) { 11 | if(status != NULL) { 12 | const char* msg = api->GetErrorMessage(status); 13 | char *copy = (char *)malloc(strlen(msg)); 14 | strcpy(copy, msg); 15 | api->ReleaseStatus(status); 16 | return copy; 17 | } 18 | return NULL; 19 | } 20 | } -------------------------------------------------------------------------------- /ort/api.go: -------------------------------------------------------------------------------- 1 | package ort 2 | 3 | /* 4 | #cgo LDFLAGS: -L/usr/local/lib/onnx -lonnxruntime 5 | #include 6 | #include "api.h" 7 | */ 8 | import "C" 9 | import ( 10 | "fmt" 11 | "unsafe" 12 | ) 13 | 14 | type api struct { 15 | ort *C.OrtApi 16 | } 17 | 18 | var ortApi = newApi() 19 | 20 | func newApi() *api { 21 | return &api{ 22 | ort: C.getApi(), 23 | } 24 | } 25 | 26 | func (a *api) ParseStatus(status *C.OrtStatus) error { 27 | if status == nil { 28 | return nil 29 | } 30 | 31 | cMessage := C.parseStatus(a.ort, status) 32 | defer C.free(unsafe.Pointer(cMessage)) 33 | var message string 34 | message = C.GoString(cMessage) 35 | 36 | return fmt.Errorf("%s", message) 37 | } 38 | -------------------------------------------------------------------------------- /ort/api.h: -------------------------------------------------------------------------------- 1 | #ifndef GOONNX_ORT_API 2 | #define GOONNX_ORT_API 3 | #include 4 | 5 | const OrtApi* getApi(); 6 | const char* parseStatus(OrtApi* api, OrtStatus* status); 7 | 8 | #endif -------------------------------------------------------------------------------- /ort/custom-op-domain.go: -------------------------------------------------------------------------------- 1 | package ort 2 | 3 | /* 4 | #include 5 | */ 6 | import "C" 7 | 8 | type CustomOpDomain interface { 9 | //TODO 10 | AddCustomOp(op CustomOp) error 11 | toCCustomOpDomain() *C.OrtCustomOpDomain 12 | } 13 | 14 | func CreateCustomOpDomain(domain string) (CustomOpDomain, error) { 15 | //TODO 16 | return nil, nil 17 | } 18 | -------------------------------------------------------------------------------- /ort/custom-op.go: -------------------------------------------------------------------------------- 1 | package ort 2 | 3 | type CustomOp interface { 4 | //TODO 5 | } 6 | -------------------------------------------------------------------------------- /ort/environment.cpp: -------------------------------------------------------------------------------- 1 | extern "C" { 2 | #include 3 | #include "environment.h" 4 | 5 | void logCustomWrapper(void *params, OrtLoggingLevel severity, const char *category, const char *logId, const char *codeLocation, const char *message) { 6 | logCustom(params, severity, (char *)category, (char *)logId, (char *)codeLocation, (char *) message); 7 | } 8 | 9 | OrtCreateEnvResponse createEnv(OrtApi* api, OrtLoggingLevel level, char* logId) { 10 | OrtEnv* env; 11 | OrtStatus* status; 12 | 13 | status = api->CreateEnv(level, logId, &env); 14 | 15 | OrtCreateEnvResponse response; 16 | response.env = env; 17 | response.status = status; 18 | 19 | return response; 20 | } 21 | 22 | OrtCreateEnvResponse createEnvWithCustomLogger(OrtApi* api, void *params, OrtLoggingLevel level, char* logId) { 23 | OrtEnv* env; 24 | OrtStatus* status; 25 | 26 | status = api->CreateEnvWithCustomLogger(logCustomWrapper, params, level, logId, &env); 27 | 28 | OrtCreateEnvResponse response; 29 | response.env = env; 30 | response.status = status; 31 | 32 | return response; 33 | } 34 | 35 | void releaseEnv(OrtApi* api, OrtEnv* env) { 36 | api->ReleaseEnv(env); 37 | } 38 | } -------------------------------------------------------------------------------- /ort/environment.go: -------------------------------------------------------------------------------- 1 | package ort 2 | 3 | /* 4 | #include 5 | #include "environment.h" 6 | */ 7 | import "C" 8 | import ( 9 | "fmt" 10 | "sync" 11 | "unsafe" 12 | ) 13 | 14 | type LoggingLevel int 15 | 16 | const ( 17 | LoggingLevelVerbose LoggingLevel = 0 18 | LoggingLevelInfo LoggingLevel = 1 19 | LoggingLevelWarning LoggingLevel = 2 20 | LoggingLevelError LoggingLevel = 3 21 | LoggingLevelFatal LoggingLevel = 4 22 | ) 23 | 24 | type CustomLogger func(severity LoggingLevel, category string, codeLocation string, message string) 25 | type cCustomLogger func(params unsafe.Pointer, severity C.OrtLoggingLevel, category *C.char, logId *C.char, codeLocation *C.char, message *C.char) 26 | 27 | type Environment interface { 28 | ReleaseEnvironment() 29 | } 30 | 31 | type environment struct { 32 | env *C.OrtEnv 33 | cLogId *C.char 34 | logger CustomLogger 35 | } 36 | 37 | var mu sync.Mutex 38 | var customLoggers = make(map[string]cCustomLogger) 39 | 40 | func NewEnvironment(loggingLevel LoggingLevel, logId string) (Environment, error) { 41 | logLevel, err := getOrtLoggingLevelForLoggingLevel(loggingLevel) 42 | if err != nil { 43 | return nil, err 44 | } 45 | 46 | cLogId := C.CString(logId) 47 | 48 | response := C.createEnv(ortApi.ort, logLevel, cLogId) 49 | err = ortApi.ParseStatus(response.status) 50 | if err != nil { 51 | return nil, err 52 | } 53 | 54 | return &environment{ 55 | env: response.env, 56 | cLogId: cLogId, 57 | logger: nil, 58 | }, nil 59 | } 60 | 61 | func NewEnvironmentWithCustomLogger(loggingLevel LoggingLevel, logId string, logger CustomLogger) (Environment, error) { 62 | logLevel, err := getOrtLoggingLevelForLoggingLevel(loggingLevel) 63 | if err != nil { 64 | return nil, err 65 | } 66 | 67 | cLogId := C.CString(logId) 68 | 69 | response := C.createEnvWithCustomLogger(ortApi.ort, nil, logLevel, cLogId) 70 | err = ortApi.ParseStatus(response.status) 71 | if err != nil { 72 | return nil, err 73 | } 74 | 75 | env := &environment{ 76 | env: response.env, 77 | cLogId: cLogId, 78 | logger: logger, 79 | } 80 | register(logId, env) 81 | return env, nil 82 | } 83 | 84 | func getOrtLoggingLevelForLoggingLevel(loggingLevel LoggingLevel) (C.OrtLoggingLevel, error) { 85 | switch loggingLevel { 86 | case LoggingLevelVerbose: 87 | return C.ORT_LOGGING_LEVEL_VERBOSE, nil 88 | case LoggingLevelInfo: 89 | return C.ORT_LOGGING_LEVEL_INFO, nil 90 | case LoggingLevelWarning: 91 | return C.ORT_LOGGING_LEVEL_WARNING, nil 92 | case LoggingLevelError: 93 | return C.ORT_LOGGING_LEVEL_ERROR, nil 94 | case LoggingLevelFatal: 95 | return C.ORT_LOGGING_LEVEL_FATAL, nil 96 | } 97 | return 0, fmt.Errorf("invalid logging level %d", loggingLevel) 98 | } 99 | 100 | func getLoggingLeveForOrtLoggingLevel(ortLoggingLevel C.OrtLoggingLevel) (LoggingLevel, error) { 101 | switch ortLoggingLevel { 102 | case C.ORT_LOGGING_LEVEL_VERBOSE: 103 | return LoggingLevelVerbose, nil 104 | case C.ORT_LOGGING_LEVEL_INFO: 105 | return LoggingLevelInfo, nil 106 | case C.ORT_LOGGING_LEVEL_WARNING: 107 | return LoggingLevelWarning, nil 108 | case C.ORT_LOGGING_LEVEL_ERROR: 109 | return LoggingLevelError, nil 110 | case C.ORT_LOGGING_LEVEL_FATAL: 111 | return LoggingLevelFatal, nil 112 | } 113 | return 0, fmt.Errorf("invalid ORT logging level %d", int(ortLoggingLevel)) 114 | } 115 | 116 | func (e *environment) logCustom(params unsafe.Pointer, severity C.OrtLoggingLevel, category *C.char, logId *C.char, codeLocation *C.char, message *C.char) { 117 | if e.logger != nil { 118 | level, err := getLoggingLeveForOrtLoggingLevel(severity) 119 | if err != nil { 120 | level = LoggingLevelError 121 | } 122 | cat := C.GoString(category) 123 | loc := C.GoString(codeLocation) 124 | msg := C.GoString(message) 125 | 126 | e.logger(level, cat, loc, msg) 127 | } 128 | } 129 | 130 | func (e *environment) ReleaseEnvironment() { 131 | C.releaseEnv(ortApi.ort, e.env) 132 | C.free(unsafe.Pointer(e.cLogId)) 133 | } 134 | 135 | //export logCustom 136 | func logCustom(params unsafe.Pointer, severity C.OrtLoggingLevel, category *C.char, logId *C.char, codeLocation *C.char, message *C.char) { 137 | sLogId := C.GoString(logId) 138 | f := lookup(sLogId) 139 | if f != nil { 140 | f(params, severity, category, logId, codeLocation, message) 141 | } 142 | } 143 | 144 | func register(logId string, env *environment) { 145 | mu.Lock() 146 | defer mu.Unlock() 147 | 148 | customLoggers[logId] = env.logCustom 149 | } 150 | 151 | func lookup(logId string) cCustomLogger { 152 | mu.Lock() 153 | defer mu.Unlock() 154 | 155 | logger := customLoggers[logId] 156 | if logger == nil { 157 | return nil 158 | } 159 | return logger 160 | } 161 | -------------------------------------------------------------------------------- /ort/environment.h: -------------------------------------------------------------------------------- 1 | #ifndef GOONNX_ORT_ENVIRONMENT 2 | #define GOONNX_ORT_ENVIRONMENT 3 | 4 | #include 5 | 6 | typedef struct OrtCreateEnvResponse { 7 | const OrtEnv *env; 8 | OrtStatus *status; 9 | } OrtCreateEnvResponse; 10 | 11 | OrtCreateEnvResponse createEnv(OrtApi* api, OrtLoggingLevel level, char* logId); 12 | OrtCreateEnvResponse createEnvWithCustomLogger(OrtApi *api, void *params, OrtLoggingLevel level, char *logId); 13 | void releaseEnv(OrtApi* api, OrtEnv* env); 14 | extern void logCustom(void *param, OrtLoggingLevel severity, char *category, char *logId, char *codeLocation, char *message); 15 | 16 | #endif -------------------------------------------------------------------------------- /ort/memory-info.cpp: -------------------------------------------------------------------------------- 1 | extern "C" { 2 | #include 3 | #include "memory-info.h" 4 | 5 | OrtCreateCpuMemoryInfoResponse createCpuMemoryInfo(OrtApi *api, OrtAllocatorType allocatorType, OrtMemType memType) { 6 | OrtMemoryInfo *memoryInfo; 7 | OrtStatus *status; 8 | 9 | status = api->CreateCpuMemoryInfo(allocatorType, memType, &memoryInfo); 10 | 11 | OrtCreateCpuMemoryInfoResponse response; 12 | response.memoryInfo = memoryInfo; 13 | response.status = status; 14 | 15 | return response; 16 | } 17 | 18 | void releaseMemoryInfo(OrtApi *api, OrtMemoryInfo *memoryInfo) { 19 | api->ReleaseMemoryInfo(memoryInfo); 20 | } 21 | } -------------------------------------------------------------------------------- /ort/memory-info.go: -------------------------------------------------------------------------------- 1 | package ort 2 | 3 | /* 4 | #include 5 | #include "memory-info.h" 6 | */ 7 | import "C" 8 | import "fmt" 9 | 10 | type AllocatorType int 11 | 12 | const ( 13 | AllocatorTypeInvalid AllocatorType = -1 14 | AllocatorTypeDevice AllocatorType = 0 15 | AllocatorTypeArena AllocatorType = 1 16 | ) 17 | 18 | type MemType int 19 | 20 | const ( 21 | MemTypeCPUInput MemType = -2 22 | MemTypeCPUOutput MemType = -1 23 | MemTypeCPU MemType = MemTypeCPUOutput 24 | MemTypeDefault MemType = 0 25 | ) 26 | 27 | type MemoryInfo interface { 28 | ReleaseMemoryInfo() 29 | } 30 | 31 | type memoryInfo struct { 32 | allocatorType AllocatorType 33 | memType MemType 34 | cMemoryInfo *C.OrtMemoryInfo 35 | } 36 | 37 | func NewCPUMemoryInfo(allocatorType AllocatorType, memType MemType) (MemoryInfo, error) { 38 | cAllocatorType, err := getCAllocatorTypeForAllocatorType(allocatorType) 39 | if err != nil { 40 | return nil, err 41 | } 42 | cMemType, err := getCMemTypeForMemType(memType) 43 | if err != nil { 44 | return nil, err 45 | } 46 | 47 | response := C.createCpuMemoryInfo(ortApi.ort, cAllocatorType, cMemType) 48 | err = ortApi.ParseStatus(response.status) 49 | if err != nil { 50 | return nil, err 51 | } 52 | 53 | return &memoryInfo{ 54 | allocatorType: allocatorType, 55 | memType: memType, 56 | cMemoryInfo: response.memoryInfo, 57 | }, nil 58 | } 59 | 60 | func (i *memoryInfo) ReleaseMemoryInfo() { 61 | C.releaseMemoryInfo(ortApi.ort, i.cMemoryInfo) 62 | } 63 | 64 | func getCAllocatorTypeForAllocatorType(allocatorType AllocatorType) (C.OrtAllocatorType, error) { 65 | switch allocatorType { 66 | case AllocatorTypeInvalid: 67 | return C.Invalid, nil 68 | case AllocatorTypeDevice: 69 | return C.OrtDeviceAllocator, nil 70 | case AllocatorTypeArena: 71 | return C.OrtArenaAllocator, nil 72 | } 73 | return C.Invalid, fmt.Errorf("invalid allocator type %d", allocatorType) 74 | } 75 | 76 | func getCMemTypeForMemType(memType MemType) (C.OrtMemType, error) { 77 | switch memType { 78 | case MemTypeCPUInput: 79 | return C.OrtMemTypeCPUInput, nil 80 | case MemTypeCPUOutput: 81 | return C.OrtMemTypeCPUOutput, nil 82 | case MemTypeDefault: 83 | return C.OrtMemTypeDefault, nil 84 | } 85 | return -3, fmt.Errorf("invalid memory type %d", memType) 86 | } 87 | -------------------------------------------------------------------------------- /ort/memory-info.h: -------------------------------------------------------------------------------- 1 | #ifndef GOONNX_ORT_MEMORY_INFO 2 | #define GOONNX_ORT_MEMORY_INFO 3 | 4 | #include 5 | 6 | typedef struct OrtCreateCpuMemoryInfoResponse { 7 | OrtMemoryInfo *memoryInfo; 8 | OrtStatus *status; 9 | } OrtCreateCpuMemoryInfoResponse; 10 | 11 | OrtCreateCpuMemoryInfoResponse createCpuMemoryInfo(OrtApi *api, OrtAllocatorType allocatorType, OrtMemType memType); 12 | void releaseMemoryInfo(OrtApi *api, OrtMemoryInfo *memoryInfo); 13 | 14 | #endif -------------------------------------------------------------------------------- /ort/run-options.cpp: -------------------------------------------------------------------------------- 1 | extern "C" { 2 | #include 3 | #include "run-options.h" 4 | 5 | OrtCreateRunOptionsResponse createRunOptions(OrtApi *api, OrtCreateRunOptionsParameters *params) { 6 | OrtStatus *status; 7 | OrtRunOptions *options; 8 | 9 | status = api->CreateRunOptions(&options); 10 | if(status != NULL){ 11 | return respondRunOptionsErrorStatus(status); 12 | } 13 | 14 | if(params->tag != NULL) { 15 | status = api->RunOptionsSetRunTag(options, params->tag); 16 | if(status != NULL) { 17 | return releaseRunOptionsAndRespondErrorStatus(api, options, status); 18 | } 19 | } 20 | if(params->logVerbosityLevel > 0) { 21 | status = api->RunOptionsSetRunLogVerbosityLevel(options, params->logVerbosityLevel); 22 | if(status != NULL) { 23 | return releaseRunOptionsAndRespondErrorStatus(api, options, status); 24 | } 25 | } 26 | if(params->logSeverityLevel > 0) { 27 | status = api->RunOptionsSetRunLogSeverityLevel(options, params->logSeverityLevel); 28 | if(status != NULL) { 29 | return releaseRunOptionsAndRespondErrorStatus(api, options, status); 30 | } 31 | } 32 | if(params->terminate == 1) { 33 | status = api->RunOptionsSetTerminate(options); 34 | if(status != NULL) { 35 | return releaseRunOptionsAndRespondErrorStatus(api, options, status); 36 | } 37 | } 38 | 39 | OrtCreateRunOptionsResponse response; 40 | response.runOptions = options; 41 | response.status = NULL; 42 | 43 | return response; 44 | } 45 | 46 | OrtCreateRunOptionsResponse releaseRunOptionsAndRespondErrorStatus(OrtApi *api, OrtRunOptions *runOptions, OrtStatus *status) { 47 | api->ReleaseRunOptions(runOptions); 48 | return respondRunOptionsErrorStatus(status); 49 | } 50 | 51 | OrtCreateRunOptionsResponse respondRunOptionsErrorStatus(OrtStatus *status) { 52 | OrtCreateRunOptionsResponse response; 53 | 54 | response.status = status; 55 | response.runOptions = NULL; 56 | 57 | return response; 58 | } 59 | } -------------------------------------------------------------------------------- /ort/run-options.go: -------------------------------------------------------------------------------- 1 | package ort 2 | 3 | /* 4 | #include 5 | #include "run-options.h" 6 | */ 7 | import "C" 8 | import "unsafe" 9 | 10 | type RunOptions struct { 11 | Tag string 12 | LogVerbosityLevel int 13 | LogSeverityLevel int 14 | Terminate bool 15 | } 16 | 17 | type ortRunOptions struct { 18 | cRunOptions *C.OrtRunOptions 19 | } 20 | 21 | func (o *RunOptions) toOrtRunOptions() (*C.OrtRunOptions, error) { 22 | roParams := C.OrtCreateRunOptionsParameters{} 23 | if len(o.Tag) > 0 { 24 | roParams.tag = C.CString(o.Tag) 25 | defer C.free(unsafe.Pointer(roParams.tag)) 26 | } 27 | roParams.logVerbosityLevel = C.int(o.LogSeverityLevel) 28 | roParams.logSeverityLevel = C.int(o.LogSeverityLevel) 29 | if o.Terminate { 30 | roParams.terminate = C.int(1) 31 | } else { 32 | roParams.terminate = C.int(0) 33 | } 34 | 35 | response := C.createRunOptions(ortApi.ort, &roParams) 36 | err := ortApi.ParseStatus(response.status) 37 | if err != nil { 38 | return nil, err 39 | } 40 | return response.runOptions, nil 41 | } 42 | -------------------------------------------------------------------------------- /ort/run-options.h: -------------------------------------------------------------------------------- 1 | #ifndef GOONNX_RUN_OPTIONS 2 | #define GOONNX_RUN_OPTIONS 3 | 4 | #include 5 | typedef struct OrtCreateRunOptionsParameters { 6 | const char *tag; 7 | int logVerbosityLevel; 8 | int logSeverityLevel; 9 | int terminate; 10 | } OrtCreateRunOptionsParameters; 11 | 12 | typedef struct OrtCreateRunOptionsResponse { 13 | OrtRunOptions *runOptions; 14 | OrtStatus *status; 15 | } OrtCreateRunOptionsResponse; 16 | 17 | OrtCreateRunOptionsResponse createRunOptions(OrtApi *api, OrtCreateRunOptionsParameters *params); 18 | OrtCreateRunOptionsResponse releaseRunOptionsAndRespondErrorStatus(OrtApi *api, OrtRunOptions *runOptions, OrtStatus *status); 19 | OrtCreateRunOptionsResponse respondRunOptionsErrorStatus(OrtStatus *status); 20 | 21 | #endif -------------------------------------------------------------------------------- /ort/session-options.cpp: -------------------------------------------------------------------------------- 1 | extern "C" { 2 | #include 3 | #include "session-options.h" 4 | 5 | OrtCreateSessionOptionsResponse createSessionOptions(OrtApi *api, OrtCreateSessionOptionsParams *params) { 6 | OrtStatus *status; 7 | OrtSessionOptions *sessionOptions; 8 | 9 | status = api->CreateSessionOptions(&sessionOptions); 10 | if(status != NULL) { 11 | return respondErrorStatus(status); 12 | } 13 | 14 | if(params->optimizedModelFilePath != NULL) { 15 | status = api->SetOptimizedModelFilePath(sessionOptions, params->optimizedModelFilePath); 16 | if(status != NULL) { 17 | return releaseAndRespondErrorStatus(api, sessionOptions, status); 18 | } 19 | } 20 | if(params->executionMode != 0) { 21 | status = api->SetSessionExecutionMode(sessionOptions, params->executionMode); 22 | if(status != NULL) { 23 | return releaseAndRespondErrorStatus(api, sessionOptions, status); 24 | } 25 | } 26 | if(params->profilingEnabled == 1 && params->profileFilePrefix != NULL) { 27 | status = api->EnableProfiling(sessionOptions, params->profileFilePrefix); 28 | if(status != NULL) { 29 | return releaseAndRespondErrorStatus(api, sessionOptions, status); 30 | } 31 | } 32 | if(params->memPatternEnabled == 1) { 33 | status = api->EnableMemPattern(sessionOptions); 34 | if(status != NULL) { 35 | return releaseAndRespondErrorStatus(api, sessionOptions, status); 36 | } 37 | } 38 | if(params->cpuMemArenaEnabled == 1) { 39 | status = api->EnableCpuMemArena(sessionOptions); 40 | if(status != NULL) { 41 | return releaseAndRespondErrorStatus(api, sessionOptions, status); 42 | } 43 | } 44 | if(params->logId != NULL) { 45 | status = api->SetSessionLogId(sessionOptions, params->logId); 46 | if(status != NULL) { 47 | return releaseAndRespondErrorStatus(api, sessionOptions, status); 48 | } 49 | } 50 | if(params->logVerbosityLevel > 0) { 51 | status = api->SetSessionLogVerbosityLevel(sessionOptions, params->logVerbosityLevel); 52 | if(status != NULL) { 53 | return releaseAndRespondErrorStatus(api, sessionOptions, status); 54 | } 55 | } 56 | if(params->logSeverityLevel > 0) { 57 | status = api->SetSessionLogSeverityLevel(sessionOptions, params->logSeverityLevel); 58 | if(status != NULL) { 59 | return releaseAndRespondErrorStatus(api, sessionOptions, status); 60 | } 61 | } 62 | if(params->graphOptimizationLevel != DefaultGraphOptimizationLevel) { 63 | status = api->SetSessionGraphOptimizationLevel(sessionOptions, params->graphOptimizationLevel); 64 | if(status != NULL) { 65 | return releaseAndRespondErrorStatus(api, sessionOptions, status); 66 | } 67 | } 68 | if(params->intraOpNumThreads != DefaultIntraOpNumThreads) { 69 | status = api->SetIntraOpNumThreads(sessionOptions, params->intraOpNumThreads); 70 | if(status != NULL) { 71 | return releaseAndRespondErrorStatus(api, sessionOptions, status); 72 | } 73 | } 74 | if(params->interOpNumThreads != DefaultInterOpNumThreads) { 75 | status = api->SetInterOpNumThreads(sessionOptions, params->interOpNumThreads); 76 | if(status != NULL) { 77 | return releaseAndRespondErrorStatus(api, sessionOptions, status); 78 | } 79 | } 80 | if(params->numCustomOpDomains > 0 && params->customOpDomains != NULL) { 81 | for(int i = 0; i < params->numCustomOpDomains; i++) { 82 | status = api->AddCustomOpDomain(sessionOptions, params->customOpDomains[i]); 83 | if(status != NULL) { 84 | return releaseAndRespondErrorStatus(api, sessionOptions, status); 85 | } 86 | } 87 | } 88 | 89 | OrtCreateSessionOptionsResponse response; 90 | response.sessionOptions = sessionOptions; 91 | response.status = NULL; 92 | return response; 93 | } 94 | 95 | OrtCreateSessionOptionsResponse releaseAndRespondErrorStatus(OrtApi *api, OrtSessionOptions *sessionOptions, OrtStatus *status) { 96 | api->ReleaseSessionOptions(sessionOptions); 97 | return respondErrorStatus(status); 98 | } 99 | 100 | OrtCreateSessionOptionsResponse respondErrorStatus(OrtStatus *status) { 101 | OrtCreateSessionOptionsResponse response; 102 | 103 | response.status = status; 104 | response.sessionOptions = NULL; 105 | 106 | return response; 107 | } 108 | } -------------------------------------------------------------------------------- /ort/session-options.go: -------------------------------------------------------------------------------- 1 | package ort 2 | 3 | /* 4 | #include 5 | #include "session-options.h" 6 | */ 7 | import "C" 8 | import ( 9 | "fmt" 10 | "unsafe" 11 | ) 12 | 13 | type ExecutionMode int 14 | type GraphOptimizationLevel int 15 | 16 | const ( 17 | ExecutionModeSequential ExecutionMode = 0 18 | ExecutionModeParallel ExecutionMode = 1 19 | ) 20 | const ( 21 | GraphOptLevelDisableAll GraphOptimizationLevel = 0 22 | GraphOptLevelEnableBasic GraphOptimizationLevel = 1 23 | GraphOptLevelEnableExtended GraphOptimizationLevel = 2 24 | GraphOptLevelEnableAll GraphOptimizationLevel = 99 25 | ) 26 | 27 | const DefaultExecutionMode ExecutionMode = ExecutionModeSequential 28 | const DefaultGraphOptLevel GraphOptimizationLevel = GraphOptLevelEnableAll 29 | 30 | type ortSessionOptions struct { 31 | cOrtSessionOptions *C.OrtSessionOptions 32 | } 33 | 34 | type SessionOptions struct { 35 | OptimizedModelFilePath string 36 | ExecutionMode ExecutionMode 37 | ProfilingEnabled bool 38 | ProfileFilePrefix string 39 | MemPatternEnabled bool 40 | CPUMemArenaEnabled bool 41 | SessionLogID string 42 | LogVerbosityLevel int 43 | LogSeverityLevel int 44 | GraphOptimizationLevel GraphOptimizationLevel 45 | IntraOpNumThreads int 46 | InterOpNumThreads int 47 | CustomOpDomains []CustomOpDomain 48 | } 49 | 50 | func (o *SessionOptions) Clone() *SessionOptions { 51 | return &SessionOptions{ 52 | OptimizedModelFilePath: o.OptimizedModelFilePath, 53 | ExecutionMode: o.ExecutionMode, 54 | ProfilingEnabled: o.ProfilingEnabled, 55 | ProfileFilePrefix: o.ProfileFilePrefix, 56 | MemPatternEnabled: o.MemPatternEnabled, 57 | CPUMemArenaEnabled: o.CPUMemArenaEnabled, 58 | SessionLogID: o.SessionLogID, 59 | LogVerbosityLevel: o.LogVerbosityLevel, 60 | LogSeverityLevel: o.LogSeverityLevel, 61 | GraphOptimizationLevel: o.GraphOptimizationLevel, 62 | IntraOpNumThreads: o.IntraOpNumThreads, 63 | InterOpNumThreads: o.InterOpNumThreads, 64 | CustomOpDomains: nil, 65 | } 66 | } 67 | 68 | func (o *SessionOptions) toOrtSessionOptions() (*ortSessionOptions, error) { 69 | var err error 70 | 71 | soParams := C.OrtCreateSessionOptionsParams{} 72 | if len(o.OptimizedModelFilePath) > 0 { 73 | soParams.optimizedModelFilePath = C.CString(o.OptimizedModelFilePath) 74 | defer C.free(unsafe.Pointer(soParams.optimizedModelFilePath)) 75 | } else { 76 | soParams.optimizedModelFilePath = nil 77 | } 78 | soParams.executionMode, err = getOrtExecutionModeForExecutionMode(o.ExecutionMode) 79 | if err != nil { 80 | soParams.executionMode, _ = getOrtExecutionModeForExecutionMode(DefaultExecutionMode) 81 | err = nil 82 | } 83 | if o.ProfilingEnabled && len(o.ProfileFilePrefix) > 0 { 84 | soParams.profilingEnabled = C.int(1) 85 | soParams.profileFilePrefix = C.CString(o.ProfileFilePrefix) 86 | defer C.free(unsafe.Pointer(soParams.profileFilePrefix)) 87 | } else { 88 | soParams.profilingEnabled = C.int(0) 89 | soParams.profileFilePrefix = nil 90 | } 91 | if o.MemPatternEnabled { 92 | soParams.memPatternEnabled = C.int(1) 93 | } else { 94 | soParams.memPatternEnabled = C.int(0) 95 | } 96 | if o.CPUMemArenaEnabled { 97 | soParams.cpuMemArenaEnabled = C.int(1) 98 | } else { 99 | soParams.cpuMemArenaEnabled = C.int(0) 100 | } 101 | if len(o.SessionLogID) > 0 { 102 | soParams.logId = C.CString(o.SessionLogID) 103 | defer C.free(unsafe.Pointer(soParams.logId)) 104 | } else { 105 | soParams.logId = nil 106 | } 107 | soParams.logVerbosityLevel = C.int(o.LogVerbosityLevel) 108 | soParams.logSeverityLevel = C.int(o.LogSeverityLevel) 109 | soParams.graphOptimizationLevel, err = getOrtSessionGraphOptimizationLevelForGraphOptimizationLevel(o.GraphOptimizationLevel) 110 | if err != nil { 111 | soParams.graphOptimizationLevel, _ = getOrtSessionGraphOptimizationLevelForGraphOptimizationLevel(DefaultGraphOptLevel) 112 | err = nil 113 | } 114 | soParams.intraOpNumThreads = C.int(o.IntraOpNumThreads) 115 | soParams.interOpNumThreads = C.int(o.InterOpNumThreads) 116 | soParams.numCustomOpDomains = C.int(len(o.CustomOpDomains)) 117 | if len(o.CustomOpDomains) > 0 { 118 | cCustomOpDomains := make([]*C.OrtCustomOpDomain, len(o.CustomOpDomains)) 119 | for i, customOpDomain := range o.CustomOpDomains { 120 | cCustomOpDomains[i] = customOpDomain.toCCustomOpDomain() 121 | } 122 | soParams.customOpDomains = &cCustomOpDomains[0] 123 | } else { 124 | soParams.customOpDomains = nil 125 | } 126 | 127 | response := C.createSessionOptions(ortApi.ort, &soParams) 128 | err = ortApi.ParseStatus(response.status) 129 | if err != nil { 130 | return nil, err 131 | } 132 | return &ortSessionOptions{cOrtSessionOptions: response.sessionOptions}, nil 133 | } 134 | 135 | func getOrtSessionGraphOptimizationLevelForGraphOptimizationLevel(level GraphOptimizationLevel) (C.GraphOptimizationLevel, error) { 136 | switch level { 137 | case GraphOptLevelDisableAll: 138 | return C.ORT_DISABLE_ALL, nil 139 | case GraphOptLevelEnableBasic: 140 | return C.ORT_ENABLE_BASIC, nil 141 | case GraphOptLevelEnableExtended: 142 | return C.ORT_ENABLE_EXTENDED, nil 143 | case GraphOptLevelEnableAll: 144 | return C.ORT_ENABLE_ALL, nil 145 | } 146 | return 0, fmt.Errorf("invalid graph optimization level %d", level) 147 | } 148 | 149 | func getOrtExecutionModeForExecutionMode(executionMode ExecutionMode) (C.ExecutionMode, error) { 150 | switch executionMode { 151 | case ExecutionModeSequential: 152 | return C.ORT_SEQUENTIAL, nil 153 | case ExecutionModeParallel: 154 | return C.ORT_PARALLEL, nil 155 | } 156 | return 0, fmt.Errorf("invalid execution mode %d", executionMode) 157 | } 158 | -------------------------------------------------------------------------------- /ort/session-options.h: -------------------------------------------------------------------------------- 1 | #ifndef GOONNX_ORT_SESSION_OPTIONS 2 | #define GOONNX_ORT_SESSION_OPTIONS 3 | 4 | #include 5 | 6 | typedef struct OrtCreateSessionOptionsParams { 7 | ORTCHAR_T *optimizedModelFilePath; 8 | ExecutionMode executionMode; 9 | int profilingEnabled; 10 | const ORTCHAR_T *profileFilePrefix; 11 | int memPatternEnabled; 12 | int cpuMemArenaEnabled; 13 | const char *logId; 14 | int logVerbosityLevel; 15 | int logSeverityLevel; 16 | GraphOptimizationLevel graphOptimizationLevel; 17 | int intraOpNumThreads; 18 | int interOpNumThreads; 19 | int numCustomOpDomains; 20 | OrtCustomOpDomain **customOpDomains; 21 | } OrtCreateSessionOptionsParams; 22 | 23 | typedef struct OrtCreateSessionOptionsResponse { 24 | OrtSessionOptions *sessionOptions; 25 | OrtStatus *status; 26 | } OrtCreateSessionOptionsResponse; 27 | 28 | #define DefaultExecutionMode ORT_SEQUENTIAL 29 | #define DefaultGraphOptimizationLevel ORT_ENABLE_ALL 30 | #define DefaultIntraOpNumThreads 0 31 | #define DefaultInterOpNumThreads 0 32 | 33 | OrtCreateSessionOptionsResponse createSessionOptions(OrtApi *api, OrtCreateSessionOptionsParams *params); 34 | OrtCreateSessionOptionsResponse releaseAndRespondErrorStatus(OrtApi *api, OrtSessionOptions *sessionOptions, OrtStatus *status); 35 | OrtCreateSessionOptionsResponse respondErrorStatus(OrtStatus *status); 36 | 37 | #endif -------------------------------------------------------------------------------- /ort/session.cpp: -------------------------------------------------------------------------------- 1 | extern "C" { 2 | #include 3 | #include 4 | #include "session.h" 5 | 6 | OrtCreateSessionResponse createSession(OrtApi *api, OrtEnv *env, const char *modelPath, OrtSessionOptions *sessionOptions) { 7 | OrtSession *session; 8 | OrtStatus *status; 9 | 10 | status = api->CreateSession(env, modelPath, sessionOptions, &session); 11 | 12 | OrtCreateSessionResponse response; 13 | response.session = session; 14 | response.status = status; 15 | 16 | return response; 17 | } 18 | 19 | void releaseSession(OrtApi *api, OrtSession *session) { 20 | api->ReleaseSession(session); 21 | } 22 | 23 | void releaseSessionOptions(OrtApi *api, OrtSessionOptions *opts) { 24 | api->ReleaseSessionOptions(opts); 25 | } 26 | 27 | OrtGetIOCountResponse getInputCount(OrtApi *api, OrtSession *session) { 28 | size_t numInputNodes; 29 | OrtStatus *status; 30 | 31 | status = api->SessionGetInputCount(session, &numInputNodes); 32 | 33 | OrtGetIOCountResponse response; 34 | response.numNodes = numInputNodes; 35 | response.status = status; 36 | 37 | return response; 38 | } 39 | 40 | OrtGetIONameResponse getInputName(OrtApi *api, OrtSession *session, size_t i, OrtAllocator *allocator) { 41 | char *inputName; 42 | OrtStatus *status; 43 | 44 | api->SessionGetInputName(session, i, allocator, &inputName); 45 | 46 | OrtGetIONameResponse response; 47 | response.name = inputName; 48 | response.status = status; 49 | 50 | return response; 51 | } 52 | 53 | OrtGetIOTypeInfoResponse getInputTypeInfo(OrtApi *api, OrtSession *session, size_t i) { 54 | OrtTypeInfo *typeInfo; 55 | OrtStatus *status; 56 | 57 | status = api->SessionGetInputTypeInfo(session, i, &typeInfo); 58 | 59 | OrtGetIOTypeInfoResponse response; 60 | response.typeInfo = typeInfo; 61 | response.status = status; 62 | 63 | return response; 64 | } 65 | 66 | OrtGetIOCountResponse getOutputCount(OrtApi *api, OrtSession *session) { 67 | size_t numOutputNodes; 68 | OrtStatus *status; 69 | 70 | status = api->SessionGetOutputCount(session, &numOutputNodes); 71 | 72 | OrtGetIOCountResponse response; 73 | response.numNodes = numOutputNodes; 74 | response.status = status; 75 | 76 | return response; 77 | } 78 | 79 | OrtGetIONameResponse getOutputName(OrtApi *api, OrtSession *session, size_t i, OrtAllocator *allocator) { 80 | char *outputName; 81 | OrtStatus *status; 82 | 83 | api->SessionGetOutputName(session, i, allocator, &outputName); 84 | 85 | OrtGetIONameResponse response; 86 | response.name = outputName; 87 | response.status = status; 88 | 89 | return response; 90 | } 91 | 92 | OrtGetIOTypeInfoResponse getOutputTypeInfo(OrtApi *api, OrtSession *session, size_t i) { 93 | OrtTypeInfo *typeInfo; 94 | OrtStatus *status; 95 | 96 | status = api->SessionGetOutputTypeInfo(session, i, &typeInfo); 97 | 98 | OrtGetIOTypeInfoResponse response; 99 | response.typeInfo = typeInfo; 100 | response.status = status; 101 | 102 | return response; 103 | } 104 | 105 | OrtRunResponse run(OrtApi *api, OrtSession *session, OrtRunOptions *runOptions, char **inputNames, OrtValue **input, 106 | size_t inputLen, char **outputNames, size_t outputNamesLen) { 107 | OrtValue *output = NULL; 108 | OrtStatus *status; 109 | 110 | status = api->Run(session, runOptions, inputNames, input, inputLen, outputNames, outputNamesLen, &output); 111 | 112 | OrtRunResponse response; 113 | response.output = output; 114 | response.status = status; 115 | 116 | return response; 117 | } 118 | } -------------------------------------------------------------------------------- /ort/session.go: -------------------------------------------------------------------------------- 1 | package ort 2 | 3 | /* 4 | #include 5 | #include "session.h" 6 | */ 7 | import "C" 8 | import ( 9 | "fmt" 10 | "unsafe" 11 | ) 12 | 13 | type Session interface { 14 | GetInputCount() (int, error) 15 | GetInputName(index int) (string, error) 16 | GetInputNames() ([]string, error) 17 | GetInputTypeInfo(index int) (TypeInfo, error) 18 | GetInputTypeInfos() ([]TypeInfo, error) 19 | 20 | GetOutputCount() (int, error) 21 | GetOutputName(index int) (string, error) 22 | GetOutputNames() ([]string, error) 23 | GetOutputTypeInfo(index int) (TypeInfo, error) 24 | GetOutputTypeInfos() ([]TypeInfo, error) 25 | 26 | Run(runOptions *RunOptions, inputValues []Value) ([]Value, error) 27 | ReleaseSession() 28 | 29 | PrintIOInfo() 30 | } 31 | 32 | type session struct { 33 | inputCount int 34 | inputNames []string 35 | inputTypeInfos []TypeInfo 36 | outputCount int 37 | outputNames []string 38 | outputTypeInfos []TypeInfo 39 | cOpts *C.OrtSessionOptions 40 | cModelPath *C.char 41 | cSession *C.OrtSession 42 | alloc *allocator 43 | } 44 | 45 | func NewSession(env Environment, modelPath string, sessionOpts *SessionOptions) (Session, error) { 46 | cModelPath := C.CString(modelPath) 47 | 48 | e, ok := env.(*environment) 49 | if !ok { 50 | return nil, fmt.Errorf("invalid Environment type") 51 | } 52 | 53 | ortOpts, err := sessionOpts.toOrtSessionOptions() 54 | if err != nil { 55 | return nil, err 56 | } 57 | response := C.createSession(ortApi.ort, e.env, cModelPath, ortOpts.cOrtSessionOptions) 58 | err = ortApi.ParseStatus(response.status) 59 | if err != nil { 60 | return nil, err 61 | } 62 | 63 | allocator, err := newAllocatorWithDefaultOptions() 64 | if err != nil { 65 | return nil, err 66 | } 67 | 68 | return &session{ 69 | inputCount: -1, 70 | inputNames: nil, 71 | inputTypeInfos: nil, 72 | outputCount: -1, 73 | outputNames: nil, 74 | outputTypeInfos: nil, 75 | cOpts: ortOpts.cOrtSessionOptions, 76 | cModelPath: cModelPath, 77 | cSession: response.session, 78 | alloc: allocator, 79 | }, nil 80 | } 81 | 82 | func (s *session) GetInputCount() (int, error) { 83 | if s.inputCount > -1 { 84 | return s.inputCount, nil 85 | } 86 | 87 | response := C.getInputCount(ortApi.ort, s.cSession) 88 | err := ortApi.ParseStatus(response.status) 89 | if err != nil { 90 | return 0, err 91 | } 92 | 93 | s.inputCount = int(response.numNodes) 94 | return s.inputCount, nil 95 | } 96 | 97 | func (s *session) GetInputName(index int) (string, error) { 98 | inputNames, err := s.GetInputNames() 99 | if err != nil { 100 | return "", err 101 | } 102 | if index < 0 || index >= len(s.inputNames) { 103 | return "", fmt.Errorf("invalid input index %d", index) 104 | } 105 | return inputNames[index], nil 106 | } 107 | 108 | func (s *session) getInputName(index int) (string, error) { 109 | i := C.size_t(index) 110 | 111 | response := C.getInputName(ortApi.ort, s.cSession, i, s.alloc.a) 112 | err := ortApi.ParseStatus(response.status) 113 | if err != nil { 114 | return "", err 115 | } 116 | 117 | name := C.GoString(response.name) 118 | C.free(unsafe.Pointer(response.name)) 119 | 120 | return name, nil 121 | } 122 | 123 | func (s *session) GetInputNames() ([]string, error) { 124 | if s.inputNames != nil { 125 | return s.inputNames, nil 126 | } 127 | 128 | inputCount, err := s.GetInputCount() 129 | if err != nil { 130 | return nil, err 131 | } 132 | 133 | s.inputNames = make([]string, inputCount) 134 | for i := 0; i < inputCount; i++ { 135 | s.inputNames[i], err = s.getInputName(i) 136 | if err != nil { 137 | s.inputNames = nil 138 | return nil, err 139 | } 140 | } 141 | return s.inputNames, nil 142 | } 143 | 144 | func (s *session) GetInputTypeInfo(index int) (TypeInfo, error) { 145 | typeInfos, err := s.GetInputTypeInfos() 146 | if err != nil { 147 | return nil, err 148 | } 149 | if index < 0 || index >= len(typeInfos) { 150 | return nil, fmt.Errorf("invalid input index %d", index) 151 | } 152 | return typeInfos[index], nil 153 | } 154 | 155 | func (s *session) GetInputTypeInfos() ([]TypeInfo, error) { 156 | if s.inputTypeInfos != nil { 157 | return s.inputTypeInfos, nil 158 | } 159 | 160 | numInputs, err := s.GetInputCount() 161 | if err != nil { 162 | return nil, err 163 | } 164 | s.inputTypeInfos = make([]TypeInfo, numInputs) 165 | for i := 0; i < numInputs; i++ { 166 | s.inputTypeInfos[i], err = s.getInputTypeInfo(i) 167 | if err != nil { 168 | s.inputTypeInfos = nil 169 | return nil, err 170 | } 171 | } 172 | return s.inputTypeInfos, nil 173 | } 174 | 175 | func (s *session) getInputTypeInfo(index int) (TypeInfo, error) { 176 | i := C.size_t(index) 177 | 178 | response := C.getInputTypeInfo(ortApi.ort, s.cSession, i) 179 | err := ortApi.ParseStatus(response.status) 180 | if err != nil { 181 | return nil, err 182 | } 183 | 184 | return newTypeInfo(response.typeInfo), nil 185 | } 186 | 187 | func (s *session) Run(runOpts *RunOptions, inputValues []Value) ([]Value, error) { 188 | ortRunOpts, err := runOpts.toOrtRunOptions() 189 | if err != nil { 190 | return nil, err 191 | } 192 | 193 | outputNames, err := s.GetOutputNames() 194 | if err != nil { 195 | return nil, err 196 | } 197 | cOutputNames := stringsToCharArrayPtr(outputNames) 198 | defer freeCStrings(cOutputNames) 199 | 200 | cInputNames, cInputValues, err := valuesToOrtValueArray(inputValues) 201 | if err != nil { 202 | return nil, err 203 | } 204 | defer freeCStrings(cInputNames) 205 | 206 | inLen := C.size_t(len(inputValues)) 207 | outNamesLen := C.size_t(len(outputNames)) 208 | 209 | response := C.run(ortApi.ort, s.cSession, ortRunOpts, &cInputNames[0], &cInputValues[0], inLen, &cOutputNames[0], outNamesLen) 210 | err = ortApi.ParseStatus(response.status) 211 | if err != nil { 212 | return nil, err 213 | } 214 | 215 | return s.outputsToValueSlice(outputNames, response.output) 216 | } 217 | 218 | func (s *session) outputsToValueSlice(names []string, outputs *C.OrtValue) ([]Value, error) { 219 | length := len(names) 220 | tmpslice := (*[1 << 30]C.OrtValue)(unsafe.Pointer(outputs))[:length:length] 221 | outValues := make([]Value, length) 222 | 223 | for i := 0; i < length; i++ { 224 | typeInfo, err := s.GetOutputTypeInfo(i) 225 | if err != nil { 226 | return nil, err 227 | } 228 | 229 | tensorInfo, err := typeInfo.ToTensorInfo() 230 | if err != nil { 231 | return nil, err 232 | } 233 | outValues[i] = newValue(names[i], tensorInfo, &tmpslice[i]) 234 | } 235 | return outValues, nil 236 | } 237 | 238 | func stringsToCharArrayPtr(in []string) []*C.char { 239 | cStrings := make([]*C.char, len(in)) 240 | for i, inVal := range in { 241 | cStrings[i] = C.CString(inVal) 242 | } 243 | return cStrings 244 | } 245 | 246 | func freeCStrings(in []*C.char) { 247 | for i := 0; i < len(in); i++ { 248 | C.free(unsafe.Pointer(in[i])) 249 | } 250 | } 251 | 252 | func valuesToOrtValueArray(in []Value) ([]*C.char, []*C.OrtValue, error) { 253 | ortVals := make([]*C.OrtValue, len(in)) 254 | valNames := make([]*C.char, len(in)) 255 | for i, inVal := range in { 256 | sValue, ok := inVal.(*value) 257 | if !ok { 258 | return nil, nil, fmt.Errorf("invalid Value type") 259 | } 260 | valNames[i] = C.CString(inVal.GetName()) 261 | ortVals[i] = sValue.cOrtValue 262 | } 263 | return valNames, ortVals, nil 264 | } 265 | 266 | func (s *session) ReleaseSession() { 267 | for _, typeInfo := range s.inputTypeInfos { 268 | typeInfo.ReleaseTypeInfo() 269 | } 270 | for _, typeInfo := range s.outputTypeInfos { 271 | typeInfo.ReleaseTypeInfo() 272 | } 273 | 274 | C.releaseSession(ortApi.ort, s.cSession) 275 | C.releaseSessionOptions(ortApi.ort, s.cOpts) 276 | C.free(unsafe.Pointer(s.cModelPath)) 277 | } 278 | 279 | func (s *session) GetOutputCount() (int, error) { 280 | if s.outputCount > -1 { 281 | return s.outputCount, nil 282 | } 283 | 284 | response := C.getOutputCount(ortApi.ort, s.cSession) 285 | err := ortApi.ParseStatus(response.status) 286 | if err != nil { 287 | return 0, err 288 | } 289 | 290 | s.outputCount = int(response.numNodes) 291 | return s.outputCount, nil 292 | } 293 | 294 | func (s *session) GetOutputName(index int) (string, error) { 295 | outputNames, err := s.GetOutputNames() 296 | if err != nil { 297 | return "", err 298 | } 299 | if index < 0 || index >= len(s.outputNames) { 300 | return "", fmt.Errorf("invalid output index %d", index) 301 | } 302 | return outputNames[index], nil 303 | } 304 | 305 | func (s *session) getOutputName(index int) (string, error) { 306 | i := C.size_t(index) 307 | 308 | response := C.getOutputName(ortApi.ort, s.cSession, i, s.alloc.a) 309 | err := ortApi.ParseStatus(response.status) 310 | if err != nil { 311 | return "", err 312 | } 313 | 314 | name := C.GoString(response.name) 315 | C.free(unsafe.Pointer(response.name)) 316 | 317 | return name, nil 318 | } 319 | 320 | func (s *session) GetOutputNames() ([]string, error) { 321 | if s.outputNames != nil { 322 | return s.outputNames, nil 323 | } 324 | 325 | outputCount, err := s.GetOutputCount() 326 | if err != nil { 327 | return nil, err 328 | } 329 | 330 | s.outputNames = make([]string, outputCount) 331 | for i := 0; i < outputCount; i++ { 332 | s.outputNames[i], err = s.getOutputName(i) 333 | if err != nil { 334 | s.outputNames = nil 335 | return nil, err 336 | } 337 | } 338 | return s.outputNames, nil 339 | } 340 | 341 | func (s *session) GetOutputTypeInfo(index int) (TypeInfo, error) { 342 | typeInfos, err := s.GetOutputTypeInfos() 343 | if err != nil { 344 | return nil, err 345 | } 346 | if index < 0 || index >= len(typeInfos) { 347 | return nil, fmt.Errorf("invalid output index %d", index) 348 | } 349 | return typeInfos[index], nil 350 | } 351 | 352 | func (s *session) GetOutputTypeInfos() ([]TypeInfo, error) { 353 | if s.outputTypeInfos != nil { 354 | return s.outputTypeInfos, nil 355 | } 356 | 357 | numOutputs, err := s.GetOutputCount() 358 | if err != nil { 359 | return nil, err 360 | } 361 | s.outputTypeInfos = make([]TypeInfo, numOutputs) 362 | for i := 0; i < numOutputs; i++ { 363 | s.outputTypeInfos[i], err = s.getOutputTypeInfo(i) 364 | if err != nil { 365 | s.outputTypeInfos = nil 366 | return nil, err 367 | } 368 | } 369 | return s.outputTypeInfos, nil 370 | } 371 | 372 | func (s *session) getOutputTypeInfo(index int) (TypeInfo, error) { 373 | i := C.size_t(index) 374 | 375 | response := C.getOutputTypeInfo(ortApi.ort, s.cSession, i) 376 | err := ortApi.ParseStatus(response.status) 377 | if err != nil { 378 | return nil, err 379 | } 380 | 381 | return newTypeInfo(response.typeInfo), nil 382 | } 383 | 384 | func (s *session) PrintIOInfo() { 385 | fmt.Printf("*******************************\n") 386 | fmt.Printf("*** Session I/O Information ***\n") 387 | fmt.Printf("*******************************\n") 388 | inCount, err := s.GetInputCount() 389 | if err != nil { 390 | fmt.Printf("Error retrieving input count - %s\n", err.Error()) 391 | } else { 392 | fmt.Printf("Number of inputs: %d\n", inCount) 393 | for i := 0; i < inCount; i++ { 394 | fmt.Printf("Input %d:\n", i) 395 | name, err := s.GetInputName(i) 396 | if err != nil { 397 | fmt.Printf(" Error retrieving name\n") 398 | } else { 399 | fmt.Printf(" Name: %s\n", name) 400 | } 401 | typeInfo, err := s.GetInputTypeInfo(i) 402 | if err != nil { 403 | fmt.Printf(" Error retrieving type info\n") 404 | } else { 405 | s.printTensorTypeInfo(typeInfo) 406 | } 407 | } 408 | } 409 | outCount, err := s.GetOutputCount() 410 | if err != nil { 411 | fmt.Printf("Error retrieving output count - %s\n", err.Error()) 412 | } else { 413 | fmt.Printf("Number of outputs: %d\n", outCount) 414 | for i := 0; i < outCount; i++ { 415 | fmt.Printf("Output %d:\n", i) 416 | name, err := s.GetOutputName(i) 417 | if err != nil { 418 | fmt.Printf(" Error retrieving name\n") 419 | } else { 420 | fmt.Printf(" Name: %s\n", name) 421 | } 422 | typeInfo, err := s.GetOutputTypeInfo(i) 423 | if err != nil { 424 | fmt.Printf(" Error retrieving type info\n") 425 | } else { 426 | s.printTensorTypeInfo(typeInfo) 427 | } 428 | } 429 | } 430 | } 431 | 432 | func (s *session) printTensorTypeInfo(typeInfo TypeInfo) { 433 | tensorInfo, err := typeInfo.ToTensorInfo() 434 | if err != nil { 435 | fmt.Printf(" Error converting type info to tensor info\n") 436 | } else { 437 | onnxElementType, err := tensorInfo.GetElementType() 438 | if err != nil { 439 | fmt.Printf(" Error retrieving element type\n") 440 | } else { 441 | fmt.Printf(" Element Type: %d\n", onnxElementType) 442 | } 443 | dimsCount, err := tensorInfo.GetDimensionsCount() 444 | if err != nil { 445 | fmt.Printf(" Error retrieving dimensions count\n") 446 | } else { 447 | fmt.Printf(" Dimensions Count: %d\n", dimsCount) 448 | } 449 | dims, err := tensorInfo.GetDimensions() 450 | if err != nil { 451 | fmt.Printf(" Error retrieving dimensions\n") 452 | } else { 453 | for j, dim := range dims { 454 | fmt.Printf(" dim %d size: %d\n", j, dim) 455 | } 456 | } 457 | } 458 | } 459 | -------------------------------------------------------------------------------- /ort/session.h: -------------------------------------------------------------------------------- 1 | #ifndef GOONNX_ORT_SESSION 2 | #define GOONNX_ORT_SESSION 3 | 4 | #include 5 | 6 | typedef struct OrtCreateSessionResponse { 7 | OrtSession *session; 8 | OrtStatus *status; 9 | } OrtCreateSessionResponse; 10 | 11 | typedef struct OrtGetIOCountResponse { 12 | size_t numNodes; 13 | OrtStatus *status; 14 | } OrtGetIOCountResponse; 15 | 16 | typedef struct OrtGetIONameResponse { 17 | char *name; 18 | OrtStatus *status; 19 | } OrtGetIONameResponse; 20 | 21 | typedef struct OrtGetIOTypeInfoResponse { 22 | OrtTypeInfo *typeInfo; 23 | OrtStatus *status; 24 | } OrtGetIOTypeInfoResponse; 25 | 26 | typedef struct OrtRunResponse { 27 | OrtValue *output; 28 | OrtStatus *status; 29 | } OrtRunResponse; 30 | 31 | OrtCreateSessionResponse createSession(OrtApi *api, OrtEnv *env, const char *modelPath, 32 | OrtSessionOptions *sessionOptions); 33 | void releaseSession(OrtApi *api, OrtSession *session); 34 | void releaseSessionOptions(OrtApi *api, OrtSessionOptions *opts); 35 | OrtGetIOCountResponse getInputCount(OrtApi *api, OrtSession *session); 36 | OrtGetIONameResponse getInputName(OrtApi *api, OrtSession *session, size_t i, OrtAllocator *allocator); 37 | OrtGetIOTypeInfoResponse getInputTypeInfo(OrtApi *api, OrtSession *session, size_t i); 38 | OrtGetIOCountResponse getOutputCount(OrtApi *api, OrtSession *session); 39 | OrtGetIONameResponse getOutputName(OrtApi *api, OrtSession *session, size_t i, OrtAllocator *allocator); 40 | OrtGetIOTypeInfoResponse getOutputTypeInfo(OrtApi *api, OrtSession *session, size_t i); 41 | OrtRunResponse run(OrtApi *api, OrtSession *session, OrtRunOptions *runOptions, char **inputNames, OrtValue **input, 42 | size_t inputLen, char **outputNames, size_t outputNamesLen); 43 | 44 | #endif -------------------------------------------------------------------------------- /ort/tensor-type-and-shape-info.cpp: -------------------------------------------------------------------------------- 1 | extern "C" { 2 | #include 3 | #include "tensor-type-and-shape-info.h" 4 | 5 | OrtGetTensorElementTypeResponse getTensorElementType(OrtApi *api, OrtTensorTypeAndShapeInfo *typeInfo) { 6 | ONNXTensorElementDataType dataType; 7 | OrtStatus *status; 8 | 9 | status = api->GetTensorElementType(typeInfo, &dataType); 10 | 11 | OrtGetTensorElementTypeResponse response; 12 | response.dataType = dataType; 13 | response.status = status; 14 | 15 | return response; 16 | } 17 | 18 | OrtGetDimensionsCountResponse getDimensionsCount(OrtApi *api, OrtTensorTypeAndShapeInfo *typeInfo) { 19 | size_t numDims; 20 | OrtStatus *status; 21 | 22 | status = api->GetDimensionsCount(typeInfo, &numDims); 23 | 24 | OrtGetDimensionsCountResponse response; 25 | response.numDims = numDims; 26 | response.status = status; 27 | 28 | return response; 29 | } 30 | 31 | OrtGetDimensionsResponse getDimensions(OrtApi *api, OrtTensorTypeAndShapeInfo *typeInfo, size_t numDims) { 32 | int64_t *dims; 33 | OrtStatus *status; 34 | 35 | dims = (int64_t *)malloc(numDims * sizeof(int64_t)); 36 | 37 | status = api->GetDimensions(typeInfo, dims, numDims); 38 | 39 | OrtGetDimensionsResponse response; 40 | response.dims = dims; 41 | response.status = status; 42 | 43 | return response; 44 | } 45 | 46 | OrtGetSymbolicDimensionsResponse getSymbolicDimensions(OrtApi *api, OrtTensorTypeAndShapeInfo *typeInfo, size_t numDims) { 47 | const char *dimParams; 48 | OrtStatus *status; 49 | 50 | status = api->GetSymbolicDimensions(typeInfo, &dimParams, numDims); 51 | 52 | OrtGetSymbolicDimensionsResponse response; 53 | response.status = status; 54 | response.dimParams = dimParams; 55 | 56 | return response; 57 | } 58 | 59 | 60 | } -------------------------------------------------------------------------------- /ort/tensor-type-and-shape-info.go: -------------------------------------------------------------------------------- 1 | package ort 2 | 3 | /* 4 | #include 5 | #include "tensor-type-and-shape-info.h" 6 | */ 7 | import "C" 8 | import ( 9 | "reflect" 10 | "unsafe" 11 | ) 12 | 13 | type ONNXTensorElementDataType int 14 | 15 | const ( 16 | TensorElemDataTypeUndefined ONNXTensorElementDataType = 0 17 | TensorElemDataTypeFloat ONNXTensorElementDataType = 1 18 | TensorElemDataTypeUInt8 ONNXTensorElementDataType = 2 19 | TensorElemDataTypeInt8 ONNXTensorElementDataType = 3 20 | TensorElemDataTypeUInt16 ONNXTensorElementDataType = 4 21 | TensorElemDataTypeInt16 ONNXTensorElementDataType = 5 22 | TensorElemDataTypeInt32 ONNXTensorElementDataType = 6 23 | TensorElemDataTypeInt64 ONNXTensorElementDataType = 7 24 | TensorElemDataTypeString ONNXTensorElementDataType = 8 25 | TensorElemDataTypeBool ONNXTensorElementDataType = 9 26 | TensorElemDataTypeFloat16 ONNXTensorElementDataType = 10 27 | TensorElemDataTypeDouble ONNXTensorElementDataType = 11 28 | TensorElemDataTypeUInt32 ONNXTensorElementDataType = 12 29 | TensorElemDataTypeUInt64 ONNXTensorElementDataType = 13 30 | TensorElemDataTypeComplex64 ONNXTensorElementDataType = 14 31 | TensorElemDataTypeComplex128 ONNXTensorElementDataType = 15 32 | TensorElemDataTypeBFloat16 ONNXTensorElementDataType = 16 33 | ) 34 | 35 | type TensorTypeAndShapeInfo interface { 36 | GetElementType() (ONNXTensorElementDataType, error) 37 | GetDimensionsCount() (int, error) 38 | GetDimensions() ([]int64, error) 39 | } 40 | 41 | type tensorTypeAndShapeInfo struct { 42 | elementType ONNXTensorElementDataType 43 | dimCount int 44 | dims []int64 45 | cTensorInfo *C.OrtTensorTypeAndShapeInfo 46 | } 47 | 48 | func newTensorTypeAndShapeInfo(cTensorInfo *C.OrtTensorTypeAndShapeInfo) *tensorTypeAndShapeInfo { 49 | return &tensorTypeAndShapeInfo{ 50 | elementType: -1, 51 | dimCount: -1, 52 | dims: nil, 53 | cTensorInfo: cTensorInfo, 54 | } 55 | } 56 | 57 | func (i *tensorTypeAndShapeInfo) GetElementType() (ONNXTensorElementDataType, error) { 58 | if i.elementType > -1 { 59 | return i.elementType, nil 60 | } 61 | 62 | response := C.getTensorElementType(ortApi.ort, i.cTensorInfo) 63 | err := ortApi.ParseStatus(response.status) 64 | if err != nil { 65 | return TensorElemDataTypeUndefined, err 66 | } 67 | 68 | i.elementType = getONNXTensorElementDataTypeForOrtTensorElementDataType(response.dataType) 69 | return i.elementType, nil 70 | } 71 | 72 | func (i *tensorTypeAndShapeInfo) GetDimensionsCount() (int, error) { 73 | if i.dimCount > -1 { 74 | return i.dimCount, nil 75 | } 76 | 77 | response := C.getDimensionsCount(ortApi.ort, i.cTensorInfo) 78 | err := ortApi.ParseStatus(response.status) 79 | if err != nil { 80 | return 0, err 81 | } 82 | 83 | i.dimCount = int(response.numDims) 84 | return i.dimCount, nil 85 | } 86 | 87 | func (i *tensorTypeAndShapeInfo) GetDimensions() ([]int64, error) { 88 | if i.dims != nil { 89 | return i.dims, nil 90 | } 91 | 92 | numDims, err := i.GetDimensionsCount() 93 | if err != nil { 94 | return nil, err 95 | } 96 | 97 | cNumDims := C.size_t(numDims) 98 | 99 | response := C.getDimensions(ortApi.ort, i.cTensorInfo, cNumDims) 100 | err = ortApi.ParseStatus(response.status) 101 | if err != nil { 102 | return nil, err 103 | } 104 | 105 | sliceHeader := (*reflect.SliceHeader)(unsafe.Pointer(&(i.dims))) 106 | sliceHeader.Cap = numDims 107 | sliceHeader.Len = numDims 108 | sliceHeader.Data = uintptr(unsafe.Pointer(response.dims)) 109 | 110 | return i.dims, nil 111 | } 112 | 113 | func (i *tensorTypeAndShapeInfo) cGetDimensions() (interface{}, error) { 114 | dims, err := i.GetDimensions() 115 | if err != nil { 116 | return nil, err 117 | } 118 | 119 | return &dims[0], nil 120 | } 121 | 122 | func (i *tensorTypeAndShapeInfo) cGetDimensionsCount() (C.size_t, error) { 123 | numDims, err := i.GetDimensionsCount() 124 | if err != nil { 125 | return 0, nil 126 | } 127 | return C.size_t(numDims), nil 128 | } 129 | 130 | func (i *tensorTypeAndShapeInfo) cGetElementType() (C.ONNXTensorElementDataType, error) { 131 | elemType, err := i.GetElementType() 132 | if err != nil { 133 | return 0, err 134 | } 135 | return getOrtTensorElementDataTypeForONNXTensorElementDataType(elemType), nil 136 | } 137 | 138 | func getONNXTensorElementDataTypeForOrtTensorElementDataType(ortType C.ONNXTensorElementDataType) ONNXTensorElementDataType { 139 | return ONNXTensorElementDataType(ortType) 140 | } 141 | 142 | func getOrtTensorElementDataTypeForONNXTensorElementDataType(onnxType ONNXTensorElementDataType) C.ONNXTensorElementDataType { 143 | return C.ONNXTensorElementDataType(onnxType) 144 | } 145 | -------------------------------------------------------------------------------- /ort/tensor-type-and-shape-info.h: -------------------------------------------------------------------------------- 1 | #ifndef GOONNX_ORT_TENSOR_TYPE_AND_SHAPE_INFO 2 | #define GOONNX_ORT_TENSOR_TYPE_AND_SHAPE_INFO 3 | 4 | #include 5 | 6 | typedef struct OrtGetTensorElementTypeResponse { 7 | ONNXTensorElementDataType dataType; 8 | OrtStatus *status; 9 | } OrtGetTensorElementTypeResponse; 10 | 11 | typedef struct OrtGetDimensionsCountResponse { 12 | size_t numDims; 13 | OrtStatus *status; 14 | } OrtGetDimensionsCountResponse; 15 | 16 | typedef struct OrtGetDimensionsResponse { 17 | int64_t *dims; 18 | OrtStatus *status; 19 | } OrtGetDimensionsResponse; 20 | 21 | typedef struct OrtGetSymbolicDimensionsResponse { 22 | const char *dimParams; 23 | OrtStatus *status; 24 | } OrtGetSymbolicDimensionsResponse; 25 | 26 | OrtGetTensorElementTypeResponse getTensorElementType(OrtApi *api, OrtTensorTypeAndShapeInfo *typeInfo); 27 | OrtGetDimensionsCountResponse getDimensionsCount(OrtApi *api, OrtTensorTypeAndShapeInfo *typeInfo); 28 | OrtGetDimensionsResponse getDimensions(OrtApi *api, OrtTensorTypeAndShapeInfo *typeInfo, size_t numDims); 29 | OrtGetSymbolicDimensionsResponse getSymbolicDimensions(OrtApi *api, OrtTensorTypeAndShapeInfo *typeInfo, size_t numDims); 30 | 31 | #endif -------------------------------------------------------------------------------- /ort/type-info.cpp: -------------------------------------------------------------------------------- 1 | extern "C" { 2 | #include 3 | #include "type-info.h" 4 | 5 | void releaseTypeInfo(OrtApi *api, OrtTypeInfo *typeInfo) { 6 | api->ReleaseTypeInfo(typeInfo); 7 | } 8 | 9 | OrtCastTypeInfoToTensorInfoResponse castTypeInfoToTensorInfo(OrtApi *api, OrtTypeInfo *typeInfo) { 10 | const OrtTensorTypeAndShapeInfo *tensorInfo; 11 | OrtStatus *status; 12 | 13 | status = api->CastTypeInfoToTensorInfo(typeInfo, &tensorInfo); 14 | 15 | OrtCastTypeInfoToTensorInfoResponse response; 16 | response.tensorInfo = tensorInfo; 17 | response.status = status; 18 | 19 | return response; 20 | } 21 | } -------------------------------------------------------------------------------- /ort/type-info.go: -------------------------------------------------------------------------------- 1 | package ort 2 | 3 | /* 4 | #include 5 | #include "type-info.h" 6 | */ 7 | import "C" 8 | import "fmt" 9 | 10 | type ONNXType int 11 | 12 | const ( 13 | TypeUnknown ONNXType = 0 14 | TypeTensor ONNXType = 1 15 | TypeSequence ONNXType = 2 16 | TypeMap ONNXType = 3 17 | TypeOpaque ONNXType = 4 18 | TypeSparseTensor ONNXType = 5 19 | ) 20 | 21 | type TypeInfo interface { 22 | ToTensorInfo() (TensorTypeAndShapeInfo, error) 23 | ReleaseTypeInfo() 24 | } 25 | 26 | type typeInfo struct { 27 | released bool 28 | cTypeInfo *C.OrtTypeInfo 29 | } 30 | 31 | func newTypeInfo(cTypeInfo *C.OrtTypeInfo) TypeInfo { 32 | return &typeInfo{ 33 | released: false, 34 | cTypeInfo: cTypeInfo, 35 | } 36 | } 37 | 38 | func (i *typeInfo) ToTensorInfo() (TensorTypeAndShapeInfo, error) { 39 | if i.cTypeInfo == nil { 40 | return nil, fmt.Errorf("TypeInfo incorrectly instantiated") 41 | } 42 | 43 | response := C.castTypeInfoToTensorInfo(ortApi.ort, i.cTypeInfo) 44 | err := ortApi.ParseStatus(response.status) 45 | if err != nil { 46 | return nil, err 47 | } 48 | 49 | return newTensorTypeAndShapeInfo(response.tensorInfo), nil 50 | } 51 | 52 | func (i *typeInfo) ReleaseTypeInfo() { 53 | if !i.released { 54 | C.releaseTypeInfo(ortApi.ort, i.cTypeInfo) 55 | i.cTypeInfo = nil 56 | i.released = true 57 | } 58 | } 59 | -------------------------------------------------------------------------------- /ort/type-info.h: -------------------------------------------------------------------------------- 1 | #ifndef GOONNX_ORT_INPUT_TYPE_INFO 2 | #define GOONNX_ORT_INPUT_TYPE_INFO 3 | 4 | #include 5 | 6 | typedef struct OrtCastTypeInfoToTensorInfoResponse { 7 | const OrtTensorTypeAndShapeInfo *tensorInfo; 8 | OrtStatus *status; 9 | } OrtCastTypeInfoToTensorInfoResponse; 10 | 11 | void releaseTypeInfo(OrtApi *api, OrtTypeInfo *typeInfo); 12 | OrtCastTypeInfoToTensorInfoResponse castTypeInfoToTensorInfo(OrtApi *api, OrtTypeInfo *typeInfo); 13 | 14 | #endif -------------------------------------------------------------------------------- /ort/value.cpp: -------------------------------------------------------------------------------- 1 | extern "C" { 2 | #include 3 | #include "value.h" 4 | 5 | OrtCreateTensorWithDataAsOrtValueResponse createTensorWithDataAsOrtValue(OrtApi *api, OrtMemoryInfo *memoryInfo, 6 | void *data, size_t dataLen, int64_t *shape, size_t shapeLen, ONNXTensorElementDataType type) { 7 | OrtValue *value; 8 | OrtStatus *status; 9 | 10 | status = api->CreateTensorWithDataAsOrtValue(memoryInfo, data, dataLen, shape, shapeLen, type, &value); 11 | 12 | OrtCreateTensorWithDataAsOrtValueResponse response; 13 | response.value = value; 14 | response.status = status; 15 | 16 | return response; 17 | } 18 | 19 | OrtIsTensorResponse isTensor(OrtApi *api, OrtValue *value) { 20 | int isTensor; 21 | OrtStatus *status; 22 | 23 | status = api->IsTensor(value, &isTensor); 24 | 25 | OrtIsTensorResponse response; 26 | response.isTensor = isTensor; 27 | response.status = status; 28 | 29 | return response; 30 | } 31 | 32 | OrtGetTensorMutableFloatDataResponse getTensorMutableFloatData(OrtApi *api, OrtValue *value) { 33 | float *out; 34 | OrtStatus *status; 35 | 36 | status = api->GetTensorMutableData(value, (void **)&out); 37 | 38 | OrtGetTensorMutableFloatDataResponse response; 39 | response.status = status; 40 | response.out = out; 41 | 42 | return response; 43 | } 44 | } -------------------------------------------------------------------------------- /ort/value.go: -------------------------------------------------------------------------------- 1 | package ort 2 | 3 | /* 4 | #include 5 | #include "value.h" 6 | */ 7 | import "C" 8 | import ( 9 | "fmt" 10 | "reflect" 11 | "unsafe" 12 | ) 13 | 14 | type Value interface { 15 | GetName() string 16 | IsTensor() (bool, error) 17 | GetTensorMutableFloatData() ([]float32, error) 18 | } 19 | 20 | type value struct { 21 | name string 22 | typeInfo TensorTypeAndShapeInfo 23 | cOrtValue *C.OrtValue 24 | } 25 | 26 | func newValue(name string, typeInfo TensorTypeAndShapeInfo, cOrtValue *C.OrtValue) Value { 27 | return &value{ 28 | name: name, 29 | typeInfo: typeInfo, 30 | cOrtValue: cOrtValue, 31 | } 32 | } 33 | 34 | func NewTensorWithFloatDataAsValue(memInfo MemoryInfo, name string, inData []float32, typeInfo TensorTypeAndShapeInfo) (Value, error) { 35 | actInLen := uintptr(len(inData)) * reflect.TypeOf(inData[0]).Size() 36 | inLen := C.size_t(actInLen) 37 | 38 | sMemInfo, ok := memInfo.(*memoryInfo) 39 | if !ok { 40 | return nil, fmt.Errorf("invalid memory info type") 41 | } 42 | 43 | sTypeInfo, ok := typeInfo.(*tensorTypeAndShapeInfo) 44 | if !ok { 45 | return nil, fmt.Errorf("invalid tensor type and shape info") 46 | } 47 | 48 | dims, err := sTypeInfo.GetDimensions() 49 | if err != nil { 50 | return nil, err 51 | } 52 | dimCount, err := sTypeInfo.cGetDimensionsCount() 53 | if err != nil { 54 | return nil, err 55 | } 56 | elemType, err := sTypeInfo.cGetElementType() 57 | if err != nil { 58 | return nil, err 59 | } 60 | 61 | response := C.createTensorWithDataAsOrtValue(ortApi.ort, sMemInfo.cMemoryInfo, unsafe.Pointer(&inData[0]), inLen, (*C.int64_t)(&dims[0]), dimCount, elemType) 62 | err = ortApi.ParseStatus(response.status) 63 | if err != nil { 64 | return nil, err 65 | } 66 | 67 | return &value{ 68 | name: name, 69 | typeInfo: typeInfo, 70 | cOrtValue: response.value, 71 | }, nil 72 | } 73 | 74 | func NewTensorWithDataAsValue(memInfo MemoryInfo, inData []byte, typeInfo TensorTypeAndShapeInfo) (Value, error) { 75 | inLen := C.size_t(len(inData)) 76 | 77 | sMemInfo, ok := memInfo.(*memoryInfo) 78 | if !ok { 79 | return nil, fmt.Errorf("invalid memory info type") 80 | } 81 | 82 | sTypeInfo, ok := typeInfo.(*tensorTypeAndShapeInfo) 83 | if !ok { 84 | return nil, fmt.Errorf("invalid tensor type and shape info") 85 | } 86 | 87 | dims, err := sTypeInfo.GetDimensions() 88 | if err != nil { 89 | return nil, err 90 | } 91 | dimCount, err := sTypeInfo.cGetDimensionsCount() 92 | if err != nil { 93 | return nil, err 94 | } 95 | elemType, err := sTypeInfo.cGetElementType() 96 | if err != nil { 97 | return nil, err 98 | } 99 | 100 | response := C.createTensorWithDataAsOrtValue(ortApi.ort, sMemInfo.cMemoryInfo, unsafe.Pointer(&inData[0]), inLen, (*C.int64_t)(&dims[0]), dimCount, elemType) 101 | err = ortApi.ParseStatus(response.status) 102 | if err != nil { 103 | return nil, err 104 | } 105 | 106 | return &value{ 107 | cOrtValue: response.value, 108 | }, nil 109 | } 110 | 111 | func (v *value) GetName() string { 112 | return v.name 113 | } 114 | 115 | func (v *value) GetTensorMutableFloatData() ([]float32, error) { 116 | response := C.getTensorMutableFloatData(ortApi.ort, v.cOrtValue) 117 | err := ortApi.ParseStatus(response.status) 118 | if err != nil { 119 | return nil, err 120 | } 121 | 122 | len, err := v.calcDataSize() 123 | if err != nil { 124 | return nil, err 125 | } 126 | 127 | var data []float32 128 | sliceHeader := (*reflect.SliceHeader)(unsafe.Pointer(&data)) 129 | sliceHeader.Cap = int(len) 130 | sliceHeader.Len = int(len) 131 | sliceHeader.Data = uintptr(unsafe.Pointer(response.out)) 132 | 133 | output := append([]float32(nil), data...) 134 | 135 | return output, nil 136 | } 137 | 138 | func (v *value) IsTensor() (bool, error) { 139 | response := C.isTensor(ortApi.ort, v.cOrtValue) 140 | err := ortApi.ParseStatus(response.status) 141 | if err != nil { 142 | return false, err 143 | } 144 | if response.isTensor == 1 { 145 | return true, nil 146 | } 147 | return false, nil 148 | } 149 | 150 | func (v *value) calcDataSize() (int64, error) { 151 | dims, err := v.typeInfo.GetDimensions() 152 | if err != nil { 153 | return -1, err 154 | } 155 | 156 | var total int64 = 1 157 | for _, dim := range dims { 158 | total = total * dim 159 | } 160 | return total, nil 161 | } 162 | -------------------------------------------------------------------------------- /ort/value.h: -------------------------------------------------------------------------------- 1 | #ifndef GOONNX_ORT_VALUE 2 | #define GOONNX_ORT_VALUE 3 | 4 | #include 5 | 6 | typedef struct OrtCreateTensorWithDataAsOrtValueResponse { 7 | OrtValue *value; 8 | OrtStatus *status; 9 | } OrtCreateTensorWithDataAsOrtValueResponse; 10 | 11 | typedef struct OrtIsTensorResponse { 12 | int isTensor; 13 | OrtStatus *status; 14 | } OrtIsTensorResponse; 15 | 16 | typedef struct OrtGetTensorMutableFloatDataResponse { 17 | float *out; 18 | OrtStatus *status; 19 | } OrtGetTensorMutableFloatDataResponse; 20 | 21 | OrtCreateTensorWithDataAsOrtValueResponse createTensorWithDataAsOrtValue(OrtApi *api, OrtMemoryInfo *memoryInfo, 22 | void *data, size_t dataLen, int64_t *shape, size_t shapeLen, ONNXTensorElementDataType type); 23 | 24 | OrtIsTensorResponse isTensor(OrtApi *api, OrtValue *value); 25 | 26 | OrtGetTensorMutableFloatDataResponse getTensorMutableFloatData(OrtApi *api, OrtValue *value); 27 | 28 | #endif --------------------------------------------------------------------------------