├── .gitignore ├── LICENSE ├── README.md ├── pyproject.toml ├── requirements.txt ├── src └── onnx_shrink_ray │ ├── __init__.py │ └── shrink.py └── tests └── shrink_test.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.onnx 2 | *.pyc 3 | dist/ 4 | __pycache__/ 5 | *.egg-info -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ONNX Shrink Ray 2 | 3 | Shrinks the size of ONNX files by quantizing large float constants into eight bit equivalents, while leaving all calculations in floating point. 4 | 5 | - [Installation](#installation) 6 | - [Usage](#usage) 7 | - [To reduce the size of a single file](#to-reduce-the-size-of-a-single-file) 8 | - [To reduce the compressed size of a file](#to-reduce-the-compressed-size-of-a-file) 9 | - [To print information about the weights in a file](#to-print-information-about-the-weights-in-a-file) 10 | - [What Shrink Ray does](#what-shrink-ray-does) 11 | - [Results](#results) 12 | - [Moonshine Tiny](#moonshine-tiny) 13 | - [Moonshine Base](#moonshine-base) 14 | - [Notes](#notes) 15 | - [Other Models](#other-models) 16 | 17 | ## Installation 18 | 19 | The easiest way to get started is to install this package in Python using pip: 20 | 21 | ```bash 22 | pip install onnx_shrink_ray 23 | ``` 24 | 25 | You can also download this repository and run the `shrink.py` script directly. 26 | 27 | ## Usage 28 | 29 | ### To reduce the size of a single file 30 | 31 | ```bash 32 | python -m onnx_shrink_ray.shrink myfile.onnx 33 | ``` 34 | 35 | This will convert all of the weights in the ONNX file from 32-bit floating point to 8-bit integers, followed by a `DequantizeLinear` operation to linearly scale those into approximations of the original values for later calculations. The resulting ONNX file is typically less than 30% of the input's size. 36 | 37 | ### To reduce the compressed size of a file 38 | 39 | ```bash 40 | python -m onnx_shrink_ray.shrink --method "float_weights" --float_levels 256 myfile.onnx 41 | ``` 42 | 43 | A lot of downloads and app bundles are automatically compressed using a standard like `gzip` or `brotli`. Neural network weights often don't compress well when they're stored as floating point numbers, since there is very little repetition in the values, they're usually all slightly different from one another. If we know our model will be compressed for delivery, we can reduce the actual download size by making the weight values (which normally make up the majority of the file) easier to compress. 44 | 45 | This tool does this by rounding all the float values in a weight array to the nearest in a limited number of quantized steps, but then storing the results back into a 32-bit floating point tensor. This means the uncompressed size on disk remains the same, but the compressed version is often several times smaller. This is because there's now only a limited number of values in each weight tensor, so there's a lot more repetition in the byte stream for the compression algorithm to take advantage of. 46 | 47 | By default, each weight tensor is quantized to 256 levels, but since the results are stored as floating point values, you can modify this to trade off compressed file size for accuracy. For example, increasing the `--float_levels` argument to 1,000 can improve accuracy at the cost of a larger compressed file, whereas 100 would shrink the size, but could negatively impact quality. 48 | 49 | ### To print information about the weights in a file 50 | 51 | ```bash 52 | python -m onnx_shrink_ray.shrink --info myfile.onnx 53 | ``` 54 | 55 | This will analyze the file, and output information about the weight arrays stored in it, including their shape, type, and size in bytes. It will also show how much of the file size is weights, and how much is from other information. Ideally, the weights should be the majority of the file size. Here is some example output: 56 | 57 | ```bash 58 | Model: decoder_model_merged.onnx 59 | Initializer: onnx::MatMul_2282_merged_0_quantized: [288, 288] - 82,944 elements, uint8, 82,944 bytes 60 | ... 61 | Initializer: onnx::MatMul_2444_merged_0_quantized: [1152, 288] - 331,776 elements, uint8, 331,776 bytes 62 | Initializer: model.decoder.embed_tokens.weight_merged_0_quantized: [32768, 288] - 9,437,184 elements, int8, 9,437,184 bytes 63 | Total nodes: 0 64 | Total initializers: 61 65 | Total bytes from weights: 19,475,173 bytes, 9,819,391 bytes from other data 66 | ------------------------------------------- 67 | ``` 68 | 69 | ## What Shrink Ray does 70 | 71 | Standard ONNX quantization is focused on converting all calculations to eight bit, which can reduce latency dramatically on some platforms. This approach can also cause accuracy problems however, and often requires some manual work to achieve the best results. 72 | 73 | Sometimes though, the biggest problem is not speeding up the execution of a network, but reducing the size of the model data. This can be the case when a model has to be downloaded, where the size determines the loading time before it can be used, or when it's part of a mobile app bundle or other edge device with limited storage space. 74 | 75 | The standard ONNX quantization does offer some file size benefits, but the potential impact on accuracy means it can take time and effort to achieve these savings. As an alternative, this module implements "weight-only quantization", where all calculations and activation layers are left in their initial precision, and only the weights are stored in a lower-fidelity format. 76 | 77 | This approach has the advantage that it is much less likely to significantly impact accuracy, and so can usually be applied quickly, with no manual tweaking or fixups required. It will not speed up latency (and some of the methods may actually slow execution by a small amount) but it can offer significant file size savings. 78 | 79 | Though this method is designed to have a minimal impact on the accuracy of the model, there are networks that may be adversely affected. The heuristic used to identify weights simply searches for constants or initializers that are larger than 16,384 elements, with the assumption that smaller constants are more likely to be non-weight parameters, and won't contribute much to the overall size of the model on disk. 80 | 81 | ## Results 82 | 83 | The initial reason for creating this project was to reduce the download size for the [Moonshine](https://github.com/usefulsensors/moonshine) models on the web, so I've done the most extensive testing on those networks. Here are the size and accuracy results when running against the LibreSpeech clean English-language dataset. 84 | 85 | ### Moonshine Tiny 86 | 87 | | | WER | File Size | GZIP Size | Brotli Size | Latency | 88 | |------------------------------|--------|-----------|-----------|-------------|---------| 89 | | Original | 4.51% | 272MB | 251MB | 226MB | 307ms | 90 | | Integer Weights | 4.69% | 69MB | 53MB | 46MB | 466ms | 91 | | Float Weights (100 levels) | 11.34% | 272MB | 60MB | 46MB | 188ms | 92 | | Float Weights (256 levels) | 4.69% | 272MB | 75MB | 59MB | 329ms | 93 | | Float Weights (1,000 levels) | 4.47% | 272MB | 108MB | 79MB | 296ms | 94 | | ONNX Dynamic Quantization | 30.99% | 113MB | 95MB | 71MB | 317ms | 95 | 96 | ### Moonshine Base 97 | 98 | | | WER | File Size | GZIP Size | Brotli Size | Latency | 99 | |------------------------------|--------|-----------|-----------|-------------|---------| 100 | | Original | 3.29% | 556MB | 515MB | 469MB | 420ms | 101 | | Integer Weights | 3.28% | 141MB | 105MB | 92MB | 729ms | 102 | | Float Weights (100 levels) | 3.55% | 556MB | 120MB | 94MB | 402ms | 103 | | Float Weights (256 levels) | 3.28% | 556MB | 155MB | 121MB | 407ms | 104 | | Float Weights (1,000 levels) | 3.29% | 556MB | 217MB | 161MB | 411ms | 105 | | ONNX Dynamic Quantization | 19.06% | 264MB | 225MB | 180MB | 221ms | 106 | 107 | ### Notes 108 | 109 | The compressed file sizes were calculated by checking the archive size after running `tar --use-compress-program=" --best" -cvf archive.tbz `. The `--best` flag is used here to ensure the compression is as effective as possible by running multiple passes. 110 | 111 | Latency values were calculated by running a ten second audio clip through each model on a Microsoft Surface Pro with an x86 CPU, using the `moonshine_onnx.benchmark()` function included in the library. 112 | 113 | ONNX dynamic quantization results are included for reference. These are models produced by the [`onnxruntime.quantization.quantize_dynamic()`](https://iot-robotics.github.io/ONNXRuntime/docs/performance/quantization.html#quantization-api) function with default arguments. For convenience you can invoke this through the `--method "integer_activations"` option. 114 | 115 | Some interesting patterns are visible: 116 | 117 | - The float weight quantization has no effect on the uncompressed file size, but dramatically decreases the compressed file size, as expected. It also has makes no statistically significant difference to the latency. 118 | 119 | - The integer weight quantization is a lot slower than float weights. This is a bit surprising, since the only difference is a DequantizeLinear operation for each weight constant, but my best guess is that the op hasn't been optimized, on this platform at least. 120 | 121 | - ONNX quantization produces models that are fast, but much less accurate. In my experience this is a common outcome, and can be fixed with some investigation into exactly where the accuracy loss is occuring, but it tends to be a time-consuming process, hence my desire for something easier when file size is the biggest obstacle. 122 | 123 | - ONNX quantization doesn't shrink the raw files as much as I'd expect. If the weights were being stored as 8-bit integers, I'd expect the file size to be the same as the `integer_weights` version, but they're about twice as large. I wonder if the weights are actually stored as 16-bit in this case, or if there's somehow an extra copy? 124 | 125 | - Different models can tolerate different levels of float quantization. The base model only loses a fraction of a percent at 100 levels, whereas the tiny model loses several points. 126 | 127 | - Brotli does a better job at compressing these files than gzip, though the compression process takes significantly longer in my experience. Since brotli is now widely supported by browsers, it seems like the best method to use overall. 128 | 129 | - Apart from the integer weights, most of the float weights versions have similar latencies to the original model. This is expected, since the overall network architecture isn't changed, just the values stored in constants. The only exception is the tiny float weights with 100 levels, which is unexpectedly fast. I don't have a good explanation for this yet, it will require deeper profiling. 130 | 131 | ## Other Models 132 | 133 | I haven't done widespread testing with other models to see what the quality, size, and performance impact is. I'll be maintaining this repository on a best effort basis, so though there are no guarantees on fixes, please [file an issue](https://github.com/usefulsensors/onnx_shrink_ray/issues) if you hit problems with your own models and I'll take a look. 134 | 135 | Pete Warden, pete@usefulsensors.com -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "onnx_shrink_ray" 7 | version = "0.0.8" 8 | authors = [ 9 | { name="Pete Warden", email="pete@petewarden.com" }, 10 | ] 11 | description = "Shrinks the size of ONNX files by quantizing large float constants into eight bit equivalents, while leaving all calculations in floating point." 12 | readme = "README.md" 13 | requires-python = ">=3.8" 14 | classifiers = [ 15 | "Programming Language :: Python :: 3", 16 | "Operating System :: OS Independent", 17 | ] 18 | dependencies = [ 19 | "numpy", 20 | "onnx", 21 | "onnx_graphsurgeon", 22 | "onnxruntime", 23 | ] 24 | [project.urls] 25 | Homepage = "https://github.com/usefulsensors/onnx_shrink_ray" 26 | Issues = "https://github.com/usefulsensors/onnx_shrink_ray/issues" -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | onnx 3 | onnx_graphsurgeon 4 | onnxruntime 5 | -------------------------------------------------------------------------------- /src/onnx_shrink_ray/__init__.py: -------------------------------------------------------------------------------- 1 | # __init__.py 2 | 3 | # This file marks the directory as a Python package and can be used to initialize package-level variables or import submodules. -------------------------------------------------------------------------------- /src/onnx_shrink_ray/shrink.py: -------------------------------------------------------------------------------- 1 | import onnx_graphsurgeon as gs 2 | import numpy as np 3 | import onnx 4 | from onnxruntime.quantization import quantize_dynamic, QuantType 5 | 6 | 7 | # Don't quantify constants smaller than this. 8 | DEFAULT_MIN_ELEMENTS = 16 * 1024 9 | 10 | 11 | def replace_tensor_for_subgraph(graph, original_tensor_name, new_tensor): 12 | """Replace a tensor in a graph with a new tensor. 13 | 14 | Args: 15 | graph: The graph to modify. 16 | original_tensor_name: The name of the tensor to replace. 17 | new_tensor: The tensor to replace it with. 18 | """ 19 | for node in graph.nodes: 20 | for subgraph in node.attrs.values(): 21 | if isinstance(subgraph, gs.Graph): 22 | replace_tensor_for_subgraph(subgraph, original_tensor_name, new_tensor) 23 | for i, tensor in enumerate(node.inputs): 24 | if tensor.name == original_tensor_name: 25 | node.inputs[i] = new_tensor 26 | 27 | for i, tensor in enumerate(graph.outputs): 28 | if tensor.name == original_tensor_name: 29 | graph.outputs[i] = new_tensor 30 | 31 | def gather_initializers_in_graph(graph, all_initializers): 32 | 33 | for initializer in graph.initializer: 34 | all_initializers[initializer.name] = initializer 35 | 36 | graph.initializer.clear() 37 | 38 | for node in graph.node: 39 | if node.op_type == "If": 40 | for attr in node.attribute: 41 | if attr.name == "then_branch": 42 | all_initializers = gather_initializers_in_graph(attr.g, all_initializers) 43 | elif attr.name == "else_branch": 44 | all_initializers = gather_initializers_in_graph(attr.g, all_initializers) 45 | 46 | return all_initializers 47 | 48 | def hoist_subgraph_initializers(onnx_model): 49 | """GraphSurgeon seems to leave duplicated initializers in the graph, so remove them.""" 50 | 51 | all_initializers = {} 52 | gather_initializers_in_graph(onnx_model.graph, all_initializers) 53 | 54 | for name, initializer in all_initializers.items(): 55 | onnx_model.graph.initializer.append(initializer) 56 | 57 | return onnx_model 58 | 59 | def quantize_tensor(name, value_tensor, original_output_tensor_name, graph, root_graph): 60 | """Quantize a constant tensor to int8 using the DequantizeLinear op. 61 | 62 | Args: 63 | name: The name of the tensor to quantize. 64 | value_tensor: The tensor to quantize. 65 | original_output_tensor_name: The name of the original tensor in the graph. 66 | graph: The graph to modify. 67 | root_graph: The root graph of the model. 68 | """ 69 | float_values = value_tensor.values 70 | min_val = np.min(float_values) 71 | max_val = np.max(float_values) 72 | range_val = max_val - min_val 73 | inverse_range = 1.0 / range_val 74 | zero_point = round(-min_val * inverse_range * 255.0) - 128 75 | quantized_values = np.round(float_values * inverse_range * 255.0) + zero_point 76 | quantized_values = np.clip(quantized_values, -128, 127).astype(np.int8) 77 | 78 | quantized_tensor = gs.Constant( 79 | name=f"{name}_quantized", 80 | values=quantized_values) 81 | 82 | scale_value = range_val / 255.0 83 | scale_tensor = gs.Constant( 84 | name=f"{name}_scale", 85 | values=np.array([scale_value], dtype=np.float32)) 86 | 87 | zero_point_tensor = gs.Constant( 88 | name=f"{name}_zero_point", 89 | values=np.array([-zero_point * scale_value], dtype=np.float32)) 90 | 91 | # DequantizeLinear is surprisingly slow in the OnnxRuntime, so achieve the 92 | # same effect with a Cast, Mul, and Add. 93 | cast_tensor_name = f"{name}_cast_tensor" 94 | cast_tensor = gs.Variable( 95 | name=cast_tensor_name, 96 | dtype=np.float32, 97 | shape=value_tensor.shape) 98 | cast_node = gs.Node( 99 | op="Cast", 100 | name=f"{name}_cast_node", 101 | inputs=[quantized_tensor], 102 | outputs=[cast_tensor], 103 | attrs={"to": np.float32}) 104 | 105 | mul_tensor_name = f"{name}_mul_tensor" 106 | mul_tensor = gs.Variable( 107 | name=mul_tensor_name, 108 | dtype=np.float32, 109 | shape=value_tensor.shape) 110 | mul_node = gs.Node( 111 | op="Mul", 112 | name=f"{name}_mul_node", 113 | inputs=[cast_tensor, scale_tensor], 114 | outputs=[mul_tensor]) 115 | 116 | add_tensor_name = f"{name}_add_tensor" 117 | add_tensor = gs.Variable( 118 | name=add_tensor_name, 119 | dtype=np.float32, 120 | shape=value_tensor.shape) 121 | add_node = gs.Node( 122 | op="Add", 123 | name=f"{name}_add_node", 124 | inputs=[mul_tensor, zero_point_tensor], 125 | outputs=[add_tensor]) 126 | 127 | replace_tensor_for_subgraph(root_graph, original_output_tensor_name, add_tensor) 128 | 129 | root_graph.nodes.append(cast_node) 130 | root_graph.nodes.append(mul_node) 131 | root_graph.nodes.append(add_node) 132 | 133 | 134 | def float_quantize_node(name, value_tensor, original_output_tensor_name, root_graph, levels=256): 135 | """Quantize a constant tensor to a small number of float values. 136 | 137 | Args: 138 | name: The name of the tensor to quantize. 139 | value_tensor: The tensor to quantize. 140 | original_output_tensor_name: The name of the original tensor in the graph. 141 | graph: The graph to modify. 142 | levels: The number of levels to quantize to. 143 | """ 144 | float_values = value_tensor.values 145 | min_val = np.min(float_values) 146 | max_val = np.max(float_values) 147 | range_val = max_val - min_val 148 | inverse_range = 1.0 / range_val 149 | half_levels = (levels / 2) 150 | zero_point = round(-min_val * inverse_range * (levels - 1)) - half_levels 151 | scale_value = range_val / (levels - 1) 152 | quantized_values = np.round(float_values * inverse_range * (levels - 1)) + zero_point 153 | quantized_values = np.clip(quantized_values, -half_levels, (half_levels - 1)) 154 | dequantized_values = ((quantized_values.astype(np.int32) - zero_point) * scale_value).astype(np.float32) 155 | 156 | dequantized_tensor = gs.Constant( 157 | name=f"{name}_dequantized", 158 | values=dequantized_values) 159 | 160 | replace_tensor_for_subgraph(root_graph, original_output_tensor_name, dequantized_tensor) 161 | 162 | def quantize_weights_for_graph(graph, root_graph, already_processed, min_elements=DEFAULT_MIN_ELEMENTS, float_quantization=False, float_levels=256, verbose=False): 163 | for node in graph.nodes: 164 | for subgraph in node.attrs.values(): 165 | if isinstance(subgraph, gs.Graph): 166 | if verbose: 167 | print(f"Processing subgraph {subgraph.name}") 168 | already_processed = quantize_weights_for_graph( 169 | subgraph, root_graph, already_processed, min_elements, float_quantization, float_levels) 170 | if node.op != "Constant": 171 | continue 172 | name = node.name 173 | value_tensor = node.attrs["value"] 174 | if value_tensor.dtype != np.float32 and value_tensor.dtype != np.float64: 175 | continue 176 | original_output_tensor_name = node.outputs[0].name 177 | if original_output_tensor_name in already_processed: 178 | continue 179 | already_processed.add(original_output_tensor_name) 180 | elements = np.prod(value_tensor.shape) 181 | if elements < min_elements: 182 | continue 183 | if verbose: 184 | print(f"Processing node {name}") 185 | if float_quantization: 186 | float_quantize_node(name, value_tensor, original_output_tensor_name, root_graph, levels=float_levels) 187 | else: 188 | quantize_tensor(name, value_tensor, original_output_tensor_name, graph, root_graph) 189 | 190 | for name, value_tensor in graph.tensors().items(): 191 | if value_tensor.dtype != np.float32 and value_tensor.dtype != np.float64: 192 | continue 193 | if value_tensor.__class__ != gs.Constant: 194 | continue 195 | original_output_tensor_name = name 196 | if original_output_tensor_name in already_processed: 197 | continue 198 | already_processed.add(original_output_tensor_name) 199 | elements = np.prod(value_tensor.shape) 200 | if elements < min_elements: 201 | continue 202 | if verbose: 203 | print(f"Processing initializer {name}") 204 | if float_quantization: 205 | float_quantize_node(name, value_tensor, original_output_tensor_name, root_graph, levels=float_levels) 206 | else: 207 | quantize_tensor(name, value_tensor, original_output_tensor_name, graph, root_graph) 208 | 209 | return already_processed 210 | 211 | def quantize_weights(input_data, min_elements=DEFAULT_MIN_ELEMENTS, float_quantization=False, float_levels=256, verbose=False): 212 | """Quantize the weights of an ONNX model. 213 | 214 | Args: 215 | input_data: The path or contents of the ONNX model to quantize. 216 | min_elements: The minimum number of elements a tensor must have to be quantized. 217 | float_quantization: If True, store the quantized values as float, not integers. 218 | float_levels: The number of levels to quantize to if using float quantization. 219 | verbose: If True, log detailed information about the weight processing. 220 | """ 221 | if verbose: 222 | print(f"quantize_weights(input_data, min_elements={min_elements}, float_quantization={float_quantization}, float_levels={float_levels})") 223 | 224 | graph = gs.import_onnx(input_data) 225 | 226 | already_processed = set() 227 | quantize_weights_for_graph(graph, graph, already_processed, min_elements, float_quantization, float_levels, verbose) 228 | 229 | graph.cleanup(remove_unused_graph_inputs=False).toposort(recurse_subgraphs=True) 230 | 231 | no_shape_model = gs.export_onnx(graph) 232 | deduped_model = hoist_subgraph_initializers(no_shape_model) 233 | new_model = onnx.shape_inference.infer_shapes(deduped_model) 234 | 235 | onnx.checker.check_model(new_model) 236 | 237 | return new_model 238 | 239 | def print_weight_info_for_graph(onnx_graph, total_bytes, node_count, initializer_count, already_processed, min_elements=DEFAULT_MIN_ELEMENTS): 240 | for node in onnx_graph.node: 241 | value_tensor = None 242 | for attribute in node.attribute: 243 | if attribute.name == "value": 244 | value_tensor = attribute.t 245 | if attribute.HasField("g"): 246 | subgraph = attribute.g 247 | total_bytes, node_count, initializer_count, already_processed = print_weight_info_for_graph( 248 | subgraph, total_bytes, node_count, initializer_count, already_processed, min_elements) 249 | if node.op_type != "Constant": 250 | continue 251 | output_tensor_name = node.output[0] 252 | if output_tensor_name in already_processed: 253 | continue 254 | already_processed.add(output_tensor_name) 255 | name = node.name 256 | elements = np.prod(value_tensor.dims) 257 | np_dtype = onnx.helper.tensor_dtype_to_np_dtype(value_tensor.data_type) 258 | byte_count = int(elements * np_dtype.itemsize) 259 | total_bytes += byte_count 260 | if elements < min_elements: 261 | continue 262 | node_count += 1 263 | print(f"Node: {name}: {value_tensor.dims} - {elements} elements, {np_dtype}, {byte_count:,} bytes") 264 | 265 | duplicate_names = set() 266 | for value_tensor in onnx_graph.initializer: 267 | name = value_tensor.name 268 | if name in already_processed: 269 | duplicate_names.add(name) 270 | continue 271 | already_processed.add(name) 272 | elements = np.prod(value_tensor.dims) 273 | np_dtype = onnx.helper.tensor_dtype_to_np_dtype(value_tensor.data_type) 274 | byte_count = int(elements * np_dtype.itemsize) 275 | total_bytes += byte_count 276 | if elements < min_elements: 277 | continue 278 | initializer_count += 1 279 | print(f"Initializer: {name}: {value_tensor.dims} - {elements:,} elements, {np_dtype}, {byte_count:,} bytes") 280 | 281 | if len(duplicate_names) > 0: 282 | print(f"Duplicate initializers: {duplicate_names}") 283 | 284 | return total_bytes, node_count, initializer_count, already_processed 285 | 286 | def print_weight_info(filename_or_model, min_elements=DEFAULT_MIN_ELEMENTS): 287 | """Return information about the size of the weights in an ONNX model. 288 | 289 | Args: 290 | model: The ONNX model to inspect. 291 | """ 292 | if isinstance(filename_or_model, str): 293 | filename = filename_or_model 294 | onnx_model = onnx.load(filename) 295 | file_byte_count = os.path.getsize(filename) 296 | print(f"Model: {filename}") 297 | else: 298 | onnx_model = filename_or_model 299 | file_byte_count = onnx_model.ByteSize() 300 | 301 | total_bytes = 0 302 | node_count = 0 303 | initializer_count = 0 304 | already_processed = set() 305 | 306 | total_bytes, node_count, initializer_count, already_processed = print_weight_info_for_graph( 307 | onnx_model.graph, total_bytes, node_count, initializer_count, already_processed, min_elements) 308 | 309 | print(f"Total nodes: {node_count}") 310 | print(f"Total initializers: {initializer_count}") 311 | print(f"Total bytes from weights: {total_bytes:,} bytes, {file_byte_count - total_bytes:,} bytes from other data") 312 | print("-------------------------------------------") 313 | 314 | 315 | if __name__ == "__main__": 316 | """Command line utility to quantize ONNX models.""" 317 | import argparse 318 | import glob 319 | import os 320 | import sys 321 | 322 | def get_list_arg(arg): 323 | if arg is None: 324 | return None 325 | return arg.split(",") 326 | 327 | parser = argparse.ArgumentParser( 328 | prog=sys.argv[0], 329 | description="Quantization utility for ONNX models", 330 | ) 331 | parser.add_argument( 332 | "--method", "-m", 333 | help="How to quantize the models", 334 | default="integer_weights", 335 | choices=["integer_weights", "float_weights", "integer_activations"], 336 | ) 337 | parser.add_argument( 338 | "--float_levels", "-l", 339 | help="Number of levels to use for float quantization.", 340 | default=256, 341 | type=int, 342 | ) 343 | parser.add_argument( 344 | "--output_dir", "-o", 345 | help="Folder to write the quantized models to. If not specified, uses the same folder as the input models.", 346 | default=None, 347 | ) 348 | parser.add_argument( 349 | "--output_suffix", "-s", 350 | help="Suffix to add to the output model filenames.", 351 | default="_quantized_weights.onnx", 352 | ) 353 | parser.add_argument( 354 | "--op_types_to_quantize", "-q", 355 | help="Comma-separated list of op types to quantize (default is all supported).", 356 | default=None, 357 | ) 358 | parser.add_argument( 359 | "--nodes_to_quantize", "-t", 360 | help="Comma-separated list of node names to quantize (default is all).", 361 | default=None, 362 | ) 363 | parser.add_argument( 364 | "--nodes_to_exclude", "-n", 365 | help="Comma-separated list of node names not to quantize (default is none).", 366 | default=None, 367 | ) 368 | parser.add_argument( 369 | "--info", "-i", 370 | help="Whether to print information about the weights in the model.", 371 | default=False, 372 | action="store_true", 373 | ) 374 | parser.add_argument( 375 | "--verbose", "-v", 376 | help="Log detailed information about the weight processing.", 377 | default=False, 378 | action="store_true", 379 | ) 380 | parser.add_argument( 381 | "--save-protos", "-p", 382 | help="Write out the input and output ONNX files as text protobufs.", 383 | default=False, 384 | action="store_true", 385 | ) 386 | parser.add_argument("globs", nargs="*") 387 | args = parser.parse_args() 388 | if len(args.globs) == 0: 389 | args.globs = ["*.onnx"] 390 | 391 | if args.output_dir is not None and not os.path.isdir(args.output_dir): 392 | os.makedirs(args.output_dir) 393 | 394 | op_types_to_quantize = get_list_arg(args.op_types_to_quantize) 395 | nodes_to_quantize = get_list_arg(args.nodes_to_quantize) 396 | nodes_to_exclude = get_list_arg(args.nodes_to_exclude) 397 | 398 | for input_glob in args.globs: 399 | if os.path.isdir(input_glob): 400 | input_glob = os.path.join(input_glob, "*.onnx") 401 | input_filenames = list(glob.glob(input_glob)) 402 | if len(input_filenames) == 0: 403 | print(f"No files found matching '{input_glob}'.") 404 | sys.exit(1) 405 | 406 | for input_filename in input_filenames: 407 | if args.info: 408 | print_weight_info(input_filename) 409 | continue 410 | if args.output_suffix != ".onnx" and input_filename.endswith(args.output_suffix): 411 | print(f"Skipping '{input_filename}' as it is already quantized.") 412 | continue 413 | if args.verbose: 414 | print(f"Processing '{input_filename}'") 415 | input_base = os.path.basename(input_filename) 416 | input_dir = os.path.dirname(input_filename) 417 | output_base = os.path.splitext(input_base)[0] + args.output_suffix 418 | if args.output_dir is None: 419 | output_filename = os.path.join(input_dir, output_base) 420 | else: 421 | output_filename = os.path.join(args.output_dir, output_base) 422 | if args.verbose: 423 | print(f"Writing to '{output_filename}'") 424 | if output_filename == input_filename: 425 | print(f"Skipping '{input_filename}' as the output filename is the same and it would be overwritten.") 426 | continue 427 | if args.verbose: 428 | input_file_length = os.path.getsize(input_filename) 429 | if args.method == "float_weights" or args.method == "integer_weights": 430 | original_model = onnx.load(input_filename) 431 | if args.save_protos: 432 | with open(input_filename + ".txt", "w") as f: 433 | f.write(str(original_model)) 434 | float_quantization = (args.method == "float_weights") 435 | new_model = quantize_weights(original_model, float_quantization=float_quantization, float_levels=args.float_levels, verbose=args.verbose) 436 | onnx.save(new_model, output_filename) 437 | if args.save_protos: 438 | with open(output_filename + ".txt", "w") as f: 439 | f.write(str(new_model)) 440 | elif args.method == "integer_activations": 441 | if args.verbose: 442 | print(f"quantize_dynamic('{input_filename}', '{output_filename}', weight_type=QuantType.QUInt8, op_types_to_quantize={op_types_to_quantize}, nodes_to_quantize={nodes_to_quantize}, nodes_to_exclude={nodes_to_exclude}, extra_options={{'EnableSubgraph': True}})") 443 | quantize_dynamic( 444 | input_filename, 445 | output_filename, 446 | weight_type=QuantType.QUInt8, 447 | op_types_to_quantize=op_types_to_quantize, 448 | nodes_to_quantize=nodes_to_quantize, 449 | nodes_to_exclude=nodes_to_exclude, 450 | extra_options={"EnableSubgraph": True}) 451 | else: 452 | print(f"Unknown quantization method: {args.method}") 453 | sys.exit(1) 454 | if args.verbose: 455 | output_file_length = os.path.getsize(output_filename) 456 | print(f"Original file size: {input_file_length:,} bytes, quantized file size: {output_file_length:,} bytes") -------------------------------------------------------------------------------- /tests/shrink_test.py: -------------------------------------------------------------------------------- 1 | import onnx 2 | from onnx import helper as h, TensorProto as tp 3 | import onnxruntime as ort 4 | import numpy as np 5 | 6 | from onnx_shrink_ray.shrink import quantize_weights 7 | 8 | def check_quantization(float_model): 9 | onnx.checker.check_model(float_model) 10 | 11 | quantized_model = quantize_weights(float_model, min_elements=2) 12 | quantized_float_model = quantize_weights(float_model, min_elements=2, float_quantization=True) 13 | 14 | float_session = ort.InferenceSession(float_model.SerializeToString()) 15 | actual_float_output = np.array( 16 | float_session.run(None, {})[0], 17 | dtype=np.float32) 18 | 19 | quantized_session = ort.InferenceSession(quantized_model.SerializeToString()) 20 | actual_quantized_output = np.array( 21 | quantized_session.run(None, {})[0], 22 | dtype=np.float32) 23 | 24 | quantized_float_session = ort.InferenceSession(quantized_float_model.SerializeToString()) 25 | actual_quantized_float_output = np.array( 26 | quantized_float_session.run(None, {})[0], 27 | dtype=np.float32) 28 | 29 | output_min = np.min(actual_float_output) 30 | output_max = np.max(actual_float_output) 31 | output_range = output_max - output_min 32 | output_bin_size = output_range / 255.0 33 | 34 | # print(f"actual_float_output: {actual_float_output}") 35 | # print(f"actual_quantized_output: {actual_quantized_output}") 36 | 37 | output_diff = np.abs(actual_float_output - actual_quantized_output) 38 | max_diff = np.max(output_diff) 39 | if max_diff > output_bin_size: 40 | raise Exception(f"Max difference {max_diff} is greater than output bin size {output_bin_size}.") 41 | 42 | output_qf_diff = np.abs(actual_float_output - actual_quantized_float_output) 43 | max_qf_diff = np.max(output_qf_diff) 44 | if max_qf_diff > output_bin_size: 45 | raise Exception(f"Max difference {max_qf_diff} is greater than output bin size {output_bin_size}.") 46 | 47 | def test_single_constant(): 48 | weights_shape = (1, 1, 2, 2) 49 | weights_values = np.array([[[[0.0, 2.5], [5.0, 10.0]]]], dtype=np.float32) 50 | 51 | weights_tensor = h.make_tensor(name="weights_tensor", data_type=tp.FLOAT, 52 | dims=weights_shape, 53 | vals=weights_values) 54 | 55 | weights_node = h.make_node("Constant", inputs=[], outputs=["weights_output"], name="weights_node", 56 | value=weights_tensor) 57 | 58 | float_graph = h.make_graph([weights_node], "test_graph", 59 | [], 60 | [h.make_tensor_value_info("weights_output", tp.FLOAT, weights_shape)]) 61 | 62 | float_model = h.make_model(float_graph, producer_name="quantization_test") 63 | 64 | check_quantization(float_model) 65 | 66 | def test_identity(): 67 | weights_shape = (1, 1, 2, 2) 68 | weights_values = np.array([[[[0.0, 2.5], [5.0, 10.0]]]], dtype=np.float32) 69 | 70 | weights_tensor = h.make_tensor(name="weights_tensor", data_type=tp.FLOAT, 71 | dims=weights_shape, 72 | vals=weights_values) 73 | 74 | weights_node = h.make_node("Constant", inputs=[], outputs=["weights_output"], name="weights_node", 75 | value=weights_tensor) 76 | 77 | identity_node = h.make_node("Identity", inputs=["weights_output"], outputs=["identity_output"], name="identity_node") 78 | 79 | float_graph = h.make_graph([weights_node, identity_node], "test_graph", 80 | [], 81 | [h.make_tensor_value_info("identity_output", tp.FLOAT, weights_shape)]) 82 | 83 | float_model = h.make_model(float_graph, producer_name="quantization_test") 84 | 85 | check_quantization(float_model) 86 | 87 | def test_mul(): 88 | weights_shape = (1, 1, 2, 2) 89 | weights_values = np.array([[[[0.0, 2.5], [5.0, 10.0]]]], dtype=np.float32) 90 | weights_tensor = h.make_tensor(name="weights_tensor", data_type=tp.FLOAT, 91 | dims=weights_shape, 92 | vals=weights_values) 93 | weights_node = h.make_node("Constant", inputs=[], outputs=["weights_output"], name="weights_node", 94 | value=weights_tensor) 95 | 96 | two_shape = (1, ) 97 | two_values = np.array([2.0], dtype=np.float32) 98 | two_tensor = h.make_tensor(name="two_tensor", data_type=tp.FLOAT, 99 | dims=two_shape, 100 | vals=two_values) 101 | two_node = h.make_node("Constant", inputs=[], outputs=["two_output"], name="two_node", 102 | value=two_tensor) 103 | 104 | mul_node = h.make_node("Mul", inputs=["weights_output", "two_output"], outputs=["mul_output"], name="mul_node") 105 | 106 | float_graph = h.make_graph([weights_node, two_node, mul_node], "test_graph", 107 | [], 108 | [h.make_tensor_value_info("mul_output", tp.FLOAT, weights_shape)]) 109 | 110 | float_model = h.make_model(float_graph, producer_name="quantization_test") 111 | 112 | check_quantization(float_model) 113 | 114 | def test_large_constant(): 115 | weights_width = 256 116 | weights_height = 256 117 | weights_shape = (1, 1, weights_height, weights_width) 118 | rng = np.random.default_rng(7528840384) 119 | weights_values = rng.random((weights_shape)).astype(np.float32) 120 | 121 | weights_tensor = h.make_tensor(name="weights_tensor", data_type=tp.FLOAT, 122 | dims=weights_shape, 123 | vals=weights_values) 124 | 125 | weights_node = h.make_node("Constant", inputs=[], outputs=["weights_output"], name="weights_node", 126 | value=weights_tensor) 127 | 128 | float_graph = h.make_graph([weights_node], "test_graph", 129 | [], 130 | [h.make_tensor_value_info("weights_output", tp.FLOAT, weights_shape)]) 131 | 132 | float_model = h.make_model(float_graph, producer_name="quantization_test") 133 | 134 | check_quantization(float_model) 135 | 136 | def test_signed_constant(): 137 | weights_shape = (1, 1, 2, 2) 138 | weights_values = np.array([[[[-5.0, -2.5], [0.0, 5.0]]]], dtype=np.float32) 139 | 140 | weights_tensor = h.make_tensor(name="weights_tensor", data_type=tp.FLOAT, 141 | dims=weights_shape, 142 | vals=weights_values) 143 | 144 | weights_node = h.make_node("Constant", inputs=[], outputs=["weights_output"], name="weights_node", 145 | value=weights_tensor) 146 | 147 | float_graph = h.make_graph([weights_node], "test_graph", 148 | [], 149 | [h.make_tensor_value_info("weights_output", tp.FLOAT, weights_shape)]) 150 | 151 | float_model = h.make_model(float_graph, producer_name="quantization_test") 152 | 153 | check_quantization(float_model) 154 | 155 | 156 | def test_unbalanced_constant(): 157 | weights_shape = (1, 1, 2, 2) 158 | weights_values = np.array([[[[-2.0, 0.5], [3.0, 8.0]]]], dtype=np.float32) 159 | 160 | weights_tensor = h.make_tensor(name="weights_tensor", data_type=tp.FLOAT, 161 | dims=weights_shape, 162 | vals=weights_values) 163 | 164 | weights_node = h.make_node("Constant", inputs=[], outputs=["weights_output"], name="weights_node", 165 | value=weights_tensor) 166 | 167 | float_graph = h.make_graph([weights_node], "test_graph", 168 | [], 169 | [h.make_tensor_value_info("weights_output", tp.FLOAT, weights_shape)]) 170 | 171 | float_model = h.make_model(float_graph, producer_name="quantization_test") 172 | 173 | check_quantization(float_model) 174 | 175 | if __name__ == "__main__": 176 | test_single_constant() 177 | test_identity() 178 | test_mul() 179 | test_large_constant() 180 | test_signed_constant() 181 | test_unbalanced_constant() 182 | print("All tests passed.") 183 | --------------------------------------------------------------------------------