├── README.md ├── image_object_detect ├── .gitignore ├── README.md ├── car.png ├── go.mod ├── go.sum ├── image_object_detect.go └── yolov8n.onnx ├── mnist ├── .gitignore ├── README.md ├── eight.png ├── go.mod ├── go.sum ├── mnist.go ├── mnist.onnx ├── seven.png └── tiny_5.png ├── mnist_float16 ├── .gitignore ├── README.md ├── go.mod ├── go.sum ├── mnist_float16.go └── mnist_float16.onnx ├── non_tensor_outputs ├── .gitignore ├── generate_sklearn_network.py ├── go.mod ├── go.sum ├── non_tensor_outputs.go └── sklearn_randomforest.onnx ├── onnx_list_inputs_and_outputs ├── .gitignore ├── README.md ├── go.mod ├── go.sum └── onnx_list_inputs_and_outputs.go ├── sum_and_difference ├── .gitignore ├── README.md ├── example_network.onnx ├── generate_network.py ├── go.mod ├── go.sum ├── sum_and_difference.go └── sum_and_difference.onnx └── third_party ├── onnxruntime.dll ├── onnxruntime.so ├── onnxruntime_amd64.dylib ├── onnxruntime_arm64.dylib └── onnxruntime_arm64.so /README.md: -------------------------------------------------------------------------------- 1 | Example Applications for the `onnxruntime_go` Library 2 | ===================================================== 3 | 4 | This repository contains a collection of (mostly) simple standalone examples 5 | using the [`onnxruntime_go`](https://github.com/yalue/onnxruntime_go) library 6 | to run neural-network applications. 7 | 8 | 9 | Prerequisites 10 | ------------- 11 | 12 | You will need to be using a version of Go with cgo enabled---meaning that on 13 | Windows you'll need to have `gcc` available on your PATH. 14 | 15 | If you wish to use hardware acceleration such as CUDA, you'll need to have a 16 | compatible version of the `onnxruntime` library compiled with support for your 17 | platform of choice. CoreML should almost always be available on Apple hardware, 18 | but other supported acceleration frameworks (e.g., TensorRT or CUDA) may have 19 | additional prerequisites, which are documented in 20 | [the official onnxruntime documentation](https://onnxruntime.ai/docs/execution-providers/). 21 | Note that not all execution providers supported by `onnxruntime` itself are 22 | supported by `onnxruntime_go`. 23 | 24 | The `onnxruntime` shared libraries for some common platforms are included 25 | under the `third_party/` directory in this repository. 26 | 27 | 28 | Usage 29 | ----- 30 | 31 | Navigate to any one of the subdirectories, and run `go build` to produce an 32 | executable on your system. Many executables will provide a mechanism for 33 | specifying a path to an `onnxruntime` shared library file. For example: 34 | 35 | ```bash 36 | cd sum_and_difference 37 | go build 38 | 39 | # You can specify any version of the onnxruntime library here, but this would 40 | # be the correct library version on 64-bit AMD or Intel Linux systems. 41 | ./sum_and_difference --onnxruntime_lib ../third_party/onnxruntime.so 42 | ``` 43 | 44 | Be aware that different examples may use different mechanisms for locating the 45 | correct shared library version. 46 | 47 | 48 | List of Examples 49 | ---------------- 50 | 51 | - `sum_and_difference`: This is the simplest example, copied from a unit test 52 | in the `onnxruntime_go` library. It uses a basic neural network (trained 53 | using a pytorch script contained in the directory) on a tiny amount of 54 | hardcoded data. The source code is very heavily commented for reference. 55 | 56 | - `mnist`: This example runs a CNN trained to identify handwritten digits from 57 | the MNIST dataset. It processes a single input image, and outputs the digit 58 | it is most likely to contain. 59 | 60 | - `mnist_float16`: This example is identical to the plain `mnist` example, 61 | except it uses a 16-bit network, including 16-bit inputs and outputs. It is 62 | intended to illustrate how to use a float16 `CustomDataTensor`. 63 | 64 | - `onnx_list_inputs_and_outputs`: This example prints the inputs and outputs 65 | of a user-specified .onnx file to stdout. It is intended to illustrate the 66 | usage of the `onnxruntime_go.GetInputOutputInfo` function. 67 | 68 | - `image_object_detect`: This example uses the YOLOv8 network to detect a list 69 | of objects in an input image. It also attempts to use CoreML if the 70 | `USE_COREML` environment variable is set to `true`. 71 | 72 | - `non_tensor_outputs`: This example runs a network produced by the `sklearn` 73 | python library, which is notable for outputting ONNX `Map` and `Sequence` 74 | types. This example is meant to serve as a reference for how users may 75 | access `Map` and `Sequence` contents. 76 | 77 | Contributing and Opening New Issues 78 | ----------------------------------- 79 | 80 | PRs with new examples to this repository are welcome. Each example should be 81 | in its own subdirectory with its own go.mod file, and include only minimal 82 | dependencies (i.e., do not include several hundred megabytes of .onnx files or 83 | data). Each example should include a README, be formatted using `gofmt`, and 84 | contain ample comments to serve as a useful example to other users. 85 | 86 | Please limit open issues in this repository to bugs with existing examples. 87 | Issues are _not_ a place to request help with `onnxruntime` in general. Such 88 | issues will be ignored going forward. If you have not run your `.onnx` network 89 | using the `onnxruntime` library in python, this is not the place to get help 90 | with it. Learning to use `onnxruntime` in python is easier than in Go, and 91 | will give a point of reference that you understand the network you are trying to 92 | run, and that your inputs and outputs are correct. 93 | 94 | In short, this repository is intended to provide examples for using the 95 | `onnxruntime_go` wrapper in specific. Users are expected to already understand 96 | `.onnx` files and how to use `onnxruntime` in general. 97 | -------------------------------------------------------------------------------- /image_object_detect/.gitignore: -------------------------------------------------------------------------------- 1 | image_object_detect.exe 2 | image_object_detect 3 | 4 | -------------------------------------------------------------------------------- /image_object_detect/README.md: -------------------------------------------------------------------------------- 1 | Image Object Detection Using Yolo 2 | ================================= 3 | 4 | This example uses the included yolov8n.onnx network to detect images in an 5 | image. For now, the example is hardcoded to process the included car.png image. 6 | It performs the detection several times in order to compute timing statistics. 7 | 8 | 9 | CoreML can be enabled by setting the `USE_COREML` environment variable to 10 | `true`. (Though this will cause the program to fail on systems where CoreML is 11 | not supported.) 12 | 13 | Running with CoreML 14 | ------------------- 15 | ```bash 16 | $ go build . 17 | $ USE_COREML=true ./image_object_detect 18 | 19 | Object: car Confidence: 0.50 Coordinates: (392.156250, 286.328125), (692.111755, 655.371094) 20 | Object: car Confidence: 0.50 Coordinates: (392.156250, 286.328125), (692.111755, 655.371094) 21 | Object: car Confidence: 0.50 Coordinates: (392.156250, 286.328125), (692.111755, 655.371094) 22 | Object: car Confidence: 0.50 Coordinates: (392.156250, 286.328125), (692.111755, 655.371094) 23 | Object: car Confidence: 0.50 Coordinates: (392.156250, 286.328125), (692.111755, 655.371094) 24 | Min Time: 17.401875ms, Max Time: 21.7065ms, Avg Time: 19.258691ms, Count: 5 25 | 50th: 18.485666ms, 90th: 21.7065ms, 99th: 21.7065ms 26 | ``` 27 | 28 | Run on the CPU only, without CoreML 29 | ----------------------------------- 30 | ```bash 31 | $ go build . 32 | $ ./image_object_detect 33 | 34 | Object: car Confidence: 0.50 Coordinates: (392.655396, 285.742920), (691.901306, 656.455566) 35 | Object: car Confidence: 0.50 Coordinates: (392.655396, 285.742920), (691.901306, 656.455566) 36 | Object: car Confidence: 0.50 Coordinates: (392.655396, 285.742920), (691.901306, 656.455566) 37 | Object: car Confidence: 0.50 Coordinates: (392.655396, 285.742920), (691.901306, 656.455566) 38 | Object: car Confidence: 0.50 Coordinates: (392.655396, 285.742920), (691.901306, 656.455566) 39 | Min Time: 41.5205ms, Max Time: 58.348084ms, Avg Time: 46.154341ms, Count: 5 40 | 50th: 43.471958ms, 90th: 58.348084ms, 99th: 58.348084ms 41 | ``` 42 | (Note the slower execution times.) 43 | -------------------------------------------------------------------------------- /image_object_detect/car.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yalue/onnxruntime_go_examples/af5620ad77f175d73eb0e505aecce27e8e867642/image_object_detect/car.png -------------------------------------------------------------------------------- /image_object_detect/go.mod: -------------------------------------------------------------------------------- 1 | module github.com/yalue/onnxruntime_go_examples/image_object_detect 2 | 3 | go 1.21.0 4 | 5 | toolchain go1.21.4 6 | 7 | require ( 8 | github.com/8ff/prettyTimer v0.0.0-20230830184900-c96793faf613 9 | github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 10 | github.com/yalue/onnxruntime_go v1.13.0 11 | ) 12 | -------------------------------------------------------------------------------- /image_object_detect/go.sum: -------------------------------------------------------------------------------- 1 | github.com/8ff/prettyTimer v0.0.0-20230830184900-c96793faf613 h1:mIPSzE+OciNlYwNQs1qi7GoKRI3SKGKrVsGnap20iqQ= 2 | github.com/8ff/prettyTimer v0.0.0-20230830184900-c96793faf613/go.mod h1:iQAVuoCXBrrxT875kd25GCALLf+ulTOt/mCikuQs2j8= 3 | github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 h1:zYyBkD/k9seD2A7fsi6Oo2LfFZAehjjQMERAvZLEDnQ= 4 | github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8= 5 | github.com/yalue/onnxruntime_go v1.13.0 h1:5HDXHon3EukQMyYA7yPMed/raWaDE/gjwLOwnVoiwy8= 6 | github.com/yalue/onnxruntime_go v1.13.0/go.mod h1:b4X26A8pekNb1ACJ58wAXgNKeUCGEAQ9dmACut9Sm/4= 7 | -------------------------------------------------------------------------------- /image_object_detect/image_object_detect.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "image" 6 | _ "image/gif" 7 | _ "image/jpeg" 8 | _ "image/png" 9 | "os" 10 | "runtime" 11 | "sort" 12 | 13 | "github.com/8ff/prettyTimer" 14 | "github.com/nfnt/resize" 15 | ort "github.com/yalue/onnxruntime_go" 16 | ) 17 | 18 | var modelPath = "./yolov8n.onnx" 19 | var imagePath = "./car.png" 20 | var useCoreML = false 21 | 22 | type ModelSession struct { 23 | Session *ort.AdvancedSession 24 | Input *ort.Tensor[float32] 25 | Output *ort.Tensor[float32] 26 | } 27 | 28 | func main() { 29 | os.Exit(run()) 30 | } 31 | 32 | func run() int { 33 | timingStats := prettyTimer.NewTimingStats() 34 | 35 | if os.Getenv("USE_COREML") == "true" { 36 | useCoreML = true 37 | } 38 | 39 | // Read the input image into a image.Image object 40 | pic, e := loadImageFile(imagePath) 41 | if e != nil { 42 | fmt.Printf("Error loading input image: %s\n", e) 43 | return 1 44 | } 45 | originalWidth := pic.Bounds().Canon().Dx() 46 | originalHeight := pic.Bounds().Canon().Dy() 47 | 48 | modelSession, e := initSession() 49 | if e != nil { 50 | fmt.Printf("Error creating session and tensors: %s\n", e) 51 | return 1 52 | } 53 | defer modelSession.Destroy() 54 | 55 | // Run the detection 5 times 56 | for i := 0; i < 5; i++ { 57 | e = prepareInput(pic, modelSession.Input) 58 | if e != nil { 59 | fmt.Printf("Error converting image to network input: %s\n", e) 60 | return 1 61 | } 62 | 63 | timingStats.Start() 64 | e = modelSession.Session.Run() 65 | if e != nil { 66 | fmt.Printf("Error running ORT session: %s\n", e) 67 | return 1 68 | } 69 | timingStats.Finish() 70 | 71 | // Print the results 72 | boxes := processOutput(modelSession.Output.GetData(), originalWidth, 73 | originalHeight) 74 | for i, box := range boxes { 75 | fmt.Printf("Box %d: %s\n", i, &box) 76 | } 77 | } 78 | timingStats.PrintStats() 79 | return 0 80 | } 81 | 82 | func loadImageFile(filePath string) (image.Image, error) { 83 | f, e := os.Open(filePath) 84 | if e != nil { 85 | return nil, fmt.Errorf("Error opening %s: %w", filePath, e) 86 | } 87 | defer f.Close() 88 | pic, _, e := image.Decode(f) 89 | if e != nil { 90 | return nil, fmt.Errorf("Error decoding %s: %w", filePath, e) 91 | } 92 | return pic, nil 93 | } 94 | 95 | // Populates a yolov8n input tensor with the contents of the given image. 96 | func prepareInput(pic image.Image, dst *ort.Tensor[float32]) error { 97 | data := dst.GetData() 98 | channelSize := 640 * 640 99 | if len(data) < (channelSize * 3) { 100 | return fmt.Errorf("Destination tensor only holds %d floats, needs "+ 101 | "%d (make sure it's the right shape!)", len(data), channelSize*3) 102 | } 103 | redChannel := data[0:channelSize] 104 | greenChannel := data[channelSize : channelSize*2] 105 | blueChannel := data[channelSize*2 : channelSize*3] 106 | 107 | // Resize the image to 640x640 using Lanczos3 algorithm 108 | pic = resize.Resize(640, 640, pic, resize.Lanczos3) 109 | i := 0 110 | for y := 0; y < 640; y++ { 111 | for x := 0; x < 640; x++ { 112 | r, g, b, _ := pic.At(x, y).RGBA() 113 | redChannel[i] = float32(r>>8) / 255.0 114 | greenChannel[i] = float32(g>>8) / 255.0 115 | blueChannel[i] = float32(b>>8) / 255.0 116 | i++ 117 | } 118 | } 119 | 120 | return nil 121 | } 122 | 123 | func getSharedLibPath() string { 124 | if runtime.GOOS == "windows" { 125 | if runtime.GOARCH == "amd64" { 126 | return "../third_party/onnxruntime.dll" 127 | } 128 | } 129 | if runtime.GOOS == "darwin" { 130 | if runtime.GOARCH == "arm64" { 131 | return "../third_party/onnxruntime_arm64.dylib" 132 | } 133 | if runtime.GOARCH == "amd64" { 134 | return "../third_party/onnxruntime_amd64.dylib" 135 | } 136 | 137 | } 138 | if runtime.GOOS == "linux" { 139 | if runtime.GOARCH == "arm64" { 140 | return "../third_party/onnxruntime_arm64.so" 141 | } 142 | return "../third_party/onnxruntime.so" 143 | } 144 | panic("Unable to find a version of the onnxruntime library supporting this system.") 145 | } 146 | 147 | func initSession() (*ModelSession, error) { 148 | ort.SetSharedLibraryPath(getSharedLibPath()) 149 | err := ort.InitializeEnvironment() 150 | if err != nil { 151 | return nil, fmt.Errorf("Error initializing ORT environment: %w", err) 152 | } 153 | 154 | inputShape := ort.NewShape(1, 3, 640, 640) 155 | inputTensor, err := ort.NewEmptyTensor[float32](inputShape) 156 | if err != nil { 157 | return nil, fmt.Errorf("Error creating input tensor: %w", err) 158 | } 159 | outputShape := ort.NewShape(1, 84, 8400) 160 | outputTensor, err := ort.NewEmptyTensor[float32](outputShape) 161 | if err != nil { 162 | inputTensor.Destroy() 163 | return nil, fmt.Errorf("Error creating output tensor: %w", err) 164 | } 165 | options, err := ort.NewSessionOptions() 166 | if err != nil { 167 | inputTensor.Destroy() 168 | outputTensor.Destroy() 169 | return nil, fmt.Errorf("Error creating ORT session options: %w", err) 170 | } 171 | defer options.Destroy() 172 | 173 | // If CoreML is enabled, append the CoreML execution provider 174 | if useCoreML { 175 | err = options.AppendExecutionProviderCoreML(0) 176 | if err != nil { 177 | inputTensor.Destroy() 178 | outputTensor.Destroy() 179 | return nil, fmt.Errorf("Error enabling CoreML: %w", err) 180 | } 181 | } 182 | 183 | session, err := ort.NewAdvancedSession(modelPath, 184 | []string{"images"}, []string{"output0"}, 185 | []ort.ArbitraryTensor{inputTensor}, 186 | []ort.ArbitraryTensor{outputTensor}, 187 | options) 188 | if err != nil { 189 | inputTensor.Destroy() 190 | outputTensor.Destroy() 191 | return nil, fmt.Errorf("Error creating ORT session: %w", err) 192 | } 193 | 194 | return &ModelSession{ 195 | Session: session, 196 | Input: inputTensor, 197 | Output: outputTensor, 198 | }, nil 199 | } 200 | 201 | func (m *ModelSession) Destroy() { 202 | m.Session.Destroy() 203 | m.Input.Destroy() 204 | m.Output.Destroy() 205 | } 206 | 207 | type boundingBox struct { 208 | label string 209 | confidence float32 210 | x1, y1, x2, y2 float32 211 | } 212 | 213 | func (b *boundingBox) String() string { 214 | return fmt.Sprintf("Object %s (confidence %f): (%f, %f), (%f, %f)", 215 | b.label, b.confidence, b.x1, b.y1, b.x2, b.y2) 216 | } 217 | 218 | // This loses precision, but recall that the boundingBox has already been 219 | // scaled up to the original image's dimensions. So, it will only lose 220 | // fractional pixels around the edges. 221 | func (b *boundingBox) toRect() image.Rectangle { 222 | return image.Rect(int(b.x1), int(b.y1), int(b.x2), int(b.y2)).Canon() 223 | } 224 | 225 | // Returns the area of b in pixels, after converting to an image.Rectangle. 226 | func (b *boundingBox) rectArea() int { 227 | size := b.toRect().Size() 228 | return size.X * size.Y 229 | } 230 | 231 | func (b *boundingBox) intersection(other *boundingBox) float32 { 232 | r1 := b.toRect() 233 | r2 := other.toRect() 234 | intersected := r1.Intersect(r2).Canon().Size() 235 | return float32(intersected.X * intersected.Y) 236 | } 237 | 238 | func (b *boundingBox) union(other *boundingBox) float32 { 239 | intersectArea := b.intersection(other) 240 | totalArea := float32(b.rectArea() + other.rectArea()) 241 | return totalArea - intersectArea 242 | } 243 | 244 | // This won't be entirely precise due to conversion to the integral rectangles 245 | // from the image.Image library, but we're only using it to estimate which 246 | // boxes are overlapping too much, so some imprecision should be OK. 247 | func (b *boundingBox) iou(other *boundingBox) float32 { 248 | return b.intersection(other) / b.union(other) 249 | } 250 | 251 | func processOutput(output []float32, originalWidth, 252 | originalHeight int) []boundingBox { 253 | boundingBoxes := make([]boundingBox, 0, 8400) 254 | 255 | var classID int 256 | var probability float32 257 | 258 | // Iterate through the output array, considering 8400 indices 259 | for idx := 0; idx < 8400; idx++ { 260 | // Iterate through 80 classes and find the class with the highest probability 261 | probability = -1e9 262 | for col := 0; col < 80; col++ { 263 | currentProb := output[8400*(col+4)+idx] 264 | if currentProb > probability { 265 | probability = currentProb 266 | classID = col 267 | } 268 | } 269 | 270 | // If the probability is less than 0.5, continue to the next index 271 | if probability < 0.5 { 272 | continue 273 | } 274 | 275 | // Extract the coordinates and dimensions of the bounding box 276 | xc, yc := output[idx], output[8400+idx] 277 | w, h := output[2*8400+idx], output[3*8400+idx] 278 | x1 := (xc - w/2) / 640 * float32(originalWidth) 279 | y1 := (yc - h/2) / 640 * float32(originalHeight) 280 | x2 := (xc + w/2) / 640 * float32(originalWidth) 281 | y2 := (yc + h/2) / 640 * float32(originalHeight) 282 | 283 | // Append the bounding box to the result 284 | boundingBoxes = append(boundingBoxes, boundingBox{ 285 | label: yoloClasses[classID], 286 | confidence: probability, 287 | x1: x1, 288 | y1: y1, 289 | x2: x2, 290 | y2: y2, 291 | }) 292 | } 293 | 294 | // Sort the bounding boxes by probability 295 | sort.Slice(boundingBoxes, func(i, j int) bool { 296 | return boundingBoxes[i].confidence < boundingBoxes[j].confidence 297 | }) 298 | 299 | // Define a slice to hold the final result 300 | mergedResults := make([]boundingBox, 0, len(boundingBoxes)) 301 | 302 | // Iterate through sorted bounding boxes, removing overlaps 303 | for _, candidateBox := range boundingBoxes { 304 | overlapsExistingBox := false 305 | for _, existingBox := range mergedResults { 306 | if (&candidateBox).iou(&existingBox) > 0.7 { 307 | overlapsExistingBox = true 308 | break 309 | } 310 | } 311 | if !overlapsExistingBox { 312 | mergedResults = append(mergedResults, candidateBox) 313 | } 314 | } 315 | 316 | // This will still be in sorted order by confidence 317 | return mergedResults 318 | } 319 | 320 | // Array of YOLOv8 class labels 321 | var yoloClasses = []string{ 322 | "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", 323 | "traffic light", "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", 324 | "sheep", "cow", "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", 325 | "suitcase", "frisbee", "skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove", 326 | "skateboard", "surfboard", "tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon", 327 | "bowl", "banana", "apple", "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", 328 | "cake", "chair", "couch", "potted plant", "bed", "dining table", "toilet", "tv", "laptop", "mouse", 329 | "remote", "keyboard", "cell phone", "microwave", "oven", "toaster", "sink", "refrigerator", "book", 330 | "clock", "vase", "scissors", "teddy bear", "hair drier", "toothbrush", 331 | } 332 | -------------------------------------------------------------------------------- /image_object_detect/yolov8n.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yalue/onnxruntime_go_examples/af5620ad77f175d73eb0e505aecce27e8e867642/image_object_detect/yolov8n.onnx -------------------------------------------------------------------------------- /mnist/.gitignore: -------------------------------------------------------------------------------- 1 | mnist.exe 2 | mnist 3 | postprocessed_input_image.png 4 | 5 | -------------------------------------------------------------------------------- /mnist/README.md: -------------------------------------------------------------------------------- 1 | `onnxruntime_go`: MNIST Example 2 | =============================== 3 | 4 | This example makes use of the pre-trained MNIST network, obtained from the 5 | [official ONNX models repository](https://github.com/onnx/models/tree/ddbbd1274c8387e3745778705810c340dea3d8c7/validated/vision/classification/mnist). 6 | Specifically, the included `mnist.onnx` is MNIST-12 from the above link. 7 | 8 | This example uses the network to analyze single image files specified on the 9 | command line. 10 | 11 | Example Usage 12 | ------------- 13 | 14 | Run the program with `-help` to see all command-line flags. In general, you 15 | will need to supply it with an input image. 16 | 17 | ```bash 18 | ./mnist -image_path ./eight.png 19 | ./mnist -image_path ./tiny_5.png 20 | 21 | # There's an additional flag if you want to invert the image colors. The 22 | # network is trained on images with black backgrounds, so you may want to 23 | # invert images with white backgrounds. 24 | ./mnist -image_path ./seven.png -invert_image 25 | ``` 26 | 27 | Note that the program will also create `postprocessed_input_image.png` in the 28 | current directory, showing the image that was passed to the neural network 29 | after resizing and converting to grayscale. 30 | 31 | -------------------------------------------------------------------------------- /mnist/eight.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yalue/onnxruntime_go_examples/af5620ad77f175d73eb0e505aecce27e8e867642/mnist/eight.png -------------------------------------------------------------------------------- /mnist/go.mod: -------------------------------------------------------------------------------- 1 | module github.com/yalue/onnxruntime_go_examples/mnist 2 | 3 | go 1.20 4 | 5 | require github.com/yalue/onnxruntime_go v1.13.0 6 | -------------------------------------------------------------------------------- /mnist/go.sum: -------------------------------------------------------------------------------- 1 | github.com/yalue/onnxruntime_go v1.13.0 h1:5HDXHon3EukQMyYA7yPMed/raWaDE/gjwLOwnVoiwy8= 2 | github.com/yalue/onnxruntime_go v1.13.0/go.mod h1:b4X26A8pekNb1ACJ58wAXgNKeUCGEAQ9dmACut9Sm/4= 3 | -------------------------------------------------------------------------------- /mnist/mnist.go: -------------------------------------------------------------------------------- 1 | // This is a command-line application that uses an onnx network to attempt to 2 | // classify handwritten digits trained on the MNIST dataset. 3 | // 4 | // This program shares a fair amount of boilerplate with the simpler 5 | // sum_and_difference example, which includes far more comments and may be an 6 | // easier starting point for someone entirely new to the onnxruntime_go 7 | // library. 8 | package main 9 | 10 | import ( 11 | "flag" 12 | "fmt" 13 | ort "github.com/yalue/onnxruntime_go" 14 | "image" 15 | "image/color" 16 | _ "image/gif" 17 | _ "image/jpeg" 18 | "image/png" 19 | "os" 20 | "runtime" 21 | ) 22 | 23 | // For more comments, see the sum_and_difference example. 24 | func getDefaultSharedLibPath() string { 25 | if runtime.GOOS == "windows" { 26 | if runtime.GOARCH == "amd64" { 27 | return "../third_party/onnxruntime.dll" 28 | } 29 | } 30 | if runtime.GOOS == "darwin" { 31 | if runtime.GOARCH == "arm64" { 32 | return "../third_party/onnxruntime_arm64.dylib" 33 | } 34 | if runtime.GOARCH == "amd64" { 35 | return "../third_party/onnxruntime_amd64.dylib" 36 | } 37 | } 38 | if runtime.GOOS == "linux" { 39 | if runtime.GOARCH == "arm64" { 40 | return "../third_party/onnxruntime_arm64.so" 41 | } 42 | return "../third_party/onnxruntime.so" 43 | } 44 | fmt.Printf("Unable to determine a path to the onnxruntime shared library"+ 45 | " for OS \"%s\" and architecture \"%s\".\n", runtime.GOOS, 46 | runtime.GOARCH) 47 | return "" 48 | } 49 | 50 | // Implements the color interface 51 | type grayscaleFloat float32 52 | 53 | func (f grayscaleFloat) RGBA() (r, g, b, a uint32) { 54 | a = 0xffff 55 | v := uint32(f * 0xffff) 56 | if v > 0xffff { 57 | v = 0xffff 58 | } 59 | r = v 60 | g = v 61 | b = v 62 | return 63 | } 64 | 65 | // Used to satisfy the image interface as well as to help with formatting and 66 | // resizing an input image into the format expected as a network input. 67 | type ProcessedImage struct { 68 | // The number of "pixels" in the input image corresponding to a single 69 | // pixel in the 28x28 output image. 70 | dx, dy float32 71 | 72 | // The input image being transformed 73 | pic image.Image 74 | 75 | // If true, the grayscale values in the postprocessed image will be 76 | // inverted, so that dark colors in the original become light, and vice 77 | // versa. Recall that the network expects black backgrounds, so this should 78 | // be set to true for images with light backgrounds. 79 | Invert bool 80 | } 81 | 82 | func (p *ProcessedImage) ColorModel() color.Model { 83 | return color.Gray16Model 84 | } 85 | 86 | func (p *ProcessedImage) Bounds() image.Rectangle { 87 | return image.Rect(0, 0, 28, 28) 88 | } 89 | 90 | // Returns an average grayscale value using the pixels in the input image. 91 | func (p *ProcessedImage) At(x, y int) color.Color { 92 | if (x < 0) || (x >= 28) || (y < 0) || (y >= 28) { 93 | return color.Black 94 | } 95 | 96 | // Compute the window of pixels in the input image we'll be averaging. 97 | startX := int(float32(x) * p.dx) 98 | endX := int(float32(x+1) * p.dx) 99 | if endX == startX { 100 | endX = startX + 1 101 | } 102 | startY := int(float32(y) * p.dy) 103 | endY := int(float32(y+1) * p.dy) 104 | if endY == startY { 105 | endY = startY + 1 106 | } 107 | 108 | // Compute the average brightness over the window of pixels 109 | var sum float32 110 | var nPix int 111 | for row := startY; row < endY; row++ { 112 | for col := startX; col < endX; col++ { 113 | c := p.pic.At(col, row) 114 | grayValue := color.Gray16Model.Convert(c).(color.Gray16).Y 115 | sum += float32(grayValue) / 0xffff 116 | nPix++ 117 | } 118 | } 119 | 120 | brightness := grayscaleFloat(sum / float32(nPix)) 121 | if p.Invert { 122 | brightness = 1.0 - brightness 123 | } 124 | return brightness 125 | } 126 | 127 | // Returns a slice of data that can be used as the input to the onnx network. 128 | func (p *ProcessedImage) GetNetworkInput() []float32 { 129 | toReturn := make([]float32, 0, 28*28) 130 | for row := 0; row < 28; row++ { 131 | for col := 0; col < 28; col++ { 132 | c := float32(p.At(col, row).(grayscaleFloat)) 133 | toReturn = append(toReturn, c) 134 | } 135 | } 136 | return toReturn 137 | } 138 | 139 | // Takes a path to an image file, loads the image, and returns a ProcessedImage 140 | // struct which can be used to obtain the neural network input. 141 | func NewProcessedImage(path string, invertBrightness bool) (*ProcessedImage, 142 | error) { 143 | f, e := os.Open(path) 144 | if e != nil { 145 | return nil, fmt.Errorf("Error opening %s: %w", path, e) 146 | } 147 | defer f.Close() 148 | originalPic, _, e := image.Decode(f) 149 | if e != nil { 150 | return nil, fmt.Errorf("Error decoding image %s: %w", path, e) 151 | } 152 | bounds := originalPic.Bounds().Canon() 153 | if (bounds.Min.X != 0) || (bounds.Min.Y != 0) { 154 | // Should never happen with the standard library. 155 | return nil, fmt.Errorf("Bounding rect of %s doesn't start at 0, 0", 156 | path) 157 | } 158 | return &ProcessedImage{ 159 | dx: float32(bounds.Dx()) / 28.0, 160 | dy: float32(bounds.Dy()) / 28.0, 161 | pic: originalPic, 162 | Invert: invertBrightness, 163 | }, nil 164 | } 165 | 166 | // Attempts to save the given image as a png. 167 | func saveImage(pic image.Image, path string) error { 168 | f, e := os.Create(path) 169 | if e != nil { 170 | return fmt.Errorf("Error creating %s: %w", path, e) 171 | } 172 | defer f.Close() 173 | e = png.Encode(f, pic) 174 | if e != nil { 175 | return fmt.Errorf("Error encoding PNG image to %s: %w", path, e) 176 | } 177 | return nil 178 | } 179 | 180 | // Takes a path to the onnxruntime shared library as well as the image file 181 | // containing a digit to be classified. The image file will be processed into 182 | // the format expected by the .onnx network. 183 | // 184 | // If the network runs successfully, this will print the classification results 185 | // to stdout. 186 | func classifyDigit(onnxruntimeLibPath, imagePath string, 187 | invertBrightness bool) error { 188 | ort.SetSharedLibraryPath(onnxruntimeLibPath) 189 | e := ort.InitializeEnvironment() 190 | if e != nil { 191 | return fmt.Errorf("Error initializing the onnxruntime library: %w", e) 192 | } 193 | defer ort.DestroyEnvironment() 194 | 195 | // Load the input image and save the postprocessed version for a visual 196 | // inspection. 197 | inputImage, e := NewProcessedImage(imagePath, invertBrightness) 198 | if e != nil { 199 | return fmt.Errorf("Error loading input image: %w", e) 200 | } 201 | postprocessedPath := "./postprocessed_input_image.png" 202 | e = saveImage(inputImage, postprocessedPath) 203 | if e != nil { 204 | fmt.Printf("Error saving postprocessed input: %s. Continuing.\n", e) 205 | } else { 206 | fmt.Printf("Saved postprocessed input image to %s.\n", 207 | postprocessedPath) 208 | } 209 | 210 | // Create and populate the input tensor 211 | inputShape := ort.NewShape(1, 1, 28, 28) 212 | inputData := inputImage.GetNetworkInput() 213 | input, e := ort.NewTensor(inputShape, inputData) 214 | if e != nil { 215 | return fmt.Errorf("Error creating input tensor: %w", e) 216 | } 217 | defer input.Destroy() 218 | 219 | // Create the output tensor 220 | output, e := ort.NewEmptyTensor[float32](ort.NewShape(1, 10)) 221 | if e != nil { 222 | return fmt.Errorf("Error creating output tensor: %w", e) 223 | } 224 | defer output.Destroy() 225 | 226 | // The input and output names are required by this network; they can be 227 | // found on the MNIST ONNX models page linked in the README. 228 | session, e := ort.NewAdvancedSession("./mnist.onnx", 229 | []string{"Input3"}, []string{"Plus214_Output_0"}, 230 | []ort.ArbitraryTensor{input}, []ort.ArbitraryTensor{output}, nil) 231 | if e != nil { 232 | return fmt.Errorf("Error creating MNIST network session: %w", e) 233 | } 234 | defer session.Destroy() 235 | 236 | // Run the network and print the results. 237 | e = session.Run() 238 | if e != nil { 239 | return fmt.Errorf("Error running the MNIST network: %w", e) 240 | } 241 | 242 | fmt.Printf("Output probabilities:\n") 243 | outputData := output.GetData() 244 | maxIndex := 0 245 | maxProbability := float32(-1.0e9) 246 | for i, v := range outputData { 247 | fmt.Printf(" %d: %f\n", i, v) 248 | if v > maxProbability { 249 | maxProbability = v 250 | maxIndex = i 251 | } 252 | } 253 | fmt.Printf("%s is probably a %d, with probability %f\n", imagePath, 254 | maxIndex, maxProbability) 255 | 256 | return nil 257 | } 258 | 259 | func run() int { 260 | var onnxruntimeLibPath string 261 | var imagePath string 262 | var invertImage bool 263 | flag.StringVar(&onnxruntimeLibPath, "onnxruntime_lib", 264 | getDefaultSharedLibPath(), 265 | "The path to the onnxruntime shared library for your system.") 266 | flag.StringVar(&imagePath, "image_path", "", 267 | "The image containing a digit to classify.") 268 | flag.BoolVar(&invertImage, "invert_image", false, 269 | "If set, the image's colors will be inverted before processing. "+ 270 | "The network expects inputs with dark backgrounds, so you should "+ 271 | "set this to true for images with light backgrounds.") 272 | flag.Parse() 273 | if onnxruntimeLibPath == "" { 274 | fmt.Println("You must specify a path to the onnxruntime shared " + 275 | "on your system. Run with -help for more information.") 276 | return 1 277 | } 278 | if imagePath == "" { 279 | fmt.Println("You must specify an input image. Run with -help for " + 280 | "more information.") 281 | return 1 282 | } 283 | e := classifyDigit(onnxruntimeLibPath, imagePath, invertImage) 284 | if e != nil { 285 | fmt.Printf("Error running network: %s\n", e) 286 | return 1 287 | } 288 | fmt.Printf("Everything seemed to run OK!\n") 289 | return 0 290 | } 291 | 292 | func main() { 293 | os.Exit(run()) 294 | } 295 | -------------------------------------------------------------------------------- /mnist/mnist.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yalue/onnxruntime_go_examples/af5620ad77f175d73eb0e505aecce27e8e867642/mnist/mnist.onnx -------------------------------------------------------------------------------- /mnist/seven.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yalue/onnxruntime_go_examples/af5620ad77f175d73eb0e505aecce27e8e867642/mnist/seven.png -------------------------------------------------------------------------------- /mnist/tiny_5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yalue/onnxruntime_go_examples/af5620ad77f175d73eb0e505aecce27e8e867642/mnist/tiny_5.png -------------------------------------------------------------------------------- /mnist_float16/.gitignore: -------------------------------------------------------------------------------- 1 | mnist_float16.exe 2 | mnist_float16 3 | postprocessed_input_image.png 4 | 5 | -------------------------------------------------------------------------------- /mnist_float16/README.md: -------------------------------------------------------------------------------- 1 | `onnxruntime_go`: Float16 MNIST Example 2 | ======================================= 3 | 4 | This example is nearly identical to the plain `mnist` example from this 5 | repository, but uses a model that has been converted to use 16-bit floats. This 6 | example is intended to illustrate how to convert inputs to 16-bit floating 7 | point values using the `github.com/x448/float16` package and the 8 | `CustomDataTensor` type from `onnxruntime_go`. 9 | 10 | The code has been mostly copied and pasted from the `../mnist` example. It 11 | differs only in a few places: 12 | 13 | - The `ProcessedImage.GetNetworkInput` function now converts each input pixel 14 | from a float32 grayscale value to a float16, and writes the float16 data 15 | into a slice of bytes. 16 | 17 | - The `input` and `output` tensors created in the `classifyDigit` function are 18 | now `CustomDataTensor`s, backed by slices of bytes. 19 | 20 | - The `convertFloat16Data` function has been added to convert the output 21 | tensor's bytes from `float16.Float16` data to a slice of `float32`s. 22 | 23 | The included `mnist_float16.onnx` network was created by using the 24 | `onnxconverter-common` python package on the `../mnist/mnist.onnx` network, 25 | using the process described on 26 | [this page](https://onnxruntime.ai/docs/performance/model-optimizations/float16.html). 27 | 28 | Example Usage 29 | ------------- 30 | 31 | This program is used in the exact same way as `../mnist`. Build it using 32 | `go build`, and run it with `-help` to see all command-line flags. It loads 33 | `mnist_float16.onnx` from the current directory. 34 | 35 | For example, 36 | ```bash 37 | go build . 38 | ./mnist_float16 -image_path ../mnist/eight.png 39 | ``` 40 | 41 | Will produce the following output: 42 | ``` 43 | Saved postprocessed input image to ./postprocessed_input_image.png. 44 | 0: 1.350586 45 | 1: 1.148438 46 | 2: 2.232422 47 | 3: 0.827148 48 | 4: -3.474609 49 | 5: 1.199219 50 | 6: -1.187500 51 | 7: -5.960938 52 | 8: 4.765625 53 | 9: -2.345703 54 | ../mnist/eight.png is probably a 8, with probability 4.765625 55 | Everything seemed to run OK! 56 | ``` 57 | 58 | -------------------------------------------------------------------------------- /mnist_float16/go.mod: -------------------------------------------------------------------------------- 1 | module github.com/yalue/onnxruntime_go_examples/mnist_float16 2 | 3 | go 1.20 4 | 5 | require ( 6 | github.com/x448/float16 v0.8.4 7 | github.com/yalue/onnxruntime_go v1.13.0 8 | ) 9 | -------------------------------------------------------------------------------- /mnist_float16/go.sum: -------------------------------------------------------------------------------- 1 | github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM= 2 | github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg= 3 | github.com/yalue/onnxruntime_go v1.13.0 h1:5HDXHon3EukQMyYA7yPMed/raWaDE/gjwLOwnVoiwy8= 4 | github.com/yalue/onnxruntime_go v1.13.0/go.mod h1:b4X26A8pekNb1ACJ58wAXgNKeUCGEAQ9dmACut9Sm/4= 5 | -------------------------------------------------------------------------------- /mnist_float16/mnist_float16.go: -------------------------------------------------------------------------------- 1 | // This is a command-line application that should behave identically to the 2 | // plain "mnist" example, but using float16 types. A large amount of this 3 | // program was simply copied from the mnist example, but modified to use 4 | // 16-bit floats with the help of github.com/x448/float16. 5 | package main 6 | 7 | import ( 8 | "encoding/binary" 9 | "flag" 10 | "fmt" 11 | "github.com/x448/float16" 12 | ort "github.com/yalue/onnxruntime_go" 13 | "image" 14 | "image/color" 15 | _ "image/gif" 16 | _ "image/jpeg" 17 | "image/png" 18 | "os" 19 | "runtime" 20 | ) 21 | 22 | // For more comments, see the sum_and_difference example. 23 | func getDefaultSharedLibPath() string { 24 | if runtime.GOOS == "windows" { 25 | if runtime.GOARCH == "amd64" { 26 | return "../third_party/onnxruntime.dll" 27 | } 28 | } 29 | if runtime.GOOS == "darwin" { 30 | if runtime.GOARCH == "arm64" { 31 | return "../third_party/onnxruntime_arm64.dylib" 32 | } 33 | if runtime.GOARCH == "amd64" { 34 | return "../third_party/onnxruntime_amd64.dylib" 35 | } 36 | } 37 | if runtime.GOOS == "linux" { 38 | if runtime.GOARCH == "arm64" { 39 | return "../third_party/onnxruntime_arm64.so" 40 | } 41 | return "../third_party/onnxruntime.so" 42 | } 43 | fmt.Printf("Unable to determine a path to the onnxruntime shared library"+ 44 | " for OS \"%s\" and architecture \"%s\".\n", runtime.GOOS, 45 | runtime.GOARCH) 46 | return "" 47 | } 48 | 49 | // Implements the color interface 50 | type grayscaleFloat float32 51 | 52 | func (f grayscaleFloat) RGBA() (r, g, b, a uint32) { 53 | a = 0xffff 54 | v := uint32(f * 0xffff) 55 | if v > 0xffff { 56 | v = 0xffff 57 | } 58 | r = v 59 | g = v 60 | b = v 61 | return 62 | } 63 | 64 | // Used to satisfy the image interface as well as to help with formatting and 65 | // resizing an input image into the format expected as a network input. 66 | type ProcessedImage struct { 67 | // The number of "pixels" in the input image corresponding to a single 68 | // pixel in the 28x28 output image. 69 | dx, dy float32 70 | 71 | // The input image being transformed 72 | pic image.Image 73 | 74 | // If true, the grayscale values in the postprocessed image will be 75 | // inverted, so that dark colors in the original become light, and vice 76 | // versa. Recall that the network expects black backgrounds, so this should 77 | // be set to true for images with light backgrounds. 78 | Invert bool 79 | } 80 | 81 | func (p *ProcessedImage) ColorModel() color.Model { 82 | return color.Gray16Model 83 | } 84 | 85 | func (p *ProcessedImage) Bounds() image.Rectangle { 86 | return image.Rect(0, 0, 28, 28) 87 | } 88 | 89 | // Returns an average grayscale value using the pixels in the input image. 90 | func (p *ProcessedImage) At(x, y int) color.Color { 91 | if (x < 0) || (x >= 28) || (y < 0) || (y >= 28) { 92 | return color.Black 93 | } 94 | 95 | // Compute the window of pixels in the input image we'll be averaging. 96 | startX := int(float32(x) * p.dx) 97 | endX := int(float32(x+1) * p.dx) 98 | if endX == startX { 99 | endX = startX + 1 100 | } 101 | startY := int(float32(y) * p.dy) 102 | endY := int(float32(y+1) * p.dy) 103 | if endY == startY { 104 | endY = startY + 1 105 | } 106 | 107 | // Compute the average brightness over the window of pixels 108 | var sum float32 109 | var nPix int 110 | for row := startY; row < endY; row++ { 111 | for col := startX; col < endX; col++ { 112 | c := p.pic.At(col, row) 113 | grayValue := color.Gray16Model.Convert(c).(color.Gray16).Y 114 | sum += float32(grayValue) / 0xffff 115 | nPix++ 116 | } 117 | } 118 | 119 | brightness := grayscaleFloat(sum / float32(nPix)) 120 | if p.Invert { 121 | brightness = 1.0 - brightness 122 | } 123 | return brightness 124 | } 125 | 126 | // Returns the float16 network inputs as a slice of bytes to be used with a 127 | // CustomDataTensor. This is where we convert the float32 grayscale image to 128 | // float16. 129 | func (p *ProcessedImage) GetNetworkInput() []byte { 130 | // We need two bytes per float16 pixel, and 28x28 to form the image. 131 | toReturn := make([]byte, 28*28*2) 132 | currentOffset := 0 133 | for row := 0; row < 28; row++ { 134 | for col := 0; col < 28; col++ { 135 | c := float32(p.At(col, row).(grayscaleFloat)) 136 | valueFloat16 := float16.Fromfloat32(c) 137 | // The float16.Float16 type is just a uint16 underneath; write its 138 | // bytes to the data slice. 139 | binary.LittleEndian.PutUint16(toReturn[currentOffset:], 140 | uint16(valueFloat16)) 141 | currentOffset += 2 142 | } 143 | } 144 | return toReturn 145 | } 146 | 147 | // Takes a path to an image file, loads the image, and returns a ProcessedImage 148 | // struct which can be used to obtain the neural network input. 149 | func NewProcessedImage(path string, invertBrightness bool) (*ProcessedImage, 150 | error) { 151 | f, e := os.Open(path) 152 | if e != nil { 153 | return nil, fmt.Errorf("Error opening %s: %w", path, e) 154 | } 155 | defer f.Close() 156 | originalPic, _, e := image.Decode(f) 157 | if e != nil { 158 | return nil, fmt.Errorf("Error decoding image %s: %w", path, e) 159 | } 160 | bounds := originalPic.Bounds().Canon() 161 | if (bounds.Min.X != 0) || (bounds.Min.Y != 0) { 162 | // Should never happen with the standard library. 163 | return nil, fmt.Errorf("Bounding rect of %s doesn't start at 0, 0", 164 | path) 165 | } 166 | return &ProcessedImage{ 167 | dx: float32(bounds.Dx()) / 28.0, 168 | dy: float32(bounds.Dy()) / 28.0, 169 | pic: originalPic, 170 | Invert: invertBrightness, 171 | }, nil 172 | } 173 | 174 | // Attempts to save the given image as a png. 175 | func saveImage(pic image.Image, path string) error { 176 | f, e := os.Create(path) 177 | if e != nil { 178 | return fmt.Errorf("Error creating %s: %w", path, e) 179 | } 180 | defer f.Close() 181 | e = png.Encode(f, pic) 182 | if e != nil { 183 | return fmt.Errorf("Error encoding PNG image to %s: %w", path, e) 184 | } 185 | return nil 186 | } 187 | 188 | // Takes a list of float16 values as a slice of bytes, and converts each 189 | // float16 to a float32. 190 | func convertFloat16Data(data []byte) ([]float32, error) { 191 | if (len(data) % 2) != 0 { 192 | return nil, fmt.Errorf("A slice of float16s must have an even length") 193 | } 194 | toReturn := make([]float32, len(data)/2) 195 | for i := range toReturn { 196 | valueUint16 := binary.LittleEndian.Uint16(data[i*2:]) 197 | valueFloat16 := float16.Frombits(valueUint16) 198 | toReturn[i] = valueFloat16.Float32() 199 | } 200 | return toReturn, nil 201 | } 202 | 203 | // Takes a path to the onnxruntime shared library as well as the image file 204 | // containing a digit to be classified. The image file will be processed into 205 | // the format expected by the .onnx network. 206 | // 207 | // If the network runs successfully, this will print the classification results 208 | // to stdout. 209 | func classifyDigit(onnxruntimeLibPath, imagePath string, 210 | invertBrightness bool) error { 211 | ort.SetSharedLibraryPath(onnxruntimeLibPath) 212 | e := ort.InitializeEnvironment() 213 | if e != nil { 214 | return fmt.Errorf("Error initializing the onnxruntime library: %w", e) 215 | } 216 | defer ort.DestroyEnvironment() 217 | 218 | // Load the input image and save the postprocessed version for a visual 219 | // inspection. 220 | inputImage, e := NewProcessedImage(imagePath, invertBrightness) 221 | if e != nil { 222 | return fmt.Errorf("Error loading input image: %w", e) 223 | } 224 | postprocessedPath := "./postprocessed_input_image.png" 225 | e = saveImage(inputImage, postprocessedPath) 226 | if e != nil { 227 | fmt.Printf("Error saving postprocessed input: %s. Continuing.\n", e) 228 | } else { 229 | fmt.Printf("Saved postprocessed input image to %s.\n", 230 | postprocessedPath) 231 | } 232 | 233 | // Create and populate the input tensor 234 | inputShape := ort.NewShape(1, 1, 28, 28) 235 | inputData := inputImage.GetNetworkInput() 236 | input, e := ort.NewCustomDataTensor(inputShape, inputData, 237 | ort.TensorElementDataTypeFloat16) 238 | if e != nil { 239 | return fmt.Errorf("Error creating input tensor: %w", e) 240 | } 241 | defer input.Destroy() 242 | 243 | // Create the output tensor. We need a 1x10 float16 tensor, with two bytes 244 | // per float16. 245 | outputShape := ort.NewShape(1, 10) 246 | outputData := make([]byte, outputShape.FlattenedSize()*2) 247 | output, e := ort.NewCustomDataTensor(outputShape, outputData, 248 | ort.TensorElementDataTypeFloat16) 249 | if e != nil { 250 | return fmt.Errorf("Error creating output tensor: %w", e) 251 | } 252 | defer output.Destroy() 253 | 254 | // The input and output names are required by this network; they can be 255 | // found on the MNIST ONNX models page linked in the README. 256 | session, e := ort.NewAdvancedSession("./mnist_float16.onnx", 257 | []string{"Input3"}, []string{"Plus214_Output_0"}, 258 | []ort.ArbitraryTensor{input}, []ort.ArbitraryTensor{output}, nil) 259 | if e != nil { 260 | return fmt.Errorf("Error creating MNIST network session: %w", e) 261 | } 262 | defer session.Destroy() 263 | 264 | // Run the network and print the results. 265 | e = session.Run() 266 | if e != nil { 267 | return fmt.Errorf("Error running the MNIST network: %w", e) 268 | } 269 | 270 | // Convert the outputs from float16 back to float32 to make them easier to 271 | // compare and print. 272 | outputFloat32, e := convertFloat16Data(output.GetData()) 273 | if e != nil { 274 | return fmt.Errorf("Error converting float16 bytes to float32's: %w", e) 275 | } 276 | 277 | // Find the most likely output. 278 | maxIndex := 0 279 | maxProbability := float32(-1.0e9) 280 | for i, v := range outputFloat32 { 281 | fmt.Printf(" %d: %f\n", i, v) 282 | if v > maxProbability { 283 | maxProbability = v 284 | maxIndex = i 285 | } 286 | } 287 | fmt.Printf("%s is probably a %d, with probability %f\n", imagePath, 288 | maxIndex, maxProbability) 289 | 290 | return nil 291 | } 292 | 293 | func run() int { 294 | var onnxruntimeLibPath string 295 | var imagePath string 296 | var invertImage bool 297 | flag.StringVar(&onnxruntimeLibPath, "onnxruntime_lib", 298 | getDefaultSharedLibPath(), 299 | "The path to the onnxruntime shared library for your system.") 300 | flag.StringVar(&imagePath, "image_path", "", 301 | "The image containing a digit to classify.") 302 | flag.BoolVar(&invertImage, "invert_image", false, 303 | "If set, the image's colors will be inverted before processing. "+ 304 | "The network expects inputs with dark backgrounds, so you should "+ 305 | "set this to true for images with light backgrounds.") 306 | flag.Parse() 307 | if onnxruntimeLibPath == "" { 308 | fmt.Println("You must specify a path to the onnxruntime shared " + 309 | "on your system. Run with -help for more information.") 310 | return 1 311 | } 312 | if imagePath == "" { 313 | fmt.Println("You must specify an input image. Run with -help for " + 314 | "more information.") 315 | return 1 316 | } 317 | e := classifyDigit(onnxruntimeLibPath, imagePath, invertImage) 318 | if e != nil { 319 | fmt.Printf("Error running network: %s\n", e) 320 | return 1 321 | } 322 | fmt.Printf("Everything seemed to run OK!\n") 323 | return 0 324 | } 325 | 326 | func main() { 327 | os.Exit(run()) 328 | } 329 | -------------------------------------------------------------------------------- /mnist_float16/mnist_float16.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yalue/onnxruntime_go_examples/af5620ad77f175d73eb0e505aecce27e8e867642/mnist_float16/mnist_float16.onnx -------------------------------------------------------------------------------- /non_tensor_outputs/.gitignore: -------------------------------------------------------------------------------- 1 | non_tensor_outputs 2 | non_tensor_outputs.exe 3 | 4 | -------------------------------------------------------------------------------- /non_tensor_outputs/generate_sklearn_network.py: -------------------------------------------------------------------------------- 1 | # This script is a modified version of the example from 2 | # https://pypi.org/project/skl2onnx/, which we use to produce 3 | # sklearn_randomforest.onnx. This network is used here solely because it's one 4 | # of the few I've seen that uses ONNX sequences and maps in an output. 5 | 6 | import numpy as np 7 | from sklearn.datasets import load_iris 8 | from sklearn.model_selection import train_test_split 9 | from sklearn.ensemble import RandomForestClassifier 10 | 11 | iris = load_iris() 12 | inputs, outputs = iris.data, iris.target 13 | inputs = inputs.astype(np.float32) 14 | inputs_train, inputs_test, outputs_train, outputs_test = train_test_split(inputs, outputs) 15 | classifier = RandomForestClassifier() 16 | classifier.fit(inputs_train, outputs_train) 17 | 18 | # Convert into ONNX format. 19 | from skl2onnx import to_onnx 20 | output_filename = "sklearn_randomforest.onnx" 21 | onnx_content = to_onnx(classifier, inputs[:1]) 22 | with open(output_filename, "wb") as f: 23 | f.write(onnx_content.SerializeToString()) 24 | 25 | # Compute the prediction with onnxruntime. 26 | import onnxruntime as ort 27 | 28 | def float_formatter(f): 29 | return f"{float(f):.06f}" 30 | 31 | np.set_printoptions(formatter = {'float_kind': float_formatter}) 32 | session = ort.InferenceSession(output_filename) 33 | print(f"Input names: {[n.name for n in session.get_inputs()]!s}") 34 | print(f"Output names: {[o.name for o in session.get_outputs()]!s}") 35 | example_inputs = inputs_test.astype(np.float32)[:6] 36 | print(f"Inputs shape = {example_inputs.shape!s}") 37 | onnx_predictions = session.run(["output_label", "output_probability"], 38 | {"X": example_inputs}) 39 | labels = onnx_predictions[0] 40 | probabilities = onnx_predictions[1] 41 | 42 | print(f"Inputs to network: {example_inputs.astype(np.float32)}") 43 | print(f"ONNX predicted labels: {labels!s}") 44 | print(f"ONNX predicted probabilities: {probabilities!s}") 45 | 46 | -------------------------------------------------------------------------------- /non_tensor_outputs/go.mod: -------------------------------------------------------------------------------- 1 | module github.com/yalue/onnxruntime_go_examples/non_tensor_outputs 2 | 3 | go 1.20 4 | 5 | require github.com/yalue/onnxruntime_go v1.13.0 6 | -------------------------------------------------------------------------------- /non_tensor_outputs/go.sum: -------------------------------------------------------------------------------- 1 | github.com/yalue/onnxruntime_go v1.13.0 h1:5HDXHon3EukQMyYA7yPMed/raWaDE/gjwLOwnVoiwy8= 2 | github.com/yalue/onnxruntime_go v1.13.0/go.mod h1:b4X26A8pekNb1ACJ58wAXgNKeUCGEAQ9dmACut9Sm/4= 3 | -------------------------------------------------------------------------------- /non_tensor_outputs/non_tensor_outputs.go: -------------------------------------------------------------------------------- 1 | // This example illustrates how to access the contents of Maps and sequences, 2 | // using the random-forest sklearn network originally copied and modified from 3 | // here: http://onnx.ai/sklearn-onnx/. 4 | package main 5 | 6 | import ( 7 | "flag" 8 | "fmt" 9 | ort "github.com/yalue/onnxruntime_go" 10 | "os" 11 | "runtime" 12 | ) 13 | 14 | // For more comments, see the sum_and_difference example. 15 | func getDefaultSharedLibPath() string { 16 | if runtime.GOOS == "windows" { 17 | if runtime.GOARCH == "amd64" { 18 | return "../third_party/onnxruntime.dll" 19 | } 20 | } 21 | if runtime.GOOS == "darwin" { 22 | if runtime.GOARCH == "arm64" { 23 | return "../third_party/onnxruntime_arm64.dylib" 24 | } 25 | if runtime.GOARCH == "amd64" { 26 | return "../third_party/onnxruntime_amd64.dylib" 27 | } 28 | } 29 | if runtime.GOOS == "linux" { 30 | if runtime.GOARCH == "arm64" { 31 | return "../third_party/onnxruntime_arm64.so" 32 | } 33 | return "../third_party/onnxruntime.so" 34 | } 35 | fmt.Printf("Unable to determine a path to the onnxruntime shared library"+ 36 | " for OS \"%s\" and architecture \"%s\".\n", runtime.GOOS, 37 | runtime.GOARCH) 38 | return "" 39 | } 40 | 41 | func run() int { 42 | var onnxruntimeLibPath string 43 | flag.StringVar(&onnxruntimeLibPath, "onnxruntime_lib", 44 | getDefaultSharedLibPath(), 45 | "The path to the onnxruntime shared library for your system.") 46 | flag.Parse() 47 | if onnxruntimeLibPath == "" { 48 | fmt.Println("You must specify a path to the onnxruntime shared " + 49 | "on your system. Run with -help for more information.") 50 | return 1 51 | } 52 | e := runSklearnNetwork(onnxruntimeLibPath) 53 | if e != nil { 54 | fmt.Printf("Encountered an error running the network: %s\n", e) 55 | return 1 56 | } 57 | return 0 58 | } 59 | 60 | func main() { 61 | os.Exit(run()) 62 | } 63 | 64 | func runSklearnNetwork(sharedLibPath string) error { 65 | ort.SetSharedLibraryPath(sharedLibPath) 66 | e := ort.InitializeEnvironment() 67 | if e != nil { 68 | return fmt.Errorf("Error initializing onnxruntime library: %w", e) 69 | } 70 | 71 | // Load the session. We'll use DynamicAdvancedSession so that onnxruntime 72 | // can automatically allocate the more complicated outputs for us. 73 | modelPath := "./sklearn_randomforest.onnx" 74 | inputNames := []string{"X"} 75 | outputNames := []string{"output_label", "output_probability"} 76 | session, e := ort.NewDynamicAdvancedSession(modelPath, inputNames, 77 | outputNames, nil) 78 | if e != nil { 79 | return fmt.Errorf("Error loading %s: %w", modelPath, e) 80 | } 81 | defer session.Destroy() 82 | 83 | // Create the 6x4 input tensor (6 vectors of 4 elements each). This data 84 | // is from information printed by generate_sklearn_network.py. 85 | inputShape := ort.NewShape(6, 4) 86 | inputValues := []float32{ 87 | 5.9, 3.0, 5.1, 1.8, 88 | 6.8, 2.8, 4.8, 1.4, 89 | 6.3, 2.3, 4.4, 1.3, 90 | 6.5, 3.0, 5.5, 1.8, 91 | 7.7, 2.8, 6.7, 2.0, 92 | 5.5, 2.5, 4.0, 1.3, 93 | } 94 | inputTensor, e := ort.NewTensor(inputShape, inputValues) 95 | if e != nil { 96 | return fmt.Errorf("Error creating input tensor: %w", e) 97 | } 98 | defer inputTensor.Destroy() 99 | 100 | // Create a two-element slice that will be populated by the values 101 | // automatically allocated while running the network. (Leaving the outputs 102 | // as nil allows DynamicAdvancedSession.Run() to allocate them.) 103 | outputValues := []ort.Value{nil, nil} 104 | 105 | // Actually run the network. 106 | e = session.Run([]ort.Value{inputTensor}, outputValues) 107 | if e != nil { 108 | return fmt.Errorf("Error running %s: %w", modelPath, e) 109 | } 110 | // Any auto-allocated outputs must be manually destroyed when no longer 111 | // needed. 112 | defer outputValues[0].Destroy() 113 | defer outputValues[1].Destroy() 114 | fmt.Printf("Successfully ran %s!\n", modelPath) 115 | 116 | // The first output of this network is just a Tensor containing the labels 117 | // with the highest probabilities. 118 | labelTensor := outputValues[0].(*ort.Tensor[int64]) 119 | predictedLabels := labelTensor.GetData() 120 | for i, v := range predictedLabels { 121 | fmt.Printf("Predicted label for input %d: %d\n", i, v) 122 | } 123 | 124 | // The second output of this network is an ONNX Sequence of maps. The 125 | // sequence contains one map for each of the 6 input vectors. Each map 126 | // maps every possible label to its predicted probability for the 127 | // corresponding input vector. (You'll see that the label with the highest 128 | // probability was already provided in the first output.) 129 | sequence := outputValues[1].(*ort.Sequence) 130 | probabilityMaps, e := sequence.GetValues() 131 | if e != nil { 132 | return fmt.Errorf("Error getting contents of sequence: %w", e) 133 | } 134 | 135 | for i := range probabilityMaps { 136 | // An ONNX Map is represented by two tensors of the same size: one 137 | // containing keys and one containing values. keys.GetData()[i] 138 | // contains the key, and values.GetData()[i] contains the value the 139 | // key maps to. 140 | m := probabilityMaps[i].(*ort.Map) 141 | keys, values, e := m.GetKeysAndValues() 142 | if e != nil { 143 | return fmt.Errorf("Error getting keys and values for map at "+ 144 | "index %d: %w", i, e) 145 | } 146 | keysTensor := keys.(*ort.Tensor[int64]) 147 | valuesTensor := values.(*ort.Tensor[float32]) 148 | 149 | fmt.Printf("Individual probabilities for input %d:\n", i) 150 | for j, key := range keysTensor.GetData() { 151 | value := valuesTensor.GetData()[j] 152 | fmt.Printf(" Label %d: %f\n", key, value) 153 | } 154 | } 155 | return nil 156 | } 157 | -------------------------------------------------------------------------------- /non_tensor_outputs/sklearn_randomforest.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yalue/onnxruntime_go_examples/af5620ad77f175d73eb0e505aecce27e8e867642/non_tensor_outputs/sklearn_randomforest.onnx -------------------------------------------------------------------------------- /onnx_list_inputs_and_outputs/.gitignore: -------------------------------------------------------------------------------- 1 | onnx_list_inputs_and_outputs 2 | onnx_list_inputs_and_outputs.exe 3 | 4 | -------------------------------------------------------------------------------- /onnx_list_inputs_and_outputs/README.md: -------------------------------------------------------------------------------- 1 | Getting ONNX Input and Output Information 2 | ========================================= 3 | 4 | This example project defines a command-line utility that prints the input and 5 | output information about a user-specified .onnx file to standard output. 6 | 7 | Example usage: 8 | ``` 9 | go build . 10 | 11 | ./onnx_list_inputs_and_outputs -onnx_file ../image_object_detect/yolov8n.onnx 12 | ``` 13 | 14 | The above command should output something like the following: 15 | 16 | ``` 17 | 1 inputs to ../image_object_detect/yolov8n.onnx: 18 | Index 0: "images": [1 3 640 640], ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT 19 | 1 outputs from ../image_object_detect/yolov8n.onnx: 20 | Index 0: "output0": [1 84 8400], ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT 21 | ``` 22 | 23 | (The yolov8 network only has one input and one output: a 1x3x640x640 input, 24 | named "images", and a 1x84x8400 output, named "output0".) 25 | 26 | -------------------------------------------------------------------------------- /onnx_list_inputs_and_outputs/go.mod: -------------------------------------------------------------------------------- 1 | module github.com/yalue/onnxruntime_go_examples/onnx_list_inputs_and_outputs 2 | 3 | go 1.20 4 | 5 | require github.com/yalue/onnxruntime_go v1.13.0 6 | -------------------------------------------------------------------------------- /onnx_list_inputs_and_outputs/go.sum: -------------------------------------------------------------------------------- 1 | github.com/yalue/onnxruntime_go v1.13.0 h1:5HDXHon3EukQMyYA7yPMed/raWaDE/gjwLOwnVoiwy8= 2 | github.com/yalue/onnxruntime_go v1.13.0/go.mod h1:b4X26A8pekNb1ACJ58wAXgNKeUCGEAQ9dmACut9Sm/4= 3 | -------------------------------------------------------------------------------- /onnx_list_inputs_and_outputs/onnx_list_inputs_and_outputs.go: -------------------------------------------------------------------------------- 1 | // This is a simple command-line utility that takes a single .onnx file and 2 | // lists the inputs and outputs to it. 3 | package main 4 | 5 | import ( 6 | "flag" 7 | "fmt" 8 | ort "github.com/yalue/onnxruntime_go" 9 | "os" 10 | "runtime" 11 | ) 12 | 13 | // For more comments, see the sum_and_difference example. 14 | func getDefaultSharedLibPath() string { 15 | if runtime.GOOS == "windows" { 16 | if runtime.GOARCH == "amd64" { 17 | return "../third_party/onnxruntime.dll" 18 | } 19 | } 20 | if runtime.GOOS == "darwin" { 21 | if runtime.GOARCH == "arm64" { 22 | return "../third_party/onnxruntime_arm64.dylib" 23 | } 24 | if runtime.GOARCH == "amd64" { 25 | return "../third_party/onnxruntime_amd64.dylib" 26 | } 27 | } 28 | if runtime.GOOS == "linux" { 29 | if runtime.GOARCH == "arm64" { 30 | return "../third_party/onnxruntime_arm64.so" 31 | } 32 | return "../third_party/onnxruntime.so" 33 | } 34 | fmt.Printf("Unable to determine a path to the onnxruntime shared library"+ 35 | " for OS \"%s\" and architecture \"%s\".\n", runtime.GOOS, 36 | runtime.GOARCH) 37 | return "" 38 | } 39 | 40 | // Prints the inputs and outputs of an onnx-format network to stdout. 41 | func showNetworkInputsAndOutputs(libPath, networkPath string) error { 42 | ort.SetSharedLibraryPath(libPath) 43 | e := ort.InitializeEnvironment() 44 | if e != nil { 45 | return fmt.Errorf("Error initializing onnxruntime library: %w", e) 46 | } 47 | inputs, outputs, e := ort.GetInputOutputInfo(networkPath) 48 | if e != nil { 49 | return fmt.Errorf("Error getting input and output info for %s: %w", 50 | networkPath, e) 51 | } 52 | fmt.Printf("%d inputs to %s:\n", len(inputs), networkPath) 53 | for i, v := range inputs { 54 | fmt.Printf(" Index %d: %s\n", i, &v) 55 | } 56 | fmt.Printf("%d outputs from %s:\n", len(outputs), networkPath) 57 | for i, v := range outputs { 58 | fmt.Printf(" Index %d: %s\n", i, &v) 59 | } 60 | return nil 61 | } 62 | 63 | func run() int { 64 | var onnxruntimeLibPath string 65 | var onnxNetworkPath string 66 | flag.StringVar(&onnxruntimeLibPath, "onnxruntime_lib", 67 | getDefaultSharedLibPath(), 68 | "The path to the onnxruntime shared library for your system.") 69 | flag.StringVar(&onnxNetworkPath, "onnx_file", "", 70 | "The path to the .onnx file to load.") 71 | flag.Parse() 72 | if onnxruntimeLibPath == "" { 73 | fmt.Println("You must specify a path to the onnxruntime shared " + 74 | "on your system. Run with -help for more information.") 75 | return 1 76 | } 77 | if onnxNetworkPath == "" { 78 | fmt.Println("You must specify a .onnx network to list the inputs and" + 79 | " outputs for. Run with -help for more information.") 80 | } 81 | e := showNetworkInputsAndOutputs(onnxruntimeLibPath, onnxNetworkPath) 82 | if e != nil { 83 | fmt.Printf("Error getting network inputs and outputs: %s\n", e) 84 | return 1 85 | } 86 | return 0 87 | } 88 | 89 | func main() { 90 | os.Exit(run()) 91 | } 92 | -------------------------------------------------------------------------------- /sum_and_difference/.gitignore: -------------------------------------------------------------------------------- 1 | sum_and_difference.exe 2 | sum_and_difference 3 | 4 | -------------------------------------------------------------------------------- /sum_and_difference/README.md: -------------------------------------------------------------------------------- 1 | Sum and Difference `onnxruntime_go` Example 2 | =========================================== 3 | 4 | This is a basic, heavily-commented command-line program that uses the 5 | `onnxruntime_go` library to load and run an ONNX-format neural network. 6 | 7 | Usage 8 | ----- 9 | 10 | Build the program using `go build`. After this, it should run without arguments 11 | on most systems: `./sum_and_difference`. If you encounter errors, you may need 12 | to specify a different version of the `onnxruntime` shared library, using the 13 | `-onnxruntime_lib` command-line flag. (Run the program with `-help` to see 14 | usage information.) 15 | 16 | ```bash 17 | go build . 18 | ./sum_and_difference 19 | ``` 20 | 21 | Should output the following if successful: 22 | ``` 23 | The network ran without errors. 24 | Input data: [0.2 0.3 0.6 0.9] 25 | Approximate sum of inputs: 1.999988 26 | Approximate max difference between any two inputs: 0.607343 27 | The network seemed to run OK! 28 | ``` 29 | 30 | -------------------------------------------------------------------------------- /sum_and_difference/example_network.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yalue/onnxruntime_go_examples/af5620ad77f175d73eb0e505aecce27e8e867642/sum_and_difference/example_network.onnx -------------------------------------------------------------------------------- /sum_and_difference/generate_network.py: -------------------------------------------------------------------------------- 1 | # This script sets up and "trains" a toy pytorch network, that trains a NN to 2 | # map a 1x4 vector to a 1x2 vector containing [sum, max difference] of the 3 | # input values. Finally, it exports the network to an ONNX file. 4 | # 5 | # This script is adapted from one used in unit-testing onnxruntime_go to use 6 | # a slightly bigger neural network. 7 | import torch 8 | from torch.nn.functional import relu 9 | 10 | def fake_dataset(size): 11 | """ Returns a dataset filled with our fake training data. """ 12 | inputs = torch.rand((size, 1, 4)) 13 | outputs = torch.zeros((size, 1, 2)) 14 | for i in range(size): 15 | outputs[i][0][0] = inputs[i][0].sum() 16 | outputs[i][0][1] = inputs[i][0].max() - inputs[i][0].min() 17 | return torch.utils.data.TensorDataset(inputs, outputs) 18 | 19 | class SumAndDiffModel(torch.nn.Module): 20 | """ Just a standard, fairly minimal, pytorch model for generating the NN. 21 | """ 22 | def __init__(self): 23 | super().__init__() 24 | # We'll do four 1x4 convolutions to make the network more interesting. 25 | self.conv = torch.nn.Conv1d(1, 4, 4) 26 | # We'll follow the conv with two FC layers to produce the outputs. The 27 | # input to the first FC layer are the 4 conv outputs concatenated with 28 | # the original input. 29 | self.fc1 = torch.nn.Linear(8, 32) 30 | self.fc2 = torch.nn.Linear(32, 2) 31 | 32 | def forward(self, data): 33 | batch_size = len(data) 34 | conv_out = relu(self.conv(data)) 35 | conv_flattened = torch.flatten(conv_out, start_dim=1) 36 | data_flattened = torch.flatten(data, start_dim=1) 37 | combined = torch.cat((conv_flattened, data_flattened), dim=1) 38 | output = relu(self.fc1(combined)) 39 | output = relu(self.fc2(output)) 40 | output = output.view(batch_size, 1, 2) 41 | return output 42 | 43 | def get_test_loss(model, loader, loss_function): 44 | """ Just runs a single epoch of data from the given loader. Returns the 45 | average loss per batch. The provided model is expected to be in eval mode. 46 | """ 47 | i = 0 48 | total_loss = 0.0 49 | for in_data, desired_result in loader: 50 | produced_result = model(in_data) 51 | loss = loss_function(desired_result, produced_result) 52 | total_loss += loss.item() 53 | i += 1 54 | return total_loss / i 55 | 56 | def save_model(model, output_filename): 57 | """ Saves the model to an onnx file with the given name. Assumes the model 58 | is in eval mode. """ 59 | print("Saving network to " + output_filename) 60 | dummy_input = torch.rand(1, 1, 4) 61 | input_names = ["1x4 Input Vector"] 62 | output_names = ["1x2 Output Vector"] 63 | torch.onnx.export(model, dummy_input, output_filename, 64 | input_names=input_names, output_names=output_names) 65 | return None 66 | 67 | def print_sample(model): 68 | """ Prints a sample input and output computation using the model. Expects 69 | the model to be in eval mode. """ 70 | example_input = torch.rand(1, 1, 4) 71 | result = model(example_input) 72 | print("Sample model execution:") 73 | print(" Example input: " + str(example_input)) 74 | print(" Produced output: " + str(result)) 75 | return None 76 | 77 | def main(): 78 | print("Generating train and test datasets...") 79 | train_data = fake_dataset(200 * 1000) 80 | train_loader = torch.utils.data.DataLoader(dataset=train_data, 81 | batch_size=16, shuffle=True) 82 | test_data = fake_dataset(10 * 1000) 83 | test_loader = torch.utils.data.DataLoader(dataset=test_data, 84 | batch_size=16) 85 | model = SumAndDiffModel() 86 | model.train() 87 | loss_function = torch.nn.L1Loss(reduction="mean") 88 | optimizer = torch.optim.Adam(model.parameters(), lr=0.0001) 89 | for epoch in range(4): 90 | i = 0 91 | total_loss = 0.0 92 | for in_data, desired_result in train_loader: 93 | i += 1 94 | produced_result = model(in_data) 95 | loss = loss_function(desired_result, produced_result) 96 | loss.backward() 97 | optimizer.step() 98 | optimizer.zero_grad() 99 | total_loss += loss.item() 100 | if (i % 1000) == 1: 101 | print("Epoch %d, iteration %d. Current loss = %f" % (epoch, i, 102 | loss.item())) 103 | train_loss = total_loss / i 104 | print(" => Average train-set loss: " + str(train_loss)) 105 | model.eval() 106 | with torch.no_grad(): 107 | test_loss = get_test_loss(model, test_loader, loss_function) 108 | model.train() 109 | print(" => Average test-set loss: " + str(test_loss)) 110 | 111 | model.eval() 112 | with torch.no_grad(): 113 | save_model(model, "example_network.onnx") 114 | print_sample(model) 115 | print("Done!") 116 | 117 | if __name__ == "__main__": 118 | main() 119 | 120 | -------------------------------------------------------------------------------- /sum_and_difference/go.mod: -------------------------------------------------------------------------------- 1 | module github.com/yalue/onnxruntime_go_examples/sum_and_difference 2 | 3 | go 1.20 4 | 5 | require github.com/yalue/onnxruntime_go v1.13.0 6 | -------------------------------------------------------------------------------- /sum_and_difference/go.sum: -------------------------------------------------------------------------------- 1 | github.com/yalue/onnxruntime_go v1.13.0 h1:5HDXHon3EukQMyYA7yPMed/raWaDE/gjwLOwnVoiwy8= 2 | github.com/yalue/onnxruntime_go v1.13.0/go.mod h1:b4X26A8pekNb1ACJ58wAXgNKeUCGEAQ9dmACut9Sm/4= 3 | -------------------------------------------------------------------------------- /sum_and_difference/sum_and_difference.go: -------------------------------------------------------------------------------- 1 | // This is a basic command-line application that serves as a minimal example 2 | // for loading an executing a neural network using the onnxruntime library. 3 | // 4 | // If you're wanting to learn how to use the onnxruntime_go library, the 5 | // runTest function is the most important one here. The rest of this program 6 | // is mostly boilerplate for setting up a command-line program. 7 | // 8 | // The actual network used by this program was generated by the included 9 | // generate_network.py pytorch script. It takes a 1x4 input vector of 32-bit 10 | // floats, and produces a 1x2 output vector of 32-bit floats. The network 11 | // attempts to populate the two values in the output vector with 1) the sum 12 | // of the four inputs, and 2), the maximum difference between any two of the 13 | // input values. 14 | package main 15 | 16 | import ( 17 | "flag" 18 | "fmt" 19 | ort "github.com/yalue/onnxruntime_go" 20 | "os" 21 | "runtime" 22 | ) 23 | 24 | // Attempts to find and return a path to a version of the onnxruntime shared 25 | // library compatible with the current OS and system architecture. 26 | func getDefaultSharedLibPath() string { 27 | // For now, the third_party directory includes libraries for x86_64 28 | // windows, ARM64 and AMD64 darwin, and x86_64 or ARM64 Linux. The point of 29 | // these examples is to show _how_ to select libraries, rather than to 30 | // provide broad support, so this will probably not be expanded. 31 | if runtime.GOOS == "windows" { 32 | if runtime.GOARCH == "amd64" { 33 | return "../third_party/onnxruntime.dll" 34 | } 35 | } 36 | if runtime.GOOS == "darwin" { 37 | if runtime.GOARCH == "arm64" { 38 | return "../third_party/onnxruntime_arm64.dylib" 39 | } 40 | if runtime.GOARCH == "amd64" { 41 | return "../third_party/onnxruntime_amd64.dylib" 42 | } 43 | } 44 | if runtime.GOOS == "linux" { 45 | if runtime.GOARCH == "arm64" { 46 | return "../third_party/onnxruntime_arm64.so" 47 | } 48 | return "../third_party/onnxruntime.so" 49 | } 50 | fmt.Printf("Unable to determine a path to the onnxruntime shared library"+ 51 | " for OS \"%s\" and architecture \"%s\".\n", runtime.GOOS, 52 | runtime.GOARCH) 53 | return "" 54 | } 55 | 56 | // Actually sets up and runs the neural network. Requires a path to the 57 | // onnxruntime shared library file. 58 | func runTest(onnxruntimeLibPath string) error { 59 | // Step 1: Initialize the onnxruntime library after providing a path to the 60 | // shared library to use. 61 | ort.SetSharedLibraryPath(onnxruntimeLibPath) 62 | e := ort.InitializeEnvironment() 63 | if e != nil { 64 | return fmt.Errorf("Error initializing the onnxruntime library: %w", e) 65 | } 66 | // Clean up the onnxruntime library when we're done using it. 67 | defer ort.DestroyEnvironment() 68 | 69 | // Step 2: Create the input tensor. Tensors are wrappers around Go slices; 70 | // the onnxruntime networks access the data in these slices to read inputs 71 | // or write outputs. Here, we'll create a 1x4 input tensor initialized 72 | // with some preset data. The inputData slice can be modified directly to 73 | // change the input values, even after creating the inputTensor. 74 | inputData := []float32{0.2, 0.3, 0.6, 0.9} 75 | // The tensor's shape is actually 1x1x4 rather than 1x4 because the first 76 | // dimension in the PyTorch script was used for batch size. 77 | inputTensor, e := ort.NewTensor(ort.NewShape(1, 1, 4), inputData) 78 | if e != nil { 79 | return fmt.Errorf("Error creating the input tensor: %w", e) 80 | } 81 | // Tensors must always be destroyed when they're no longer needed to free 82 | // associated onnxruntime structures. Destroying the tensor object won't 83 | // change the underlying Go data slice, which can still be cleaned up by 84 | // Go's garbage collector when it's no longer referenced. 85 | defer inputTensor.Destroy() 86 | 87 | // Step 3: Create the output tensor. Since we don't need to initialize it, 88 | // we can use the NewEmptyTensor to just get a zero-filled tensor with the 89 | // required shape for this network. The library will automatically allocate 90 | // a Go slice with the necessary capacity in this case. To access this 91 | // slice, we can call outputTensor.GetData() after creating the tensor. 92 | outputTensor, e := ort.NewEmptyTensor[float32](ort.NewShape(1, 1, 2)) 93 | if e != nil { 94 | return fmt.Errorf("Error creating the output tensor: %w", e) 95 | } 96 | defer outputTensor.Destroy() 97 | 98 | // Step 4: Load the network itself into an onnxruntime Session instance. 99 | // Note that we call "NewAdvancedSession"---this isn't particularly 100 | // "Advanced", but it's simply a newer version of the API that allows 101 | // specifying additional options (which we don't use here). onnxruntime 102 | // requires associating input and output tensors with names, which in this 103 | // case we set to "1x4 Input Vector" and "1x2 Output Vector" when creating 104 | // the network. (If you're curious, this was done when exporting the .onnx 105 | // file from the the python script.) The last argument to 106 | // NewAdvancedSession is a pointer to a SessionOptions instance, which we 107 | // leave as nil to indicate that default options are OK. 108 | session, e := ort.NewAdvancedSession("./sum_and_difference.onnx", 109 | []string{"1x4 Input Vector"}, 110 | []string{"1x2 Output Vector"}, 111 | []ort.ArbitraryTensor{inputTensor}, 112 | []ort.ArbitraryTensor{outputTensor}, 113 | nil) 114 | if e != nil { 115 | return fmt.Errorf("Error creating the session: %w", e) 116 | } 117 | // The session must also always be destroyed to free internal data. 118 | // Destroying the session will not modify or destroy the input or output 119 | // tensors it was using. 120 | defer session.Destroy() 121 | 122 | // Step 5: Actually run the network. This will read the data from the input 123 | // tensor, and write to the output tensor. To re-run the network with 124 | // different inputs, we can simply modify the inputData slice before 125 | // calling Run() again. (Here, we only call it once, though.) 126 | e = session.Run() 127 | if e != nil { 128 | return fmt.Errorf("Error executing the network: %w", e) 129 | } 130 | 131 | // Step 6: Read the output data and present the results. The network may 132 | // not be very good, but it was designed to be a small test and not trained 133 | // for very long! 134 | outputData := outputTensor.GetData() 135 | fmt.Printf("The network ran without errors.\n") 136 | fmt.Printf(" Input data: %v\n", inputData) 137 | fmt.Printf(" Approximate sum of inputs: %f\n", outputData[0]) 138 | fmt.Printf(" Approximate max difference between any two inputs: %f\n", outputData[1]) 139 | return nil 140 | } 141 | 142 | func run() int { 143 | var onnxruntimeLibPath string 144 | flag.StringVar(&onnxruntimeLibPath, "onnxruntime_lib", 145 | getDefaultSharedLibPath(), 146 | "The path to the onnxruntime shared library for your system.") 147 | flag.Parse() 148 | if onnxruntimeLibPath == "" { 149 | fmt.Println("You must specify a path to the onnxruntime shared " + 150 | "on your system. Run with -help for more information.") 151 | return 1 152 | } 153 | e := runTest(onnxruntimeLibPath) 154 | if e != nil { 155 | fmt.Printf("Encountered an error running the network: %s\n", e) 156 | return 1 157 | } 158 | fmt.Printf("The network seemed to run OK!\n") 159 | return 0 160 | } 161 | 162 | func main() { 163 | os.Exit(run()) 164 | } 165 | -------------------------------------------------------------------------------- /sum_and_difference/sum_and_difference.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yalue/onnxruntime_go_examples/af5620ad77f175d73eb0e505aecce27e8e867642/sum_and_difference/sum_and_difference.onnx -------------------------------------------------------------------------------- /third_party/onnxruntime.dll: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yalue/onnxruntime_go_examples/af5620ad77f175d73eb0e505aecce27e8e867642/third_party/onnxruntime.dll -------------------------------------------------------------------------------- /third_party/onnxruntime.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yalue/onnxruntime_go_examples/af5620ad77f175d73eb0e505aecce27e8e867642/third_party/onnxruntime.so -------------------------------------------------------------------------------- /third_party/onnxruntime_amd64.dylib: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yalue/onnxruntime_go_examples/af5620ad77f175d73eb0e505aecce27e8e867642/third_party/onnxruntime_amd64.dylib -------------------------------------------------------------------------------- /third_party/onnxruntime_arm64.dylib: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yalue/onnxruntime_go_examples/af5620ad77f175d73eb0e505aecce27e8e867642/third_party/onnxruntime_arm64.dylib -------------------------------------------------------------------------------- /third_party/onnxruntime_arm64.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yalue/onnxruntime_go_examples/af5620ad77f175d73eb0e505aecce27e8e867642/third_party/onnxruntime_arm64.so --------------------------------------------------------------------------------