├── .gitignore ├── .sccignore ├── LICENSE ├── README.md ├── docs ├── CHANGELOG.md └── benchmarks.md ├── go.mod ├── go.sum ├── internal ├── benchmarks │ ├── add1.onnx │ ├── add1div2.onnx │ ├── benchmarks.go │ ├── knights_sbert_test.go │ ├── mean_norm_sqrt_add1div2.onnx │ ├── rob_sentences_embeddings.bin │ ├── rob_sentences_test.go │ ├── short_programs_test.go │ └── sqrt_add1div2.onnx ├── cmd │ └── protoc_onnx_protos │ │ └── main.go └── protos │ ├── README.md │ ├── onnx-data.pb.go │ ├── onnx-data.proto │ ├── onnx-ml.pb.go │ ├── onnx-ml.proto │ ├── onnx-operators-ml.pb.go │ ├── onnx-operators-ml.proto │ └── protos.go ├── onnx-go.ipynb ├── onnx-py.ipynb └── onnx ├── dtypes.go ├── dynamicshape.go ├── dynamicshape_test.go ├── graph.go ├── linear_test.go ├── linear_test.onnx ├── materialize.go ├── onnx.go ├── ops.go ├── ops_test.go ├── prettyprint.go ├── tensor.go └── variables.go /.gitignore: -------------------------------------------------------------------------------- 1 | # If you prefer the allow list template instead of the deny list, see community template: 2 | # https://github.com/github/gitignore/blob/main/community/Golang/Go.AllowList.gitignore 3 | # 4 | # Binaries for programs and plugins 5 | *.exe 6 | *.exe~ 7 | *.dll 8 | *.so 9 | *.dylib 10 | 11 | # Test binary, built with `go test -c` 12 | *.test 13 | 14 | # Output of the go coverage tool, specifically when used with LiteIDE 15 | *.out 16 | 17 | # Dependency directories (remove the comment below to include it) 18 | # vendor/ 19 | 20 | # Go workspace file 21 | go.work 22 | go.work.sum 23 | 24 | # env file 25 | .env 26 | 27 | 28 | # GoLand IDE 29 | .idea/ 30 | -------------------------------------------------------------------------------- /.sccignore: -------------------------------------------------------------------------------- 1 | internal/protos 2 | LICENSE 3 | .gitignore 4 | .idea 5 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ONNX-GoMLX from ONNX to GoMLX and back 2 | 3 | [![GoDev](https://img.shields.io/badge/go.dev-reference-007d9c?logo=go&logoColor=white)](https://pkg.go.dev/github.com/gomlx/onnx-gomlx?tab=doc) 4 | 5 | ## 📖 Overview 6 | ONNX-GoMLX converts [ONNX models](https://onnx.ai/) (`.onnx` suffix) to 7 | [GoMLX (an accelerated machine learning framework for Go](https://github.com/gomlx/gomlx) and optionally back to ONNX. 8 | 9 | The main use cases so far are: 10 | 11 | 1. **Fine-tuning**: import an inference only ONNX model to GoMLX, and use its auto-differentiation and training loop to 12 | fine-tune models. It allows saving the fine-tuned model as a GoMLX checkpoint or export the fine-tuned weights 13 | back to the ONNX model. This can also be used to expand / combine models. 14 | 2. **Inference**: use an ONNX file using Go and not having to include [ONNX Runtime](https://onnxruntime.ai/) (or Python) 15 | -- at the cost of including XLA/PJRT (the current only backend for GoMLX). It also allows one to extend the 16 | model with extra ML pre/post-processing using GoMLX (image transformations, normalization, combining models, 17 | building ensembles, etc.). This may be interesting for large/expensive models, or large throughput on large 18 | batches. 19 | * Notice if you want to simply get a pure Go inference of ONNX models, see 20 | [github.com/AdvancedClimateSystems/gonnx](https://github.com/AdvancedClimateSystems/gonnx) or 21 | [github.com/oramasearch/onnx-go](https://github.com/oramasearch/onnx-go). They will be slower (~8x based on a SentenceEncoder model, BERT based using `gonnx` vs `ONNXRuntime`) than 22 | the XLA inference (or `onnxruntime`) for large projects, but for many use cases it doesn't matter, and they 23 | are a much smaller pure Go dependency. Only for CPU (no GPU support). 24 | 25 | ## Coverage of ONNX Ops Set 26 | 27 | There are at least 10 or so models that are working so far: 28 | 29 | * [Sentence Enconding all-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2) 30 | has been working perfectly, see example below. 31 | * [ONNX-GoMLX demo/development notebook](https://github.com/gomlx/onnx-gomlx/blob/main/onnx-go.ipynb): both serves as a functional test and to demo what it can do. 32 | 33 | But **not all operations ("ops") are converted yet**. If you try it and find some that is not, 34 | please let us know (create an "issue") we will be happy to try to convert them -- generally, 35 | all the required scaffolding and tooling is already there, and 36 | converting ops has been very easy. 37 | 38 | ## 🎓 Example 39 | 40 | We download (and cache) the [all-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2) 41 | using [github.com/gomlx/go-huggingface](https://github.com/gomlx/go-huggingface). 42 | 43 | The tokens were for now hardcoded -- eventually [github.com/gomlx/go-huggingface](https://github.com/gomlx/go-huggingface) should also 44 | do the tokenization for various models. 45 | 46 | ```go 47 | import ( 48 | "github.com/gomlx/onnx-gomlx/onnx" 49 | "github.com/gomlx/go-huggingface/hub" 50 | ) 51 | 52 | ... 53 | 54 | // Download and cache ONNX model from HuggingFace. 55 | hfAuthToken := os.Getenv("HF_TOKEN") 56 | hfModelID := "sentence-transformers/all-MiniLM-L6-v2" 57 | repo := hub.New(modelID).WithAuth(hfAuthToken) 58 | modelPath := must.M1(repo.DownloadFile("onnx/model.onnx")) 59 | 60 | // Parse ONNX model. 61 | model := must.M1(onnx.ReadFile(modelPath)) 62 | 63 | // Convert ONNX variables (model weights) to GoMLX Context -- which stores variables and can be checkpointed (saved): 64 | ctx := context.New() 65 | must.M(model.VariablesToContext(ctx)) 66 | 67 | // Execute it with GoMLX/XLA: 68 | sentences := []string{ 69 | "This is an example sentence", 70 | "Each sentence is converted"} 71 | //... tokenize ... 72 | inputIDs := [][]int64{ 73 | {101, 2023, 2003, 2019, 2742, 6251, 102}, 74 | { 101, 2169, 6251, 2003, 4991, 102, 0}} 75 | tokenTypeIDs := [][]int64{ 76 | {0, 0, 0, 0, 0, 0, 0}, 77 | {0, 0, 0, 0, 0, 0, 0}} 78 | attentionMask := [][]int64{ 79 | {1, 1, 1, 1, 1, 1, 1}, 80 | {1, 1, 1, 1, 1, 1, 0}} 81 | var embeddings []*tensors.Tensor 82 | embeddings = context.ExecOnceN( // Execute a GoMLX computation graph with a context 83 | backends.New(), // GoMLX backend to use (defaults to XLA) 84 | ctx, // Context store the model variables/weights and optional hyperparameters. 85 | func (ctx *context.Context, inputs []*Node) []*Node { 86 | // Convert ONNX model (in `model`) to a GoMLX computation graph. It returns a slice of values (with only one for this model) 87 | return model.CallGraph(ctx, inputs[0].Graph(), map[string]*Node{ 88 | "input_ids": inputs[0], 89 | "attention_mask": inputs[1], 90 | "token_type_ids": inputs[2]}, targetOutputs...) 91 | }, 92 | inputIDs, attentionMask, tokenTypeIDs) // Inputs to the GoMLX function. 93 | fmt.Printf("Embeddings: %s", embeddings) 94 | ``` 95 | 96 | The output looks like: 97 | 98 | ``` 99 | Embeddings: [2][7][384]float32{ 100 | {{-0.0886, -0.0368, 0.0180, ..., 0.0261, 0.0912, -0.0152}, 101 | {-0.0200, -0.0014, -0.0177, ..., 0.0204, 0.0522, 0.1991}, 102 | {-0.0196, -0.0336, -0.0319, ..., 0.0203, 0.0709, 0.0644}, 103 | ..., 104 | {-0.0253, 0.0408, 0.0125, ..., -0.0270, 0.0377, 0.1133}, 105 | {-0.0140, -0.0275, 0.0796, ..., -0.0748, 0.0774, -0.0657}, 106 | {0.0318, -0.0032, -0.0210, ..., 0.0387, 0.0191, -0.0059}}, 107 | {{-0.0886, -0.0368, 0.0180, ..., 0.0261, 0.0912, -0.0152}, 108 | {0.0304, 0.0531, -0.0238, ..., -0.1011, 0.0218, 0.0473}, 109 | {-0.0027, -0.0508, 0.0805, ..., -0.0777, 0.0881, -0.0560}, 110 | ..., 111 | {0.0928, 0.0165, -0.0976, ..., 0.0449, 0.0390, -0.0182}, 112 | {0.0231, 0.0090, -0.0213, ..., 0.0232, 0.0191, -0.0066}, 113 | {-0.0213, 0.0019, 0.0043, ..., 0.0561, 0.0170, 0.0256}}} 114 | ``` 115 | 116 | ## Fine-Tuning 117 | 118 | 1. Extract the ONNX model's weight to GoMLX `Context`: see `Model.VariablesToContext()`. 119 | 2. Use `Model.CallGraph()` in your GoMLX model function (see example just above). 120 | 3. Train model as usual in GoMLX. 121 | 4. Depending how you are going to use the model: 122 | 1. Save the model as a GoMLX checkpoint, as usual. 123 | 2. Save the model by updating the ONNX model: after training use `Model.ContextToONNX()` to copy the update variable 124 | values from GoMLX `Context` back to the ONNX model (in-memory), and then use `Model.Write()` or 125 | `Model.SaveToFile()` to save the updated ONNX model to disk. 126 | 127 | ## Benchmarks 128 | 129 | We have some GoMLX/XLA and ONNX Runtime (Microsoft) benchmarks in [this spreadsheet](https://docs.google.com/spreadsheets/d/1ikpJH6rVVHq8ES-IA8U4lkKH4XsTSpRyZewXwGTgits/edit?usp=sharing), 130 | tested on the sentence encoder model we were interested in. This was used during development, and reflects 131 | how it improves performance -- the numbers on the bottom of the sheets are the currently accurate. 132 | 133 | See [docs/benchmarks.md](docs/benchmarks.md) for more information. 134 | 135 | ## 🤝 Collaborating 136 | 137 | Collaboration is very welcome: either in the form of code, or simply with ideas with real applicability. Don't 138 | hesitate starting a discussion or issue in the repository. 139 | 140 | If you are interested, we have two notebooks we use to compare results. They are a good starting point for anyone curious: 141 | 142 | * [Go Version](https://github.com/gomlx/onnx-gomlx/blob/main/onnx-go.ipynb) 143 | * [ONNX Python Version](https://github.com/gomlx/onnx-gomlx) 144 | 145 | ## 🥳 Thanks 146 | 147 | * This project was born from brainstorming with the talented folks at [KnightAnalytics](https://www.knightsanalytics.com/). 148 | Without their insights and enthusiasm this wouldn't have gotten off the ground. 149 | * [ONNX models](https://onnx.ai/) is such a nice open source standard to communicate models across different implementations. 150 | * [OpenXLA/XLA](https://github.com/openxla/xla) the open-source backend engine by Google that powers this implementation. 151 | * Sources of inspiration: 152 | * https://github.com/knights-analytics/onnx-gomlx 153 | * https://github.com/oramasearch/onnx-go 154 | -------------------------------------------------------------------------------- /docs/CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # v0.2.3 2025/05/31 2 | 3 | * Added Save/Check values of outputs for internal/benchmarks: allows it to be 4 | used as a functional test during the development of GoMLX SimpleGo backend. 5 | * Updated dependencies to latest GoMLX v0.19.5 6 | 7 | # v0.2.2 2025/05/22 8 | 9 | * Added Min and Max operators. 10 | * Updated dependency to GoMLX v0.19.3. 11 | 12 | # v0.2.1 2025/05/01 13 | 14 | * Updated to GoMLX v0.19.1 15 | * Included default GoMLX backends by default. 16 | 17 | # v0.2.0 2025/02/02 18 | 19 | * Updated to GoMLX v0.17.0 20 | * Added bitwise operators. 21 | * Added parallel benchmarks. 22 | * Added benchmarks documentation. 23 | 24 | # v0.1.5 🎄 2024/12/19 🎄 25 | 26 | * Added `internal/bechmarks` package: See progress in https://docs.google.com/spreadsheets/d/1ikpJH6rVVHq8ES-IA8U4lkKH4XsTSpRyZewXwGTgits/edit?gid=1753191050#gid=1753191050 27 | * Benchmark ONNX models with XLA, ONNX Runtime (ORT), CPU and GPU 28 | * Very simple models 29 | * KnightsAnalytics/all-MiniLM-L6-v2 30 | * Slices (parts of) KnightsAnalytics/all-MiniLM-L6-v2 31 | * Updated dependencies to GoMLX 0.16.1 with lots of accelerations. 32 | 33 | # v0.1.4 - 2024/11/28 34 | 35 | * Added Flatten op support. 36 | 37 | # v0.1.3 - 2024/11/21 38 | 39 | * Added ContextToONNX to save variables back to ONNX model (in memory). 40 | * Refactored internal/togomlx to inside onnx/ subdir. 41 | * Added Model.Write and Model.SaveToFile. 42 | 43 | # v0.1.2 - 2024/11/17 44 | 45 | * Added LSTM op support, with a small example. 46 | 47 | # v0.1.1 - 2024/11/15 48 | 49 | * Assume some variables are constant during constant-expression evaluation. 50 | * Improved pretty-printing of attributes: include their values for small values. 51 | * New ops: Range, Tile, CumSum, Not, Tanh, GatherElements, several standard unary and binary operators. 52 | * Fixed ops: Where. 53 | 54 | # v0.1.0 55 | 56 | * First working version -- for a few models. 57 | * Constant-expression evaluation during model build: needed for parameters that are fed dynamically 58 | to ONNX, but require static values in GoMLX/XLA. -------------------------------------------------------------------------------- /docs/benchmarks.md: -------------------------------------------------------------------------------- 1 | # Benchmarks 2 | 3 | The first use case for onnx-gomlx (that prompted us to start the project) was to allow serving (inference) and 4 | fine-tuning of sentence encoder models for KnightAnalytics using XLA. 5 | 6 | So the benchmarks use that sentence-encoder model as reference. There are two variations: 7 | 8 | 1. Benchmarks the [_KnightsAnalytics/all-MiniLM-L6-v2_](https://huggingface.co/KnightsAnalytics/all-MiniLM-L6-v2) fine-tuned 9 | sentence encoder model with random sentences truncated to 128 tokens, sampled from 10 | [HuggingFaceFW/fineweb](https://huggingface.co/datasets/HuggingFaceFW/fineweb). 11 | See `internal/benchmarks/knights_sbert_test.go`. 12 | 2. Benchmarks the same [_KnightsAnalytics/all-MiniLM-L6-v2_](https://huggingface.co/KnightsAnalytics/all-MiniLM-L6-v2) 13 | model with a small fixed list of titles (~13 tokens). This is the `internal/benchmarks/rob_sentences.go` test. 14 | 15 | The benchmarks cover both GoMLX+XLA/PJRT execution and ORT (Microsoft ONNX Runtime), both in CPU and GPU versions. 16 | It also include throughput measure, if using parallelization. 17 | 18 | The benchmark include **full** model benchmarking (`Full` suffix) or partial model benchmarking -- this was used 19 | during development of **onnx-gomlx** to identify slowness. 20 | 21 | ## Glossary 22 | 23 | * **ORT**: ONNX Runtime, the Microsoft runtime to execute ONNX models. Supports CPU and GPUs 24 | * **XLA/PJRT**: A Google library to JIT-compile ML models and then execute them. Supports CPU and various accelerators (GPU, ROCm, TPU,etc) 25 | 26 | ## Practical considerations 27 | 28 | * If using Intel CPUs (or any heterogeneous CPU set up), make sure to exclude the slow cores (E-cores in intel). 29 | In Linux, with Intel 12K900, use for instance `sudo chcpu -d 16-23` or `taskset 0xFFFF ./benchmarks.test...`. 30 | See running benchmarks examples below. 31 | * For tests that use ONNXRuntime you must set ORT_SO_PATH to point to your ORT `.so` file (either CPU or the CUDA one). 32 | * There is lots of variance. Literally, even the weather may impact it: I suppose the CPU/GPU can be 33 | temperature-throttled differently. 34 | 35 | ## Running benchmarks 36 | 37 | The benchmark code is in `internal/benchmarks`, and the examples below assume you are in that subdirectory. 38 | 39 | Notice the ORT `.so` files are installed in `~/.local/lib` for these examples. 40 | 41 | You have to set the flag `--bench_duration=10s` (or some other amount of time). If you leave it at the default 0, 42 | it won't run any benchmark test. These are not Go's benchmarks, but rather built as tests. 43 | 44 | Example of commands to run benchmarks in an Intel 12K900: 45 | 46 | * Running KnightsAnalystics SBert with GoMLX/XLA + ORT CPU: 47 | 48 | ``` 49 | go test -c . && GOMLX_BACKEND=xla:cpu ORT_SO_PATH=/home/janpf/.local/lib/libonnxruntime.so taskset 0xFFFF ./benchmarks.test -test.run=BenchKnights.*Full -test.v --bench_duration=10s 50 | ``` 51 | 52 | * Running KnightsAnalystics SBert with ORT GPU: 53 | 54 | ``` 55 | go test -c . && GOMLX_BACKEND=xla:cuda ORT_SO_PATH=/home/janpf/.local/lib/libonnxruntime_gpu .so taskset 0xFFFF ./benchmarks.test -test.run=BenchKnights.*Full -test.v --bench_duration=10s 56 | ``` 57 | 58 | * Experimenting with XLA_FLAGS: 59 | 60 | ``` 61 | $ go test -c . && GOMLX_BACKEND=xla:cpu XLA_FLAGS='--xla_cpu_enable_fast_math=true --xla_cpu_fast_math_honor_nans=false --xla_cpu_fast_math_honor_infs=false --xla_cpu_fast_math_honor_division=false --xla_cpu_fast_math_honor_functions=false --xla_cpu_enable_concurrency_optimized_scheduler=true' ORT_SO_PATH=/home/janpf/lib/libonnxruntime.so taskset 0xFFFF ./benchmarks.test -test.run=BenchKnights.*FullXLA -test.v --bench_duration=10s 62 | ``` 63 | 64 | * Running RobSentences benchmark -- which includes parallelization on CPU: 65 | 66 | ``` 67 | $ go test -c . && GOMLX_BACKEND=xla:cpu ORT_SO_PATH=/home/janpf/.local/lib/libonnxruntime.so ./benchmarks.test -test.run=BenchRob -test.v --bench_duration=10s 68 | ``` 69 | 70 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/gomlx/onnx-gomlx 2 | 3 | go 1.24 4 | 5 | toolchain go1.24.2 6 | 7 | require ( 8 | github.com/chewxy/math32 v1.11.1 9 | github.com/daulet/tokenizers v1.20.2 10 | github.com/gomlx/exceptions v0.0.3 11 | github.com/gomlx/go-huggingface v0.1.1 12 | github.com/gomlx/gomlx v0.19.4 13 | github.com/gomlx/gopjrt v0.7.1 14 | github.com/janpfeifer/go-benchmarks v0.1.1 15 | github.com/janpfeifer/must v0.2.0 16 | github.com/parquet-go/parquet-go v0.24.0 17 | github.com/pkg/errors v0.9.1 18 | github.com/stretchr/testify v1.10.0 19 | github.com/yalue/onnxruntime_go v1.13.0 20 | google.golang.org/protobuf v1.36.6 21 | k8s.io/klog/v2 v2.130.1 22 | ) 23 | 24 | require ( 25 | github.com/andybalholm/brotli v1.1.0 // indirect 26 | github.com/davecgh/go-spew v1.1.1 // indirect 27 | github.com/dustin/go-humanize v1.0.1 // indirect 28 | github.com/go-logr/logr v1.4.2 // indirect 29 | github.com/google/uuid v1.6.0 // indirect 30 | github.com/klauspost/compress v1.17.9 // indirect 31 | github.com/kr/pretty v0.3.1 // indirect 32 | github.com/mattn/go-runewidth v0.0.16 // indirect 33 | github.com/olekukonko/tablewriter v0.0.5 // indirect 34 | github.com/pierrec/lz4/v4 v4.1.21 // indirect 35 | github.com/pmezard/go-difflib v1.0.0 // indirect 36 | github.com/rivo/uniseg v0.4.7 // indirect 37 | github.com/streadway/quantile v0.0.0-20220407130108-4246515d968d // indirect 38 | github.com/x448/float16 v0.8.4 // indirect 39 | golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0 // indirect 40 | golang.org/x/sys v0.32.0 // indirect 41 | gopkg.in/yaml.v3 v3.0.1 // indirect 42 | ) 43 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/andybalholm/brotli v1.1.0 h1:eLKJA0d02Lf0mVpIDgYnqXcUn0GqVmEFny3VuID1U3M= 2 | github.com/andybalholm/brotli v1.1.0/go.mod h1:sms7XGricyQI9K10gOSf56VKKWS4oLer58Q+mhRPtnY= 3 | github.com/chewxy/math32 v1.11.1 h1:b7PGHlp8KjylDoU8RrcEsRuGZhJuz8haxnKfuMMRqy8= 4 | github.com/chewxy/math32 v1.11.1/go.mod h1:dOB2rcuFrCn6UHrze36WSLVPKtzPMRAQvBvUwkSsLqs= 5 | github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= 6 | github.com/daulet/tokenizers v1.20.2 h1:tlq/vIOiBTKDPets3596aFvmJYLn3XI6LFKq4q9LKhQ= 7 | github.com/daulet/tokenizers v1.20.2/go.mod h1:tGnMdZthXdcWY6DGD07IygpwJqiPvG85FQUnhs/wSCs= 8 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 9 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 10 | github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= 11 | github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= 12 | github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY= 13 | github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= 14 | github.com/gomlx/exceptions v0.0.3 h1:HKnTgEjj4jlmhr8zVFkTP9qmV1ey7ypYYosQ8GzXWuM= 15 | github.com/gomlx/exceptions v0.0.3/go.mod h1:uHL0TQwJ0xaV2/snJOJV6hSE4yRmhhfymuYgNredGxU= 16 | github.com/gomlx/go-huggingface v0.1.1 h1:N+I8nfXZ+pA4KfWR3Nlk9RaxSy6Cyeg5Q+XIr6OlUJQ= 17 | github.com/gomlx/go-huggingface v0.1.1/go.mod h1:xbiYrFHjJFkJZVSxYasYwYBvyy7EQZikTaQ25ox57Gs= 18 | github.com/gomlx/gomlx v0.19.1 h1:SHuj1yVqrg4E/DX74eVfdNQ/xW5ijhBRkZ8USkRZcUk= 19 | github.com/gomlx/gomlx v0.19.1/go.mod h1:RdhIh7sixw3HWFevfiM7DOPu2SAATjKcOzBI/ZV7Inc= 20 | github.com/gomlx/gomlx v0.19.3 h1:Vr0MK/657uiRO64DT5GCZSf3uJ1sVNq2vfRSPLAHcrU= 21 | github.com/gomlx/gomlx v0.19.3/go.mod h1:RdhIh7sixw3HWFevfiM7DOPu2SAATjKcOzBI/ZV7Inc= 22 | github.com/gomlx/gomlx v0.19.4 h1:71I3VeBS00IUUBZcE9mZxmMx65qL7O3CtumaHK6ghp0= 23 | github.com/gomlx/gomlx v0.19.4/go.mod h1:6zkDUqqdEl16DgoVZrTaEH643gf3wvQFAknnGgFQXOE= 24 | github.com/gomlx/gopjrt v0.7.0 h1:7TwlK+mRGTkqQYHemxwzIPGGIG9Q8fN7GnmqRspRMHY= 25 | github.com/gomlx/gopjrt v0.7.0/go.mod h1:HJn0wemLuFPxHr7P7zvuiloLR2OsLv60HV/4obxTXVc= 26 | github.com/gomlx/gopjrt v0.7.1 h1:OGZbMN7CCn2dU+CDr65InDO0XPmrC5NUnJ9STXa/lXE= 27 | github.com/gomlx/gopjrt v0.7.1/go.mod h1:VswjttDY1uSllQ+Vs69P4kgsH3EkFEHADUCdDbfgh0Y= 28 | github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= 29 | github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= 30 | github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= 31 | github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= 32 | github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM= 33 | github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg= 34 | github.com/janpfeifer/go-benchmarks v0.1.1 h1:gLLy07/JrOKSnMWeUxSnjTdhkglgmrNR2IBDnR4kRqw= 35 | github.com/janpfeifer/go-benchmarks v0.1.1/go.mod h1:5AagXCOUzevvmYFQalcgoa4oWPyH1IkZNckolGWfiSM= 36 | github.com/janpfeifer/must v0.2.0 h1:yWy1CE5gtk1i2ICBvqAcMMXrCMqil9CJPkc7x81fRdQ= 37 | github.com/janpfeifer/must v0.2.0/go.mod h1:S6c5Yg/YSMR43cJw4zhIq7HFMci90a7kPY9XA4c8UIs= 38 | github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA= 39 | github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= 40 | github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= 41 | github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= 42 | github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= 43 | github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= 44 | github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI= 45 | github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc= 46 | github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= 47 | github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec= 48 | github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY= 49 | github.com/parquet-go/parquet-go v0.24.0 h1:VrsifmLPDnas8zpoHmYiWDZ1YHzLmc7NmNwPGkI2JM4= 50 | github.com/parquet-go/parquet-go v0.24.0/go.mod h1:OqBBRGBl7+llplCvDMql8dEKaDqjaFA/VAPw+OJiNiw= 51 | github.com/pierrec/lz4/v4 v4.1.21 h1:yOVMLb6qSIDP67pl/5F7RepeKYu/VmTyEXvuMI5d9mQ= 52 | github.com/pierrec/lz4/v4 v4.1.21/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4= 53 | github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= 54 | github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= 55 | github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= 56 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 57 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 58 | github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= 59 | github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= 60 | github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= 61 | github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= 62 | github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= 63 | github.com/streadway/quantile v0.0.0-20220407130108-4246515d968d h1:X4+kt6zM/OVO6gbJdAfJR60MGPsqCzbtXNnjoGqdfAs= 64 | github.com/streadway/quantile v0.0.0-20220407130108-4246515d968d/go.mod h1:lbP8tGiBjZ5YWIc2fzuRpTaz0b/53vT6PEs3QuAWzuU= 65 | github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= 66 | github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= 67 | github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM= 68 | github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg= 69 | github.com/yalue/onnxruntime_go v1.13.0 h1:5HDXHon3EukQMyYA7yPMed/raWaDE/gjwLOwnVoiwy8= 70 | github.com/yalue/onnxruntime_go v1.13.0/go.mod h1:b4X26A8pekNb1ACJ58wAXgNKeUCGEAQ9dmACut9Sm/4= 71 | golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0 h1:R84qjqJb5nVJMxqWYb3np9L5ZsaDtB+a39EqjV0JSUM= 72 | golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0/go.mod h1:S9Xr4PYopiDyqSyp5NjCrhFrqg6A5zA2E/iPHPhqnS8= 73 | golang.org/x/sys v0.32.0 h1:s77OFDvIQeibCmezSnk/q6iAfkdiQaJi4VzroCFrN20= 74 | golang.org/x/sys v0.32.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= 75 | google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY= 76 | google.golang.org/protobuf v1.36.6/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY= 77 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 78 | gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= 79 | gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= 80 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 81 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 82 | k8s.io/klog/v2 v2.130.1 h1:n9Xl7H1Xvksem4KFG4PYbdQCQxqc/tTUyrgXaOhHSzk= 83 | k8s.io/klog/v2 v2.130.1/go.mod h1:3Jpz1GvMt720eyJH1ckRHK1EDfpxISzJ7I9OYgaDtPE= 84 | -------------------------------------------------------------------------------- /internal/benchmarks/add1.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gomlx/onnx-gomlx/c68337bcf8365e94e3691fa67192ab1e38a3c77a/internal/benchmarks/add1.onnx -------------------------------------------------------------------------------- /internal/benchmarks/add1div2.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gomlx/onnx-gomlx/c68337bcf8365e94e3691fa67192ab1e38a3c77a/internal/benchmarks/add1div2.onnx -------------------------------------------------------------------------------- /internal/benchmarks/benchmarks.go: -------------------------------------------------------------------------------- 1 | // Package benchmarks implements support functionality for the benchmark tests for ONNX models 2 | // in XLA and in ONNXRuntime. 3 | package benchmarks 4 | 5 | import ( 6 | "fmt" 7 | "math" 8 | "testing" 9 | 10 | _ "github.com/gomlx/gomlx/backends/default" 11 | "github.com/gomlx/gomlx/types/tensors" 12 | "github.com/pkg/errors" 13 | "github.com/stretchr/testify/require" 14 | ) 15 | 16 | // requireSameTensorsFloat32 compares two tensors, and fails the test if they are not within a delta margin. 17 | func requireSameTensorsFloat32(t *testing.T, want, got *tensors.Tensor, delta float64) { 18 | // Make sure shapes are the same. 19 | require.True(t, got.Shape().Equal(want.Shape())) 20 | flatIdx := 0 21 | gotFlat := tensors.CopyFlatData[float32](got) 22 | wantFlat := tensors.CopyFlatData[float32](want) 23 | var mismatches int 24 | for indices := range got.Shape().Iter() { 25 | gotValue := gotFlat[flatIdx] 26 | wantValue := wantFlat[flatIdx] 27 | if math.Abs(float64(gotValue)-float64(wantValue)) > delta { 28 | if mismatches < 3 { 29 | fmt.Printf("\tIndex %v (flatIdx=%d) has a mismatch: got %f, want %f\n", indices, flatIdx, gotValue, wantValue) 30 | } else if mismatches == 4 { 31 | fmt.Printf("\t...\n") 32 | } 33 | mismatches++ 34 | } 35 | flatIdx++ 36 | } 37 | if mismatches > 0 { 38 | fmt.Printf("Found %d mismatches in tensors\n", mismatches) 39 | panic(errors.Errorf("found %d mismatches in tensors", mismatches)) 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /internal/benchmarks/knights_sbert_test.go: -------------------------------------------------------------------------------- 1 | package benchmarks 2 | 3 | import ( 4 | "flag" 5 | "fmt" 6 | "os" 7 | "path" 8 | "runtime" 9 | "strings" 10 | "sync" 11 | "testing" 12 | "unicode/utf8" 13 | 14 | dtok "github.com/daulet/tokenizers" 15 | "github.com/gomlx/exceptions" 16 | "github.com/gomlx/go-huggingface/hub" 17 | "github.com/gomlx/gomlx/graph" 18 | "github.com/gomlx/gomlx/graph/graphtest" 19 | "github.com/gomlx/gomlx/ml/context" 20 | "github.com/gomlx/gomlx/types/shapes" 21 | "github.com/gomlx/gomlx/types/tensors" 22 | "github.com/gomlx/gopjrt/dtypes" 23 | "github.com/gomlx/onnx-gomlx/internal/protos" 24 | "github.com/gomlx/onnx-gomlx/onnx" 25 | benchmarks "github.com/janpfeifer/go-benchmarks" 26 | "github.com/janpfeifer/must" 27 | parquet "github.com/parquet-go/parquet-go" 28 | ort "github.com/yalue/onnxruntime_go" 29 | "google.golang.org/protobuf/proto" 30 | ) 31 | 32 | var ( 33 | // HuggingFace authentication token read from environment. 34 | // It can be created in https://huggingface.co 35 | // Some files may require it for downloading. 36 | hfAuthToken = os.Getenv("HF_TOKEN") 37 | 38 | KnightsAnalyticsSBertID = "KnightsAnalytics/all-MiniLM-L6-v2" 39 | FineWebID = "HuggingFaceFW/fineweb" 40 | FineWebSampleFile = "sample/10BT/000_00000.parquet" 41 | 42 | // Benchmark hyperparameters. 43 | BatchSizes = []int{1, 16, 32, 64} // {1, 16, 64} 44 | SequenceLength = 128 // Shouldn't be changed, since the tokenizer is hard-coded to pad to 128. 45 | NumSentences = 128 // 10_000 46 | 47 | flagBenchDuration = flag.Duration("bench_duration", 0, "Benchmark duration, typically use 10 seconds. If left as 0, benchmark tests are disabled") 48 | flagPrintXLAGraph = flag.Bool("xla_graph", false, "Prints XLA graph") 49 | flagExcludePadded = flag.Bool("exclude_padded", false, "Exclude sentences with less than 128 tokens") 50 | 51 | // Save embeddings to file, one per example, named /tmp/embeddings_%03d.bin 52 | flagSaveEmbeddings = flag.Bool("save_embeddings", false, "Save embeddings to file, one per example, named embeddings-%03d.bin") 53 | flagCheckEmbeddings = flag.Bool("check_embeddings", false, "Check embeddings generated match the ones loaded from files") 54 | ) 55 | 56 | // tokenizedSentence stores the tokenized input for models of a sentence. 57 | type tokenizedSentence struct { 58 | Encoding [3][]int64 // IDs, Masks, tokenTypeIDs 59 | } 60 | 61 | // fineWebEntry: inspection of fields in parquet file done with tool in 62 | // github.com/xitongsys/parquet-go/tool/parquet-tools. 63 | // 64 | // The parquet annotations are described in: https://pkg.go.dev/github.com/parquet-go/parquet-go#SchemaOf 65 | type fineWebEntry struct { 66 | Text string `parquet:"text,snappy"` 67 | ID string `parquet:"id,snappy"` 68 | Dump string `parquet:"dump,snappy"` 69 | URL string `parquet:"url,snappy"` 70 | Score float64 `parquet:"language_score"` 71 | } 72 | 73 | // trimString returns s trimmed to at most maxLength runes. If trimmed it appends "…" at the end. 74 | func trimString(s string, maxLength int) string { 75 | if utf8.RuneCountInString(s) <= maxLength { 76 | return s 77 | } 78 | runes := []rune(s) 79 | return string(runes[:maxLength-1]) + "…" 80 | } 81 | 82 | func padOrTrim[T any](n int, values []T, padding T) []T { 83 | if len(values) >= n { 84 | return values[:n] 85 | } 86 | newValues := make([]T, n) 87 | copy(newValues, values) 88 | for ii := len(values); ii < n; ii++ { 89 | newValues[ii] = padding 90 | } 91 | return newValues 92 | } 93 | 94 | // sampleFineWeb returns the first n tokenized sentences from a 2Gb sample of the FineWeb dataset. 95 | // 96 | // The modelID is used to download the tokenization model. 97 | // 98 | // sequenceLen is the length of each sentence in number of tokens. 99 | // If the original sentence is longer, it is truncated. 100 | // If it is shorter, it is padded. 101 | func sampleFineWeb(modelID string, n, sequenceLen int) []tokenizedSentence { 102 | results := make([]tokenizedSentence, n) 103 | 104 | // Download repo file. 105 | repo := hub.New(FineWebID).WithType(hub.RepoTypeDataset).WithAuth(hfAuthToken) 106 | localSampleFile := must.M1(repo.DownloadFile(FineWebSampleFile)) 107 | 108 | // Parquet reading using parquet-go: it's somewhat cumbersome (to open the file it needs its size!?), but it works. 109 | schema := parquet.SchemaOf(&fineWebEntry{}) 110 | fSize := must.M1(os.Stat(localSampleFile)).Size() 111 | fReader := must.M1(os.Open(localSampleFile)) 112 | fParquet := must.M1(parquet.OpenFile(fReader, fSize)) 113 | reader := parquet.NewGenericReader[fineWebEntry](fParquet, schema) 114 | defer func() { _ = reader.Close() }() 115 | 116 | // Create tokenizer: it is configured by the "tokenizer.json" to a max_length of 128, with padding. 117 | repoTokenizer := hub.New(modelID).WithAuth(hfAuthToken) 118 | localFile := must.M1(repoTokenizer.DownloadFile("tokenizer.json")) 119 | tokenizer := must.M1(dtok.FromFile(localFile)) 120 | defer func() { _ = tokenizer.Close() }() 121 | 122 | // Read a batch at a time and tokenize. 123 | const maxBatchSize = 32 124 | current := 0 125 | for current < n { 126 | batchSize := min(maxBatchSize, n-current) 127 | rows := make([]fineWebEntry, batchSize) 128 | numRead := must.M1(reader.Read(rows)) 129 | if numRead == 0 { 130 | break 131 | } 132 | for _, row := range rows { 133 | encoding := tokenizer.EncodeWithOptions(row.Text, false, 134 | dtok.WithReturnTypeIDs(), 135 | dtok.WithReturnAttentionMask(), 136 | ) 137 | 138 | if *flagExcludePadded { 139 | var countMasked int 140 | for _, id := range encoding.AttentionMask { 141 | if id == 0 { 142 | countMasked++ 143 | } 144 | } 145 | if countMasked > 0 { 146 | continue 147 | } 148 | } 149 | 150 | results[current].Encoding[0] = padOrTrim(sequenceLen, 151 | sliceMap(encoding.IDs, func(id uint32) int64 { return int64(id) }), 152 | 0) 153 | results[current].Encoding[1] = padOrTrim(sequenceLen, 154 | sliceMap(encoding.AttentionMask, func(id uint32) int64 { return int64(id) }), 155 | 0) 156 | results[current].Encoding[2] = padOrTrim(sequenceLen, 157 | sliceMap(encoding.TypeIDs, func(id uint32) int64 { return int64(id) }), 158 | 0) 159 | current++ 160 | } 161 | } 162 | if current < n { 163 | exceptions.Panicf("requested %d sentences to sample, got only %d", n, current) 164 | } 165 | return results 166 | } 167 | 168 | var ( 169 | tokenizedExamples []tokenizedSentence 170 | tokenizedExamplesOnce sync.Once 171 | ) 172 | 173 | func initTokenizedExamples() { 174 | tokenizedExamplesOnce.Do(func() { 175 | fmt.Printf("Tokenizing %d sentences of length %d...\n", NumSentences, SequenceLength) 176 | tokenizedExamples = sampleFineWeb(KnightsAnalyticsSBertID, NumSentences, SequenceLength) 177 | fmt.Printf("\tfinished tokenizing.\n") 178 | }) 179 | } 180 | 181 | func benchmarkONNXModelWithXLA(withHeader bool, name, onnxModelPath string, batchSize int, 182 | targetNodeNames ...string) { 183 | initTokenizedExamples() 184 | if NumSentences < batchSize { 185 | exceptions.Panicf("batchSize(%d) must be <= to the number of sentences sampled (%d)", batchSize, NumSentences) 186 | } 187 | 188 | // Build model 189 | backend := graphtest.BuildTestBackend() 190 | model := must.M1(onnx.ReadFile(onnxModelPath)) 191 | ctx := context.New() 192 | must.M(model.VariablesToContext(ctx)) 193 | ctx = ctx.Reuse() 194 | exec := context.NewExec(backend, ctx, func(ctx *context.Context, tokenIDs, attentionMask, tokenTypeIDs *graph.Node) *graph.Node { 195 | //fmt.Printf("Exec inputs (tokens, mask, types): %s, %s, %s\n", tokenIDs.Shape(), attentionMask.Shape(), tokenTypeIDs.Shape()) 196 | g := tokenIDs.Graph() 197 | outputs := model.CallGraph(ctx, g, 198 | map[string]*graph.Node{ 199 | "input_ids": tokenIDs, 200 | "attention_mask": attentionMask, 201 | "token_type_ids": tokenTypeIDs, 202 | }, targetNodeNames...) 203 | if *flagPrintXLAGraph { 204 | fmt.Printf("Graph:\n%s\n", g) 205 | } 206 | return outputs[0] 207 | }) 208 | defer exec.Finalize() 209 | 210 | // Create input tensors: 211 | var inputTensors [3]*tensors.Tensor // tokenIDs, attentionMask, tokenTypeIDs 212 | for ii := range inputTensors { 213 | inputTensors[ii] = tensors.FromShape(shapes.Make(dtypes.Int64, batchSize, SequenceLength)) 214 | } 215 | 216 | runIdx := 0 217 | sentenceIdx := 0 218 | testFn := benchmarks.NamedFunction{ 219 | Name: fmt.Sprintf("XLA/%s/BatchSize=%2d:", name, batchSize), 220 | Func: func() { 221 | // Create batch for each input tensor. 222 | for inputIdx, t := range inputTensors { 223 | tensors.MutableFlatData[int64](t, func(flat []int64) { 224 | for exampleIdx := range batchSize { 225 | sample := tokenizedExamples[sentenceIdx+exampleIdx] 226 | copy(flat[exampleIdx*SequenceLength:], sample.Encoding[inputIdx]) 227 | } 228 | }) 229 | } 230 | 231 | // Execute program. 232 | //start := time.Now() 233 | output := exec.Call(inputTensors[0], inputTensors[1], inputTensors[2])[0] 234 | tensors.ConstFlatData(output, func(flat []float32) { 235 | if runIdx == 0 { 236 | fmt.Printf("\t> Last value of result: %v\n", flat[len(flat)-1]) 237 | } 238 | }) 239 | output.FinalizeAll() 240 | 241 | //elapsed := time.Since(start) 242 | //if elapsed > 200*time.Microsecond { 243 | // fmt.Printf("runIdx=%d, sentenceIdx=%d: elapsed=%s\n", runIdx, sentenceIdx, elapsed) 244 | //} 245 | 246 | // Next batch. 247 | runIdx++ 248 | sentenceIdx += batchSize 249 | if sentenceIdx+batchSize >= NumSentences { 250 | sentenceIdx = 0 251 | } 252 | }, 253 | } 254 | 255 | runtime.LockOSThread() 256 | defer runtime.UnlockOSThread() 257 | benchmarks.New(testFn). 258 | WithWarmUps(128). 259 | WithDuration(*flagBenchDuration). 260 | WithHeader(withHeader). 261 | Done() 262 | } 263 | 264 | // ortInitFn will execute only once. 265 | var ( 266 | ortInitFn = sync.OnceFunc(func() { 267 | ortPath := os.Getenv("ORT_SO_PATH") 268 | if ortPath == "" { 269 | exceptions.Panicf("Please set environment ORT_SO_PATH with the path to your ONNX Runtime dynamic linked library") 270 | } 271 | if strings.Index(ortPath, "gpu") != -1 { 272 | ortIsCUDA = true 273 | } 274 | ort.SetSharedLibraryPath(ortPath) 275 | must.M(ort.InitializeEnvironment()) 276 | // Since we may run this function multiple times, we never destroy the environment. 277 | //defer func() { _ = ort.DestroyEnvironment() }() 278 | }) 279 | ortIsCUDA bool 280 | ) 281 | 282 | func benchmarkONNXModelWithORT(withHeader bool, 283 | name, onnxModelPath string, batchSize int, 284 | outputNodeName string, outputNodeShape shapes.Shape) { 285 | ortInitFn() 286 | 287 | // Tokenize examples from FineWeb (or from testSentences) 288 | initTokenizedExamples() 289 | if NumSentences < batchSize { 290 | exceptions.Panicf("batchSize(%d) must be >= to the number of sentences sampled (%d)", batchSize, NumSentences) 291 | } 292 | 293 | // Create input and output tensors. 294 | var inputTensors [3]*ort.Tensor[int64] 295 | inputShape := ort.NewShape(int64(batchSize), int64(SequenceLength)) 296 | for ii := range inputTensors { 297 | inputTensors[ii] = must.M1(ort.NewEmptyTensor[int64](inputShape)) 298 | } 299 | outputShape := ort.NewShape(sliceMap(outputNodeShape.Dimensions, func(dim int) int64 { return int64(dim) })...) 300 | outputTensor := must.M1(ort.NewEmptyTensor[float32](outputShape)) 301 | 302 | // Create session with ONNX program. 303 | var options *ort.SessionOptions 304 | if ortIsCUDA { 305 | options = must.M1(ort.NewSessionOptions()) 306 | cudaOptions := must.M1(ort.NewCUDAProviderOptions()) 307 | // must.M(cudaOptions.Update(map[string]string{"device_id": "0"})) 308 | must.M(options.AppendExecutionProviderCUDA(cudaOptions)) 309 | } 310 | session := must.M1(ort.NewAdvancedSession( 311 | onnxModelPath, 312 | []string{"input_ids", "attention_mask", "token_type_ids"}, 313 | []string{outputNodeName}, 314 | sliceMap(inputTensors[:], func(t *ort.Tensor[int64]) ort.Value { return t }), 315 | []ort.Value{outputTensor}, options)) 316 | defer func() { must.M(session.Destroy()) }() 317 | 318 | sentenceIdx := 0 319 | runIdx := 0 320 | testFn := benchmarks.NamedFunction{ 321 | Name: fmt.Sprintf("ORT/%s/BatchSize=%2d:", name, batchSize), 322 | Func: func() { 323 | // Create batch for each input tensor. 324 | for inputIdx, t := range inputTensors { 325 | flat := t.GetData() 326 | for batchIdx := range batchSize { 327 | sample := tokenizedExamples[sentenceIdx+batchIdx] 328 | copy(flat[batchIdx*SequenceLength:], sample.Encoding[inputIdx]) 329 | } 330 | } 331 | 332 | // Execute program. 333 | must.M(session.Run()) 334 | 335 | flat := outputTensor.GetData() 336 | if runIdx == 0 { 337 | fmt.Printf("\t> Last value of result: %v\n", flat[len(flat)-1]) 338 | } 339 | 340 | // Next batch. 341 | sentenceIdx += batchSize 342 | if sentenceIdx+batchSize >= NumSentences { 343 | sentenceIdx = 0 344 | } 345 | runIdx++ 346 | }, 347 | } 348 | 349 | benchmarks.New(testFn). 350 | WithWarmUps(10). 351 | WithDuration(*flagBenchDuration). 352 | WithHeader(withHeader). 353 | Done() 354 | } 355 | 356 | func TestBenchKnightsSBertFullORT(t *testing.T) { 357 | if testing.Short() || *flagBenchDuration == 0 { 358 | t.SkipNow() 359 | } 360 | repo := hub.New(KnightsAnalyticsSBertID).WithAuth(hfAuthToken) 361 | onnxModelPath := must.M1(repo.DownloadFile("model.onnx")) 362 | for ii, batchSize := range BatchSizes { 363 | benchmarkONNXModelWithORT(ii == 0, "Full", onnxModelPath, batchSize, 364 | "last_hidden_state", shapes.Make(dtypes.Float32, batchSize, SequenceLength, 384)) 365 | } 366 | } 367 | 368 | func TestBenchKnightsSBertFullXLA(t *testing.T) { 369 | if testing.Short() || *flagBenchDuration == 0 { 370 | t.SkipNow() 371 | } 372 | repo := hub.New(KnightsAnalyticsSBertID).WithAuth(hfAuthToken) 373 | onnxModelPath := must.M1(repo.DownloadFile("model.onnx")) 374 | for _, batchSize := range BatchSizes { 375 | benchmarkONNXModelWithXLA(false, "Full", onnxModelPath, batchSize) 376 | } 377 | } 378 | 379 | func recursivelyTagNode(allNodes, usedNodes map[string]*protos.NodeProto, outputName string) { 380 | if _, found := usedNodes[outputName]; found { 381 | return 382 | } 383 | node := allNodes[outputName] 384 | if node == nil { 385 | // Likely node is a variable or an input, simply ignore. 386 | return 387 | } 388 | usedNodes[outputName] = node 389 | for _, inputNode := range node.Input { 390 | recursivelyTagNode(allNodes, usedNodes, inputNode) 391 | } 392 | } 393 | 394 | // saveONNXModelWithOutput reads an ONNX model from fromPath, changes its output to 395 | // the node named newOutputNode and then saves the modified model to toPath. 396 | func saveONNXModelWithOutput(fromPath, toPath, newOutputNode string) (shapePerBatchSize map[int]shapes.Shape) { 397 | model := must.M1(onnx.ReadFile(fromPath)) 398 | 399 | // Find output shape for each batchSize. 400 | shapePerBatchSize = make(map[int]shapes.Shape, len(BatchSizes)) 401 | backend := graphtest.BuildTestBackend() 402 | ctx := context.New() 403 | must.M(model.VariablesToContext(ctx)) 404 | ctx = ctx.Reuse() 405 | for _, batchSize := range BatchSizes { 406 | g := graph.NewGraph(backend, fmt.Sprintf("batchSize=%d", batchSize)) 407 | var inputs [3]*graph.Node 408 | inputsNames := []string{"token_ids", "attention_mask", "token_type_ids"} 409 | for ii := range inputs { 410 | inputs[ii] = graph.Parameter(g, inputsNames[ii], shapes.Make(dtypes.Int64, batchSize, SequenceLength)) 411 | } 412 | output := model.CallGraph(ctx, g, 413 | map[string]*graph.Node{ 414 | "input_ids": inputs[0], 415 | "attention_mask": inputs[1], 416 | "token_type_ids": inputs[2], 417 | }, newOutputNode)[0] 418 | shapePerBatchSize[batchSize] = output.Shape().Clone() 419 | g.Finalize() 420 | fmt.Printf("\tbatch size %d: shape %s\n", batchSize, output.Shape()) 421 | } 422 | 423 | // Change output in model proto. 424 | graphProto := model.Proto.Graph 425 | newOutput := &protos.ValueInfoProto{ 426 | Name: newOutputNode, 427 | } 428 | graphProto.Output = []*protos.ValueInfoProto{newOutput} 429 | 430 | // Mark nodes that are needed to generate target output node. 431 | allNodes := make(map[string]*protos.NodeProto, 2*len(graphProto.Node)) 432 | for _, node := range graphProto.Node { 433 | for _, outputName := range node.Output { 434 | allNodes[outputName] = node 435 | } 436 | } 437 | usedNodes := make(map[string]*protos.NodeProto, len(allNodes)) 438 | recursivelyTagNode(allNodes, usedNodes, newOutputNode) 439 | fmt.Printf("\t%d nodes kept out of %d.\n", len(usedNodes), len(graphProto.Node)) 440 | graphProto.Node = make([]*protos.NodeProto, 0, len(usedNodes)) 441 | for _, node := range usedNodes { 442 | graphProto.Node = append(graphProto.Node, node) 443 | } 444 | 445 | // Save model 446 | contents := must.M1(proto.Marshal(&model.Proto)) 447 | must.M(os.WriteFile(toPath, contents, 0644)) 448 | 449 | return 450 | } 451 | 452 | // ModelSlicesOutputs points to intermediary outputs in the KnightsAnalytics/all-MiniLM-L6-v2 model. 453 | var ModelSlicesOutputs = [][2]string{ 454 | // Format: , 455 | //{"/embeddings/Add_output_0", "embeddingGather"}, 456 | //{"/embeddings/LayerNorm/Add_1_output_0", "embeddingsLayerNorm"}, 457 | 458 | //{"/embeddings/LayerNorm/ReduceMean_output_0", "ReduceMean0"}, 459 | //{"/embeddings/LayerNorm/Sub_output_0", "LayerNorm0Shifted"}, 460 | 461 | //{"/embeddings/LayerNorm/Pow_output_0", "LayerNorm0Squares"}, 462 | //{"/embeddings/LayerNorm/ReduceMean_1_output_0", "LayerNorm0SquaresMean"}, 463 | //{"/embeddings/LayerNorm/Add_output_0", "LayerNorm0SquaresMeanEpsilon"}, 464 | //{"/embeddings/LayerNorm/Sqrt_output_0", "layerNorm0Scale"}, 465 | //{"/embeddings/LayerNorm/Div_output_0", "layerNorm0ScaleNormalized"}, 466 | //{"/embeddings/LayerNorm/Mul_output_0", "layerNorm0Scaled"}, 467 | 468 | //{"/embeddings/LayerNorm/Add_1_output_0", "attentionLayer0.PreValueMul"}, 469 | {"/encoder/layer.0/attention/self/value/MatMul_output_0", "attentionLayer0.ValueMul"}, 470 | 471 | //{"/encoder/layer.0/attention/self/Reshape_3_output_0", "attentionLayer0"}, 472 | //{"/encoder/layer.0/attention/output/Add_output_0", "attentionLayer0"}, 473 | 474 | //{"/encoder/layer.1/attention/self/Reshape_3_output_0", "attentionLayer1"}, 475 | } 476 | 477 | func TestBenchKnightsSBertSliceXLA(t *testing.T) { 478 | if testing.Short() || *flagBenchDuration == 0 { 479 | t.SkipNow() 480 | } 481 | repo := hub.New(KnightsAnalyticsSBertID).WithAuth(hfAuthToken) 482 | onnxModelPath := must.M1(repo.DownloadFile("model.onnx")) 483 | for _, modelSlice := range ModelSlicesOutputs { 484 | name := modelSlice[1] 485 | outputNodeName := modelSlice[0] 486 | for ii, batchSize := range BatchSizes { 487 | displayHeader := ii == 0 488 | benchmarkONNXModelWithXLA(displayHeader, name, onnxModelPath, batchSize, outputNodeName) 489 | } 490 | } 491 | } 492 | 493 | func TestBenchKnightsSBertSliceORT(t *testing.T) { 494 | if testing.Short() || *flagBenchDuration == 0 { 495 | t.SkipNow() 496 | } 497 | repo := hub.New(KnightsAnalyticsSBertID).WithAuth(hfAuthToken) 498 | onnxModelPath := must.M1(repo.DownloadFile("model.onnx")) 499 | 500 | for _, modelSlice := range ModelSlicesOutputs { 501 | name := modelSlice[1] 502 | outputNodeName := modelSlice[0] 503 | editedModelPath := path.Join(t.TempDir(), name) + ".onnx" 504 | shapesPerBatchSize := saveONNXModelWithOutput(onnxModelPath, editedModelPath, outputNodeName) 505 | _ = shapesPerBatchSize 506 | for ii, batchSize := range BatchSizes { 507 | displayHeader := ii == 0 508 | benchmarkONNXModelWithORT(displayHeader, name, editedModelPath, batchSize, 509 | outputNodeName, shapesPerBatchSize[batchSize]) 510 | } 511 | } 512 | } 513 | -------------------------------------------------------------------------------- /internal/benchmarks/mean_norm_sqrt_add1div2.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gomlx/onnx-gomlx/c68337bcf8365e94e3691fa67192ab1e38a3c77a/internal/benchmarks/mean_norm_sqrt_add1div2.onnx -------------------------------------------------------------------------------- /internal/benchmarks/rob_sentences_embeddings.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gomlx/onnx-gomlx/c68337bcf8365e94e3691fa67192ab1e38a3c77a/internal/benchmarks/rob_sentences_embeddings.bin -------------------------------------------------------------------------------- /internal/benchmarks/rob_sentences_test.go: -------------------------------------------------------------------------------- 1 | package benchmarks 2 | 3 | // This file is an extension of knights_sbert_test but defining the test sentences on robSentences. 4 | 5 | import ( 6 | "fmt" 7 | "runtime" 8 | "strconv" 9 | "sync" 10 | "testing" 11 | "time" 12 | 13 | dtok "github.com/daulet/tokenizers" 14 | "github.com/gomlx/exceptions" 15 | "github.com/gomlx/go-huggingface/hub" 16 | "github.com/gomlx/gomlx/graph" 17 | "github.com/gomlx/gomlx/graph/graphtest" 18 | "github.com/gomlx/gomlx/ml/context" 19 | "github.com/gomlx/gomlx/types/shapes" 20 | "github.com/gomlx/gomlx/types/tensors" 21 | "github.com/gomlx/gomlx/types/xsync" 22 | "github.com/gomlx/gopjrt/dtypes" 23 | "github.com/gomlx/onnx-gomlx/onnx" 24 | "github.com/janpfeifer/go-benchmarks" 25 | "github.com/janpfeifer/must" 26 | ort "github.com/yalue/onnxruntime_go" 27 | ) 28 | 29 | var ( 30 | robSentences = []string{ 31 | "robert smith junior", 32 | "francis ford coppola", 33 | "robert smith", 34 | "Tech Innovators Inc. Launches Revolutionary AI Platform", 35 | "Green Energy Solutions Unveils Next-Gen Solar Panels", 36 | "Global Ventures Co. Secures $2 Billion in Funding", 37 | "Creative Minds Studio Launches Virtual Creativity Hub", 38 | "Healthcare Partners Ltd. Introduces AI-Driven Diagnostics", 39 | "Future Finance Group Predicts Key Market Trends for 2024", 40 | "Premier Logistics LLC Expands Into New International Markets", 41 | "Dynamic Marketing Agency Announces Strategic Partnership", 42 | "Eco-Friendly Products Corp. Debuts Sustainable Tech Line", 43 | "Blue Ocean Enterprises Leads the Way in Marine Technology", 44 | "NextGen Software Solutions Rolls Out New Cloud Suite", 45 | "Innovative Construction Co. Breaks Ground on Green Projects", 46 | "Precision Engineering Ltd. Redefines Robotics Efficiency", 47 | "Elite Consulting Group Forecasts Industry Growth in 2024", 48 | "Urban Development LLC Transforms City Skylines Nationwide", 49 | "Digital Media Concepts Sets New Standards for AI Content Delivery", 50 | "Community Builders Inc. Wins National Housing Award", 51 | "Trusted Insurance Brokers Introduces Smart Policy Options", 52 | "Advanced Manufacturing Corp. Showcases Cutting-Edge Automation", 53 | "Visionary Design Studio Redefines Modern Architecture", 54 | "Strategic Investment Partners Reveals Key Acquisitions", 55 | "Modern Retail Solutions Integrates AI Shopping Experiences", 56 | "Efficient Energy Systems Revolutionizes Grid Technology", 57 | "High-Tech Components Inc. Develops Next-Gen Processors", 58 | "Education Outreach Network Empowers Communities with New Programs", 59 | "Healthcare Innovations Ltd. Drives Breakthrough in Medical Research", 60 | "Creative Film Productions Wins Prestigious Global Awards", 61 | "Global Trade Services Expands Globalized Shipping Network", 62 | "NextLevel Sports Management Signs High-Profile Athletes", 63 | //"Sustainable Agriculture Group Promotes Organic Farming", 64 | //"Tech Innovators Inc. to Host Annual Tech Summit This Fall", 65 | //"Cloud Based Solutions Unveils New Secure Data Services", 66 | } 67 | ) 68 | 69 | // initializeRobSentences tokenizes the fixed robSentences (as opposed to using FineWeb, the default) 70 | // and trims any padding. 71 | func initializeRobSentences(minNumExamples int) []tokenizedSentence { 72 | numSentences := len(robSentences) 73 | results := make([]tokenizedSentence, max(numSentences, minNumExamples)) 74 | 75 | // Create tokenizer: it is configured by the "tokenizer.json" to a max_length of 128, with padding. 76 | repoTokenizer := hub.New(KnightsAnalyticsSBertID).WithAuth(hfAuthToken) 77 | localFile := must.M1(repoTokenizer.DownloadFile("tokenizer.json")) 78 | tokenizer := must.M1(dtok.FromFile(localFile)) 79 | defer func() { _ = tokenizer.Close() }() 80 | 81 | for idxSentence, sentence := range robSentences { 82 | encoding := tokenizer.EncodeWithOptions(sentence, false, 83 | dtok.WithReturnTypeIDs(), 84 | dtok.WithReturnAttentionMask(), 85 | ) 86 | 87 | // Find sequenceLen for sentence. 88 | sequenceLen := len(encoding.AttentionMask) 89 | for sequenceLen > 0 && encoding.AttentionMask[sequenceLen-1] == 0 { 90 | sequenceLen-- 91 | } 92 | sequenceLen = 13 93 | 94 | results[idxSentence].Encoding[0] = padOrTrim(sequenceLen, 95 | sliceMap(encoding.IDs, func(id uint32) int64 { return int64(id) }), 96 | 0) 97 | results[idxSentence].Encoding[1] = padOrTrim(sequenceLen, 98 | sliceMap(encoding.AttentionMask, func(id uint32) int64 { return int64(id) }), 99 | 0) 100 | results[idxSentence].Encoding[2] = padOrTrim(sequenceLen, 101 | sliceMap(encoding.TypeIDs, func(id uint32) int64 { return int64(id) }), 102 | 0) 103 | } 104 | 105 | // Replicate extra examples at the end. 106 | for ii := numSentences; ii < len(results); ii++ { 107 | results[ii] = results[ii-numSentences] // Keep repeating. 108 | } 109 | return results 110 | } 111 | 112 | // formatDuration formats the duration with 2 decimal places but keeping the unit suffix. 113 | func formatDuration(d time.Duration) string { 114 | s := d.String() 115 | i := 0 116 | for ; i < len(s); i++ { 117 | if (s[i] < '0' || s[i] > '9') && s[i] != '.' { 118 | break 119 | } 120 | } 121 | // Found the time unit (the suffix) 122 | num := s[:i] 123 | unit := s[i:] 124 | f, err := strconv.ParseFloat(num, 64) 125 | if err != nil { 126 | return s 127 | } 128 | return fmt.Sprintf("%.2f%s", f, unit) 129 | } 130 | 131 | func implParallelBenchmark[E any]( 132 | name string, 133 | numWorkers, batchSize int, header bool, 134 | warmUpRuns int, 135 | inputFn func() E, 136 | workerFn func(workerIdx int, e E)) { 137 | // Parallelization: 138 | var wg sync.WaitGroup 139 | done := xsync.NewLatch() 140 | 141 | // Start producer of inputs: 142 | // - We add some buffer, because we don't want the preparation of the inputs (producer) 143 | // to be a bottleneck or even accounted for. 144 | examplesChan := make(chan E, numWorkers) 145 | wg.Add(1) 146 | go func() { 147 | defer wg.Done() 148 | // Create input and output tensors. 149 | for { 150 | e := inputFn() 151 | // Write the example or interrupt. 152 | select { 153 | case <-done.WaitChan(): 154 | // Finished executing, simply exit. 155 | return 156 | case examplesChan <- e: 157 | // Move forward to produce the next input example. 158 | } 159 | } 160 | }() 161 | 162 | // Start consumers: 163 | finishedCounter := make(chan struct{}) 164 | for workerIdx := range numWorkers { 165 | wg.Add(1) 166 | go func(workerIdx int) { 167 | defer wg.Done() 168 | runtime.LockOSThread() 169 | defer runtime.UnlockOSThread() 170 | 171 | for { 172 | var e E 173 | select { 174 | case <-done.WaitChan(): 175 | return 176 | case e = <-examplesChan: 177 | // Received next input. 178 | } 179 | workerFn(workerIdx, e) 180 | select { 181 | case <-done.WaitChan(): 182 | return 183 | case finishedCounter <- struct{}{}: 184 | // Accounted for, loop to next. 185 | } 186 | } 187 | }(workerIdx) 188 | } 189 | 190 | // Benchmark function is simply reading out finished 191 | testFn := benchmarks.NamedFunction{ 192 | Name: name, 193 | Func: func() { 194 | <-finishedCounter 195 | }, 196 | } 197 | benchmarks.New(testFn). 198 | WithWarmUps(warmUpRuns). 199 | WithDuration(*flagBenchDuration). 200 | WithHeader(header). 201 | WithInnerRepeats(batchSize). // Report will be "per example". 202 | WithPrettyPrintFn(formatDuration). 203 | Done() 204 | 205 | // done.Trigger will signal all goroutines to end. 206 | done.Trigger() 207 | wg.Wait() 208 | } 209 | 210 | func implBenchRobSentencesORT(parallelization, batchSize int, header bool) { 211 | name := fmt.Sprintf("ORT/RobSentences/Parallel=%02d/BatchSize=%02d", parallelization, batchSize) 212 | outputNodeName := "last_hidden_state" 213 | embeddingSize := 384 214 | 215 | // Tokenize Rob's sentences. 216 | examples := initializeRobSentences(batchSize) 217 | if len(examples) < batchSize { 218 | exceptions.Panicf("batchSize(%d) must be <= to the number of examples (%d)", batchSize, len(examples)) 219 | } 220 | 221 | // Create session with ONNX program. 222 | ortInitFn() 223 | repoModel := hub.New(KnightsAnalyticsSBertID).WithAuth(hfAuthToken) 224 | onnxModelPath := must.M1(repoModel.DownloadFile("model.onnx")) 225 | var options *ort.SessionOptions 226 | if ortIsCUDA { 227 | options = must.M1(ort.NewSessionOptions()) 228 | cudaOptions := must.M1(ort.NewCUDAProviderOptions()) 229 | // must.M(cudaOptions.Update(map[string]string{"device_id": "0"})) 230 | must.M(options.AppendExecutionProviderCUDA(cudaOptions)) 231 | } else { 232 | if parallelization > 1 { 233 | options = must.M1(ort.NewSessionOptions()) 234 | must.M(options.SetIntraOpNumThreads(1)) 235 | must.M(options.SetInterOpNumThreads(1)) 236 | must.M(options.SetCpuMemArena(false)) 237 | must.M(options.SetMemPattern(false)) 238 | } 239 | } 240 | 241 | // Create sessions, one per parallel run. 242 | sessions := make([]*ort.DynamicAdvancedSession, parallelization) 243 | for pIdx := range parallelization { 244 | sessions[pIdx] = must.M1(ort.NewDynamicAdvancedSession( 245 | onnxModelPath, 246 | []string{"input_ids", "attention_mask", "token_type_ids"}, []string{outputNodeName}, 247 | options)) 248 | } 249 | defer func() { 250 | for _, session := range sessions { 251 | must.M(session.Destroy()) 252 | } 253 | }() 254 | 255 | // Generating examples for sessions. 256 | type ExampleInput [3]*ort.Tensor[int64] 257 | sentenceIdx := 0 258 | inputFn := func() (inputTensors ExampleInput) { 259 | sentenceLen := len(examples[sentenceIdx].Encoding[0]) 260 | inputShape := ort.NewShape(int64(batchSize), int64(sentenceLen)) 261 | for ii := range inputTensors { 262 | inputTensors[ii] = must.M1(ort.NewEmptyTensor[int64](inputShape)) 263 | } 264 | // Create batch for each input tensor. 265 | for inputIdx, t := range inputTensors { 266 | flat := t.GetData() 267 | for inBatchIdx := range batchSize { 268 | example := examples[(sentenceIdx+inBatchIdx)%len(examples)] 269 | copy(flat[inBatchIdx*sentenceLen:], example.Encoding[inputIdx]) 270 | } 271 | } 272 | // Next batch. 273 | sentenceIdx += batchSize 274 | if sentenceIdx+batchSize >= len(examples) { 275 | sentenceIdx = 0 276 | } 277 | return 278 | } 279 | 280 | // workerFn is executed in each goroutine -- one per parallelization 281 | workerFn := func(workerIdx int, inputTensors ExampleInput) { 282 | session := sessions[workerIdx] 283 | sentenceLen := inputTensors[0].GetShape()[1] 284 | outputShape := ort.NewShape(int64(batchSize), int64(sentenceLen), int64(embeddingSize)) 285 | outputTensor := must.M1(ort.NewEmptyTensor[float32](outputShape)) 286 | // Execute program. 287 | must.M(session.Run( 288 | []ort.Value{inputTensors[0], inputTensors[1], inputTensors[2]}, 289 | []ort.Value{outputTensor}, 290 | )) 291 | } 292 | 293 | // Benchmark function is simply reading out finished 294 | warmUpRuns := 10 295 | implParallelBenchmark(name, parallelization, batchSize, header, warmUpRuns, inputFn, workerFn) 296 | } 297 | 298 | const robSentencesEmbeddingsFileName = "rob_sentences_embeddings.bin" 299 | 300 | func implBenchRobSentencesXLA(t *testing.T, parallelization, batchSize int, header bool) { 301 | name := fmt.Sprintf("XLA/RobSentences/Parallel=%02d/BatchSize=%02d", parallelization, batchSize) 302 | // Make sure to release all resources no longer in use. 303 | for _ = range 10 { 304 | runtime.GC() 305 | } 306 | 307 | // Tokenize Rob's sentences. 308 | examples := initializeRobSentences(batchSize) 309 | if len(examples) < batchSize { 310 | exceptions.Panicf("batchSize(%d) must be <= to the number of examples (%d)", batchSize, len(examples)) 311 | } 312 | if (*flagSaveEmbeddings || *flagCheckEmbeddings) && batchSize != len(robSentences) { 313 | exceptions.Panicf("batchSize(%d) must be %d (all robSentences) when saving embeddings (--save_embeddings) or "+ 314 | "checking embeddings (--check_embeddings)", batchSize, len(robSentences)) 315 | } 316 | 317 | // Build model 318 | repoModel := hub.New(KnightsAnalyticsSBertID).WithAuth(hfAuthToken) 319 | onnxModelPath := must.M1(repoModel.DownloadFile("model.onnx")) 320 | backend := graphtest.BuildTestBackend() 321 | model := must.M1(onnx.ReadFile(onnxModelPath)) 322 | ctx := context.New() 323 | must.M(model.VariablesToContext(ctx)) 324 | ctx = ctx.Reuse() 325 | exec := context.NewExec(backend, ctx, func(ctx *context.Context, tokenIDs, attentionMask, tokenTypeIDs *graph.Node) *graph.Node { 326 | //fmt.Printf("Exec inputs (tokens, mask, types): %s, %s, %s\n", tokenIDs.Shape(), attentionMask.Shape(), tokenTypeIDs.Shape()) 327 | g := tokenIDs.Graph() 328 | outputs := model.CallGraph(ctx, g, 329 | map[string]*graph.Node{ 330 | "input_ids": tokenIDs, 331 | "attention_mask": attentionMask, 332 | "token_type_ids": tokenTypeIDs, 333 | }) 334 | if *flagPrintXLAGraph { 335 | fmt.Printf("Graph:\n%s\n", g) 336 | } 337 | return outputs[0] 338 | }) 339 | defer exec.Finalize() 340 | 341 | // Load expected results. 342 | var referenceEmbeddings *tensors.Tensor 343 | if *flagCheckEmbeddings { 344 | var err error 345 | referenceEmbeddings, err = tensors.Load(robSentencesEmbeddingsFileName) 346 | if err != nil { 347 | panic(err) 348 | } 349 | } 350 | 351 | // Generating examples for sessions. 352 | type ExampleInput [3]*tensors.Tensor 353 | sentenceLen := 13 354 | inputsPool := sync.Pool{ 355 | New: func() any { 356 | var inputTensors ExampleInput 357 | for ii := range inputTensors { 358 | inputTensors[ii] = tensors.FromShape(shapes.Make(dtypes.Int64, batchSize, sentenceLen)) 359 | } 360 | return inputTensors 361 | }, 362 | } 363 | nextSentenceIdx := 0 364 | inputFn := func() (inputTensors ExampleInput) { 365 | inputTensors = inputsPool.Get().(ExampleInput) 366 | for inputIdx := range inputTensors { 367 | t := inputTensors[inputIdx] 368 | tensors.MutableFlatData[int64](t, func(flat []int64) { 369 | for inBatchIdx := range batchSize { 370 | example := examples[(nextSentenceIdx+inBatchIdx)%len(examples)] 371 | copy(flat[inBatchIdx*sentenceLen:], example.Encoding[inputIdx]) 372 | } 373 | }) 374 | } 375 | // Next batch. 376 | nextSentenceIdx = (nextSentenceIdx + batchSize) % len(examples) 377 | return 378 | } 379 | 380 | if *flagSaveEmbeddings { 381 | // Run inline and save the resulting embeddings: 382 | fmt.Println("Generating embeddings to save:") 383 | inputTensors := inputFn() 384 | output := exec.Call(inputTensors[0], inputTensors[1], inputTensors[2])[0] 385 | fmt.Printf("\tSaving reference embeddings to %q - shape=%s, embedding[0, 0, 0]=%.3f, token[0, 0]=%d\n", 386 | robSentencesEmbeddingsFileName, 387 | output.Shape(), 388 | tensors.CopyFlatData[float32](output)[0], 389 | tensors.CopyFlatData[int64](inputTensors[0])[0]) 390 | err := output.Save(robSentencesEmbeddingsFileName) 391 | if err != nil { 392 | panic(err) 393 | } 394 | output.FinalizeAll() 395 | return 396 | } 397 | 398 | var workerCount int 399 | workerFn := func(workerIdx int, inputTensors ExampleInput) { 400 | defer inputsPool.Put(inputTensors) 401 | output := exec.Call(inputTensors[0], inputTensors[1], inputTensors[2])[0] 402 | tensors.ConstFlatData(output, func(flat []float32) { 403 | // Force local copy: this should be part of the cost. 404 | _ = flat 405 | }) 406 | if referenceEmbeddings != nil { 407 | requireSameTensorsFloat32(t, referenceEmbeddings, output, checkingEmbeddingsDelta) 408 | } 409 | workerCount++ 410 | output.FinalizeAll() 411 | } 412 | 413 | // Benchmark function is simply reading out finished 414 | warmUpRuns := 2 * len(examples) 415 | if *flagCheckEmbeddings { 416 | warmUpRuns = 1 417 | } 418 | implParallelBenchmark( 419 | name, parallelization, batchSize, header, warmUpRuns, inputFn, workerFn) 420 | } 421 | 422 | func TestRobSentences_BenchORT(t *testing.T) { 423 | if testing.Short() || *flagBenchDuration == 0 { 424 | t.SkipNow() 425 | } 426 | count := 0 427 | for _, parallelism := range []int{4} { // {2, 3, 4, 6, 8} { 428 | for _, batchSize := range []int{256} { // 1, 2, 4, 8, 16, 32} { 429 | implBenchRobSentencesORT(parallelism, batchSize, count == 0) 430 | count++ 431 | } 432 | } 433 | } 434 | 435 | func TestRobSentences_BenchXLA(t *testing.T) { 436 | if testing.Short() || *flagBenchDuration == 0 { 437 | t.SkipNow() 438 | } 439 | count := 0 440 | // Change parallelism/batchSize according to backend, see best values in the bottom 441 | // of the "Rob Sentences" sheet in: 442 | // https://docs.google.com/spreadsheets/d/1ikpJH6rVVHq8ES-IA8U4lkKH4XsTSpRyZewXwGTgits/edit?gid=397722581#gid=397722581 443 | for _, parallelism := range []int{48} { // {4, 6, 8} { 444 | for _, batchSize := range []int{32} { // 1, 2, 4, 8, 16, 32} { 445 | implBenchRobSentencesXLA(t, parallelism, batchSize, count == 0) 446 | count++ 447 | } 448 | } 449 | } 450 | 451 | func TestRobSentences_SaveEmbeddings(t *testing.T) { 452 | if !*flagSaveEmbeddings { 453 | fmt.Println("Skipping SaveEmbeddings test, --save_embeddings not set.") 454 | t.SkipNow() 455 | return 456 | } 457 | implBenchRobSentencesXLA(t, 1, len(robSentences), false) 458 | } 459 | 460 | const checkingEmbeddingsDelta = 1e-2 461 | 462 | func TestRobSentences_CheckEmbeddings(t *testing.T) { 463 | if !*flagCheckEmbeddings { 464 | fmt.Println("Skipping CheckEmbeddings test, --check_embeddings not set.") 465 | t.SkipNow() 466 | return 467 | } 468 | implBenchRobSentencesXLA(t, 1, len(robSentences), false) 469 | } 470 | -------------------------------------------------------------------------------- /internal/benchmarks/short_programs_test.go: -------------------------------------------------------------------------------- 1 | package benchmarks 2 | 3 | // Benchmark results for this file using: 4 | // 5 | // - GoMLX: v0.15.4rc / GoPJRT: 0.4.10rc 6 | // - ONNX Runtime v1.20.1 7 | // - Command used: 8 | // go test . -test.bench=. 9 | // 10 | // - cpu: 12th Gen Intel(R) Core(TM) i9-12900K: GOMAXPROC=24, 12 cores (4P, 8E), 24 hyperthread cores. 11 | // - results: https://docs.google.com/spreadsheets/d/1ikpJH6rVVHq8ES-IA8U4lkKH4XsTSpRyZewXwGTgits/edit?gid=0#gid=0 12 | 13 | import ( 14 | "fmt" 15 | "github.com/chewxy/math32" 16 | "github.com/gomlx/exceptions" 17 | "github.com/gomlx/gomlx/backends" 18 | "github.com/gomlx/gomlx/graph" 19 | "github.com/gomlx/gomlx/graph/graphtest" 20 | "github.com/gomlx/gomlx/types/shapes" 21 | "github.com/gomlx/gomlx/types/tensors" 22 | "github.com/gomlx/gopjrt/dtypes" 23 | "github.com/gomlx/onnx-gomlx/onnx" 24 | "github.com/janpfeifer/must" 25 | ort "github.com/yalue/onnxruntime_go" 26 | "k8s.io/klog/v2" 27 | "os" 28 | "runtime" 29 | "sync" 30 | "testing" 31 | ) 32 | 33 | func init() { 34 | klog.InitFlags(nil) 35 | } 36 | 37 | var ( 38 | TestShapes = []shapes.Shape{ 39 | //shapes.Make(dtypes.Float32, 1, 1), 40 | //shapes.Make(dtypes.Float32, 10, 10), 41 | //shapes.Make(dtypes.Float32, 100, 100), 42 | shapes.Make(dtypes.Float32, 1000, 1000), 43 | } 44 | numShapes = len(TestShapes) 45 | 46 | SmallTestPrograms = [][2]string{ 47 | {"add1.onnx", "f(x)=x+1"}, 48 | {"add1div2.onnx", "f(x)=(x+1)/2"}, 49 | {"sqrt_add1div2.onnx", "f(x)=Sqrt((x+1)/2)"}, 50 | {"mean_norm_sqrt_add1div2.onnx", "f(x)=Sqrt((x+1)/2) - ReduceMean(Sqrt((x+1)/2))"}, 51 | } 52 | numPrograms = len(SmallTestPrograms) 53 | 54 | // testGoPrograms is a Go version of the SmallTestPrograms, as a per-element function. 55 | testGoPrograms = []goVectorFunc{ 56 | parallelizeGoVectorFunc(func(inputs, outputs []float32) { 57 | for ii, v := range inputs { 58 | outputs[ii] = v + 1 59 | } 60 | }), 61 | parallelizeGoVectorFunc(func(inputs, outputs []float32) { 62 | for ii, v := range inputs { 63 | outputs[ii] = (v + 1) * 0.5 64 | } 65 | }), 66 | parallelizeGoVectorFunc(func(inputs, outputs []float32) { 67 | for ii, v := range inputs { 68 | outputs[ii] = math32.Sqrt((v + 1) * 0.5) 69 | } 70 | }), 71 | } 72 | 73 | benchmarkNameSuffix = "|GOMAXPROCS" 74 | ) 75 | 76 | // goVectorFunc defines the signature for functions that process slices. 77 | type goVectorFunc func(inputs, outputs []float32) 78 | 79 | // sliceMap executes the given function sequentially for every element on in, and returns a mapped slice. 80 | func sliceMap[In, Out any](in []In, fn func(e In) Out) (out []Out) { 81 | out = make([]Out, len(in)) 82 | for ii, e := range in { 83 | out[ii] = fn(e) 84 | } 85 | return 86 | } 87 | 88 | // parallelizeGoVectorFunc takes a goVectorFunc and parallelizes its execution if the input size is large enough. 89 | func parallelizeGoVectorFunc(fn goVectorFunc) goVectorFunc { 90 | return func(inputs, outputs []float32) { 91 | numInputs := len(inputs) 92 | if numInputs < 100_000 { // Threshold for parallelization. Tune this value. 93 | fn(inputs, outputs) 94 | return 95 | } 96 | 97 | numCPU := runtime.NumCPU() 98 | chunkSize := numInputs / numCPU 99 | var wg sync.WaitGroup 100 | wg.Add(numCPU) 101 | for i := 0; i < numCPU; i++ { 102 | start := i * chunkSize 103 | end := (i + 1) * chunkSize 104 | if i == numCPU-1 { 105 | end = numInputs // Handle any remainder 106 | } 107 | go func(start, end int) { 108 | defer wg.Done() 109 | fn(inputs[start:end], outputs[start:end]) 110 | }(start, end) 111 | } 112 | wg.Wait() 113 | } 114 | } 115 | 116 | // BenchmarkSmallXLAExec executes SmallTestPrograms on XLA using the normal Exec method. 117 | // We try not to count the time for tensor transfers in and out. 118 | func BenchmarkSmallXLAExec(b *testing.B) { 119 | // Check conversion. 120 | backend := graphtest.BuildTestBackend() 121 | execs := make([]*graph.Exec, numPrograms) 122 | for progIdx, program := range SmallTestPrograms { 123 | model := must.M1(onnx.ReadFile(program[0])) 124 | execs[progIdx] = graph.NewExec(backend, func(x *graph.Node) *graph.Node { 125 | g := x.Graph() 126 | outputs := model.CallGraph(nil, g, map[string]*graph.Node{"X": x}) 127 | return outputs[0] 128 | }) 129 | } 130 | 131 | // Pre-allocate tensors. 132 | numShapes := len(TestShapes) 133 | inputTensors := make([]*tensors.Tensor, numShapes) 134 | outputTensors := make([]*tensors.Tensor, numShapes) 135 | for shapeIdx, s := range TestShapes { 136 | inputTensors[shapeIdx] = tensors.FromShape(s) 137 | outputTensors[shapeIdx] = tensors.FromShape(s) 138 | } 139 | 140 | // Run tests for each shape/program combination. 141 | for shapeIdx, s := range TestShapes { 142 | for progIdx, program := range SmallTestPrograms { 143 | exec := execs[progIdx] 144 | b.Run(fmt.Sprintf("shape=%s/%s%s", s, program[1], benchmarkNameSuffix), 145 | func(b *testing.B) { 146 | // Set input to value of v. 147 | x := inputTensors[shapeIdx] 148 | tensors.MutableFlatData[float32](x, func(flat []float32) { 149 | for ii := range flat { 150 | flat[ii] = float32(shapeIdx*numPrograms + progIdx + 1) 151 | } 152 | }) 153 | 154 | // WarmUp: 155 | for _ = range 10 { 156 | tmpOutput := exec.Call(x)[0] 157 | tmpOutput.FinalizeAll() 158 | } 159 | b.ResetTimer() 160 | 161 | for _ = range b.N { 162 | tmpOutput := exec.Call(x)[0] 163 | tmpOutput.FinalizeAll() 164 | } 165 | }) 166 | } 167 | } 168 | } 169 | 170 | // BenchmarkSmallXLADirect benchmarks SmallTestPrograms using direct GoMLX execution. 171 | // We try not to count the time for tensor transfers in and out. 172 | func BenchmarkSmallXLADirect(b *testing.B) { 173 | // Create executables. 174 | backend := graphtest.BuildTestBackend() 175 | numShapes := len(TestShapes) 176 | graphPerShapePerProgram := make([][]*graph.Graph, numShapes) 177 | inputTensors := make([]*tensors.Tensor, numShapes) 178 | outputTensors := make([]*tensors.Tensor, numShapes) 179 | for shapeIdx, s := range TestShapes { 180 | graphPerShapePerProgram[shapeIdx] = make([]*graph.Graph, numPrograms) 181 | for progIdx, program := range SmallTestPrograms { 182 | model := must.M1(onnx.ReadFile(program[0])) 183 | g := graph.NewGraph(backend, fmt.Sprintf("Graph #%d", shapeIdx)) 184 | x := graph.Parameter(g, "X", s) 185 | y := model.CallGraph(nil, g, map[string]*graph.Node{"X": x})[0] 186 | g.Compile(y) 187 | graphPerShapePerProgram[shapeIdx][progIdx] = g 188 | inputTensors[shapeIdx] = tensors.FromShape(s) 189 | outputTensors[shapeIdx] = tensors.FromShape(s) 190 | } 191 | } 192 | 193 | // Run tests for each shape/program combination. 194 | for shapeIdx, s := range TestShapes { 195 | for progIdx, program := range SmallTestPrograms { 196 | b.Run(fmt.Sprintf("shape=%s/%s%s", s, program[1], benchmarkNameSuffix), 197 | func(b *testing.B) { 198 | // Set input to value of v. 199 | x := inputTensors[shapeIdx] 200 | tensors.MutableFlatData[float32](x, func(flat []float32) { 201 | for ii := range flat { 202 | flat[ii] = float32(shapeIdx*numPrograms + progIdx + 1) 203 | } 204 | }) 205 | xBuf := x.Buffer(backend) 206 | g := graphPerShapePerProgram[shapeIdx][progIdx] 207 | 208 | // WarmUp: 209 | for _ = range 10 { 210 | tmpOutput := g.RunWithBuffers( 211 | []backends.Buffer{xBuf}, 212 | []bool{false})[0] 213 | tmpOutput.FinalizeAll() 214 | } 215 | b.ResetTimer() 216 | 217 | for _ = range b.N { 218 | tmpOutput := g.RunWithBuffers( 219 | []backends.Buffer{xBuf}, 220 | []bool{false})[0] 221 | tmpOutput.FinalizeAll() 222 | } 223 | }) 224 | } 225 | } 226 | } 227 | 228 | // BenchmarkSmallORT benchmarks the SmallTestPrograms using ONNX Runtime. 229 | // We try not to count the time for tensor transfers in and out. 230 | func BenchmarkSmallORT(b *testing.B) { 231 | ortPath := os.Getenv("ORT_SO_PATH") 232 | if ortPath == "" { 233 | exceptions.Panicf("Please set environment ORT_SO_PATH with the path to your ONNX Runtime dynamic linked library") 234 | } 235 | ort.SetSharedLibraryPath(ortPath) 236 | must.M(ort.InitializeEnvironment()) 237 | defer func() { _ = ort.DestroyEnvironment() }() 238 | 239 | // Create a session for each tensor shape: 240 | numShapes := len(TestShapes) 241 | sessions := make([][]*ort.AdvancedSession, numShapes) 242 | inputsPerShape := make([]*ort.Tensor[float32], 0, numShapes) 243 | outputsPerShape := make([]*ort.Tensor[float32], 0, numShapes) 244 | for shapeIdx, s := range TestShapes { 245 | inputData := make([]float32, s.Size()) 246 | dims64 := sliceMap(s.Dimensions, func(dim int) int64 { return int64(dim) }) 247 | inputShape := ort.NewShape(dims64...) 248 | inputTensor := must.M1(ort.NewTensor(inputShape, inputData)) 249 | inputsPerShape = append(inputsPerShape, inputTensor) 250 | outputTensor := must.M1(ort.NewEmptyTensor[float32](inputShape)) 251 | outputsPerShape = append(outputsPerShape, outputTensor) 252 | 253 | sessions[shapeIdx] = make([]*ort.AdvancedSession, numPrograms) 254 | for progIdx, program := range SmallTestPrograms { 255 | sessions[shapeIdx][progIdx] = must.M1(ort.NewAdvancedSession( 256 | program[0], 257 | []string{"X"}, 258 | []string{"Y"}, 259 | []ort.Value{inputTensor}, 260 | []ort.Value{outputTensor}, 261 | nil)) 262 | } 263 | } 264 | 265 | // Run tests for each shape/program combination. 266 | for shapeIdx, s := range TestShapes { 267 | for progIdx, program := range SmallTestPrograms { 268 | b.Run(fmt.Sprintf("shape=%s/%s%s", s, program[1], benchmarkNameSuffix), 269 | func(b *testing.B) { 270 | // Set input to value of v. 271 | input := inputsPerShape[shapeIdx] 272 | data := input.GetData() 273 | for ii := range data { 274 | data[ii] = float32(shapeIdx*numPrograms + progIdx + 1) 275 | } 276 | session := sessions[shapeIdx][progIdx] 277 | 278 | // WarmUp: 279 | for _ = range 10 { 280 | must.M(session.Run()) 281 | } 282 | b.ResetTimer() 283 | 284 | for _ = range b.N { 285 | must.M(session.Run()) 286 | } 287 | }) 288 | } 289 | } 290 | } 291 | 292 | func BenchmarkPureGo(b *testing.B) { 293 | // Pre-allocate tensors. 294 | numShapes := len(TestShapes) 295 | inputTensors := make([]*tensors.Tensor, numShapes) 296 | outputTensors := make([]*tensors.Tensor, numShapes) 297 | for shapeIdx, s := range TestShapes { 298 | inputTensors[shapeIdx] = tensors.FromShape(s) 299 | outputTensors[shapeIdx] = tensors.FromShape(s) 300 | } 301 | 302 | // Run tests for each shape/program combination. 303 | for shapeIdx, s := range TestShapes { 304 | x := inputTensors[shapeIdx] 305 | y := outputTensors[shapeIdx] 306 | for progIdx, program := range SmallTestPrograms { 307 | b.Run(fmt.Sprintf("shape=%s/%s%s", s, program[1], benchmarkNameSuffix), 308 | func(b *testing.B) { 309 | // Set value: 310 | tensors.MutableFlatData[float32](x, func(flat []float32) { 311 | for ii := range flat { 312 | flat[ii] = float32(shapeIdx*numPrograms + progIdx + 1) 313 | } 314 | }) 315 | testProgram := testGoPrograms[progIdx] 316 | 317 | // Warm-up: 318 | for _ = range 10 { 319 | tensors.ConstFlatData[float32](x, func(inputs []float32) { 320 | tensors.MutableFlatData[float32](y, func(outputs []float32) { 321 | testProgram(inputs, outputs) 322 | }) 323 | }) 324 | } 325 | 326 | // Benchmark 327 | b.ResetTimer() 328 | for _ = range b.N { 329 | tensors.ConstFlatData[float32](x, func(inputs []float32) { 330 | tensors.MutableFlatData[float32](y, func(outputs []float32) { 331 | testProgram(inputs, outputs) 332 | }) 333 | }) 334 | } 335 | }) 336 | } 337 | } 338 | } 339 | -------------------------------------------------------------------------------- /internal/benchmarks/sqrt_add1div2.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gomlx/onnx-gomlx/c68337bcf8365e94e3691fa67192ab1e38a3c77a/internal/benchmarks/sqrt_add1div2.onnx -------------------------------------------------------------------------------- /internal/cmd/protoc_onnx_protos/main.go: -------------------------------------------------------------------------------- 1 | // protoc_onnx_protos compiles the .proto from the https://github.com/onnx/onnx sources to subpackages of 2 | // onnx-gomlx/internal/protos. 3 | // 4 | // It uses the standard `protoc` tool. Remember to update your go protoc plugin with 5 | // `go install google.golang.org/protobuf/cmd/protoc-gen-go@latest` 6 | // 7 | // It should be executed under the onnx-gomlx/internal/protos directory -- suggested as a go:generate -- 8 | // and it requires ONNX_SRC to be set to a cloned github.com/onnx/onnx repository. 9 | // 10 | // It copies over the proto files from "${ONNX_SRC}" to the current directory, editing its contents to 11 | // onnx-gomlx repository. Then it executes `protoc` to generate the Go code. 12 | package main 13 | 14 | import ( 15 | "fmt" 16 | "log" 17 | "os" 18 | "os/exec" 19 | "path/filepath" 20 | "regexp" 21 | ) 22 | 23 | const ( 24 | onnxSrcEnvVar = "ONNX_SRC" 25 | goProtosPackage = "github.com/gomlx/onnx-gomlx/internal/protos" 26 | ) 27 | 28 | var protoFiles = []string{ 29 | // We use the "onnx-ml.proto3" version of this file. See brief mention here: 30 | // https://github.com/onnx/onnx/blob/main/docs/IR.md 31 | "onnx-ml.proto", 32 | "onnx-operators-ml.proto", 33 | "onnx-data.proto", 34 | } 35 | 36 | // must log.Fatalf if err is not nil. 37 | func must(err error) { 38 | if err != nil { 39 | log.Fatalf("Error:\n%+v\n", err) 40 | } 41 | } 42 | 43 | // must1 is like must, but returns the output if there was no error. 44 | func must1[T any](t T, err error) T { 45 | must(err) 46 | return t 47 | } 48 | 49 | func main() { 50 | onnxSrc := os.Getenv(onnxSrcEnvVar) 51 | if onnxSrc == "" { 52 | fmt.Fprintf(os.Stderr, "Please set %s to the directory containing the locally cloned github.com/onnx/onnx repository.\n", onnxSrcEnvVar) 53 | os.Exit(1) 54 | } 55 | 56 | // Generate the --go_opt=M... flags 57 | goOpts := make([]string, len(protoFiles)) 58 | for ii, protoFile := range protoFiles { 59 | goOpts[ii] = fmt.Sprintf("--go_opt=M%s=%s", protoFile, goProtosPackage) 60 | } 61 | 62 | // Read file from $ONNX_SRC, remove any go_package options, rewrite the package to "protoFiles", and write to currnet 63 | // directory. 64 | for _, protoFile := range protoFiles { 65 | // Fix file and write to current directory. 66 | protoPath := filepath.Join(onnxSrc, "onnx", protoFile) + "3" // Use the version 3 proto file. 67 | protoContents := must1(os.ReadFile(protoPath)) 68 | protoContents = removeGoPackageOption(protoContents) 69 | protoContents = fixPackageName(protoContents) 70 | protoContents = fixImports(protoContents) 71 | must(os.WriteFile(protoFile, protoContents, 0644)) 72 | } 73 | 74 | // Compile each of the proto files. 75 | for _, protoFile := range protoFiles { 76 | // Construct the protoc command 77 | args := []string{ 78 | "--go_out=.", 79 | "-I=.", 80 | fmt.Sprintf("-I=%s/onnx", onnxSrc), 81 | fmt.Sprintf("--go_opt=module=%s", goProtosPackage), 82 | } 83 | args = append(args, goOpts...) 84 | args = append(args, protoFile) 85 | cmd := exec.Command("protoc", args...) 86 | cmd.Stdout = os.Stdout 87 | cmd.Stderr = os.Stderr 88 | 89 | if err := cmd.Run(); err != nil { 90 | _, _ = fmt.Fprintf(os.Stderr, "Error executing protoc for %s: %v\n", protoFile, err) 91 | _, _ = fmt.Fprintf(os.Stderr, "Command:\n%s\n", cmd) 92 | os.Exit(1) 93 | } 94 | } 95 | } 96 | 97 | var reRemoveGoPackageOption = regexp.MustCompile(`option\s+go_package\s*=\s*"[^"]*?";`) 98 | 99 | func removeGoPackageOption(content []byte) []byte { 100 | return reRemoveGoPackageOption.ReplaceAll(content, []byte{}) 101 | } 102 | 103 | var rePackageName = regexp.MustCompile(`package onnx;`) 104 | 105 | func fixPackageName(content []byte) []byte { 106 | return rePackageName.ReplaceAll(content, []byte(`package protos;`)) 107 | } 108 | 109 | var reImports = regexp.MustCompile(`import\s+"onnx/(.*)3";`) 110 | 111 | func fixImports(content []byte) []byte { 112 | return reImports.ReplaceAll(content, []byte(`import "$1";`)) 113 | } 114 | -------------------------------------------------------------------------------- /internal/protos/README.md: -------------------------------------------------------------------------------- 1 | # Protobuf Generated Files 2 | 3 | All files in this directory are generated using `onnx-gomlx/internal/cmd/protoc_onnx_protos` tool, except 4 | the `protos.go` file, which includes the one `//go:generate go run ../cmd/protoc_onnx_protos` line. 5 | 6 | Notice there are two variants of ONNX protos: `onnx.proto` and `onnx-ml.proto`, and one can't include both, 7 | since they redefine each other. Why would they do that :sad: !? (and not document it in the proto files ...) 8 | 9 | Anyway, this project takes the `onnx-ml.proto`, because according to the [IR mention](https://github.com/onnx/onnx/blob/main/docs/IR.md) 10 | it seems to be more complete, even though `onnx-gomlx` not necessarily supports all its operations. 11 | 12 | -------------------------------------------------------------------------------- /internal/protos/onnx-data.proto: -------------------------------------------------------------------------------- 1 | // 2 | // WARNING: This file is automatically generated! Please edit onnx.in.proto. 3 | // 4 | 5 | 6 | // SPDX-License-Identifier: Apache-2.0 7 | 8 | 9 | syntax = "proto3"; 10 | 11 | package protos; 12 | import "onnx-ml.proto"; 13 | 14 | // This file contains the proto definitions for MapProto and 15 | // SequenceProto. These protos are used to represent the data structures 16 | // of maps and sequence for use in test data or ModelProto. 17 | 18 | // Sequences 19 | // 20 | // Defines a dense, ordered, collection of elements that are of homogeneous types. 21 | // Sequences can be made out of tensors, maps, or sequences. 22 | // 23 | // If a sequence is made out of tensors, the tensors must have the same element 24 | // type (i.e. int32). In some cases, the tensors in a sequence can have different 25 | // shapes. Whether the tensors can have different shapes or not depends on the 26 | // type/shape associated with the corresponding "ValueInfo". For example, 27 | // "Sequence" means that all tensors have same shape. However, 28 | // "Sequence" means they can have different 29 | // shapes (all of rank 2), where "omitted" means the corresponding dimension has 30 | // no symbolic/constant value. Finally, "Sequence>" means 31 | // that the different tensors can have different ranks, when the "shape" itself 32 | // is omitted from the tensor-type. For a more complete description, refer to 33 | // https://github.com/onnx/onnx/blob/main/docs/IR.md#static-tensor-shapes. 34 | // 35 | message SequenceProto { 36 | 37 | string name = 1; 38 | 39 | enum DataType { 40 | UNDEFINED = 0; 41 | TENSOR = 1; 42 | SPARSE_TENSOR = 2; 43 | SEQUENCE = 3; 44 | MAP = 4; 45 | OPTIONAL = 5; 46 | } 47 | 48 | // The data type of the element. 49 | // This field MUST have a valid SequenceProto.DataType value 50 | int32 elem_type = 2; 51 | 52 | // For TensorProto values. 53 | // When this field is present, the elem_type field MUST be TENSOR. 54 | repeated TensorProto tensor_values = 3; 55 | 56 | // For SparseTensorProto values. 57 | // When this field is present, the elem_type field MUST be SPARSE_TENSOR. 58 | repeated SparseTensorProto sparse_tensor_values = 4; 59 | 60 | // For SequenceProto values, allowing sequences to be of themselves. 61 | // When this field is present, the elem_type field MUST be SEQUENCE. 62 | repeated SequenceProto sequence_values = 5; 63 | 64 | // For MapProto values. 65 | // When this field is present, the elem_type field MUST be MAP. 66 | repeated MapProto map_values = 6; 67 | 68 | // For OptionalProto values. 69 | // When this field is present, the elem_type field MUST be Optional. 70 | repeated OptionalProto optional_values = 7; 71 | 72 | } 73 | 74 | 75 | // Maps 76 | // 77 | // Specifies an associative table, defined by keys and values. 78 | // MapProto is formed with a repeated field of keys (of type INT8, INT16, INT32, 79 | // INT64, UINT8, UINT16, UINT32, UINT64, or STRING) and values (of type TENSOR, 80 | // SPARSE_TENSOR, SEQUENCE, or MAP). Key types and value types have to remain 81 | // the same throughout the instantiation of the MapProto. 82 | // 83 | message MapProto { 84 | 85 | string name = 1; 86 | 87 | // All MapProto data types must have the same length of keys and values. 88 | 89 | // The data type of the key. 90 | // This field MUST have a valid TensorProto.DataType value of 91 | // INT8, INT16, INT32, INT64, UINT8, UINT16, UINT32, UINT64, or STRING 92 | int32 key_type = 2; 93 | 94 | // Every element of keys has to be one of the following data types 95 | // INT8, INT16, INT32, INT64, UINT8, UINT16, UINT32, UINT64, or STRING. 96 | // The integer cases are represented by the repeated int64 field keys below. 97 | repeated int64 keys = 3; 98 | 99 | // If keys are strings, they are represented by the repeated bytes field 100 | // string_keys below. 101 | repeated bytes string_keys = 4; 102 | 103 | // MapProto values are represented in a SequenceProto of the same length as the 104 | // repeated keys field and have to be one of the following data types 105 | // TENSOR, SPARSE_TENSOR, MAP, SEQUENCE. 106 | SequenceProto values = 5; 107 | } 108 | 109 | // Optional 110 | // 111 | // 112 | message OptionalProto { 113 | 114 | string name = 1; 115 | 116 | enum DataType { 117 | UNDEFINED = 0; 118 | TENSOR = 1; 119 | SPARSE_TENSOR = 2; 120 | SEQUENCE = 3; 121 | MAP = 4; 122 | OPTIONAL = 5; 123 | } 124 | 125 | // The data type of the element, identifies if the OptionalProto value 126 | // is Tensor, Sparse Tensor, Sequence, Map, or Optional. 127 | // The type of the optional value MUST match the elem_type specified. 128 | // This field MUST have a valid OptionalProto.DataType value. 129 | int32 elem_type = 2; 130 | 131 | // For TensorProto value. 132 | // When this field is present, the elem_type field MUST be TENSOR. 133 | TensorProto tensor_value = 3; 134 | 135 | // For SparseTensorProto value. 136 | // When this field is present, the elem_type field MUST be SPARSE_TENSOR. 137 | SparseTensorProto sparse_tensor_value = 4; 138 | 139 | // For SequenceProto value. 140 | // When this field is present, the elem_type field MUST be SEQUENCE. 141 | SequenceProto sequence_value = 5; 142 | 143 | // For MapProto value. 144 | // When this field is present, the elem_type field MUST be MAP. 145 | MapProto map_value = 6; 146 | 147 | // For OptionalProto value, allowing optional to be of itself (completeness) 148 | // When this field is present, the elem_type field MUST be OPTIONAL. 149 | OptionalProto optional_value = 7; 150 | 151 | } 152 | 153 | // For using protobuf-lite 154 | option optimize_for = LITE_RUNTIME; 155 | 156 | -------------------------------------------------------------------------------- /internal/protos/onnx-operators-ml.pb.go: -------------------------------------------------------------------------------- 1 | // 2 | // WARNING: This file is automatically generated! Please edit onnx.in.proto. 3 | // 4 | 5 | // Copyright (c) ONNX Project Contributors. 6 | // Licensed under the Apache-2.0 license. 7 | 8 | // Code generated by protoc-gen-go. DO NOT EDIT. 9 | // versions: 10 | // protoc-gen-go v1.35.1 11 | // protoc v3.21.12 12 | // source: onnx-operators-ml.proto 13 | 14 | package protos 15 | 16 | import ( 17 | protoreflect "google.golang.org/protobuf/reflect/protoreflect" 18 | protoimpl "google.golang.org/protobuf/runtime/protoimpl" 19 | reflect "reflect" 20 | sync "sync" 21 | ) 22 | 23 | const ( 24 | // Verify that this generated code is sufficiently up-to-date. 25 | _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) 26 | // Verify that runtime/protoimpl is sufficiently up-to-date. 27 | _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) 28 | ) 29 | 30 | // An OperatorProto represents the immutable specification of the signature 31 | // and semantics of an operator. 32 | // 33 | // Operators are declared as part of an OperatorSet, which also defines the 34 | // domain name for the set. 35 | // 36 | // Operators are uniquely identified by a three part identifier 37 | // 38 | // (domain, op_type, since_version) 39 | // 40 | // where 41 | // 42 | // *domain* is the domain of an operator set that 43 | // contains this operator specification. 44 | // 45 | // *op_type* is the name of the operator as referenced by a 46 | // NodeProto.op_type 47 | // 48 | // *since_version* is the version of the operator set that 49 | // this operator was initially declared in. 50 | type OperatorProto struct { 51 | state protoimpl.MessageState 52 | sizeCache protoimpl.SizeCache 53 | unknownFields protoimpl.UnknownFields 54 | 55 | // The name of the operator within a domain. 56 | // This field MUST be present in this version of the IR. 57 | OpType string `protobuf:"bytes,1,opt,name=op_type,json=opType,proto3" json:"op_type,omitempty"` 58 | // The version of the operator set that first introduced this 59 | // operator. This value MUST be the same value as the 60 | // opset_version of the operator set that first published this operator. 61 | // Subsequent versions of the operator set MUST NOT alter the signature 62 | // or semantics of the operator once published as STABLE. 63 | // This field MUST be present in this version of the IR. 64 | SinceVersion int64 `protobuf:"varint,2,opt,name=since_version,json=sinceVersion,proto3" json:"since_version,omitempty"` 65 | // This field indicates whether the syntax, semantics, or presence 66 | // of this operator is in an experimental or stable stage. Once an 67 | // operator is published as STABLE, it's syntax and semantics MUST NOT 68 | // change in subsequent versions of the operator set. 69 | // When an operator is published as EXPERIMENTAL, the syntax and semantics 70 | // of the operator MAY change across operator set versions. 71 | // Operators "become" stable by deprecating the experimental version and 72 | // introducing a new stable operator with the same op_type. 73 | Status OperatorStatus `protobuf:"varint,3,opt,name=status,proto3,enum=protos.OperatorStatus" json:"status,omitempty"` 74 | // A human-readable documentation for this operator. Markdown is allowed. 75 | DocString string `protobuf:"bytes,10,opt,name=doc_string,json=docString,proto3" json:"doc_string,omitempty"` 76 | } 77 | 78 | func (x *OperatorProto) Reset() { 79 | *x = OperatorProto{} 80 | mi := &file_onnx_operators_ml_proto_msgTypes[0] 81 | ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) 82 | ms.StoreMessageInfo(mi) 83 | } 84 | 85 | func (x *OperatorProto) String() string { 86 | return protoimpl.X.MessageStringOf(x) 87 | } 88 | 89 | func (*OperatorProto) ProtoMessage() {} 90 | 91 | func (x *OperatorProto) ProtoReflect() protoreflect.Message { 92 | mi := &file_onnx_operators_ml_proto_msgTypes[0] 93 | if x != nil { 94 | ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) 95 | if ms.LoadMessageInfo() == nil { 96 | ms.StoreMessageInfo(mi) 97 | } 98 | return ms 99 | } 100 | return mi.MessageOf(x) 101 | } 102 | 103 | // Deprecated: Use OperatorProto.ProtoReflect.Descriptor instead. 104 | func (*OperatorProto) Descriptor() ([]byte, []int) { 105 | return file_onnx_operators_ml_proto_rawDescGZIP(), []int{0} 106 | } 107 | 108 | func (x *OperatorProto) GetOpType() string { 109 | if x != nil { 110 | return x.OpType 111 | } 112 | return "" 113 | } 114 | 115 | func (x *OperatorProto) GetSinceVersion() int64 { 116 | if x != nil { 117 | return x.SinceVersion 118 | } 119 | return 0 120 | } 121 | 122 | func (x *OperatorProto) GetStatus() OperatorStatus { 123 | if x != nil { 124 | return x.Status 125 | } 126 | return OperatorStatus_EXPERIMENTAL 127 | } 128 | 129 | func (x *OperatorProto) GetDocString() string { 130 | if x != nil { 131 | return x.DocString 132 | } 133 | return "" 134 | } 135 | 136 | // An OperatorSetProto represents an immutable set of immutable operator 137 | // specifications. 138 | // 139 | // The domain of the set (OperatorSetProto.domain) is a reverse-DNS name 140 | // that disambiguates operator sets defined by independent entities. 141 | // 142 | // The version of the set (opset_version) is a monotonically increasing 143 | // integer that indicates changes to the membership of the operator set. 144 | // 145 | // Operator sets are uniquely identified by a two part identifier (domain, opset_version) 146 | // 147 | // Like ModelProto, OperatorSetProto is intended as a top-level file/wire format, 148 | // and thus has the standard format headers in addition to the operator set information. 149 | type OperatorSetProto struct { 150 | state protoimpl.MessageState 151 | sizeCache protoimpl.SizeCache 152 | unknownFields protoimpl.UnknownFields 153 | 154 | // All OperatorSetProtos start with a distingushed byte sequence to disambiguate 155 | // protobuf files containing OperatorSets from other content. 156 | // This field MUST be "ONNXOPSET" 157 | // This field MUST be present in this version of the IR 158 | Magic string `protobuf:"bytes,1,opt,name=magic,proto3" json:"magic,omitempty"` 159 | // All OperatorSetProtos indicate the version of the IR syntax and semantics 160 | // they adhere to. It is always IR_VERSION. 161 | // This field MUST be present in this version of the IR 162 | IrVersion int64 `protobuf:"varint,2,opt,name=ir_version,json=irVersion,proto3" json:"ir_version,omitempty"` 163 | // The prerelease component of the SemVer of the IR. 164 | // This field MAY be absent in this version of the IR 165 | IrVersionPrerelease string `protobuf:"bytes,3,opt,name=ir_version_prerelease,json=irVersionPrerelease,proto3" json:"ir_version_prerelease,omitempty"` 166 | // The build metadata component of the SemVer of the IR. 167 | // This field MAY be absent in this version of the IR 168 | IrBuildMetadata string `protobuf:"bytes,7,opt,name=ir_build_metadata,json=irBuildMetadata,proto3" json:"ir_build_metadata,omitempty"` 169 | // Domain name of the operator set, in reverse DNS form (e.g., com.acme.dnnops). 170 | Domain string `protobuf:"bytes,4,opt,name=domain,proto3" json:"domain,omitempty"` 171 | // The version of the set of operators. This is a simple int value 172 | // that is monotonically increasing as new versions of the operator set 173 | // are published. All operators in this set MUST have since_version 174 | // <= opset_version. 175 | OpsetVersion int64 `protobuf:"varint,5,opt,name=opset_version,json=opsetVersion,proto3" json:"opset_version,omitempty"` 176 | // A human-readable documentation for this set of operators. Markdown is allowed. 177 | DocString string `protobuf:"bytes,6,opt,name=doc_string,json=docString,proto3" json:"doc_string,omitempty"` 178 | // The operators specified by this operator set. 179 | // The (name, version) MUST be unique across all OperatorProtos in operator 180 | Operator []*OperatorProto `protobuf:"bytes,8,rep,name=operator,proto3" json:"operator,omitempty"` 181 | // The functions specified by this operator set. 182 | // The (name, version) MUST be unique across all OperatorProtos/FunctionProtos in operator/functions 183 | Functions []*FunctionProto `protobuf:"bytes,9,rep,name=functions,proto3" json:"functions,omitempty"` 184 | } 185 | 186 | func (x *OperatorSetProto) Reset() { 187 | *x = OperatorSetProto{} 188 | mi := &file_onnx_operators_ml_proto_msgTypes[1] 189 | ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) 190 | ms.StoreMessageInfo(mi) 191 | } 192 | 193 | func (x *OperatorSetProto) String() string { 194 | return protoimpl.X.MessageStringOf(x) 195 | } 196 | 197 | func (*OperatorSetProto) ProtoMessage() {} 198 | 199 | func (x *OperatorSetProto) ProtoReflect() protoreflect.Message { 200 | mi := &file_onnx_operators_ml_proto_msgTypes[1] 201 | if x != nil { 202 | ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) 203 | if ms.LoadMessageInfo() == nil { 204 | ms.StoreMessageInfo(mi) 205 | } 206 | return ms 207 | } 208 | return mi.MessageOf(x) 209 | } 210 | 211 | // Deprecated: Use OperatorSetProto.ProtoReflect.Descriptor instead. 212 | func (*OperatorSetProto) Descriptor() ([]byte, []int) { 213 | return file_onnx_operators_ml_proto_rawDescGZIP(), []int{1} 214 | } 215 | 216 | func (x *OperatorSetProto) GetMagic() string { 217 | if x != nil { 218 | return x.Magic 219 | } 220 | return "" 221 | } 222 | 223 | func (x *OperatorSetProto) GetIrVersion() int64 { 224 | if x != nil { 225 | return x.IrVersion 226 | } 227 | return 0 228 | } 229 | 230 | func (x *OperatorSetProto) GetIrVersionPrerelease() string { 231 | if x != nil { 232 | return x.IrVersionPrerelease 233 | } 234 | return "" 235 | } 236 | 237 | func (x *OperatorSetProto) GetIrBuildMetadata() string { 238 | if x != nil { 239 | return x.IrBuildMetadata 240 | } 241 | return "" 242 | } 243 | 244 | func (x *OperatorSetProto) GetDomain() string { 245 | if x != nil { 246 | return x.Domain 247 | } 248 | return "" 249 | } 250 | 251 | func (x *OperatorSetProto) GetOpsetVersion() int64 { 252 | if x != nil { 253 | return x.OpsetVersion 254 | } 255 | return 0 256 | } 257 | 258 | func (x *OperatorSetProto) GetDocString() string { 259 | if x != nil { 260 | return x.DocString 261 | } 262 | return "" 263 | } 264 | 265 | func (x *OperatorSetProto) GetOperator() []*OperatorProto { 266 | if x != nil { 267 | return x.Operator 268 | } 269 | return nil 270 | } 271 | 272 | func (x *OperatorSetProto) GetFunctions() []*FunctionProto { 273 | if x != nil { 274 | return x.Functions 275 | } 276 | return nil 277 | } 278 | 279 | var File_onnx_operators_ml_proto protoreflect.FileDescriptor 280 | 281 | var file_onnx_operators_ml_proto_rawDesc = []byte{ 282 | 0x0a, 0x17, 0x6f, 0x6e, 0x6e, 0x78, 0x2d, 0x6f, 0x70, 0x65, 0x72, 0x61, 0x74, 0x6f, 0x72, 0x73, 283 | 0x2d, 0x6d, 0x6c, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 284 | 0x73, 0x1a, 0x0d, 0x6f, 0x6e, 0x6e, 0x78, 0x2d, 0x6d, 0x6c, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 285 | 0x22, 0x9c, 0x01, 0x0a, 0x0d, 0x4f, 0x70, 0x65, 0x72, 0x61, 0x74, 0x6f, 0x72, 0x50, 0x72, 0x6f, 286 | 0x74, 0x6f, 0x12, 0x17, 0x0a, 0x07, 0x6f, 0x70, 0x5f, 0x74, 0x79, 0x70, 0x65, 0x18, 0x01, 0x20, 287 | 0x01, 0x28, 0x09, 0x52, 0x06, 0x6f, 0x70, 0x54, 0x79, 0x70, 0x65, 0x12, 0x23, 0x0a, 0x0d, 0x73, 288 | 0x69, 0x6e, 0x63, 0x65, 0x5f, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x02, 0x20, 0x01, 289 | 0x28, 0x03, 0x52, 0x0c, 0x73, 0x69, 0x6e, 0x63, 0x65, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 290 | 0x12, 0x2e, 0x0a, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0e, 291 | 0x32, 0x16, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x73, 0x2e, 0x4f, 0x70, 0x65, 0x72, 0x61, 0x74, 292 | 0x6f, 0x72, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 293 | 0x12, 0x1d, 0x0a, 0x0a, 0x64, 0x6f, 0x63, 0x5f, 0x73, 0x74, 0x72, 0x69, 0x6e, 0x67, 0x18, 0x0a, 294 | 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x64, 0x6f, 0x63, 0x53, 0x74, 0x72, 0x69, 0x6e, 0x67, 0x22, 295 | 0xeb, 0x02, 0x0a, 0x10, 0x4f, 0x70, 0x65, 0x72, 0x61, 0x74, 0x6f, 0x72, 0x53, 0x65, 0x74, 0x50, 296 | 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x14, 0x0a, 0x05, 0x6d, 0x61, 0x67, 0x69, 0x63, 0x18, 0x01, 0x20, 297 | 0x01, 0x28, 0x09, 0x52, 0x05, 0x6d, 0x61, 0x67, 0x69, 0x63, 0x12, 0x1d, 0x0a, 0x0a, 0x69, 0x72, 298 | 0x5f, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x09, 299 | 0x69, 0x72, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x32, 0x0a, 0x15, 0x69, 0x72, 0x5f, 300 | 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x5f, 0x70, 0x72, 0x65, 0x72, 0x65, 0x6c, 0x65, 0x61, 301 | 0x73, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x13, 0x69, 0x72, 0x56, 0x65, 0x72, 0x73, 302 | 0x69, 0x6f, 0x6e, 0x50, 0x72, 0x65, 0x72, 0x65, 0x6c, 0x65, 0x61, 0x73, 0x65, 0x12, 0x2a, 0x0a, 303 | 0x11, 0x69, 0x72, 0x5f, 0x62, 0x75, 0x69, 0x6c, 0x64, 0x5f, 0x6d, 0x65, 0x74, 0x61, 0x64, 0x61, 304 | 0x74, 0x61, 0x18, 0x07, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0f, 0x69, 0x72, 0x42, 0x75, 0x69, 0x6c, 305 | 0x64, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x12, 0x16, 0x0a, 0x06, 0x64, 0x6f, 0x6d, 306 | 0x61, 0x69, 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x64, 0x6f, 0x6d, 0x61, 0x69, 307 | 0x6e, 0x12, 0x23, 0x0a, 0x0d, 0x6f, 0x70, 0x73, 0x65, 0x74, 0x5f, 0x76, 0x65, 0x72, 0x73, 0x69, 308 | 0x6f, 0x6e, 0x18, 0x05, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0c, 0x6f, 0x70, 0x73, 0x65, 0x74, 0x56, 309 | 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x1d, 0x0a, 0x0a, 0x64, 0x6f, 0x63, 0x5f, 0x73, 0x74, 310 | 0x72, 0x69, 0x6e, 0x67, 0x18, 0x06, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x64, 0x6f, 0x63, 0x53, 311 | 0x74, 0x72, 0x69, 0x6e, 0x67, 0x12, 0x31, 0x0a, 0x08, 0x6f, 0x70, 0x65, 0x72, 0x61, 0x74, 0x6f, 312 | 0x72, 0x18, 0x08, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x15, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x73, 313 | 0x2e, 0x4f, 0x70, 0x65, 0x72, 0x61, 0x74, 0x6f, 0x72, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x52, 0x08, 314 | 0x6f, 0x70, 0x65, 0x72, 0x61, 0x74, 0x6f, 0x72, 0x12, 0x33, 0x0a, 0x09, 0x66, 0x75, 0x6e, 0x63, 315 | 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0x09, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x15, 0x2e, 0x70, 0x72, 316 | 0x6f, 0x74, 0x6f, 0x73, 0x2e, 0x46, 0x75, 0x6e, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x50, 0x72, 0x6f, 317 | 0x74, 0x6f, 0x52, 0x09, 0x66, 0x75, 0x6e, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x42, 0x02, 0x48, 318 | 0x03, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, 319 | } 320 | 321 | var ( 322 | file_onnx_operators_ml_proto_rawDescOnce sync.Once 323 | file_onnx_operators_ml_proto_rawDescData = file_onnx_operators_ml_proto_rawDesc 324 | ) 325 | 326 | func file_onnx_operators_ml_proto_rawDescGZIP() []byte { 327 | file_onnx_operators_ml_proto_rawDescOnce.Do(func() { 328 | file_onnx_operators_ml_proto_rawDescData = protoimpl.X.CompressGZIP(file_onnx_operators_ml_proto_rawDescData) 329 | }) 330 | return file_onnx_operators_ml_proto_rawDescData 331 | } 332 | 333 | var file_onnx_operators_ml_proto_msgTypes = make([]protoimpl.MessageInfo, 2) 334 | var file_onnx_operators_ml_proto_goTypes = []any{ 335 | (*OperatorProto)(nil), // 0: protos.OperatorProto 336 | (*OperatorSetProto)(nil), // 1: protos.OperatorSetProto 337 | (OperatorStatus)(0), // 2: protos.OperatorStatus 338 | (*FunctionProto)(nil), // 3: protos.FunctionProto 339 | } 340 | var file_onnx_operators_ml_proto_depIdxs = []int32{ 341 | 2, // 0: protos.OperatorProto.status:type_name -> protos.OperatorStatus 342 | 0, // 1: protos.OperatorSetProto.operator:type_name -> protos.OperatorProto 343 | 3, // 2: protos.OperatorSetProto.functions:type_name -> protos.FunctionProto 344 | 3, // [3:3] is the sub-list for method output_type 345 | 3, // [3:3] is the sub-list for method input_type 346 | 3, // [3:3] is the sub-list for extension type_name 347 | 3, // [3:3] is the sub-list for extension extendee 348 | 0, // [0:3] is the sub-list for field type_name 349 | } 350 | 351 | func init() { file_onnx_operators_ml_proto_init() } 352 | func file_onnx_operators_ml_proto_init() { 353 | if File_onnx_operators_ml_proto != nil { 354 | return 355 | } 356 | file_onnx_ml_proto_init() 357 | type x struct{} 358 | out := protoimpl.TypeBuilder{ 359 | File: protoimpl.DescBuilder{ 360 | GoPackagePath: reflect.TypeOf(x{}).PkgPath(), 361 | RawDescriptor: file_onnx_operators_ml_proto_rawDesc, 362 | NumEnums: 0, 363 | NumMessages: 2, 364 | NumExtensions: 0, 365 | NumServices: 0, 366 | }, 367 | GoTypes: file_onnx_operators_ml_proto_goTypes, 368 | DependencyIndexes: file_onnx_operators_ml_proto_depIdxs, 369 | MessageInfos: file_onnx_operators_ml_proto_msgTypes, 370 | }.Build() 371 | File_onnx_operators_ml_proto = out.File 372 | file_onnx_operators_ml_proto_rawDesc = nil 373 | file_onnx_operators_ml_proto_goTypes = nil 374 | file_onnx_operators_ml_proto_depIdxs = nil 375 | } 376 | -------------------------------------------------------------------------------- /internal/protos/onnx-operators-ml.proto: -------------------------------------------------------------------------------- 1 | // 2 | // WARNING: This file is automatically generated! Please edit onnx.in.proto. 3 | // 4 | 5 | 6 | // Copyright (c) ONNX Project Contributors. 7 | // Licensed under the Apache-2.0 license. 8 | 9 | syntax = "proto3"; 10 | 11 | package protos; 12 | import "onnx-ml.proto"; 13 | 14 | // 15 | // This file contains the proto definitions for OperatorSetProto and 16 | // OperatorProto. OperatorSetProtos are used to describe a versioned 17 | // set of operators that can be used by a ModelProto. 18 | // 19 | // Like ModelProto, OperatorSetProto is defined as a top-level file/wire 20 | // format, however their usage is different. 21 | // 22 | // ModelProto files are used to describe executable graphs that can be 23 | // executed directly by a framework, runtime, or engine. 24 | // 25 | // OperatorSetProto files are used to describe a set of operators that are 26 | // available in a given environment. The file TBD.TBD is the OperatorSetProto 27 | // that describes the ONNX standard operators. 28 | // 29 | 30 | // An OperatorProto represents the immutable specification of the signature 31 | // and semantics of an operator. 32 | // 33 | // Operators are declared as part of an OperatorSet, which also defines the 34 | // domain name for the set. 35 | // 36 | // Operators are uniquely identified by a three part identifier 37 | // (domain, op_type, since_version) 38 | // where 39 | // *domain* is the domain of an operator set that 40 | // contains this operator specification. 41 | // 42 | // *op_type* is the name of the operator as referenced by a 43 | // NodeProto.op_type 44 | // 45 | // *since_version* is the version of the operator set that 46 | // this operator was initially declared in. 47 | // 48 | message OperatorProto { 49 | // The name of the operator within a domain. 50 | // This field MUST be present in this version of the IR. 51 | string op_type = 1; 52 | 53 | // The version of the operator set that first introduced this 54 | // operator. This value MUST be the same value as the 55 | // opset_version of the operator set that first published this operator. 56 | // Subsequent versions of the operator set MUST NOT alter the signature 57 | // or semantics of the operator once published as STABLE. 58 | // This field MUST be present in this version of the IR. 59 | int64 since_version = 2; 60 | 61 | // This field indicates whether the syntax, semantics, or presence 62 | // of this operator is in an experimental or stable stage. Once an 63 | // operator is published as STABLE, it's syntax and semantics MUST NOT 64 | // change in subsequent versions of the operator set. 65 | // When an operator is published as EXPERIMENTAL, the syntax and semantics 66 | // of the operator MAY change across operator set versions. 67 | // Operators "become" stable by deprecating the experimental version and 68 | // introducing a new stable operator with the same op_type. 69 | OperatorStatus status = 3; 70 | 71 | // Eventually we will declare the signature of the operator here 72 | 73 | // A human-readable documentation for this operator. Markdown is allowed. 74 | string doc_string = 10; 75 | } 76 | 77 | // An OperatorSetProto represents an immutable set of immutable operator 78 | // specifications. 79 | // 80 | // The domain of the set (OperatorSetProto.domain) is a reverse-DNS name 81 | // that disambiguates operator sets defined by independent entities. 82 | // 83 | // The version of the set (opset_version) is a monotonically increasing 84 | // integer that indicates changes to the membership of the operator set. 85 | // 86 | // 87 | // Operator sets are uniquely identified by a two part identifier (domain, opset_version) 88 | // 89 | // Like ModelProto, OperatorSetProto is intended as a top-level file/wire format, 90 | // and thus has the standard format headers in addition to the operator set information. 91 | // 92 | message OperatorSetProto { 93 | // All OperatorSetProtos start with a distingushed byte sequence to disambiguate 94 | // protobuf files containing OperatorSets from other content. 95 | // This field MUST be "ONNXOPSET" 96 | // This field MUST be present in this version of the IR 97 | string magic = 1; 98 | 99 | // All OperatorSetProtos indicate the version of the IR syntax and semantics 100 | // they adhere to. It is always IR_VERSION. 101 | // This field MUST be present in this version of the IR 102 | int64 ir_version = 2; 103 | 104 | // The prerelease component of the SemVer of the IR. 105 | // This field MAY be absent in this version of the IR 106 | string ir_version_prerelease = 3; 107 | 108 | // The build metadata component of the SemVer of the IR. 109 | // This field MAY be absent in this version of the IR 110 | string ir_build_metadata = 7; 111 | 112 | // Domain name of the operator set, in reverse DNS form (e.g., com.acme.dnnops). 113 | string domain = 4; 114 | 115 | // The version of the set of operators. This is a simple int value 116 | // that is monotonically increasing as new versions of the operator set 117 | // are published. All operators in this set MUST have since_version 118 | // <= opset_version. 119 | int64 opset_version = 5; 120 | 121 | // A human-readable documentation for this set of operators. Markdown is allowed. 122 | string doc_string = 6; 123 | 124 | // The operators specified by this operator set. 125 | // The (name, version) MUST be unique across all OperatorProtos in operator 126 | repeated OperatorProto operator = 8; 127 | 128 | // The functions specified by this operator set. 129 | // The (name, version) MUST be unique across all OperatorProtos/FunctionProtos in operator/functions 130 | repeated FunctionProto functions = 9; 131 | } 132 | 133 | 134 | // For using protobuf-lite 135 | option optimize_for = LITE_RUNTIME; 136 | 137 | -------------------------------------------------------------------------------- /internal/protos/protos.go: -------------------------------------------------------------------------------- 1 | // Package protos is empty, it simply includes a rule to generate all the sub-packages with protobuf generated code: 2 | package protos 3 | 4 | //go:generate go run ../cmd/protoc_onnx_protos 5 | -------------------------------------------------------------------------------- /onnx-py.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "cfe7380b-e9b9-457a-bdf3-eabeae4b43d3", 6 | "metadata": {}, 7 | "source": [ 8 | "# Create Linear Model trivial ONNX model" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": 86, 14 | "id": "39897ce6-4ab9-44b8-8262-aad96635d476", 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "import onnx\n", 19 | "from onnx import TensorProto\n", 20 | "from onnx.helper import (\n", 21 | " make_model, make_node, make_graph,\n", 22 | " make_tensor_value_info)\n", 23 | "from onnx.checker import check_model\n" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": 2, 29 | "id": "d0a5111e-9542-4a90-b33e-7809b092a211", 30 | "metadata": {}, 31 | "outputs": [], 32 | "source": [ 33 | "feature_dim = 5\n", 34 | "X = make_tensor_value_info('X', TensorProto.FLOAT, [\"batch_size\", feature_dim])\n", 35 | "Y = make_tensor_value_info('Y', TensorProto.FLOAT, [\"batch_size\"])\n", 36 | "A_initializer = onnx.helper.make_tensor('A', TensorProto.FLOAT, [feature_dim], [100.0, 10.0, 1.0, 0.1, 0.01])\n", 37 | "B_initializer = onnx.helper.make_tensor('B', TensorProto.FLOAT, [], [7000.0])\n", 38 | "node1 = make_node('MatMul', ['X', 'A'], ['XA'], 'XA')\n", 39 | "node2 = make_node('Add', ['XA', 'B'], ['Y'], 'Y')\n", 40 | "graph = make_graph([node1, node2], 'lr', [X], [Y], initializer=[A_initializer, B_initializer])\n", 41 | "onnx_model = make_model(graph)\n", 42 | "check_model(onnx_model)\n", 43 | "with open(\"linear_regression.onnx\", \"wb\") as f:\n", 44 | " f.write(onnx_model.SerializeToString())" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": 3, 50 | "id": "740e9376-6c3b-4c0f-b15c-2e7c10f80174", 51 | "metadata": {}, 52 | "outputs": [ 53 | { 54 | "name": "stdout", 55 | "output_type": "stream", 56 | "text": [ 57 | "[[ 1. 2. 3. 4. 5.]\n", 58 | " [ 6. 7. 8. 9. 10.]]\n", 59 | "[7123.45 7679. ]\n" 60 | ] 61 | } 62 | ], 63 | "source": [ 64 | "import onnxruntime as ort\n", 65 | "import numpy as np\n", 66 | "\n", 67 | "x = np.arange(10, dtype=np.float32)+1\n", 68 | "x = np.reshape(x, (2, 5))\n", 69 | "print(x)\n", 70 | "ort_sess = ort.InferenceSession('linear_regression.onnx')\n", 71 | "outputs = ort_sess.run(['Y'], {'X': x})\n", 72 | "print(outputs[0])" 73 | ] 74 | }, 75 | { 76 | "cell_type": "markdown", 77 | "id": "c2477f0c-1ed3-4c39-aa1c-2a76dfb195c6", 78 | "metadata": {}, 79 | "source": [ 80 | "### MatMul WAT: How does it work on the edge cases ?" 81 | ] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "execution_count": 12, 86 | "id": "ec0933b3-8e72-4993-9029-35a329c20147", 87 | "metadata": {}, 88 | "outputs": [ 89 | { 90 | "name": "stdout", 91 | "output_type": "stream", 92 | "text": [ 93 | "(2, 1, 7, 32)\n", 94 | "(12, 32, 7)\n", 95 | "(2, 12, 32, 32)\n" 96 | ] 97 | } 98 | ], 99 | "source": [ 100 | "import numpy\n", 101 | "\n", 102 | "lhs = (np.arange(2 * 1 * 7 * 32, dtype=np.float32)+1) / 1000.0\n", 103 | "lhs = np.reshape(lhs, (2, 1, 7, 32))\n", 104 | "print(lhs.shape)\n", 105 | "rhs = (np.arange(12*7*32, dtype=np.float32)+1) / 1000.0\n", 106 | "rhs = np.reshape(rhs, (12, 32, 7))\n", 107 | "print(rhs.shape)\n", 108 | "res = np.matmul(rhs, lhs)\n", 109 | "print(res.shape)\n" 110 | ] 111 | }, 112 | { 113 | "cell_type": "markdown", 114 | "id": "f81a6f50-c45a-47cb-a297-78b92b525cf9", 115 | "metadata": {}, 116 | "source": [ 117 | "# Experimenting with model [`sentence-transformers/all-MiniLM-L6-v2`](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2)\n", 118 | "\n", 119 | "Normalization formulation:\n", 120 | "\n", 121 | "$$\n", 122 | "v = \\frac{v}{\\max(\\lVert v \\rVert_p, \\epsilon)}.\n", 123 | "$$\n", 124 | "\n", 125 | "## Imports" 126 | ] 127 | }, 128 | { 129 | "cell_type": "code", 130 | "execution_count": 6, 131 | "id": "d6c9d79b-48fa-41c8-929f-9850b587a35d", 132 | "metadata": {}, 133 | "outputs": [], 134 | "source": [ 135 | "from transformers import AutoTokenizer, AutoModel\n", 136 | "import torch\n", 137 | "import torch.nn.functional as F\n", 138 | "import onnxruntime as ort\n", 139 | "import numpy as np" 140 | ] 141 | }, 142 | { 143 | "cell_type": "markdown", 144 | "id": "2df0e2fe-16c9-41cc-a4ec-fb6aaf4e28f8", 145 | "metadata": {}, 146 | "source": [ 147 | "### Imports, create `tokenizer` and `model`" 148 | ] 149 | }, 150 | { 151 | "cell_type": "code", 152 | "execution_count": 7, 153 | "id": "f1b980b1-3745-40f6-98da-f22fe864c86a", 154 | "metadata": {}, 155 | "outputs": [ 156 | { 157 | "name": "stderr", 158 | "output_type": "stream", 159 | "text": [ 160 | "2024-10-31 07:39:20.623591: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n", 161 | "2024-10-31 07:39:20.752880: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", 162 | "2024-10-31 07:39:20.795165: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", 163 | "2024-10-31 07:39:20.811348: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n", 164 | "2024-10-31 07:39:20.914578: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n", 165 | "To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n", 166 | "2024-10-31 07:39:21.558344: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n" 167 | ] 168 | } 169 | ], 170 | "source": [ 171 | "# Load model from HuggingFace Hub\n", 172 | "tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')\n", 173 | "model = AutoModel.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')" 174 | ] 175 | }, 176 | { 177 | "cell_type": "markdown", 178 | "id": "a67b6904-6073-4070-bcbd-5e44a9ef799c", 179 | "metadata": {}, 180 | "source": [ 181 | "### Sentences and tokens" 182 | ] 183 | }, 184 | { 185 | "cell_type": "code", 186 | "execution_count": 63, 187 | "id": "2f7d2828-33ff-4e49-905e-d01214cfd1a3", 188 | "metadata": {}, 189 | "outputs": [ 190 | { 191 | "name": "stdout", 192 | "output_type": "stream", 193 | "text": [ 194 | "Encoded input:\n", 195 | "{'input_ids': tensor([[ 101, 2023, 2003, 2019, 2742, 6251, 102],\n", 196 | " [ 101, 2169, 6251, 2003, 4991, 102, 0]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0],\n", 197 | " [0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1],\n", 198 | " [1, 1, 1, 1, 1, 1, 0]])}\n" 199 | ] 200 | } 201 | ], 202 | "source": [ 203 | "# Sentences we want sentence embeddings for\n", 204 | "sentences = ['This is an example sentence', 'Each sentence is converted']\n", 205 | "\n", 206 | "# Tokenize sentences\n", 207 | "encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')\n", 208 | "print(\"Encoded input:\")\n", 209 | "print(encoded_input)" 210 | ] 211 | }, 212 | { 213 | "cell_type": "markdown", 214 | "id": "e158d375-c48f-432c-a9c5-4324ebb24409", 215 | "metadata": {}, 216 | "source": [ 217 | "### Inference with ONNX" 218 | ] 219 | }, 220 | { 221 | "cell_type": "code", 222 | "execution_count": 64, 223 | "id": "3aad9c06-6b29-43a9-b368-dd4af93a8aee", 224 | "metadata": {}, 225 | "outputs": [ 226 | { 227 | "name": "stdout", 228 | "output_type": "stream", 229 | "text": [ 230 | "(2, 7, 384)\n", 231 | "[[[ 0.03656479 -0.01616146 0.1682453 ... 0.05540764 -0.16443957\n", 232 | " -0.29669833]\n", 233 | " [ 0.7239094 0.6399461 0.18878399 ... 0.5945502 0.6205655\n", 234 | " 0.489683 ]\n", 235 | " [ 0.00637847 0.02030473 0.04475658 ... 0.34638238 1.3169885\n", 236 | " -0.16695468]\n", 237 | " ...\n", 238 | " [ 0.1479177 -0.06426162 0.14569402 ... 0.8837387 -0.33155778\n", 239 | " 0.2975315 ]\n", 240 | " [ 0.52124625 0.6562965 0.5607001 ... -0.03988977 0.04121367\n", 241 | " -1.4035654 ]\n", 242 | " [ 1.0824106 0.7140344 0.39859214 ... -0.23005268 0.32431406\n", 243 | " -1.0312778 ]]\n", 244 | "\n", 245 | " [[ 0.2802185 0.11647302 -0.04178832 ... 0.27105364 -0.16846775\n", 246 | " -0.29611403]\n", 247 | " [ 0.87294626 0.4544794 -0.10909736 ... 0.13654931 0.45797268\n", 248 | " -0.20415133]\n", 249 | " [ 0.4751616 0.5731077 0.63044137 ... 0.6525696 0.5612419\n", 250 | " -1.3268433 ]\n", 251 | " ...\n", 252 | " [ 0.61133045 0.79203445 -0.4684846 ... 0.08543227 1.0591549\n", 253 | " -0.2983293 ]\n", 254 | " [ 0.4115055 1.0945691 0.23854384 ... 0.8983636 0.3683571\n", 255 | " -0.733289 ]\n", 256 | " [ 0.13744976 0.55554354 0.26777348 ... 0.5426259 0.46651605\n", 257 | " -0.52835524]]]\n" 258 | ] 259 | } 260 | ], 261 | "source": [ 262 | "ort_sess = ort.InferenceSession('model.onnx')\n", 263 | "outputKey = 'last_hidden_state'\n", 264 | "inputs = {key: value.numpy() for key, value in encoded_input.data.items()}\n", 265 | "modelOutput = ort_sess.run([outputKey], inputs)[0]\n", 266 | "print(f\"{modelOutput.shape}\")\n", 267 | "print(modelOutput)" 268 | ] 269 | }, 270 | { 271 | "cell_type": "code", 272 | "execution_count": 136, 273 | "id": "a41d2981-06a6-452d-a703-2947dff64346", 274 | "metadata": {}, 275 | "outputs": [ 276 | { 277 | "data": { 278 | "text/plain": [ 279 | "['last_hidden_state']" 280 | ] 281 | }, 282 | "execution_count": 136, 283 | "metadata": {}, 284 | "output_type": "execute_result" 285 | } 286 | ], 287 | "source": [ 288 | "model.graph.output." 289 | ] 290 | }, 291 | { 292 | "cell_type": "code", 293 | "execution_count": 162, 294 | "id": "5763b114-bb6d-4ea0-af92-e65ea136d8bf", 295 | "metadata": {}, 296 | "outputs": [], 297 | "source": [ 298 | "def probeOnnxNodeOutput(node_output_name, inputs):\n", 299 | " model = onnx.load(\"model.onnx\")\n", 300 | " del model.graph.output[:]\n", 301 | " model.graph.output.append(onnx.ValueInfoProto(name=node_output_name))\n", 302 | " assert len(model.graph.output) == 1\n", 303 | " onnx.save(model, \"modified_model.onnx\") \n", 304 | " ort_sess = ort.InferenceSession('modified_model.onnx')\n", 305 | " return ort_sess.run([node_output_name], inputs)[0]\n", 306 | "\n", 307 | "def p(node_output_name):\n", 308 | " output = probeOnnxNodeOutput(node_output_name, inputs)\n", 309 | " print(f\"\\n{node_output_name}: f{output.shape}\")\n", 310 | " print(output)" 311 | ] 312 | }, 313 | { 314 | "cell_type": "code", 315 | "execution_count": 151, 316 | "id": "29bd8837-7704-4607-9fb1-1376c49927ee", 317 | "metadata": {}, 318 | "outputs": [], 319 | "source": [ 320 | "with open('model_shapes.txt', 'a') as f:\n", 321 | " for node in model.graph.node:\n", 322 | " for node_output_name in node.output:\n", 323 | " output = probeOnnxNodeOutput(node_output_name, inputs)\n", 324 | " print(f\"{node_output_name}\\t{output.dtype}\\t{output.shape}\", file=f)\n", 325 | " f.flush()\n", 326 | " " 327 | ] 328 | }, 329 | { 330 | "cell_type": "code", 331 | "execution_count": 167, 332 | "id": "943ad75a-4d03-4047-9af5-acc696c1f250", 333 | "metadata": {}, 334 | "outputs": [ 335 | { 336 | "name": "stdout", 337 | "output_type": "stream", 338 | "text": [ 339 | "\n", 340 | "/embeddings/Add_1_output_0: f(2, 7, 384)\n", 341 | "[[[-0.08855709 -0.03675481 0.01803644 ... 0.02607179 0.09117168\n", 342 | " -0.01518174]\n", 343 | " [-0.02002142 -0.00136943 -0.01765827 ... 0.02036703 0.05219622\n", 344 | " 0.19905484]\n", 345 | " [-0.01959006 -0.03363657 -0.03186595 ... 0.02031087 0.07087033\n", 346 | " 0.06444595]\n", 347 | " ...\n", 348 | " [-0.02530987 0.04081389 0.01253615 ... -0.02695212 0.03774461\n", 349 | " 0.11325061]\n", 350 | " [-0.01395568 -0.02749825 0.07956143 ... -0.07483339 0.07742585\n", 351 | " -0.06570429]\n", 352 | " [ 0.03182676 -0.00320992 -0.02103326 ... 0.03869266 0.01906986\n", 353 | " -0.00592621]]\n", 354 | "\n", 355 | " [[-0.08855709 -0.03675481 0.01803644 ... 0.02607179 0.09117168\n", 356 | " -0.01518174]\n", 357 | " [ 0.03040212 0.05308453 -0.02380589 ... -0.10111795 0.02182422\n", 358 | " 0.0473295 ]\n", 359 | " [-0.00270701 -0.05080456 0.08054851 ... -0.07771945 0.08808091\n", 360 | " -0.05600649]\n", 361 | " ...\n", 362 | " [ 0.0927911 0.01653565 -0.09761265 ... 0.04492704 0.03896102\n", 363 | " -0.01817189]\n", 364 | " [ 0.02310666 0.00902908 -0.02130682 ... 0.02319211 0.01912827\n", 365 | " -0.00660186]\n", 366 | " [-0.02132826 0.00192266 0.00427087 ... 0.05611002 0.01698602\n", 367 | " 0.02561522]]]\n" 368 | ] 369 | } 370 | ], 371 | "source": [ 372 | "# p(\"/embeddings/Slice_output_0\")\n", 373 | "# p(\"/embeddings/position_embeddings/Gather_output_0\")\n", 374 | "#p(\"token_type_ids\")\n", 375 | "#p(\"embeddings.token_type_embeddings.weight\")\n", 376 | "#p(\"/embeddings/token_type_embeddings/Gather_output_0\")\n", 377 | "# p(\"/embeddings/Add_output_0\")\n", 378 | "p(\"/embeddings/Add_1_output_0\")" 379 | ] 380 | }, 381 | { 382 | "cell_type": "markdown", 383 | "id": "c02ca9f7-cef6-41a6-983b-d7b01ab28db4", 384 | "metadata": {}, 385 | "source": [ 386 | "### Model Inference with HuggingFace/PyTorch version" 387 | ] 388 | }, 389 | { 390 | "cell_type": "code", 391 | "execution_count": 32, 392 | "id": "99cc66c8-139b-4b4f-9de5-49a38de3a13c", 393 | "metadata": {}, 394 | "outputs": [ 395 | { 396 | "name": "stdout", 397 | "output_type": "stream", 398 | "text": [ 399 | "torch.Size([2, 7, 384])\n", 400 | "tensor([[[ 0.0366, -0.0162, 0.1682, ..., 0.0554, -0.1644, -0.2967],\n", 401 | " [ 0.7239, 0.6399, 0.1888, ..., 0.5946, 0.6206, 0.4897],\n", 402 | " [ 0.0064, 0.0203, 0.0448, ..., 0.3464, 1.3170, -0.1670],\n", 403 | " ...,\n", 404 | " [ 0.1479, -0.0643, 0.1457, ..., 0.8837, -0.3316, 0.2975],\n", 405 | " [ 0.5212, 0.6563, 0.5607, ..., -0.0399, 0.0412, -1.4036],\n", 406 | " [ 1.0824, 0.7140, 0.3986, ..., -0.2301, 0.3243, -1.0313]],\n", 407 | "\n", 408 | " [[ 0.2802, 0.1165, -0.0418, ..., 0.2711, -0.1685, -0.2961],\n", 409 | " [ 0.8729, 0.4545, -0.1091, ..., 0.1365, 0.4580, -0.2042],\n", 410 | " [ 0.4752, 0.5731, 0.6304, ..., 0.6526, 0.5612, -1.3268],\n", 411 | " ...,\n", 412 | " [ 0.6113, 0.7920, -0.4685, ..., 0.0854, 1.0592, -0.2983],\n", 413 | " [ 0.4115, 1.0946, 0.2385, ..., 0.8984, 0.3684, -0.7333],\n", 414 | " [ 0.1374, 0.5555, 0.2678, ..., 0.5426, 0.4665, -0.5284]]])\n" 415 | ] 416 | } 417 | ], 418 | "source": [ 419 | "#Mean Pooling - Take attention mask into account for correct averaging\n", 420 | "def mean_pooling(model_output, attention_mask):\n", 421 | " token_embeddings = model_output[0] #First element of model_output contains all token embeddings\n", 422 | " input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()\n", 423 | " return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)\n", 424 | "\n", 425 | "# Compute token embeddings\n", 426 | "with torch.no_grad():\n", 427 | " model_output = model(**encoded_input)\n", 428 | "\n", 429 | "print(f\"{model_output.last_hidden_state.shape}\")\n", 430 | "print(model_output.last_hidden_state)\n", 431 | "\n", 432 | "if False:\n", 433 | " # Disabled for now\n", 434 | " # Perform pooling\n", 435 | " sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])\n", 436 | " \n", 437 | " # Normalize embeddings\n", 438 | " sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)\n", 439 | " \n", 440 | " print(f\"Sentence embeddings: {sentence_embeddings.shape}\")\n", 441 | " print(sentence_embeddings)\n" 442 | ] 443 | }, 444 | { 445 | "cell_type": "markdown", 446 | "id": "a0bfb072-1929-4a65-a1e1-aff4c7a1fa9f", 447 | "metadata": {}, 448 | "source": [ 449 | "## LSTM model example\n", 450 | "\n", 451 | "This will build a trivial PyTorch LSTM model, and save it (randomly initialized) to a ONNX model. Then we read the model and take note of the output with a fixed input, to compare with the GoMLX implementation.\n" 452 | ] 453 | }, 454 | { 455 | "cell_type": "code", 456 | "execution_count": 1, 457 | "id": "ac6ef71e-0612-47f5-96ef-914b451e9826", 458 | "metadata": {}, 459 | "outputs": [], 460 | "source": [] 461 | }, 462 | { 463 | "cell_type": "code", 464 | "execution_count": 20, 465 | "id": "2f0c6253-3962-4734-80f0-8b7e1c03dc89", 466 | "metadata": {}, 467 | "outputs": [ 468 | { 469 | "name": "stdout", 470 | "output_type": "stream", 471 | "text": [ 472 | "tensor([[0, 1, 2, 3, 4, 5, 6]], dtype=torch.int32)\n", 473 | "tensor([[0.1168, 0.1587, 0.1992]], grad_fn=)\n" 474 | ] 475 | } 476 | ], 477 | "source": [ 478 | "import torch\n", 479 | "from torch import nn\n", 480 | "\n", 481 | "class TextClassificationModel(nn.Module):\n", 482 | " def __init__(self, vocab_size, embedding_dim, hidden_dim, output_size):\n", 483 | " super(TextClassificationModel, self).__init__()\n", 484 | " self.embedding = nn.Embedding(vocab_size, embedding_dim)\n", 485 | " self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True)\n", 486 | " self.fc = nn.Linear(hidden_dim, output_size)\n", 487 | "\n", 488 | " def forward(self, x):\n", 489 | " x = self.embedding(x)\n", 490 | " _, (hn, _) = self.lstm(x)\n", 491 | " hn = hn.squeeze(0)\n", 492 | " out = self.fc(hn)\n", 493 | " return out\n", 494 | "\n", 495 | "model = TextClassificationModel(30522, 5, 11, 3)\n", 496 | "test_input = torch.tensor([[0, 1, 2, 3, 4, 5, 6]], dtype=torch.int32)\n", 497 | "print(test_input)\n", 498 | "print(model(test_input))\n", 499 | "\n", 500 | "onnx_file_path = \"test_lstm.onnx\"\n", 501 | "torch.onnx.export(\n", 502 | " model,\n", 503 | " test_input,\n", 504 | " onnx_file_path,\n", 505 | " input_names=[\"input\"],\n", 506 | " output_names=[\"output\"],\n", 507 | " dynamic_axes={\"input\": {1: \"sequence_length\"}},\n", 508 | " opset_version=20\n", 509 | ")" 510 | ] 511 | }, 512 | { 513 | "cell_type": "code", 514 | "execution_count": 21, 515 | "id": "936051e3-b91a-45a9-a4f1-ec78ce1a4d9d", 516 | "metadata": {}, 517 | "outputs": [ 518 | { 519 | "name": "stdout", 520 | "output_type": "stream", 521 | "text": [ 522 | "input = \t[[0 1 2 3 4 5 6]]\n", 523 | "lstm(x) =\t[[0.11684047 0.15874878 0.19921872]]\n" 524 | ] 525 | } 526 | ], 527 | "source": [ 528 | "import onnxruntime as ort\n", 529 | "import numpy as np\n", 530 | "\n", 531 | "x = np.arange(7, dtype=np.int32)\n", 532 | "x = np.reshape(x, (1, 7))\n", 533 | "print(f\"input = \\t{x}\")\n", 534 | "ort_sess = ort.InferenceSession('test_lstm.onnx')\n", 535 | "outputs = ort_sess.run(['output'], {'input': x})\n", 536 | "print(f\"lstm(x) =\\t{outputs[0]}\")" 537 | ] 538 | } 539 | ], 540 | "metadata": { 541 | "kernelspec": { 542 | "display_name": "Python 3 (ipykernel)", 543 | "language": "python", 544 | "name": "python3" 545 | }, 546 | "language_info": { 547 | "codemirror_mode": { 548 | "name": "ipython", 549 | "version": 3 550 | }, 551 | "file_extension": ".py", 552 | "mimetype": "text/x-python", 553 | "name": "python", 554 | "nbconvert_exporter": "python", 555 | "pygments_lexer": "ipython3", 556 | "version": "3.12.3" 557 | } 558 | }, 559 | "nbformat": 4, 560 | "nbformat_minor": 5 561 | } 562 | -------------------------------------------------------------------------------- /onnx/dtypes.go: -------------------------------------------------------------------------------- 1 | // Package togomlx contains several conversion utilities from ONNX and GoMLX. 2 | package onnx 3 | 4 | import ( 5 | "github.com/gomlx/gopjrt/dtypes" 6 | "github.com/gomlx/onnx-gomlx/internal/protos" 7 | "github.com/pkg/errors" 8 | ) 9 | 10 | // dtypeForONNX converts an ONNX data type to a gomlx data type. 11 | func dtypeForONNX(onnxDType protos.TensorProto_DataType) (dtypes.DType, error) { 12 | switch onnxDType { 13 | case protos.TensorProto_FLOAT: 14 | return dtypes.Float32, nil 15 | case protos.TensorProto_DOUBLE: 16 | return dtypes.Float64, nil 17 | case protos.TensorProto_INT32: 18 | return dtypes.Int32, nil 19 | case protos.TensorProto_INT64: 20 | return dtypes.Int64, nil 21 | case protos.TensorProto_UINT8: 22 | return dtypes.Uint8, nil 23 | case protos.TensorProto_INT8: 24 | return dtypes.Int8, nil 25 | case protos.TensorProto_INT16: 26 | return dtypes.Int16, nil 27 | case protos.TensorProto_UINT16: 28 | return dtypes.Uint16, nil 29 | case protos.TensorProto_UINT32: 30 | return dtypes.Uint32, nil 31 | case protos.TensorProto_UINT64: 32 | return dtypes.Uint64, nil 33 | case protos.TensorProto_BOOL: 34 | return dtypes.Bool, nil 35 | case protos.TensorProto_COMPLEX64: 36 | return dtypes.Complex64, nil 37 | case protos.TensorProto_COMPLEX128: 38 | return dtypes.Complex128, nil 39 | default: 40 | return dtypes.InvalidDType, errors.Errorf("unsupported/unknown ONNX data type %v", onnxDType) 41 | } 42 | } 43 | -------------------------------------------------------------------------------- /onnx/dynamicshape.go: -------------------------------------------------------------------------------- 1 | package onnx 2 | 3 | import ( 4 | "fmt" 5 | "github.com/gomlx/gomlx/types/shapes" 6 | "github.com/gomlx/gopjrt/dtypes" 7 | "github.com/gomlx/onnx-gomlx/internal/protos" 8 | "github.com/pkg/errors" 9 | "strconv" 10 | "strings" 11 | ) 12 | 13 | // DynamicShape represents a shape for which some of the axes have unknown dimensions. 14 | // 15 | // Similar to GoMLX Shape but some of the dimensions may be -1, denoting an undefined dimension. 16 | // 17 | // Dimensions may also be named, in which case shapes of inputs and outputs with the same name should match. 18 | type DynamicShape struct { 19 | dtypes.DType 20 | Dimensions []int 21 | Names []string 22 | } 23 | 24 | // UnnamedDynamicDimension is a placeholder name for an unnamed dynamic dimension, that doesn't necessarily match any other (in inputs/outputs). 25 | const UnnamedDynamicDimension = "?" 26 | 27 | // makeDynamicShapeFromProto converts from a tensor proto type to a DynamicShape. 28 | func makeDynamicShapeFromProto(proto *protos.TypeProto_Tensor) (dshape DynamicShape, err error) { 29 | dshape.DType, err = dtypeForONNX(protos.TensorProto_DataType(proto.GetElemType())) 30 | if err != nil { 31 | return 32 | } 33 | dshape.Names = make([]string, len(proto.Shape.Dim)) 34 | dshape.Dimensions = make([]int, len(proto.Shape.Dim)) 35 | for ii, dProto := range proto.Shape.Dim { 36 | if dim, ok := dProto.GetValue().(*protos.TensorShapeProto_Dimension_DimValue); ok { 37 | dshape.Names[ii] = strconv.Itoa(int(dim.DimValue)) 38 | dshape.Dimensions[ii] = int(dim.DimValue) 39 | } else if dimParam, ok := dProto.GetValue().(*protos.TensorShapeProto_Dimension_DimParam); ok { 40 | dshape.Names[ii] = dimParam.DimParam 41 | dshape.Dimensions[ii] = -1 42 | } else { 43 | dshape.Names[ii] = "?" // Un-named dynamic dimension. 44 | dshape.Dimensions[ii] = -1 45 | } 46 | } 47 | return 48 | } 49 | 50 | // Rank returns the DynamicShape's rank. 51 | func (dshape DynamicShape) Rank() int { 52 | return len(dshape.Dimensions) 53 | } 54 | 55 | // String implements fmt.Stringer. 56 | func (dshape DynamicShape) String() string { 57 | if len(dshape.Dimensions) == 0 { 58 | return fmt.Sprintf("(%s)", dshape.DType) 59 | } 60 | return fmt.Sprintf("(%s) [%s]", dshape.DType, strings.Join(dshape.Names, ", ")) 61 | } 62 | 63 | // ValidateInputs checks the inputs has a shape that is compatible with the DynamicShapes of the inputs for the model. 64 | func (m *Model) ValidateInputs(inputsShapes ...shapes.Shape) error { 65 | if len(inputsShapes) != len(m.InputsNames) { 66 | return errors.Errorf("model takes %d inputs, but %d inputs provided", 67 | len(m.InputsNames), len(inputsShapes)) 68 | } 69 | dimValues := make(map[string]int) 70 | for idx, input := range inputsShapes { 71 | name := m.InputsNames[idx] 72 | givenShape := input.Shape() 73 | wantShape := m.InputsShapes[idx] 74 | if givenShape.Rank() != wantShape.Rank() { 75 | return errors.Errorf("model input #%d (%q) should be rank %d, got rank %d instead", 76 | idx, name, wantShape.Rank(), givenShape.Rank()) 77 | } 78 | if givenShape.DType != wantShape.DType { 79 | return errors.Errorf("model input #%d (%q) should have dtype %s, got dtype %s instead", 80 | idx, name, wantShape.DType, givenShape.DType) 81 | } 82 | for axis, wantDim := range wantShape.Dimensions { 83 | gotDim := givenShape.Dim(axis) 84 | if wantDim > 0 { 85 | if wantDim != gotDim { 86 | return errors.Errorf("model input #%d (%q) has invalid shape: want %s, got %s", 87 | idx, name, wantShape, givenShape) 88 | } 89 | } else { 90 | dimName := wantShape.Names[axis] 91 | var found bool 92 | wantDim, found = dimValues[dimName] 93 | if !found { 94 | // Define dynamic shape based on input. 95 | dimValues[dimName] = gotDim 96 | } else if wantDim != gotDim { 97 | return errors.Errorf("model input #%d (%q) shaped %s got unmatching invalid shape %s for axis %q (wanted dim %d)", 98 | idx, name, wantShape, givenShape, dimName, wantDim) 99 | } 100 | } 101 | } 102 | } 103 | return nil 104 | } 105 | -------------------------------------------------------------------------------- /onnx/dynamicshape_test.go: -------------------------------------------------------------------------------- 1 | package onnx 2 | 3 | import ( 4 | "github.com/gomlx/gomlx/types/shapes" 5 | "github.com/gomlx/gopjrt/dtypes" 6 | "github.com/stretchr/testify/require" 7 | "testing" 8 | ) 9 | 10 | func TestValidateInputs(t *testing.T) { 11 | m := &Model{ 12 | InputsNames: []string{"i0", "i1"}, 13 | InputsShapes: []DynamicShape{ 14 | DynamicShape{ 15 | DType: dtypes.Float32, 16 | Dimensions: []int{-1, -1}, 17 | Names: []string{"batch_size", "feature_dim"}, 18 | }, 19 | DynamicShape{ 20 | DType: dtypes.Int32, 21 | Dimensions: []int{-1, 3}, 22 | Names: []string{"batch_size", "other"}, 23 | }, 24 | }, 25 | } 26 | 27 | // Example valid input, batch_size=5 28 | require.NoError(t, m.ValidateInputs( 29 | shapes.Make(dtypes.Float32, 5, 7), 30 | shapes.Make(dtypes.Int32, 5, 3))) 31 | 32 | // Wrong dtype: 33 | require.Error(t, m.ValidateInputs( 34 | shapes.Make(dtypes.Float32, 5, 7, 1), 35 | shapes.Make( /**/ dtypes.Int64, 5, 3))) 36 | 37 | // Wrong rank: 38 | require.Error(t, m.ValidateInputs( 39 | shapes.Make(dtypes.Float32, 5, 7 /**/, 1), 40 | shapes.Make(dtypes.Int32, 5, 3))) 41 | 42 | // Fixed dimension not matching: 43 | require.Error(t, m.ValidateInputs( 44 | shapes.Make(dtypes.Float32, 5, 7), 45 | shapes.Make(dtypes.Int32, 5 /**/, 4))) 46 | 47 | // Dynamic dimension not matching: 48 | require.Error(t, m.ValidateInputs( 49 | shapes.Make(dtypes.Float32, 5, 7), 50 | shapes.Make(dtypes.Int32 /**/, 6, 3))) 51 | } 52 | -------------------------------------------------------------------------------- /onnx/graph.go: -------------------------------------------------------------------------------- 1 | package onnx 2 | 3 | import ( 4 | "fmt" 5 | "runtime" 6 | 7 | "github.com/gomlx/exceptions" 8 | . "github.com/gomlx/gomlx/graph" 9 | "github.com/gomlx/gomlx/ml/context" 10 | "github.com/gomlx/gomlx/ml/layers/activations" 11 | "github.com/gomlx/gomlx/types" 12 | "github.com/gomlx/gomlx/types/shapes" 13 | "github.com/gomlx/onnx-gomlx/internal/protos" 14 | ) 15 | 16 | // sliceMap executes the given function sequentially for every element on in, and returns a mapped slice. 17 | func sliceMap[In, Out any](in []In, fn func(e In) Out) (out []Out) { 18 | out = make([]Out, len(in)) 19 | for ii, e := range in { 20 | out[ii] = fn(e) 21 | } 22 | return 23 | } 24 | 25 | // CallGraph calls the ONNX graph, and hence building it with GoMLX ops. 26 | // This can be used for inference or training. 27 | // 28 | // If the model has any variables, call Model.VariablesToContext first (only once) to upload all 29 | // variable values from the ONNX model to the context -- or load them from a checkpoint if you saved one. 30 | // 31 | // If the model has no variables, the context in ctx can be set to nil. 32 | // 33 | // The inputs (a map of input name to its graph.Node) can be given as normal input parameters to the graph or as 34 | // static constants -- see WithInputsAsConstants. 35 | // Set the inputs as constants if they are meant to be interpreted as constants (static) values, that won't change 36 | // in different inference/training steps. 37 | // 38 | // If outputNames is not given, it will output the model's registered outputs. Alternatively, you can select 39 | // any list of node outputs to generate. It will return the values for the selected outputs. 40 | // 41 | // The graph being built is given in g. 42 | // 43 | // As in GoMLX graph functions, it panics (throws exceptions) in case of errors. 44 | func (m *Model) CallGraph(ctx *context.Context, g *Graph, inputs map[string]*Node, outputNames ...string) (outputs []*Node) { 45 | if ctx != nil { 46 | ctx = ctx.In(ModelScope).Checked(false) 47 | } 48 | 49 | // Sanity check of things we don't support yet. 50 | if len(m.Proto.Functions) > 0 { 51 | exceptions.Panicf("onnx.CallGraph does not support ONNX functions") 52 | } 53 | if len(m.Proto.Graph.SparseInitializer) > 0 { 54 | exceptions.Panicf("onnx.CallGraph does not support ONNX SparseTensors") 55 | } 56 | 57 | // If no outputNames were given, take the model outputs. 58 | if len(outputNames) == 0 { 59 | outputNames = m.OutputsNames 60 | } 61 | 62 | // Map the given inputs to the corresponding ONNX inputs, and report (throw exception) if there are 63 | // any discrepancies. 64 | // Also initialize convertedOutputs with the given/converted inputs. 65 | convertedOutputs := make(map[string]*Node) 66 | missingInputs := types.MakeSet[string]() 67 | repeatedInputs := types.MakeSet[string]() 68 | unknownInputs := types.MakeSet[string]() 69 | for inputIdx, inputName := range m.InputsNames { 70 | if inputName == "" { 71 | inputName = fmt.Sprintf("#%d", inputIdx) 72 | } 73 | inputN := inputs[inputName] 74 | if inputN == nil { 75 | staticValue := m.inputsAsConstants[inputName] 76 | if staticValue != nil { 77 | inputN = Const(g, staticValue) 78 | } else { 79 | missingInputs.Insert(inputName) 80 | continue 81 | } 82 | } else { 83 | if _, found := m.inputsAsConstants[inputName]; found { 84 | repeatedInputs.Insert(inputName) 85 | } 86 | } 87 | convertedOutputs[inputName] = inputN 88 | } 89 | for givenName := range inputs { 90 | if _, found := convertedOutputs[givenName]; !found { 91 | unknownInputs.Insert(givenName) 92 | } 93 | } 94 | for givenName := range m.inputsAsConstants { 95 | if _, found := convertedOutputs[givenName]; !found { 96 | unknownInputs.Insert(givenName) 97 | } 98 | } 99 | if len(missingInputs) > 0 || len(unknownInputs) > 0 { 100 | exceptions.Panicf("onnx.CallGraph() called with wrong inputs: missing inputs=%q; unknown given inputs=%q; inputs given normally and as constant inputs=%q", 101 | missingInputs, unknownInputs, repeatedInputs) 102 | } 103 | 104 | // Validate the input shapes. 105 | err := m.ValidateInputs(sliceMap(m.InputsNames, func(inputName string) shapes.Shape { return convertedOutputs[inputName].Shape() })...) 106 | if err != nil { 107 | panic(err) 108 | } 109 | 110 | // Convert variables: create the GoMLX nodes corresponding to the ONNX model variables. 111 | if len(m.Proto.Graph.Initializer) > 0 && ctx == nil { 112 | exceptions.Panicf("onnx.CallGraph(): model has variables, but a nil context was give") 113 | panic(nil) // for lint benefit. 114 | } 115 | 116 | // Convert all nodes recursively, which will implicitly yield a topological order. 117 | for _, target := range outputNames { 118 | m.recursiveCallGraph(ctx, g, target, convertedOutputs) 119 | } 120 | 121 | // Pick the outputs. 122 | outputs = make([]*Node, len(outputNames)) 123 | var found bool 124 | for outputIdx, nodeName := range outputNames { 125 | outputs[outputIdx], found = convertedOutputs[nodeName] 126 | if !found { 127 | exceptions.Panicf("output node %q not found", nodeName) 128 | } 129 | } 130 | 131 | // Makes sure all temporarily allocated tensor on device are freed. 132 | for _ = range 3 { 133 | runtime.GC() 134 | } 135 | return outputs 136 | } 137 | 138 | // recursiveCallGraph recursively creates a GoMLX graph for the target output name. 139 | // The convertedOutputs is used both as input, and as output to store the converted nodes. 140 | func (m *Model) recursiveCallGraph(ctx *context.Context, g *Graph, nodeOutputName string, convertedOutputs map[string]*Node) { 141 | if _, found := convertedOutputs[nodeOutputName]; found { 142 | // Already converted. 143 | return 144 | } 145 | 146 | // Is it the output of a variable ? 147 | if _, found := m.variableNameToValue[nodeOutputName]; found { 148 | varName := SafeVarName(nodeOutputName) 149 | v := ctx.GetVariable(varName) 150 | if v == nil { 151 | exceptions.Panicf("variable %q (named %q in ONNX) has not been uploaded yet to context -- did you forget to call onnx.Model.VariablesToContext?", 152 | varName, nodeOutputName) 153 | panic(nil) // for lint benefit. 154 | } 155 | convertedOutputs[nodeOutputName] = v.ValueGraph(g) 156 | return 157 | } 158 | 159 | onnxNode, found := m.nodeOutputToNode[nodeOutputName] 160 | if !found { 161 | exceptions.Panicf("ONNX node output %q not found as the output of any Op, and not a variable or input either -- could it be a node name, and note a node **output** name ?", nodeOutputName) 162 | } 163 | 164 | // Recursively converts the inputs of the onnxNode: 165 | for _, inputName := range onnxNode.Input { 166 | if inputName == "" { 167 | // Probably an optional parameter, not used. LSTM nodes have this. 168 | continue 169 | } 170 | m.recursiveCallGraph(ctx, g, inputName, convertedOutputs) 171 | } 172 | 173 | // Convert the node itself. 174 | m.convertNode(ctx, g, onnxNode, convertedOutputs) 175 | } 176 | 177 | // opRequiresContext checks if the given operation type requires a context. 178 | // Currently only LSTM. 179 | func opRequiresContext(opType string) bool { 180 | return opType == "LSTM" 181 | } 182 | 183 | // convertNode converts a single ONNX node to a GoMLX node. 184 | // 185 | // Previously converted nodes are given in convertedNodes. 186 | // The converted output(s) are updated into `convertedNodes`. 187 | // 188 | // It panics (throw exceptions) in case of errors. 189 | // 190 | // TODO: One of ONNX broadcasting rule is not applied by default in GoMLX/XLA for binary operators, namely: 191 | // 192 | // "The tensors that have too few dimensions can have their shapes prepended with a dimension of length 1 to satisfy property 2." 193 | // 194 | // See the definitions in: 195 | // . https://openxla.org/xla/broadcasting 196 | // . https://github.com/onnx/onnx/blob/main/docs/Broadcasting.md 197 | func (m *Model) convertNode(ctx *context.Context, g *Graph, node *protos.NodeProto, convertedOutputs map[string]*Node) { 198 | if node.Overload != "" { 199 | exceptions.Panicf("overload %q to in-model function in ONNX model not implemented in node %q", node.Overload, node.Name) 200 | } 201 | 202 | // Convert the node: the usual case is that there is only one output. 203 | // If res is not nil, it is set to convertedOutputs[output[0]]. 204 | // Anything different must be implemented by the specific op switch. 205 | var res *Node 206 | inputs := sliceMap(node.Input, func(n string) *Node { return convertedOutputs[n] }) 207 | switch node.OpType { 208 | // Binary operators: see note on differences on default broadcasting. 209 | case "Add": 210 | res = convertBinaryOp(Add, inputs[0], inputs[1]) 211 | case "Sub": 212 | res = convertBinaryOp(Sub, inputs[0], inputs[1]) 213 | case "Mul": 214 | res = convertBinaryOp(Mul, inputs[0], inputs[1]) 215 | case "Div": 216 | res = convertBinaryOp(Div, inputs[0], inputs[1]) 217 | case "Pow": 218 | //res = convertBinaryOp(Pow, inputs[0], inputs[1]) 219 | res = convertPow(m, convertedOutputs, node, inputs) 220 | case "And": 221 | res = convertBinaryOp(LogicalAnd, inputs[0], inputs[1]) 222 | case "Or": 223 | res = convertBinaryOp(LogicalOr, inputs[0], inputs[1]) 224 | case "Xor": 225 | res = convertBinaryOp(LogicalXor, inputs[0], inputs[1]) 226 | case "BitwiseAnd": 227 | res = convertBinaryOp(BitwiseAnd, inputs[0], inputs[1]) 228 | case "BitwiseOr": 229 | res = convertBinaryOp(BitwiseOr, inputs[0], inputs[1]) 230 | case "BitwiseXor": 231 | res = convertBinaryOp(BitwiseXor, inputs[0], inputs[1]) 232 | case "Equal": 233 | res = convertBinaryOp(Equal, inputs[0], inputs[1]) 234 | case "Less": 235 | res = convertBinaryOp(LessThan, inputs[0], inputs[1]) 236 | case "LessOrEqual": 237 | res = convertBinaryOp(LessOrEqual, inputs[0], inputs[1]) 238 | case "Greater": 239 | res = convertBinaryOp(GreaterThan, inputs[0], inputs[1]) 240 | case "GreaterOrEqual": 241 | res = convertBinaryOp(GreaterOrEqual, inputs[0], inputs[1]) 242 | 243 | // Unary operators 244 | case "Sqrt": 245 | res = Sqrt(inputs[0]) 246 | case "Exp": 247 | res = Exp(inputs[0]) 248 | case "Log": 249 | res = Log(inputs[0]) 250 | case "Erf": 251 | res = Erf(inputs[0]) 252 | case "Relu": 253 | res = activations.Relu(inputs[0]) 254 | case "Abs": 255 | res = Abs(inputs[0]) 256 | case "Neg": 257 | res = Neg(inputs[0]) 258 | case "Sign": 259 | res = Sign(inputs[0]) 260 | case "Ceil": 261 | res = Ceil(inputs[0]) 262 | case "Floor": 263 | res = Floor(inputs[0]) 264 | case "Identity": 265 | res = Identity(inputs[0]) 266 | case "Not": 267 | res = LogicalNot(inputs[0]) 268 | case "BitwiseNot": 269 | res = BitwiseNot(inputs[0]) 270 | case "Tanh": 271 | res = Tanh(inputs[0]) 272 | 273 | // Ops with equivalents: 274 | case "MatMul": 275 | res = MatMul(inputs[0], inputs[1]) 276 | 277 | // Ops with special behavior: 278 | case "Clip": 279 | res = convertClip(node, inputs) 280 | case "Where": 281 | res = convertWhere(node, inputs) 282 | case "Min": 283 | res = convertMin(inputs) 284 | case "Max": 285 | res = convertMax(inputs) 286 | 287 | // Ops with attributes: 288 | case "Constant": 289 | res = convertConstant(node, g) 290 | case "Gather": 291 | res = convertGather(node, inputs) 292 | case "GatherElements": 293 | res = convertGatherElements(node, inputs) 294 | case "Shape": 295 | res = convertShape(node, inputs) 296 | case "Concat": 297 | res = convertConcat(node, inputs) 298 | case "Softmax": 299 | res = convertSoftmax(node, inputs) 300 | case "Cast": 301 | res = convertCast(node, inputs) 302 | case "Transpose": 303 | res = convertTranspose(node, inputs) 304 | case "Gemm": 305 | res = convertGemm(node, inputs) 306 | case "Flatten": 307 | res = convertFlatten(node, inputs) 308 | 309 | // Ops that require constant sub-expression materialization: 310 | // they take dynamic (graph) values in ONNX, but only take static values in XLA 311 | case "Squeeze": 312 | res = convertSqueeze(m, convertedOutputs, node, inputs) 313 | case "Unsqueeze": 314 | res = convertUnsqueeze(m, convertedOutputs, node, inputs) 315 | case "Slice": 316 | res = convertSlice(m, convertedOutputs, node, inputs) 317 | case "Reshape": 318 | res = convertReshape(m, convertedOutputs, node, inputs) 319 | case "ReduceMean": 320 | res = convertReduceMean(m, convertedOutputs, node, inputs) 321 | case "ConstantOfShape": 322 | res = convertConstantOfShape(m, convertedOutputs, node, inputs) 323 | case "Expand": 324 | res = convertExpand(m, convertedOutputs, node, inputs) 325 | case "Tile": 326 | res = convertTile(m, convertedOutputs, node, inputs) 327 | case "Range": 328 | res = convertRange(m, convertedOutputs, node, inputs) 329 | case "CumSum": 330 | res = convertCumSum(m, convertedOutputs, node, inputs) 331 | 332 | // Full ML layers ops: 333 | case "LSTM": 334 | res = convertLSTM(m, convertedOutputs, node, inputs) 335 | 336 | default: 337 | exceptions.Panicf("unimplemented ONNX op %q in %s", node.OpType, nodeToString(node)) 338 | } 339 | if res != nil { 340 | convertedOutputs[node.Output[0]] = res 341 | } else { 342 | exceptions.Panicf("nil output for ONNX node %q", node.Name) 343 | } 344 | } 345 | -------------------------------------------------------------------------------- /onnx/linear_test.go: -------------------------------------------------------------------------------- 1 | package onnx 2 | 3 | import ( 4 | "fmt" 5 | "github.com/gomlx/gomlx/graph" 6 | "github.com/gomlx/gomlx/graph/graphtest" 7 | "github.com/gomlx/gomlx/ml/context" 8 | "github.com/gomlx/gomlx/types/tensors" 9 | "github.com/gomlx/gopjrt/dtypes" 10 | "github.com/stretchr/testify/require" 11 | "testing" 12 | ) 13 | 14 | import _ "github.com/gomlx/gomlx/backends/default" 15 | 16 | // TestEndToEnd based on the `linear_test.onnx` minimalistic model. 17 | // Only a couple of ops tested, but from end-to-end, including if changes can be saved 18 | // again (ContextToONNX). 19 | func TestEndToEnd(t *testing.T) { 20 | model, err := ReadFile("linear_test.onnx") 21 | fmt.Printf("%s\n", model) 22 | require.NoError(t, err) 23 | require.Len(t, model.InputsNames, 1) 24 | require.Equal(t, "X", model.InputsNames[0]) 25 | require.Len(t, model.OutputsNames, 1) 26 | require.Equal(t, "Y", model.OutputsNames[0]) 27 | 28 | require.Equal(t, model.OutputsShapes[0].Rank(), 1) 29 | require.Equal(t, "batch_size", model.OutputsShapes[0].Names[0]) 30 | 31 | // Verify correct setting of variables. 32 | ctx := context.New() 33 | require.NoError(t, model.VariablesToContext(ctx)) 34 | for v := range ctx.IterVariables() { 35 | fmt.Printf("\tVariable %q: %s\n", v.ScopeAndName(), v.Value()) 36 | } 37 | vA := ctx.In(ModelScope).GetVariable("A") 38 | require.NotNil(t, vA) 39 | require.Equal(t, 1, vA.Shape().Rank()) 40 | require.Equal(t, 5, vA.Shape().Dim(0)) 41 | require.Equal(t, []float32{100, 10, 1, 0.1, 0.01}, tensors.CopyFlatData[float32](vA.Value())) 42 | vB := ctx.In(ModelScope).GetVariable("B") 43 | require.NotNil(t, vB) 44 | require.Equal(t, 0, vB.Shape().Rank()) 45 | require.Equal(t, float32(7000), tensors.ToScalar[float32](vB.Value())) 46 | 47 | // Check conversion. 48 | backend := graphtest.BuildTestBackend() 49 | y := context.ExecOnce(backend, ctx, func(ctx *context.Context, x *graph.Node) *graph.Node { 50 | g := x.Graph() 51 | outputs := model.CallGraph(ctx, g, map[string]*graph.Node{"X": x}) 52 | vB = ctx.In(ModelScope).GetVariable("B") 53 | vB.SetValueGraph(graph.OnePlus(vB.ValueGraph(g))) 54 | return outputs[0] 55 | }, [][]float32{{1, 2, 3, 4, 5}}) // BatchSize = 1 56 | require.NoError(t, y.Shape().Check(dtypes.Float32, 1)) 57 | require.InDeltaSlice(t, []float32{7123.45}, tensors.CopyFlatData[float32](y), 0.1) 58 | 59 | // Save change of variable "B" to ONNX model. 60 | require.NoError(t, model.ContextToONNX(ctx)) 61 | tensorProto, found := model.variableNameToValue["B"] 62 | require.True(t, found, "Didn't find B variable") 63 | require.Equal(t, []float32{7001}, tensorProto.FloatData, "ONNX variable B initial value was not updated.") 64 | } 65 | -------------------------------------------------------------------------------- /onnx/linear_test.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gomlx/onnx-gomlx/c68337bcf8365e94e3691fa67192ab1e38a3c77a/onnx/linear_test.onnx -------------------------------------------------------------------------------- /onnx/materialize.go: -------------------------------------------------------------------------------- 1 | package onnx 2 | 3 | import ( 4 | "fmt" 5 | "github.com/gomlx/exceptions" 6 | . "github.com/gomlx/gomlx/graph" 7 | "github.com/gomlx/gomlx/types" 8 | "github.com/gomlx/gomlx/types/tensors" 9 | "github.com/gomlx/onnx-gomlx/internal/protos" 10 | "github.com/pkg/errors" 11 | "strings" 12 | ) 13 | 14 | // nonConstantDependencies returns the non-constant dependencies: inputs or variables. 15 | func (m *Model) nonConstantDependencies(nodeOutputName string) (inputs, variables []string, contextNodes []*protos.NodeProto) { 16 | visitedNodes := types.MakeSet[string]() 17 | return m.recursiveNonConstantDependencies(nodeOutputName, visitedNodes, inputs, variables, contextNodes) 18 | } 19 | 20 | // recursiveNonConstantDependencies is the recursive implementation of nonConstantDependencies. 21 | // Use nonConstantDependencies. 22 | func (m *Model) recursiveNonConstantDependencies(name string, visitedNodes types.Set[string], 23 | nonConstInputs, variables []string, contextNodes []*protos.NodeProto) ([]string, []string, []*protos.NodeProto) { 24 | visitedNodes.Insert(name) 25 | if _, found := m.variableNameToValue[name]; found { 26 | // Record a variable dependency. 27 | if m.isVariableConstant(name) { 28 | // Constant variable, ok. 29 | return nonConstInputs, variables, contextNodes 30 | } 31 | variables = append(variables, name) 32 | return nonConstInputs, variables, contextNodes 33 | } 34 | if m.inputsNameSet.Has(name) { 35 | // Input dependency is recorded as non-constant only if the input is not fed as a constant. 36 | if m.inputsAsConstants == nil || m.inputsAsConstants[name] == nil { 37 | nonConstInputs = append(nonConstInputs, name) 38 | } 39 | return nonConstInputs, variables, contextNodes 40 | } 41 | 42 | // Recurse into the inputs of the node that generated the `name` output. 43 | node := m.nodeOutputToNode[name] 44 | if node == nil { 45 | exceptions.Panicf("nonConstantDepedencies given an unknown node output name %q", name) 46 | panic(nil) // lint. 47 | } 48 | if opRequiresContext(node.OpType) { 49 | contextNodes = append(contextNodes, node) 50 | } 51 | if node.OpType == "Shape" { 52 | // Shape op returns a static value after converting to GoMLX, independent of inputs. 53 | // So we don't recurse into its inputs. 54 | return nonConstInputs, variables, contextNodes 55 | } 56 | for _, input := range node.Input { 57 | if visitedNodes.Has(input) { 58 | continue 59 | } 60 | nonConstInputs, variables, contextNodes = m.recursiveNonConstantDependencies(input, visitedNodes, nonConstInputs, variables, contextNodes) 61 | } 62 | return nonConstInputs, variables, contextNodes 63 | } 64 | 65 | // isVariableConstant tries to guess if the variable can be used as a constant during the graph construction. 66 | // For instance as the dimension for a "Reshape" or axis for a "Slice" method. 67 | // Some ONNX models use variables instead of constants. 68 | // 69 | // varName must be an existing variable name. 70 | func (m *Model) isVariableConstant(varName string) bool { 71 | sizeLimit := 100 // Max size to be accepted as constant. 72 | lowerName := strings.ToLower(varName) 73 | if strings.Index(lowerName, "constant") >= 0 { 74 | // If there is "constant" in the name, we assume constant at a higher size. 75 | sizeLimit = 10_000 76 | } else if strings.Index(lowerName, "const") >= 0 { 77 | // With less confidence... 78 | sizeLimit = 1_000 79 | } 80 | tensorProto := m.variableNameToValue[varName] 81 | shape, err := Shape(tensorProto) 82 | if err != nil { 83 | panic(errors.WithMessagef(err, "ONNX variable %q has an invalid shape", varName)) 84 | } 85 | return shape.DType.IsInt() && shape.Size() <= sizeLimit 86 | } 87 | 88 | // materializeConstantExpression materializes a node to its constant expression. 89 | // 90 | // This is required for ONNX ops that take dynamic values (like axes and shapes), but for which GoMLX only accept 91 | // static (materialized) values. 92 | // 93 | // If the node depends on non-constant values (like input parameters) this fails with an exception. 94 | func (m *Model) materializeConstantExpression(nodeOutputName string, convertedOutputs map[string]*Node) (*tensors.Tensor, error) { 95 | // Easy reply: if the node is already a constant. 96 | node := convertedOutputs[nodeOutputName] 97 | if node == nil { 98 | return nil, errors.Errorf("node output %q hasn't been converted yet, so we can't materializeConstantExpression!?", nodeOutputName) 99 | } 100 | if node.Type() == NodeTypeConstant { 101 | return node.ConstantValue(), nil 102 | } 103 | 104 | // See if it is possible: if subgraph that generated the node is a constant expression. 105 | nonConstInputs, nonConstVariables, contextNodes := m.nonConstantDependencies(nodeOutputName) 106 | if len(nonConstInputs) > 0 || len(nonConstVariables) > 0 || len(contextNodes) > 0 { 107 | // Add shape info for variables. 108 | varDesc := make([]string, 0, len(nonConstVariables)) 109 | for _, varName := range nonConstVariables { 110 | // We discard the error, because we know this conversion works already, to have reached this point. 111 | shape, _ := Shape(m.variableNameToValue[varName]) 112 | varDesc = append(varDesc, fmt.Sprintf("%q (%s)", varName, shape)) 113 | } 114 | opsDesc := make([]string, 0, len(contextNodes)) 115 | for _, node := range contextNodes { 116 | // We discard the error, because we know this conversion works already, to have reached this point. 117 | varDesc = append(opsDesc, node.String()) 118 | } 119 | return nil, errors.Errorf("cannot materialize constant/static value for %q: it depends on non-constant: inputs=%q, variables: %s, ops with context: %s", 120 | nodeOutputName, nonConstInputs, strings.Join(varDesc, ", "), strings.Join(opsDesc, ", ")) 121 | } 122 | 123 | // Evaluate constant sub-expression in a newly created sub-graph. 124 | backend := node.Graph().Backend() 125 | var result *tensors.Tensor 126 | err := exceptions.TryCatch[error](func() { 127 | result = ExecOnce(backend, func(g *Graph) *Node { 128 | constConvertedOutputs := make(map[string]*Node) 129 | m.recursiveMaterializeConstantExpression(nodeOutputName, g, constConvertedOutputs, convertedOutputs) 130 | return constConvertedOutputs[nodeOutputName] 131 | }) 132 | }) 133 | if err != nil { 134 | return nil, errors.WithMessage(err, "while evaluating constant sub-expression") 135 | } 136 | return result, nil 137 | } 138 | 139 | // recursiveMaterializeConstantExpression creates a GoMLX graph with the constant expressions in constConvertedOutputs. 140 | // It may use the original converted graph in originalConvertedOutput, but it doesn't change it. 141 | func (m *Model) recursiveMaterializeConstantExpression(nodeOutputName string, g *Graph, constConvertedOutputs, originalConvertedOutput map[string]*Node) { 142 | if _, found := constConvertedOutputs[nodeOutputName]; found { 143 | // Already converted. 144 | return 145 | } 146 | 147 | // Check in the original graph being converted if this node was converted as a constant (for instance for nodes like "Shape"), 148 | // in which case we take the constant value and inject it directly in the new constant expression graph. 149 | if originalNode, found := originalConvertedOutput[nodeOutputName]; found { 150 | if originalNode.Type() == NodeTypeConstant { 151 | // Duplicate the constant in the new graph. 152 | constConvertedOutputs[nodeOutputName] = Const(g, originalNode.ConstantValue()) 153 | return 154 | } 155 | } 156 | 157 | // Check for constant variables. 158 | if tensorNode, found := m.variableNameToValue[nodeOutputName]; found { 159 | if !m.isVariableConstant(nodeOutputName) { 160 | exceptions.Panicf("attempting to materialize as constant variable %q, which we don't think is constant", nodeOutputName) 161 | } 162 | t, err := tensorToGoMLX(tensorNode) 163 | if err != nil { 164 | panic(errors.WithMessagef(err, "attempting to materialize variable %q as constant", nodeOutputName)) 165 | } 166 | constConvertedOutputs[nodeOutputName] = Const(g, t) 167 | // TODO: mark variable as used for constant-expression and make sure it is also used in the final model, and 168 | // try to make as such that if it changes, the graph is rebuilt. 169 | return 170 | } 171 | 172 | // Find node generating this output. 173 | onnxNode, found := m.nodeOutputToNode[nodeOutputName] 174 | if !found { 175 | exceptions.Panicf("ONNX node %q not found as the output of an Op, and not a constant either -- is this really a constant expression!?", nodeOutputName) 176 | } 177 | if opRequiresContext(onnxNode.OpType) { 178 | // Operation requires a context, which is not supported when materializing constant sub-expressions. 179 | exceptions.Panicf("attempting to materialize expression with operation %q, which is not supported for matelization: %s", onnxNode.OpType, onnxNode) 180 | } 181 | 182 | // Recursively converts the inputs of the onnxNode: 183 | for _, inputName := range onnxNode.Input { 184 | m.recursiveMaterializeConstantExpression(inputName, g, constConvertedOutputs, originalConvertedOutput) 185 | } 186 | 187 | // And now convert the node itself. 188 | m.convertNode(nil, g, onnxNode, constConvertedOutputs) 189 | } 190 | -------------------------------------------------------------------------------- /onnx/onnx.go: -------------------------------------------------------------------------------- 1 | // Package onnx provides functionality to parse ONNX models and generate the corresponding GoMLX. 2 | // 3 | // - Parse: converts a serialized ONNX ModelProto to a Model. 4 | // - ReadFile: reads a file and calls Parse. It returns a Model. 5 | // - Model: object holding information about an ONNX model. It can be used to generate the corresponding GoMLX 6 | // model graph and executed for inference or used on a training loop for fine-tuning. It can also be used to 7 | // populate a context with the variables of the ONNX model. 8 | package onnx 9 | 10 | import ( 11 | "github.com/gomlx/gomlx/types" 12 | "github.com/gomlx/onnx-gomlx/internal/protos" 13 | "github.com/pkg/errors" 14 | "google.golang.org/protobuf/proto" 15 | "io" 16 | "os" 17 | ) 18 | 19 | // Model represents a parsed ONNX file. 20 | type Model struct { 21 | onnxFileName string 22 | Proto protos.ModelProto 23 | nodeOutputToNode map[string]*protos.NodeProto 24 | 25 | // names used for variables and inputs: these are like internal outputs, but they come not from a node, 26 | // but from an input or variable. Used to introspect the graph. 27 | inputsNameSet types.Set[string] 28 | variableNameToValue map[string]*protos.TensorProto 29 | 30 | name string 31 | InputsNames, OutputsNames []string 32 | InputsShapes, OutputsShapes []DynamicShape 33 | 34 | // inputsAsConstants: see WithInputsAsConstants 35 | inputsAsConstants map[string]any 36 | } 37 | 38 | // Parse parses an ONNX model into an internal representation that can be used to build a GoMLX graph. 39 | func Parse(contents []byte) (*Model, error) { 40 | m := &Model{} 41 | err := proto.Unmarshal(contents, &m.Proto) 42 | if err != nil { 43 | return nil, errors.Wrap(err, "failed to parse ONNX model proto") 44 | } 45 | 46 | // Parse inputs and outputs. 47 | m.name = m.Proto.Graph.Name 48 | m.inputsNameSet = types.MakeSet[string]() 49 | m.InputsNames = make([]string, len(m.Proto.Graph.Input)) 50 | m.InputsShapes = make([]DynamicShape, len(m.Proto.Graph.Input)) 51 | for ii, input := range m.Proto.Graph.Input { 52 | m.InputsNames[ii] = input.Name 53 | m.inputsNameSet.Insert(input.Name) 54 | 55 | tensorType, ok := input.Type.Value.(*protos.TypeProto_TensorType) 56 | if !ok { 57 | return nil, errors.Errorf("output #%d (%q) is not a tensor, not sure how to handle it", ii, input.Name) 58 | } 59 | m.InputsShapes[ii], err = makeDynamicShapeFromProto(tensorType.TensorType) 60 | if err != nil { 61 | return nil, errors.WithMessagef(err, "while parsing output #%d (%q)", ii, input.Name) 62 | } 63 | } 64 | m.OutputsNames = make([]string, len(m.Proto.Graph.Output)) 65 | m.OutputsShapes = make([]DynamicShape, len(m.Proto.Graph.Output)) 66 | for ii, output := range m.Proto.Graph.Output { 67 | m.OutputsNames[ii] = output.Name 68 | tensorType, ok := output.Type.Value.(*protos.TypeProto_TensorType) 69 | if !ok { 70 | return nil, errors.Errorf("output #%d (%q) is not a tensor, not sure how to handle it", ii, output.Name) 71 | } 72 | m.OutputsShapes[ii], err = makeDynamicShapeFromProto(tensorType.TensorType) 73 | if err != nil { 74 | return nil, errors.WithMessagef(err, "while parsing output #%d (%q)", ii, output.Name) 75 | } 76 | } 77 | 78 | // Set of variable names. 79 | m.variableNameToValue = make(map[string]*protos.TensorProto) 80 | for _, tensorProto := range m.Proto.Graph.Initializer { 81 | m.variableNameToValue[tensorProto.Name] = tensorProto 82 | } 83 | 84 | // Maps the intermediary node outputs to the nodes that create them. 85 | m.nodeOutputToNode = make(map[string]*protos.NodeProto) 86 | for _, node := range m.Proto.Graph.Node { 87 | for _, outputName := range node.GetOutput() { 88 | if otherNode, found := m.nodeOutputToNode[outputName]; found { 89 | return nil, errors.Errorf("invalid graph: node output name %q used by 2 different nodes: (1) %s, (2) %s", 90 | outputName, nodeToString(otherNode), nodeToString(node)) 91 | } 92 | m.nodeOutputToNode[outputName] = node 93 | } 94 | } 95 | return m, nil 96 | } 97 | 98 | // ReadFile parses an ONNX model file into an internal representation that can be used to build a GoMLX graph. 99 | // Notice any large constant is converted to variables. 100 | func ReadFile(filePath string) (*Model, error) { 101 | contents, err := os.ReadFile(filePath) 102 | if err != nil { 103 | return nil, errors.Wrapf(err, "failed to read ONNX model file in %s", filePath) 104 | } 105 | m, err := Parse(contents) 106 | if err != nil { 107 | return nil, err 108 | } 109 | m.onnxFileName = filePath 110 | return m, nil 111 | } 112 | 113 | // Name of the model graph. 114 | func (m *Model) Name() string { return m.name } 115 | 116 | // Inputs returns the names and DynamicShapes of the inputs. 117 | func (m *Model) Inputs() (names []string, dshapes []DynamicShape) { 118 | return m.InputsNames, m.InputsShapes 119 | } 120 | 121 | // Outputs returns a description of the outputs. 122 | func (m *Model) Outputs() (names []string, dshapes []DynamicShape) { 123 | return m.OutputsNames, m.OutputsShapes 124 | } 125 | 126 | // NumInputs returns the number of inputs this graph takes. 127 | func (m *Model) NumInputs() int { 128 | return len(m.InputsNames) 129 | } 130 | 131 | // WithInputsAsConstants marks inputs to be considered as constants, and not vary for different examples in training 132 | // or inference. 133 | // Use this just immediately after the creation of the Model. Later changes can cause inconsistencies. 134 | // 135 | // This makes them become constants in the graph, and they shouldn't be passed to CallGraph as inputs. 136 | // 137 | // The value each input maps to will be converted to a tensors.FromAnyValue. 138 | func (m *Model) WithInputsAsConstants(inputsAsConstants map[string]any) *Model { 139 | m.inputsAsConstants = inputsAsConstants 140 | return m 141 | } 142 | 143 | // Write will write the ONNX model to the given writer (usually a file). 144 | // 145 | // This is useful, if the model variables were updated (e.g.: fine-tuning in GoMLX) and one wants to save the 146 | // model. 147 | // See ContextToONNX to copy over the variables in GoMLX's Context (presumably after some training/update) to the 148 | // ONNX's model proto. 149 | // 150 | // See also Model.SaveToFile. 151 | func (m *Model) Write(w io.Writer) error { 152 | content, err := proto.Marshal(&m.Proto) 153 | if err != nil { 154 | return errors.Wrapf(err, "failed to serialize ONNX model proto") 155 | } 156 | _, err = w.Write(content) 157 | if err != nil { 158 | return errors.Wrapf(err, "failed to write serialized ONNX model proto") 159 | } 160 | return nil 161 | } 162 | 163 | // SaveToFile serializes the ONNX model to the given file. 164 | // 165 | // This is useful, if the model variables were updated (e.g.: fine-tuning in GoMLX) and one wants to save the 166 | // model. 167 | // See ContextToONNX to copy over the variables in GoMLX's Context (presumably after some training/update) to the 168 | // ONNX's model proto. 169 | func (m *Model) SaveToFile(path string) error { 170 | f, err := os.Create(path) 171 | if err != nil { 172 | return errors.Wrapf(err, "failed to save ONNX model proto to %s", path) 173 | } 174 | err = m.Write(f) 175 | if err != nil { 176 | _ = f.Close() 177 | return err 178 | } 179 | err = f.Close() 180 | if err != nil { 181 | return errors.Wrapf(err, "failed to save ONNX model proto to %s", path) 182 | } 183 | return nil 184 | } 185 | -------------------------------------------------------------------------------- /onnx/ops_test.go: -------------------------------------------------------------------------------- 1 | package onnx 2 | 3 | import ( 4 | "fmt" 5 | . "github.com/gomlx/gomlx/graph" 6 | "github.com/gomlx/gomlx/graph/graphtest" 7 | "github.com/gomlx/gomlx/types/shapes" 8 | "github.com/gomlx/gomlx/types/tensors" 9 | "github.com/gomlx/gopjrt/dtypes" 10 | "github.com/stretchr/testify/assert" 11 | "testing" 12 | ) 13 | 14 | func TestONNXWhere(t *testing.T) { 15 | graphtest.RunTestGraphFn(t, "Where(): Dense", func(g *Graph) (inputs, outputs []*Node) { 16 | cond := ConvertDType(Iota(g, shapes.Make(dtypes.Int32, 3, 2), -1), dtypes.Bool) 17 | onTrue := OnePlus(IotaFull(g, shapes.Make(dtypes.Float32, 3, 2))) 18 | onFalse := Neg(onTrue) 19 | inputs = []*Node{cond, onTrue, onFalse} 20 | outputs = []*Node{ 21 | onnxWhere([]*Node{cond, onTrue, onFalse}), 22 | onnxWhere([]*Node{Const(g, true), onTrue, onFalse}), 23 | onnxWhere([]*Node{Const(g, false), onTrue, onFalse}), 24 | onnxWhere([]*Node{cond, Const(g, float32(100)), onFalse}), 25 | onnxWhere([]*Node{cond, onTrue, Const(g, []float32{100, 1000})}), 26 | } 27 | return 28 | }, []any{ 29 | [][]float32{{-1, 2}, {-3, 4}, {-5, 6}}, 30 | [][]float32{{1, 2}, {3, 4}, {5, 6}}, 31 | [][]float32{{-1, -2}, {-3, -4}, {-5, -6}}, 32 | [][]float32{{-1, 100}, {-3, 100}, {-5, 100}}, 33 | [][]float32{{100, 2}, {100, 4}, {100, 6}}, 34 | }, -1) 35 | } 36 | 37 | func TestONNXGather(t *testing.T) { 38 | graphtest.RunTestGraphFn(t, "onnxGather(axis=0)", func(g *Graph) (inputs, outputs []*Node) { 39 | data := Const(g, [][]float32{{1.0, 1.2}, {2.3, 3.4}, {4.5, 5.7}}) 40 | indices := Const(g, [][]int32{{0, 1}, {1, 2}}) 41 | inputs = []*Node{data, indices} 42 | outputs = []*Node{onnxGather(data, indices, 0)} 43 | return 44 | }, []any{ 45 | [][][]float32{ 46 | { 47 | {1.0, 1.2}, 48 | {2.3, 3.4}, 49 | }, 50 | { 51 | {2.3, 3.4}, 52 | {4.5, 5.7}, 53 | }, 54 | }, 55 | }, -1) 56 | 57 | graphtest.RunTestGraphFn(t, "onnxGather(axis=1)", func(g *Graph) (inputs, outputs []*Node) { 58 | data := Const(g, [][]float32{ 59 | {1.0, 1.2, 1.9}, 60 | {2.3, 3.4, 3.9}, 61 | {4.5, 5.7, 5.9}, 62 | }) 63 | indices := Const(g, [][]int32{{0, 2}}) 64 | inputs = []*Node{data, indices} 65 | outputs = []*Node{onnxGather(data, indices, 1)} 66 | return 67 | }, []any{ 68 | [][][]float32{ 69 | {{1.0, 1.9}}, 70 | {{2.3, 3.9}}, 71 | {{4.5, 5.9}}, 72 | }, 73 | }, -1) 74 | } 75 | 76 | func TestTile(t *testing.T) { 77 | graphtest.RunTestGraphFn(t, "Tile 1D", func(g *Graph) (inputs, outputs []*Node) { 78 | operand := Const(g, []float32{1, 2}) 79 | inputs = []*Node{operand} 80 | outputs = []*Node{onnxTile(operand, []int{2})} 81 | return 82 | }, []any{ 83 | []float32{1, 2, 1, 2}, 84 | }, -1) 85 | 86 | graphtest.RunTestGraphFn(t, "Tile 2D", func(g *Graph) (inputs, outputs []*Node) { 87 | operand := Const(g, [][]float32{{1.0, 1.2}, {2.3, 3.4}, {4.5, 5.7}}) 88 | inputs = []*Node{operand} 89 | outputs = []*Node{onnxTile(operand, []int{1, 2})} 90 | return 91 | }, []any{ 92 | [][]float32{ 93 | {1.0, 1.2, 1.0, 1.2}, 94 | {2.3, 3.4, 2.3, 3.4}, 95 | {4.5, 5.7, 4.5, 5.7}, 96 | }, 97 | }, -1) 98 | } 99 | 100 | func TestRangeCount(t *testing.T) { 101 | backend := graphtest.BuildTestBackend() 102 | testFn := func(start, limit, delta any, want int) { 103 | startT := tensors.FromAnyValue(start) 104 | limitT := tensors.FromAnyValue(limit) 105 | deltaT := tensors.FromAnyValue(delta) 106 | got := rangeCount(backend, startT, limitT, deltaT) 107 | fmt.Printf("\trangeCount(start=%s, limit=%s, delta=%s) = %d (want %d)\n", startT, limitT, deltaT, got, want) 108 | assert.Equal(t, want, got) 109 | } 110 | 111 | testFn(uint8(3), uint8(9), uint8(3), 2) 112 | testFn(uint8(3), uint8(8), uint8(3), 2) 113 | testFn(uint8(3), uint8(7), uint8(3), 2) 114 | testFn(float32(3), float32(9.1), float32(3), 3) 115 | testFn(int32(10), int32(4), int32(-2), 3) 116 | testFn(int32(10), int32(5), int32(-2), 3) 117 | testFn(float64(10), float64(3.9), float64(-2), 4) 118 | } 119 | 120 | func TestOnnxGatherElements(t *testing.T) { 121 | graphtest.RunTestGraphFn(t, "GatherElements", func(g *Graph) (inputs, outputs []*Node) { 122 | data := Const(g, [][]float32{{1, 2}, {3, 4}}) 123 | indices := Const(g, [][]int32{{0, 0}, {1, 0}}) 124 | inputs = []*Node{data, indices} 125 | outputs = []*Node{ 126 | onnxGatherElements(data, indices, 0), 127 | onnxGatherElements(data, indices, 1), 128 | } 129 | return 130 | }, []any{ 131 | [][]float32{{1, 2}, {3, 2}}, 132 | [][]float32{{1, 1}, {4, 3}}, 133 | }, -1) 134 | 135 | graphtest.RunTestGraphFn(t, "GatherElements w/ incomplete indices", func(g *Graph) (inputs, outputs []*Node) { 136 | data := OnePlus(IotaFull(g, shapes.Make(dtypes.Float64, 3, 2))) 137 | indices0 := Const(g, [][]int8{{1, 2}}) 138 | indices1 := Const(g, [][]int8{{0}, {0}, {1}}) 139 | outputs = []*Node{ 140 | onnxGatherElements(data, indices0, 0), 141 | onnxGatherElements(data, indices1, 1), 142 | } 143 | return 144 | }, []any{ 145 | [][]float64{{3, 6}}, 146 | [][]float64{{1}, {3}, {6}}, 147 | }, -1) 148 | 149 | graphtest.RunTestGraphFn(t, "GatherElements: shape test with larger shapes", func(g *Graph) (inputs, outputs []*Node) { 150 | data := IotaFull(g, shapes.Make(dtypes.Float64, 3, 2, 512)) 151 | indices := Iota(g, shapes.Make(dtypes.Int64, 3, 2, 7), 0) 152 | outputs = []*Node{ 153 | Const(g, onnxGatherElements(data, indices, 2).Shape().Dimensions), 154 | } 155 | return 156 | }, []any{ 157 | []int64{3, 2, 7}, 158 | }, -1) 159 | } 160 | 161 | func TestONNXCumSum(t *testing.T) { 162 | graphtest.RunTestGraphFn(t, "CumSum", func(g *Graph) (inputs, outputs []*Node) { 163 | operand := Const(g, []float32{1, 2, 3}) 164 | inputs = []*Node{operand} 165 | outputs = []*Node{ 166 | onnxCumSum(operand, 0, false, false), 167 | onnxCumSum(operand, 0, true, false), 168 | onnxCumSum(operand, 0, false, true), 169 | onnxCumSum(operand, 0, true, true), 170 | } 171 | return 172 | }, []any{ 173 | []float32{1, 3, 6}, 174 | []float32{0, 1, 3}, 175 | []float32{6, 5, 3}, 176 | []float32{5, 3, 0}, 177 | }, -1) 178 | } 179 | 180 | func TestONNXFlatten(t *testing.T) { 181 | backend := graphtest.BuildTestBackend() 182 | testIdx := 0 183 | flattenFn := func(shape shapes.Shape, splitAxis int) shapes.Shape { 184 | g := NewGraph(backend, fmt.Sprintf("Flatten #%d", testIdx)) 185 | testIdx++ 186 | operand := IotaFull(g, shape) 187 | newShape := onnxFlatten(operand, splitAxis).Shape() 188 | g.Finalize() 189 | return newShape 190 | } 191 | 192 | // Scalar becomes a 1x1 matrix. 193 | flattenFn(shapes.Make(dtypes.Float32), 0).Assert(dtypes.Float32, 1, 1) 194 | 195 | // Vector can be split in 2 different ways. 196 | flattenFn(shapes.Make(dtypes.Int32, 7), 0).Assert(dtypes.Int32, 1, 7) 197 | flattenFn(shapes.Make(dtypes.Int32, 7), 1).AssertDims(7, 1) 198 | 199 | // Higher-dimensional tensor. 200 | flattenFn(shapes.Make(dtypes.Float32, 7, 2, 3, 4), 2).AssertDims(14, 12) 201 | } 202 | -------------------------------------------------------------------------------- /onnx/prettyprint.go: -------------------------------------------------------------------------------- 1 | package onnx 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "github.com/gomlx/gomlx/types" 7 | "github.com/gomlx/gomlx/types/shapes" 8 | "github.com/gomlx/onnx-gomlx/internal/protos" 9 | "github.com/pkg/errors" 10 | "io" 11 | "maps" 12 | "slices" 13 | "strings" 14 | ) 15 | 16 | // String implements fmt.Stringer, and pretty prints model information. 17 | func (m *Model) String() string { 18 | // Convenient writing to buffer that will hold result. 19 | var buf bytes.Buffer 20 | w := func(format string, args ...any) { buf.WriteString(fmt.Sprintf(format, args...)) } 21 | 22 | // Model header: 23 | w("ONNX Model:\n") 24 | if m.Proto.DocString != "" { 25 | w("# %s\n", m.Proto.DocString) 26 | } 27 | if m.Proto.ModelVersion != 0 { 28 | w("\tVersion:\t%d\n", m.Proto.ModelVersion) 29 | } 30 | if m.Proto.ProducerName != "" { 31 | w("\tProducer:\t%s / %s\n", m.Proto.ProducerName, m.Proto.ProducerVersion) 32 | } 33 | 34 | // Graph information: 35 | w("\t# inputs:\t%d\n", len(m.Proto.Graph.Input)) 36 | for ii, input := range m.Proto.Graph.Input { 37 | w("\t\t[#%d] %s\n", ii, ppValueInfo(input)) 38 | } 39 | w("\t# outputs:\t%d\n", len(m.Proto.Graph.Output)) 40 | for ii, output := range m.Proto.Graph.Output { 41 | w("\t\t[#%d] %s\n", ii, ppValueInfo(output)) 42 | } 43 | w("\t# nodes:\t%d\n", len(m.Proto.Graph.Node)) 44 | 45 | // Tensors (variables): 46 | w("\t# tensors (variables):\t%d\n", len(m.Proto.Graph.Initializer)) 47 | w("\t# sparse tensors (variables):\t%d\n", len(m.Proto.Graph.SparseInitializer)) 48 | 49 | // List op-types used. 50 | opTypesSet := types.MakeSet[string]() 51 | for _, n := range m.Proto.Graph.Node { 52 | opTypesSet.Insert(n.GetOpType()) 53 | } 54 | w("\tOp types:\t%#v\n", slices.Sorted(maps.Keys(opTypesSet))) 55 | 56 | // Training Info. 57 | if len(m.Proto.TrainingInfo) > 0 { 58 | w("\t# training info:\t%d\n", len(m.Proto.TrainingInfo)) 59 | } 60 | 61 | // Extra functions: 62 | if len(m.Proto.Functions) > 0 { 63 | fnSet := types.MakeSet[string]() 64 | for _, f := range m.Proto.Functions { 65 | fnSet.Insert(f.Name) 66 | } 67 | w("\tFunctions:\t%#v\n", slices.Sorted(maps.Keys(fnSet))) 68 | } 69 | 70 | // Versions. 71 | w("\tIR Version:\t%d\n", m.Proto.IrVersion) 72 | w("\tOperator Sets:\t[") 73 | for ii, opSetId := range m.Proto.OpsetImport { 74 | if ii > 0 { 75 | w(", ") 76 | } 77 | if opSetId.Domain != "" { 78 | w("v%d (%s)", opSetId.Version, opSetId.Domain) 79 | } else { 80 | w("v%d", opSetId.Version) 81 | } 82 | } 83 | w("]\n") 84 | 85 | // Extra meta-data. 86 | if len(m.Proto.MetadataProps) > 0 { 87 | w("\tMetadata: [") 88 | for ii, prop := range m.Proto.MetadataProps { 89 | if ii > 0 { 90 | w(", ") 91 | } 92 | w("%s=%s", prop.Key, prop.Value) 93 | } 94 | w("]\n") 95 | } 96 | return buf.String() 97 | } 98 | 99 | func ppValueInfo(vi *protos.ValueInfoProto) string { 100 | if vi.DocString != "" { 101 | return fmt.Sprintf("%q: %s # %s", vi.Name, ppType(vi.Type), vi.DocString) 102 | } 103 | return fmt.Sprintf("%q: %s", vi.Name, ppType(vi.Type)) 104 | } 105 | 106 | func ppType(t *protos.TypeProto) string { 107 | if seq := t.GetSequenceType(); seq != nil { 108 | return ppSeqType(seq) 109 | } else if tensor := t.GetTensorType(); tensor != nil { 110 | return ppTensorType(tensor) 111 | } 112 | return "??type??" 113 | } 114 | 115 | func ppSeqType(seq *protos.TypeProto_Sequence) string { 116 | return fmt.Sprintf("(%s...)", ppType(seq.ElemType)) 117 | } 118 | 119 | func ppTensorType(t *protos.TypeProto_Tensor) string { 120 | dshape, err := makeDynamicShapeFromProto(t) 121 | if err != nil { 122 | return "(invalid dtype)" 123 | } 124 | return dshape.String() 125 | } 126 | 127 | // PrintGraph prints a +/- human-readable (or debuggable) version of the graph to the given writer. 128 | func (m *Model) PrintGraph(writer io.Writer) error { 129 | var err error 130 | w := func(format string, args ...any) { 131 | if err != nil { 132 | return 133 | } 134 | _, err = fmt.Fprintf(writer, format, args...) 135 | if err != nil { 136 | err = errors.Wrapf(err, "Model.PrintGraph() failed to write") 137 | } 138 | } 139 | 140 | w("Model Graph %q:\n", m.Proto.Graph.Name) 141 | // Convenient writing to buffer that will hold result. 142 | for _, n := range m.Proto.Graph.Node { 143 | w("%q:\t[%s]\n", n.GetName(), n.GetOpType()) 144 | w("\tInputs: %q\n", n.GetInput()) 145 | w("\tOutputs: %q\n", n.GetOutput()) 146 | if len(n.Attribute) > 0 { 147 | w("\tAttributes: ") 148 | for ii, attr := range n.Attribute { 149 | if ii > 0 { 150 | w(", ") 151 | } 152 | w("%s (%s", attr.Name, attr.Type) 153 | switch attr.Type { 154 | case protos.AttributeProto_TENSOR: 155 | shape, err := Shape(attr.T) 156 | if err != nil { 157 | w(" - unparseable shape: %v", err) 158 | } else { 159 | w(": %s", shape) 160 | } 161 | case protos.AttributeProto_INT: 162 | w(": %d", attr.I) 163 | case protos.AttributeProto_INTS: 164 | if len(attr.Ints) < 20 { 165 | w(": %v", attr.Ints) 166 | } 167 | case protos.AttributeProto_FLOAT: 168 | w(": %f", attr.F) 169 | default: 170 | } 171 | w(")") 172 | } 173 | w("\n") 174 | } 175 | } 176 | return err 177 | } 178 | 179 | // nodeToString converts a NodeProto to a one-line string, that can be used for debugging. 180 | func nodeToString(n *protos.NodeProto) string { 181 | var buf bytes.Buffer 182 | w := func(format string, args ...any) { _, _ = fmt.Fprintf(&buf, format, args...) } 183 | 184 | w("Node %q [%s]", n.Name, n.OpType) 185 | w("(%s)", strings.Join(n.Input, ", ")) // Inputs 186 | w(" -> %s", strings.Join(n.Output, ", ")) // Output(s) 187 | if len(n.Attribute) > 0 { 188 | w(" - attrs[") 189 | for ii, attr := range n.Attribute { 190 | if ii > 0 { 191 | w(", ") 192 | } 193 | w("%s (%s)", attr.Name, attr.Type) 194 | } 195 | w("]") 196 | } 197 | return buf.String() 198 | } 199 | 200 | func (m *Model) PrintVariables(writer io.Writer) error { 201 | var err error 202 | w := func(format string, args ...any) { 203 | if err != nil { 204 | return 205 | } 206 | _, err = fmt.Fprintf(writer, format, args...) 207 | if err != nil { 208 | err = errors.Wrapf(err, "Model.PrintGraph() failed to write") 209 | } 210 | } 211 | 212 | w("%d tensors (variables)", len(m.Proto.Graph.Initializer)) 213 | if len(m.Proto.Graph.Initializer) > 0 { 214 | w(":") 215 | } 216 | w("\n") 217 | for _, t := range m.Proto.Graph.Initializer { 218 | shape, _ := Shape(t) 219 | w("\t%q: %s", t.Name, shape) 220 | if t.DocString != "" { 221 | w(" # %s", t.DocString) 222 | } 223 | w("\n") 224 | } 225 | w("%d sparse tensors (variables)", len(m.Proto.Graph.SparseInitializer)) 226 | if len(m.Proto.Graph.SparseInitializer) > 0 { 227 | w(":") 228 | } 229 | w("\n") 230 | for _, st := range m.Proto.Graph.SparseInitializer { 231 | shape, _ := SparseShape(st) 232 | w("\t\t%q: dense shape=%d\n", st.Values.Name, shape) 233 | } 234 | return err 235 | } 236 | 237 | // PrintGraphviz outputs the model graph using the "dot" language, starting from the target nodes towards 238 | // its dependencies. 239 | // 240 | // If targets is left empty, it takes the default graph outputs as targets. 241 | func (m *Model) PrintGraphviz(writer io.Writer, targets ...string) error { 242 | if targets == nil { 243 | targets = m.OutputsNames 244 | } 245 | 246 | var err error 247 | w := func(format string, args ...any) { 248 | if err != nil { 249 | return 250 | } 251 | _, err = fmt.Fprintf(writer, format, args...) 252 | if err != nil { 253 | err = errors.Wrapf(err, "Model.PrintGraphviz() failed to write") 254 | } 255 | } 256 | 257 | w("digraph %s {\n", m.Name()) 258 | visited := types.MakeSet[string]() 259 | for _, target := range targets { 260 | if err != nil { 261 | break 262 | } 263 | err = m.recursiveGraphviz(writer, visited, target) 264 | } 265 | w("}") 266 | return err 267 | } 268 | 269 | var ( 270 | GraphvizInputColor = "#FFF59E" 271 | GraphvizVarColor = "#E0E0E0" 272 | ) 273 | 274 | func (m *Model) recursiveGraphviz(writer io.Writer, visited types.Set[string], target string) error { 275 | if visited.Has(target) { 276 | return nil 277 | } 278 | visited.Insert(target) 279 | 280 | // Define w. 281 | var err error 282 | w := func(format string, args ...any) { 283 | if err != nil { 284 | return 285 | } 286 | _, err = fmt.Fprintf(writer, format, args...) 287 | if err != nil { 288 | err = errors.Wrapf(err, "Model.PrintGraphviz() failed to write") 289 | } 290 | } 291 | 292 | // target is an input. 293 | if m.inputsNameSet.Has(target) { 294 | w("\t%q [shape=box, style=filled, fillcolor=%q];\n", target, GraphvizInputColor) 295 | return err 296 | } 297 | 298 | // target is a label. 299 | if v, found := m.variableNameToValue[target]; found { 300 | var vShape shapes.Shape 301 | vShape, err = Shape(v) 302 | w("\t%q [shape=box, style=filled, fillcolor=%q, tooltip=%q];\n", target, GraphvizVarColor, vShape) 303 | return err 304 | } 305 | 306 | node, found := m.nodeOutputToNode[target] 307 | if !found { 308 | err = errors.Errorf("couldn't find target %q in model graph!?", target) 309 | return err 310 | } 311 | 312 | for _, input := range node.Input { 313 | w("\t%q -> %q\n", input, target) 314 | if err != nil { 315 | return err 316 | } 317 | err = m.recursiveGraphviz(writer, visited, input) 318 | } 319 | return err 320 | } 321 | -------------------------------------------------------------------------------- /onnx/tensor.go: -------------------------------------------------------------------------------- 1 | package onnx 2 | 3 | import ( 4 | "github.com/gomlx/gomlx/types/shapes" 5 | "github.com/gomlx/gomlx/types/tensors" 6 | "github.com/gomlx/gopjrt/dtypes" 7 | "github.com/gomlx/onnx-gomlx/internal/protos" 8 | "github.com/pkg/errors" 9 | ) 10 | 11 | // Shape converts an ONNX data type and shape to GoMLX shapes.Shape (it includes the dtype). 12 | func Shape(proto *protos.TensorProto) (shape shapes.Shape, err error) { 13 | if proto == nil { 14 | err = errors.New("ONNX TensorProto is nil") 15 | return 16 | } 17 | shape.DType, err = dtypeForONNX(protos.TensorProto_DataType(proto.DataType)) 18 | if err != nil { 19 | return 20 | } 21 | shape.Dimensions = make([]int, len(proto.Dims)) 22 | for axis, dim := range proto.Dims { 23 | shape.Dimensions[axis] = int(dim) 24 | } 25 | if proto.Segment != nil { 26 | err = errors.Errorf("segmented tensor not supported (%v)", proto.Segment) 27 | return 28 | } 29 | return 30 | } 31 | 32 | // SparseShape returns what would be the dense shape of an ONNX SparseTensor. 33 | func SparseShape(proto *protos.SparseTensorProto) (shape shapes.Shape, err error) { 34 | if proto == nil || proto.Values == nil || proto.Indices == nil { 35 | err = errors.New("ONNX SparseTensorProto or its components are nil") 36 | return 37 | } 38 | shape.DType, err = dtypeForONNX(protos.TensorProto_DataType(proto.Values.DataType)) 39 | if err != nil { 40 | return 41 | } 42 | shape.Dimensions = make([]int, len(proto.Dims)) 43 | for axis, dim := range proto.Dims { 44 | shape.Dimensions[axis] = int(dim) 45 | } 46 | return 47 | } 48 | 49 | // checkAndCreateTensor implements the generic check and copy of the ONNX proto data to a tensor for the supported data type. 50 | func checkAndCreateTensor[T interface { 51 | float32 | float64 | int32 | int64 | uint64 52 | }](proto *protos.TensorProto, onnxData []T, shape shapes.Shape) (*tensors.Tensor, error) { 53 | if onnxData == nil { 54 | // Not this type of data. 55 | return nil, nil 56 | } 57 | if shape.DType != dtypes.FromGenericsType[T]() { 58 | return nil, errors.Errorf("tensor %q shaped %s provided data as %T!?", proto.Name, shape, onnxData) 59 | } 60 | if len(onnxData) != shape.Size() { 61 | return nil, errors.Errorf("tensor %q shaped %s has size %d , but ONNX model provided a slice with %d values!?", 62 | proto.Name, shape, shape.Size(), len(onnxData)) 63 | } 64 | return tensors.FromFlatDataAndDimensions[T](onnxData, shape.Dimensions...), nil 65 | } 66 | 67 | // tensorToGoMLX converts a protos.TensorProto object to a tensors.Tensor object, handling errors and different data types. 68 | func tensorToGoMLX(proto *protos.TensorProto) (t *tensors.Tensor, err error) { 69 | var shape shapes.Shape 70 | shape, err = Shape(proto) 71 | if err != nil { 72 | err = errors.WithMessagef(err, "while parsing tensor %q", proto.Name) 73 | return 74 | } 75 | 76 | // If data is provided as RawData: check that the size of the data is the same used in GoMLX. 77 | if proto.RawData != nil { 78 | t = tensors.FromShape(shape) 79 | t.MutableBytes(func(data []byte) { 80 | if len(data) != len(proto.RawData) { 81 | err = errors.Errorf("tensor %q shaped %s uses %d bytes, but ONNX model provided %d bytes of raw-data!?", 82 | proto.Name, shape, len(data), len(proto.RawData)) 83 | } else { 84 | copy(data, proto.RawData) 85 | } 86 | }) 87 | if err != nil { 88 | t.FinalizeAll() 89 | t = nil 90 | return nil, err 91 | } 92 | return 93 | } 94 | 95 | // Tries to convert to each data type. 96 | t, err = checkAndCreateTensor(proto, proto.FloatData, shape) 97 | if t != nil || err != nil { 98 | return 99 | } 100 | t, err = checkAndCreateTensor(proto, proto.DoubleData, shape) 101 | if t != nil || err != nil { 102 | return 103 | } 104 | t, err = checkAndCreateTensor(proto, proto.Int32Data, shape) 105 | if t != nil || err != nil { 106 | return 107 | } 108 | t, err = checkAndCreateTensor(proto, proto.Int64Data, shape) 109 | if t != nil || err != nil { 110 | return 111 | } 112 | t, err = checkAndCreateTensor(proto, proto.Uint64Data, shape) 113 | if t != nil || err != nil { 114 | return 115 | } 116 | // Unknown tensor data type!? 117 | return nil, errors.Errorf("tensor %q shaped %s has no supported format of data in the ONNX model!?", proto.Name, shape) 118 | } 119 | 120 | // checkAndCopyTensor implements the generic check and copy of the tensor to the ONNX proto data. 121 | func checkAndCopyTensor[T interface { 122 | float32 | float64 | int32 | int64 | uint64 123 | }](t *tensors.Tensor, proto *protos.TensorProto, onnxData []T) error { 124 | shape := t.Shape() 125 | if shape.DType != dtypes.FromGenericsType[T]() { 126 | return errors.Errorf("tensor %q shaped %s provided data as %T!?", proto.Name, shape, onnxData) 127 | } 128 | if len(onnxData) != shape.Size() { 129 | return errors.Errorf("tensor %q shaped %s has size %d , but ONNX model provided a slice with %d values!?", 130 | proto.Name, shape, shape.Size(), len(onnxData)) 131 | } 132 | tensors.ConstFlatData(t, func(tensorData []T) { 133 | copy(onnxData, tensorData) // Copy data to ONNX proto. 134 | }) 135 | return nil 136 | } 137 | 138 | // TensorValueToONNX copies the value of a GoMLX tensors.Tensor to the ONNX protos.TensorProto object handling errors and different data types. 139 | // 140 | // Both tensors (GoMLX and ONNX) must already have the same shape. 141 | func TensorValueToONNX(t *tensors.Tensor, proto *protos.TensorProto) (err error) { 142 | var shape shapes.Shape 143 | shape, err = Shape(proto) 144 | if err != nil { 145 | return errors.WithMessagef(err, "while parsing tensor %q", proto.Name) 146 | } 147 | if !shape.Equal(t.Shape()) { 148 | return errors.Errorf("TensorValueToONNX: cannot copy value of GoMLX tensor shaped %s to ONNX tensor shaped %s", 149 | t.Shape(), shape) 150 | } 151 | 152 | // Raw data tensor. 153 | if proto.RawData != nil { 154 | t.ConstBytes(func(data []byte) { 155 | if len(data) != len(proto.RawData) { 156 | err = errors.Errorf("tensor %q shaped %s uses %d bytes, but ONNX model provided %d bytes of raw-data!?", 157 | proto.Name, shape, len(data), len(proto.RawData)) 158 | } 159 | copy(proto.RawData, data) // Copy data to ONNX proto. 160 | }) 161 | return err 162 | } 163 | 164 | // Float32 165 | if proto.FloatData != nil { 166 | return checkAndCopyTensor(t, proto, proto.FloatData) 167 | } 168 | if proto.DoubleData != nil { 169 | return checkAndCopyTensor(t, proto, proto.DoubleData) 170 | } 171 | if proto.Int32Data != nil { 172 | return checkAndCopyTensor(t, proto, proto.Int32Data) 173 | } 174 | if proto.Int64Data != nil { 175 | return checkAndCopyTensor(t, proto, proto.Int64Data) 176 | } 177 | if proto.Uint64Data != nil { 178 | return checkAndCopyTensor(t, proto, proto.Uint64Data) 179 | } 180 | return errors.Errorf("tensor %q shaped %s has no supported format of data in the ONNX model!?", proto.Name, shape) 181 | } 182 | -------------------------------------------------------------------------------- /onnx/variables.go: -------------------------------------------------------------------------------- 1 | package onnx 2 | 3 | import ( 4 | "github.com/gomlx/exceptions" 5 | "github.com/gomlx/gomlx/ml/context" 6 | "github.com/pkg/errors" 7 | "strings" 8 | ) 9 | 10 | // This file defines importing variables from ONNX and (TODO) saving them back to the ONNX model file. 11 | 12 | // ModelScope is the default model scope to use when for the ONNX model variables when converting to GoMLX. 13 | var ModelScope = "ONNX" 14 | 15 | // This file defines the methods that build the computation graph using GoMLX. 16 | 17 | // VariablesToContext will create variables in the context (within scope ModelScope) from 18 | // all variables present in the model initializer list. 19 | // 20 | // Call this once in your context, before using the model with Model.CallGraph. 21 | // Alternatively, if you have already checkpoint-ed your model, load the variables from a checkpoint and don't call this. 22 | // 23 | // See also ContextToONNX, if after converting and fine-tuning an ONNX model, you want to update its weights. 24 | func (m *Model) VariablesToContext(ctx *context.Context) error { 25 | if len(m.Proto.Graph.SparseInitializer) > 0 { 26 | exceptions.Panicf("onnx.VariablesToContext does not support ONNX SparseTensors") 27 | } 28 | ctx = ctx.In(ModelScope).Checked(false) 29 | for _, tensorProto := range m.Proto.Graph.Initializer { 30 | tensor, err := tensorToGoMLX(tensorProto) 31 | if err != nil { 32 | return errors.WithMessagef(err, "Model.VariablesToContext()") 33 | } 34 | tensorName := SafeVarName(tensorProto.Name) 35 | ctx.VariableWithValue(tensorName, tensor) 36 | } 37 | return nil 38 | } 39 | 40 | // SafeVarName converts an ONNX variable name to a GoMLX safe variable name by replacing the scope separator with a "|". 41 | func SafeVarName(onnxName string) (gomlxName string) { 42 | return strings.ReplaceAll(onnxName, context.ScopeSeparator, "|") 43 | } 44 | 45 | // ContextToONNX converts the variables in the context back to the ONNX model. 46 | // Do this before saving the ONNX model back to disk. 47 | // 48 | // It's the inverse of VariablesToContext, and the context given must be set in the same scope as when 49 | // VariablesToContext was first called. 50 | // 51 | // Only those variables present in the original ONNX model are converted -- so new variables (e.g.: optimizers (ADAM) 52 | // moving averages) are converted. 53 | func (m *Model) ContextToONNX(ctx *context.Context) error { 54 | if len(m.Proto.Graph.SparseInitializer) > 0 { 55 | exceptions.Panicf("onnx.VariablesToContext does not support ONNX SparseTensors") 56 | } 57 | ctx = ctx.In(ModelScope) 58 | for _, tensorProto := range m.Proto.Graph.Initializer { 59 | tensorName := SafeVarName(tensorProto.Name) 60 | gomlxVar := ctx.GetVariable(tensorName) 61 | if gomlxVar == nil { 62 | return errors.Errorf("ONNX variable '%s' not found in context scope %q --"+ 63 | " maybe you used a different scope when Model.VariablesToContext() was used ?", 64 | tensorName, ctx.Scope()) 65 | } 66 | err := TensorValueToONNX(gomlxVar.Value(), tensorProto) 67 | if err != nil { 68 | return errors.WithMessagef(err, "Model.ContextToONNX() converting tensor %q", tensorName) 69 | } 70 | } 71 | return nil 72 | 73 | } 74 | --------------------------------------------------------------------------------