├── .gitignore ├── LICENSE ├── README.md ├── mx2onnx_converter ├── __init__.py ├── conversion_helpers.py ├── mx2onnx_converter.py └── mx2onnx_converter_functions.py ├── requirements.txt ├── setup.cfg ├── setup.py └── tests └── test_convert_lenet5.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | MxNet-to-ONNX exporter 2 | ========================== 3 | 4 | **NOTE:** This repository is deprecated, since MXNet now (since v. 1.3) has an officially integrated exporter, in part based on this repository. For the latest version of the exporter, please install the latest version of MXNet. The official exporter can now be found [here](https://github.com/apache/incubator-mxnet/tree/master/python/mxnet/contrib/onnx/mx2onnx). 5 | 6 | What is this? 7 | ------------------ 8 | 9 | 10 | This is the repository for the [MxNet](https://github.com/apache/incubator-mxnet)-to-[ONNX](https://github.com/onnx/onnx) converter, which takes a trained MxNet model, represented in serialized form as the .json/.params file pair, and converts that model to ONNX. Please note that this is a file-to-file conversion - the input is a checkpointed MxNet model, NOT the [NNVM](https://github.com/dmlc/nnvm) graph. 11 | 12 | Installation 13 | ---------------------------------------- 14 | 15 | Note that --force will force an upgrade if a previous version was installed. This is equivalent to first uninstalling and then installing again. Without force, an upgrade will not be performed. 16 | 17 | ```python setup.py install --force``` 18 | 19 | Also note that since this project depends on ONNX, and ONNX depends on the [Protobuf](https://github.com/google/protobuf) compiler, the installation of the ONNX [pip](https://packaging.python.org/tutorials/installing-packages/#use-pip-for-installing) package will require the compiler. The installation of the native component will depend on your operating system, but on Ubuntu 16.04, you can simply do 20 | 21 | ```sudo apt-get install protobuf-compiler libprotoc-dev``` 22 | 23 | See the [details](https://github.com/onnx/onnx/blob/master/README.md) as to what is required to install ONNX. Note that even though the ONNX pip package can be fetched from [PyPI](https://pypi.python.org/pypi), it will still depend on the Protobuf compiler. Hence, even though ONNX is listed in requirements.txt, its installation will depend on the aforementioned native components. 24 | 25 | 26 | Tests 27 | ---------------------------------------- 28 | 29 | To run the test that: 30 | 31 | 1. trains [LeNet-5](http://yann.lecun.com/exdb/publis/pdf/lecun-01a.pdf) on [MNIST](http://yann.lecun.com/exdb/mnist/) 32 | 2. checkpoints the MxNet a trained model to the .json/.params file pair that represents a serialized MxNet model 33 | 3. loads the serialized MxNet model and runs inference on test data 34 | 4. converts the serialized MxNet model to ONNX 35 | 5. loads the ONNX model and runs inference on test data 36 | 6. asserts that all 10,000 predictions match 37 | 38 | please run: 39 | 40 | ```python setup.py test``` 41 | 42 | -------------------------------------------------------------------------------- /mx2onnx_converter/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Redistribution and use in source and binary forms, with or without 4 | # modification, are permitted provided that the following conditions 5 | # are met: 6 | # * Redistributions of source code must retain the above copyright 7 | # notice, this list of conditions and the following disclaimer. 8 | # * Redistributions in binary form must reproduce the above copyright 9 | # notice, this list of conditions and the following disclaimer in the 10 | # documentation and/or other materials provided with the distribution. 11 | # * Neither the name of NVIDIA CORPORATION nor the names of its 12 | # contributors may be used to endorse or promote products derived 13 | # from this software without specific prior written permission. 14 | # 15 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY 16 | # EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 17 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 18 | # PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR 19 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 20 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 21 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 22 | # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY 23 | # OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 24 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | 27 | from __future__ import absolute_import 28 | 29 | import mx2onnx_converter.conversion_helpers 30 | import mx2onnx_converter.mx2onnx_converter 31 | import mx2onnx_converter.mx2onnx_converter_functions 32 | -------------------------------------------------------------------------------- /mx2onnx_converter/conversion_helpers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Redistribution and use in source and binary forms, with or without 4 | # modification, are permitted provided that the following conditions 5 | # are met: 6 | # * Redistributions of source code must retain the above copyright 7 | # notice, this list of conditions and the following disclaimer. 8 | # * Redistributions in binary form must reproduce the above copyright 9 | # notice, this list of conditions and the following disclaimer in the 10 | # documentation and/or other materials provided with the distribution. 11 | # * Neither the name of NVIDIA CORPORATION nor the names of its 12 | # contributors may be used to endorse or promote products derived 13 | # from this software without specific prior written permission. 14 | # 15 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY 16 | # EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 17 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 18 | # PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR 19 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 20 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 21 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 22 | # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY 23 | # OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 24 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | 27 | from __future__ import absolute_import 28 | from __future__ import division 29 | from __future__ import print_function 30 | from __future__ import unicode_literals 31 | 32 | from onnx import defs, checker, helper, numpy_helper, mapping 33 | from .mx2onnx_converter import MxNetToONNXConverter 34 | 35 | import json 36 | 37 | import mxnet as mx 38 | import numpy as np 39 | 40 | def from_mxnet(model_file, weight_file, input_shape, input_type, log=False): 41 | mx_weights = mx.ndarray.load(weight_file) 42 | with open(model_file, 'r') as f: 43 | graph = json.loads(f.read())["nodes"] 44 | converter = MxNetToONNXConverter() 45 | onnx_graph = converter.convert_mx2onnx_graph(graph, mx_weights, input_shape, mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(input_type)], log=log) 46 | onnx_model = helper.make_model(onnx_graph) 47 | return onnx_model 48 | 49 | 50 | -------------------------------------------------------------------------------- /mx2onnx_converter/mx2onnx_converter.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Redistribution and use in source and binary forms, with or without 4 | # modification, are permitted provided that the following conditions 5 | # are met: 6 | # * Redistributions of source code must retain the above copyright 7 | # notice, this list of conditions and the following disclaimer. 8 | # * Redistributions in binary form must reproduce the above copyright 9 | # notice, this list of conditions and the following disclaimer in the 10 | # documentation and/or other materials provided with the distribution. 11 | # * Neither the name of NVIDIA CORPORATION nor the names of its 12 | # contributors may be used to endorse or promote products derived 13 | # from this software without specific prior written permission. 14 | # 15 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY 16 | # EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 17 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 18 | # PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR 19 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 20 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 21 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 22 | # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY 23 | # OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 24 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | 27 | from __future__ import absolute_import 28 | from __future__ import division 29 | from __future__ import print_function 30 | from __future__ import unicode_literals 31 | 32 | import numpy as np 33 | 34 | import sys 35 | 36 | from onnx import (defs, checker, helper, numpy_helper, mapping, onnx_pb, 37 | ModelProto, GraphProto, NodeProto, AttributeProto, TensorProto) 38 | 39 | from onnx.helper import make_tensor, make_tensor_value_info 40 | 41 | class MxNetToONNXConverter: 42 | 43 | registry_ = {} 44 | input_output_maps_ = {} 45 | 46 | def __init__(self): 47 | # topologically sorted nodes 48 | self.nodes = [] 49 | self.input_tensors = [] 50 | self.output_tensors = [] 51 | 52 | @staticmethod 53 | def register(op_name): 54 | 55 | def wrapper(func): 56 | MxNetToONNXConverter.registry_[op_name] = func 57 | return func 58 | 59 | return wrapper 60 | 61 | @staticmethod 62 | def convert_layer(node, **kwargs): 63 | op = str(node["op"]) 64 | if op not in MxNetToONNXConverter.registry_: 65 | raise AttributeError("No conversion function registered for op type %s yet." % op) 66 | convert_fun = MxNetToONNXConverter.registry_[op] 67 | return convert_fun(node, **kwargs) 68 | 69 | # Add transpose? 70 | @staticmethod 71 | def convert_weights_to_numpy(weights_dict): 72 | return dict([(k.replace("arg:", "").replace("aux:", ""), v.asnumpy()) for k, v in weights_dict.items()]) 73 | 74 | def convert_mx2onnx_graph(self, mx_graph, mx_weights, in_shape, in_type, log=False): 75 | print("\nconverting weights from MxNet NDArrays to NumPy arrays.\n") 76 | weights = MxNetToONNXConverter.convert_weights_to_numpy(mx_weights) 77 | 78 | onnx_graph = GraphProto() 79 | 80 | initializer = [] 81 | all_processed_nodes = [] 82 | onnx_processed_nodes = [] 83 | onnx_processed_inputs = [] 84 | onnx_processed_outputs = [] 85 | 86 | for idx, node in enumerate(mx_graph): 87 | op = node["op"] 88 | name = node["name"] 89 | if log: 90 | print("Converting idx: %d, op: %s, name: %s" % (idx, op, name)) 91 | converted = MxNetToONNXConverter.convert_layer( 92 | node, 93 | mx_graph = mx_graph, 94 | weights = weights, 95 | in_shape = in_shape, 96 | in_type = in_type, 97 | proc_nodes = all_processed_nodes, 98 | initializer = initializer 99 | ) 100 | 101 | if isinstance(converted, onnx_pb.ValueInfoProto): 102 | if idx < (len(mx_graph) - 1): 103 | onnx_processed_inputs.append(converted) 104 | else: 105 | onnx_processed_outputs.append(converted) 106 | elif isinstance(converted, onnx_pb.NodeProto): 107 | if idx < (len(mx_graph) - 1): 108 | onnx_processed_nodes.append(converted) 109 | else: 110 | onnx_processed_nodes.append(converted) 111 | onnx_processed_outputs.append( 112 | make_tensor_value_info( 113 | name=converted.name, 114 | elem_type=mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype('float32')], 115 | shape=(in_shape[0], -1) 116 | ) 117 | ) 118 | if log: 119 | print("Output node is: %s" % converted.name) 120 | elif isinstance(converted, onnx_pb.TensorProto): 121 | raise ValueError("Did not expect TensorProto") 122 | if idx < (len(mx_graph) - 1): 123 | onnx_processed_inputs.append(converted) 124 | else: 125 | onnx_processed_outputs.append(converted) 126 | else: 127 | print(converted) 128 | raise ValueError("node is of an unrecognized type: %s" % type(node)) 129 | 130 | all_processed_nodes.append(converted) 131 | 132 | graph = helper.make_graph( 133 | onnx_processed_nodes, 134 | "main", 135 | onnx_processed_inputs, 136 | onnx_processed_outputs 137 | ) 138 | 139 | graph.initializer.extend(initializer) 140 | 141 | checker.check_graph(graph) 142 | return graph 143 | 144 | -------------------------------------------------------------------------------- /mx2onnx_converter/mx2onnx_converter_functions.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Redistribution and use in source and binary forms, with or without 4 | # modification, are permitted provided that the following conditions 5 | # are met: 6 | # * Redistributions of source code must retain the above copyright 7 | # notice, this list of conditions and the following disclaimer. 8 | # * Redistributions in binary form must reproduce the above copyright 9 | # notice, this list of conditions and the following disclaimer in the 10 | # documentation and/or other materials provided with the distribution. 11 | # * Neither the name of NVIDIA CORPORATION nor the names of its 12 | # contributors may be used to endorse or promote products derived 13 | # from this software without specific prior written permission. 14 | # 15 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY 16 | # EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 17 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 18 | # PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR 19 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 20 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 21 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 22 | # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY 23 | # OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 24 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | 27 | """ 28 | mx_to_uff_converter_functions.py 29 | 30 | Conversion Functions for common layers. 31 | Add new functions here with a decorator. 32 | """ 33 | from __future__ import absolute_import 34 | from __future__ import division 35 | from __future__ import print_function 36 | from __future__ import unicode_literals 37 | 38 | from onnx import defs, checker, helper, numpy_helper, mapping 39 | 40 | from .mx2onnx_converter import MxNetToONNXConverter as mx2onnx 41 | 42 | import numpy as np 43 | 44 | import re 45 | 46 | import sys 47 | 48 | def looks_like_weight(name): 49 | """Internal helper to figure out if node should be hidden with `hide_weights`. 50 | """ 51 | if name.endswith("_weight"): 52 | return True 53 | if name.endswith("_bias"): 54 | return True 55 | if name.endswith("_beta") or name.endswith("_gamma") or name.endswith("_moving_var") or name.endswith("_moving_mean"): 56 | return True 57 | return False 58 | 59 | 60 | @mx2onnx.register("null") 61 | def convert_weights_and_inputs(node, **kwargs): 62 | name = node["name"] 63 | if looks_like_weight(name): 64 | weights = kwargs["weights"] 65 | initializer = kwargs["initializer"] 66 | weights = kwargs["weights"] 67 | np_arr = weights[name] 68 | data_type = mapping.NP_TYPE_TO_TENSOR_TYPE[np_arr.dtype] 69 | dims = np.shape(np_arr) 70 | 71 | tensor_node = helper.make_tensor_value_info(name, data_type, dims) 72 | 73 | initializer.append( 74 | helper.make_tensor( 75 | name=name, 76 | data_type=data_type, 77 | dims=dims, 78 | vals=np_arr.flatten().tolist(), 79 | raw=False, 80 | ) 81 | ) 82 | 83 | return tensor_node 84 | else: 85 | tval_node = helper.make_tensor_value_info(name, kwargs["in_type"], kwargs["in_shape"]) 86 | return tval_node 87 | 88 | 89 | @mx2onnx.register("Deconvolution") 90 | def convert_deconvolution(node, **kwargs): 91 | name = node["name"] 92 | inputs = node["inputs"] 93 | 94 | num_inputs = len(inputs) 95 | 96 | proc_nodes = kwargs["proc_nodes"] 97 | input_node = proc_nodes[inputs[0][0]].name 98 | weights_node = proc_nodes[inputs[1][0]].name 99 | 100 | if num_inputs > 2: 101 | bias_node = proc_nodes[inputs[2][0]].name 102 | 103 | attrs = node.get("attrs") 104 | tuple_re = re.compile('\([0-9|,| ]+\)') 105 | 106 | def parse_helper(attrs_name, alt_value=None): 107 | if attrs is None: 108 | return alt_value 109 | attrs_str = attrs.get(attrs_name) 110 | if attrs_str is None: 111 | return alt_value 112 | attrs_match = tuple_re.search(attrs_str) 113 | if attrs_match is not None: 114 | if attrs_match.span() == (0, len(attrs_str)): 115 | dims = eval(attrs_str) 116 | return dims 117 | else: 118 | raise AttributeError("Malformed %s dimensions: %s" % (attrs_name, str(attrs_str))) 119 | return alt_value 120 | 121 | num_filter = int(attrs["num_filter"]) 122 | kernel_dims = list(parse_helper("kernel")) 123 | stride_dims = list(parse_helper("stride", [1, 1])) 124 | pad_dims = parse_padding(attrs) 125 | num_group = int(attrs.get("num_group", 1)) 126 | 127 | # Not sure why this is included, it seems to change what the graphs is doing. 128 | # TODO(kellens): Ask Marek if this is requried. 129 | # if len(pad_dims) < 2 * len(kernel_dims): 130 | # pad_dims = [0] * (2 * len(kernel_dims) - len(pad_dims)) + pad_dims 131 | 132 | input_nodes = [input_node, weights_node] 133 | if num_inputs > 2: 134 | input_nodes.append(bias_node) 135 | 136 | deconv_node = helper.make_node( 137 | "ConvTranspose", 138 | inputs=input_nodes, 139 | outputs=[name], 140 | kernel_shape=kernel_dims, 141 | strides=stride_dims, 142 | pads=pad_dims, 143 | group=num_group, 144 | name=name 145 | ) 146 | 147 | return deconv_node 148 | 149 | 150 | @mx2onnx.register("Convolution") 151 | def convert_convolution(node, **kwargs): 152 | name = node["name"] 153 | inputs = node["inputs"] 154 | 155 | num_inputs = len(inputs) 156 | 157 | proc_nodes = kwargs["proc_nodes"] 158 | input_node = proc_nodes[inputs[0][0]].name 159 | weights_node = proc_nodes[inputs[1][0]].name 160 | 161 | if num_inputs > 2: 162 | bias_node = proc_nodes[inputs[2][0]].name 163 | 164 | attrs = node.get("attrs") 165 | tuple_re = re.compile('\([0-9|,| ]+\)') 166 | 167 | def parse_helper(attrs_name, alt_value=None): 168 | if attrs is None: 169 | return alt_value 170 | attrs_str = attrs.get(attrs_name) 171 | if attrs_str is None: 172 | return alt_value 173 | attrs_match = tuple_re.search(attrs_str) 174 | if attrs_match is not None: 175 | if attrs_match.span() == (0, len(attrs_str)): 176 | dims = eval(attrs_str) 177 | return dims 178 | else: 179 | raise AttributeError("Malformed %s dimensions: %s" % (attrs_name, str(attrs_str))) 180 | return alt_value 181 | 182 | num_filter = int(attrs["num_filter"]) 183 | kernel_dims = list(parse_helper("kernel")) 184 | stride_dims = list(parse_helper("stride", [1, 1])) 185 | pad_dims = parse_padding(attrs) 186 | num_group = int(attrs.get("num_group", 1)) 187 | 188 | # Not sure why this is included, it seems to change what the graphs is doing. 189 | # TODO(kellens): Ask Marek if this is requried. 190 | # if len(pad_dims) < 2 * len(kernel_dims): 191 | # pad_dims = [0] * (2 * len(kernel_dims) - len(pad_dims)) + pad_dims 192 | 193 | input_nodes = [input_node, weights_node] 194 | if num_inputs > 2: 195 | input_nodes.append(bias_node) 196 | 197 | conv_node = helper.make_node( 198 | "Conv", 199 | inputs=input_nodes, 200 | outputs=[name], 201 | kernel_shape=kernel_dims, 202 | strides=stride_dims, 203 | pads=pad_dims, 204 | group=num_group, 205 | name=name, 206 | ) 207 | 208 | return conv_node 209 | 210 | 211 | @mx2onnx.register("FullyConnected") 212 | def convert_fully_connected(node, **kwargs): 213 | name = node["name"] 214 | inputs = node["inputs"] 215 | input_node_id = inputs[0][0] 216 | weight_node_id = inputs[1][0] 217 | bias_node_id = inputs[2][0] 218 | proc_nodes = kwargs["proc_nodes"] 219 | input_node = proc_nodes[input_node_id] 220 | weights_node = proc_nodes[weight_node_id] 221 | bias_node = proc_nodes[bias_node_id] 222 | 223 | input_name = input_node.name 224 | weights_name = weights_node.name 225 | bias_name = bias_node.name 226 | 227 | node = helper.make_node( 228 | "Gemm", 229 | [input_name, weights_name, bias_name], # input (A, B, C) - C can be in place 230 | [name], # output 231 | alpha=1.0, 232 | beta=1.0, 233 | transA=False, 234 | transB=True, 235 | name=name 236 | ) 237 | 238 | return node 239 | 240 | @mx2onnx.register("BatchNorm") 241 | def convert_batchnorm(node, **kwargs): 242 | name = node["name"] 243 | proc_nodes = kwargs["proc_nodes"] 244 | inputs = node["inputs"] 245 | 246 | attrs = node["attrs"] 247 | # Default momentum is 0.9 248 | try: 249 | momentum = float(attrs["momentum"]) 250 | except: 251 | momentum = 0.9 252 | # Default eps is 0.001 253 | try: 254 | eps = float(attrs["eps"]) 255 | except: 256 | eps = 0.001 257 | 258 | data_idx = inputs[0][0] 259 | gamma_idx = inputs[1][0] 260 | beta_idx = inputs[2][0] 261 | moving_mean_idx = inputs[3][0] 262 | moving_var_idx = inputs[4][0] 263 | 264 | data_node = proc_nodes[data_idx].name 265 | gamma_node = proc_nodes[gamma_idx].name 266 | beta_node = proc_nodes[beta_idx].name 267 | 268 | mov_mean_node = proc_nodes[moving_mean_idx] 269 | mov_mean_node = mov_mean_node.name 270 | mov_var_node = proc_nodes[moving_var_idx].name 271 | 272 | bn_node = helper.make_node( 273 | "BatchNormalization", 274 | [data_node, 275 | gamma_node, # scale 276 | beta_node, # bias 277 | mov_mean_node, 278 | mov_var_node 279 | ], 280 | [name], 281 | name=name, 282 | epsilon=eps, 283 | momentum=momentum, 284 | is_test=1, 285 | spatial=1, 286 | consumed_inputs=(0, 0, 0, 1, 1) 287 | ) 288 | 289 | return bn_node 290 | 291 | 292 | @mx2onnx.register("Activation") 293 | def convert_activation(node, **kwargs): 294 | name = node["name"] 295 | 296 | proc_nodes = kwargs["proc_nodes"] 297 | attrs = node["attrs"] 298 | act_type = attrs["act_type"] 299 | 300 | inputs = node["inputs"] 301 | input_node_idx = inputs[0][0] 302 | input_node = proc_nodes[input_node_idx].output[0] 303 | 304 | # Creating a dictionary here, but if this titlecase pattern 305 | # is consistent for other activations, this can be changed to 306 | # mxnet_name.title() 307 | act_types = { 308 | "tanh": "Tanh", 309 | "relu": "Relu", 310 | "sigmoid": "Sigmoid", 311 | "softrelu": "Softplus", 312 | "softsign": "Softsign" 313 | } 314 | 315 | act_name = act_types.get(act_type) 316 | if act_name: 317 | node = helper.make_node( 318 | act_name, 319 | [input_node], 320 | [name], 321 | name=name 322 | ) 323 | else: 324 | raise AttributeError( 325 | "Activation %s not implemented or recognized in the converter" % act_type 326 | ) 327 | 328 | return node 329 | 330 | 331 | def parse_padding(attrs): 332 | tuple_re = re.compile('\([0-9|,| ]+\)') 333 | 334 | def parse_helper(attrs_name, alt_value=None): 335 | if attrs is None: 336 | return alt_value 337 | attrs_str = attrs.get(attrs_name) 338 | if attrs_str is None: 339 | return alt_value 340 | attrs_match = tuple_re.search(attrs_str) 341 | if attrs_match is not None: 342 | if attrs_match.span() == (0, len(attrs_str)): 343 | dims = eval(attrs_str) 344 | return dims 345 | else: 346 | raise AttributeError("Malformed %s dimensions: %s" % (attrs_name, str(attrs_str))) 347 | return alt_value 348 | 349 | symetric_pads = list(parse_helper("pad", [0, 0])) 350 | result = [] 351 | 352 | # Each padding in MXNet is assumed to be symetric in dim1, dim2 ... 353 | # In ONNX we need to have a start_dim1, start_dim2, ..., end_dim1, end_dim2 354 | for pad in symetric_pads: 355 | result.append(pad) 356 | for pad in symetric_pads: 357 | result.append(pad) 358 | return result 359 | 360 | 361 | @mx2onnx.register("Pooling") 362 | def convert_pooling(node, **kwargs): 363 | proc_nodes = kwargs["proc_nodes"] 364 | attrs = node["attrs"] 365 | kernel = eval(attrs["kernel"]) 366 | pool_type = attrs["pool_type"] 367 | 368 | # Default stride in MXNet for pooling is (1,1) 369 | stride = eval(attrs["stride"]) if attrs.get("stride") else (1, 1) 370 | 371 | # Global pooling is set explicitly with an attr on the op. 372 | global_pool = eval(attrs["global"]) if attrs.get("global") else None 373 | 374 | node_inputs = node["inputs"] 375 | input_node_idx = node_inputs[0][0] 376 | input_node = proc_nodes[input_node_idx] 377 | name = node["name"] 378 | 379 | pad_dims = parse_padding(attrs) 380 | 381 | pool_types = {"max": "MaxPool", "avg": "AveragePool"} 382 | global_pool_types = {"max": "GlobalMaxPool", "avg": "GlobalAveragePool"} 383 | 384 | if global_pool: 385 | node = helper.make_node( 386 | global_pool_types[pool_type], 387 | [input_node.output[0]], 388 | [name], 389 | name=name, 390 | pads=pad_dims 391 | ) 392 | else: 393 | node = helper.make_node( 394 | pool_types[pool_type], 395 | [input_node.output[0]], # input 396 | [name], 397 | # dilations = [0, 0], 398 | kernel_shape=kernel, 399 | pads=pad_dims, 400 | strides=stride, 401 | name=name 402 | ) 403 | return node 404 | 405 | 406 | @mx2onnx.register("exp") 407 | def convert_exp(node, **kwargs): 408 | raise NotImplementedError 409 | 410 | 411 | # There's also mx.sym.softmax(), which doesn't do cross-entropy loss, 412 | # just softmax for inference - hence the name convert_softmax_output. 413 | @mx2onnx.register("SoftmaxOutput") 414 | def convert_softmax_output(node, **kwargs): 415 | # print("\nIn convert_softmax_output") 416 | inputs = node["inputs"] 417 | input1_idx = inputs[0][0] 418 | proc_nodes = kwargs["proc_nodes"] 419 | input1 = proc_nodes[input1_idx] 420 | name = node["name"] 421 | 422 | softmax_node = helper.make_node( 423 | "Softmax", 424 | [input1.output[0]], 425 | [name], 426 | axis=1, 427 | name=name 428 | ) 429 | 430 | return softmax_node 431 | 432 | 433 | @mx2onnx.register("Crop") 434 | def convert_concat(node, **kwargs): 435 | name = node["name"] 436 | inputs = node["inputs"] 437 | proc_nodes = kwargs["proc_nodes"] 438 | input_names = [proc_nodes[i[0]].name for i in inputs] 439 | attrs = node["attrs"] 440 | border = [0, 0, 0, 0] 441 | offset = list(eval(attrs['offset'])) 442 | if len(inputs) == 2: 443 | border = inputs[1] 444 | axis = int(node.get("attrs", {}).get("axis", 1)) 445 | concat_node = helper.make_node( 446 | "Crop", 447 | input_names, 448 | [name], 449 | border=border, 450 | scale=offset, 451 | name=name 452 | ) 453 | return concat_node 454 | 455 | @mx2onnx.register("Concat") 456 | def convert_concat(node, **kwargs): 457 | name = node["name"] 458 | inputs = node["inputs"] 459 | proc_nodes = kwargs["proc_nodes"] 460 | input_names = [proc_nodes[i[0]].name for i in inputs] 461 | axis = int(node.get("attrs", {}).get("axis", 1)) 462 | concat_node = helper.make_node( 463 | "Concat", 464 | input_names, 465 | [name], 466 | axis = axis, 467 | name = name 468 | ) 469 | return concat_node 470 | 471 | @mx2onnx.register("Dropout") 472 | def convert_dropout(node, **kwargs): 473 | name = node["name"] 474 | input_id = node["inputs"][0][0] 475 | input_name = kwargs["proc_nodes"][input_id].name 476 | attrs = node["attrs"] 477 | p = float(attrs["p"]) 478 | dropout_node = helper.make_node( 479 | "Dropout", 480 | [input_name], 481 | [name], 482 | ratio = p, 483 | is_test = 0, 484 | name = name 485 | ) 486 | return dropout_node 487 | 488 | @mx2onnx.register("Flatten") 489 | def convert_flatten(node, **kwargs): 490 | name = node["name"] 491 | input_idx = node["inputs"][0][0] 492 | proc_nodes = kwargs["proc_nodes"] 493 | input_node = proc_nodes[input_idx].name #.output[0] 494 | 495 | flatten_node = helper.make_node( 496 | "Flatten", 497 | [input_node], 498 | [name], 499 | name = name, 500 | axis = 1 501 | ) 502 | return flatten_node 503 | 504 | @mx2onnx.register("_mul_scalar") 505 | def convert_mul_scalar(node, **kwargs): 506 | raise NotImplementedError 507 | 508 | 509 | @mx2onnx.register("elemwise_add") 510 | def convert_elementwise_add(node, **kwargs): 511 | 512 | name = node["name"] 513 | proc_nodes = kwargs["proc_nodes"] 514 | inputs = node["inputs"] 515 | weights = kwargs["weights"] 516 | 517 | a = inputs[0][0] 518 | b = inputs[1][0] 519 | 520 | a_node = proc_nodes[a].name 521 | b_node = proc_nodes[b].name 522 | 523 | add_node = helper.make_node( 524 | "Add", 525 | [a_node, b_node], 526 | [name], 527 | name = name, 528 | ) 529 | 530 | return add_node 531 | 532 | @mx2onnx.register("_sub") 533 | def convert_elementwise_sub(node, **kwargs): 534 | raise NotImplementedError 535 | 536 | 537 | @mx2onnx.register("abs") 538 | def convert_abs(node, **kwargs): 539 | raise NotImplementedError 540 | 541 | 542 | @mx2onnx.register("_mul") 543 | def convert_mul(node, proc_nodes): 544 | raise NotImplementedError 545 | 546 | 547 | @mx2onnx.register("_div") 548 | def convert_div(node, **kwargs): 549 | raise NotImplementedError 550 | 551 | 552 | @mx2onnx.register("log") 553 | def convert_log(node, **kwargs): 554 | raise NotImplementedError 555 | 556 | 557 | @mx2onnx.register("max") 558 | def convert_max(node, **kwargs): 559 | raise NotImplementedError 560 | 561 | 562 | @mx2onnx.register("_maximum") 563 | def convert_maximum(node, **kwargs): 564 | raise NotImplementedError 565 | 566 | 567 | @mx2onnx.register("min") 568 | def convert_min(node, **kwargs): 569 | raise NotImplementedError 570 | 571 | 572 | @mx2onnx.register("_minimum") 573 | def convert_minimum(node, **kwargs): 574 | raise NotImplementedError 575 | 576 | 577 | @mx2onnx.register("_power") 578 | def convert_power(node, **kwargs): 579 | raise NotImplementedError 580 | 581 | 582 | @mx2onnx.register("sqrt") 583 | def convert_sqrt(node, **kwargs): 584 | raise NotImplementedError 585 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy>=1.13.0 2 | mxnet>=1.2.0 3 | onnx>=1.0.0 4 | nose>=1.3.7 5 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [nosetests] 2 | verbosity=1 3 | detailed-errors=1 4 | nocapture=True 5 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Redistribution and use in source and binary forms, with or without 4 | # modification, are permitted provided that the following conditions 5 | # are met: 6 | # * Redistributions of source code must retain the above copyright 7 | # notice, this list of conditions and the following disclaimer. 8 | # * Redistributions in binary form must reproduce the above copyright 9 | # notice, this list of conditions and the following disclaimer in the 10 | # documentation and/or other materials provided with the distribution. 11 | # * Neither the name of NVIDIA CORPORATION nor the names of its 12 | # contributors may be used to endorse or promote products derived 13 | # from this software without specific prior written permission. 14 | # 15 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY 16 | # EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 17 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 18 | # PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR 19 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 20 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 21 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 22 | # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY 23 | # OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 24 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | 27 | import itertools as it 28 | import os 29 | import re 30 | from setuptools import setup 31 | from subprocess import call 32 | import sys 33 | 34 | match_mxnet_req = re.compile(r"mxnet>?=?=\d+.\d+\d*") 35 | extract_major_minor = re.compile(r"\D*(\d+.\d+)\D*") 36 | 37 | def check_mxnet_version(min_ver): 38 | if not int(os.environ.get('UPDATE_MXNET_FOR_ONNX_EXPORTER', '1')): 39 | print("Env var set to not upgrade MxNet for ONNX exporter. Skipping.") 40 | return False 41 | try: 42 | print("Checking if MxNet is installed.") 43 | import mxnet as mx 44 | except ImportError: 45 | print("MxNet is not installed. Installing version from requirements.txt") 46 | return False 47 | ver = float(re.match(extract_major_minor, mx.__version__).group(1)) 48 | min_ver = float(re.match(extract_major_minor, min_ver).group(1)) 49 | if ver < min_ver: 50 | print("MxNet is installed, but installed version (%s) is older than expected (%s). Upgrading." % (str(ver).rstrip('0'), str(min_ver).rstrip('0'))) 51 | return False 52 | print("Installed MxNet version (%s) meets the requirement of >= (%s). No need to install." % (str(ver).rstrip('0'), str(min_ver).rstrip('0'))) 53 | return True 54 | 55 | if __name__ == '__main__': 56 | 57 | with open('requirements.txt') as f: 58 | required = f.read().splitlines() 59 | 60 | mx_match_str = lambda x: re.match(match_mxnet_req, x) is None 61 | mx_str, new_reqs = tuple([list(i[1]) for i in it.groupby(sorted(required, key = mx_match_str), key = mx_match_str)]) 62 | 63 | if not check_mxnet_version(mx_str[0]): 64 | new_reqs += mx_str 65 | 66 | setup( 67 | install_requires = new_reqs, 68 | name = 'mx2onnx', 69 | description = 'MxNet to ONNX converter', 70 | author = 'NVIDIA Corporation', 71 | packages = ['mx2onnx_converter'], 72 | classifiers = [ 73 | 'Programming Language :: Python :: 2.7', 74 | 'Programming Language :: Python :: 3.5' 75 | ], 76 | keywords = 'mxnet onnx', 77 | zip_safe = False, 78 | test_suite='nose.collector', 79 | tests_require=['nose'], 80 | version = '0.1' 81 | ) 82 | 83 | call("rm -rf dist".split()) 84 | call("rm -rf build".split()) 85 | call("rm -rf mx2onnx.egg-info".split()) 86 | -------------------------------------------------------------------------------- /tests/test_convert_lenet5.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Redistribution and use in source and binary forms, with or without 4 | # modification, are permitted provided that the following conditions 5 | # are met: 6 | # * Redistributions of source code must retain the above copyright 7 | # notice, this list of conditions and the following disclaimer. 8 | # * Redistributions in binary form must reproduce the above copyright 9 | # notice, this list of conditions and the following disclaimer in the 10 | # documentation and/or other materials provided with the distribution. 11 | # * Neither the name of NVIDIA CORPORATION nor the names of its 12 | # contributors may be used to endorse or promote products derived 13 | # from this software without specific prior written permission. 14 | # 15 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY 16 | # EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 17 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 18 | # PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR 19 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 20 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 21 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 22 | # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY 23 | # OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 24 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | 27 | import os.path 28 | import subprocess 29 | from unittest import TestCase 30 | 31 | import mxnet as mx 32 | import numpy as np 33 | 34 | # needed by both the exporter and importer 35 | import onnx 36 | 37 | # MxNet exporter 38 | from mx2onnx_converter.conversion_helpers import from_mxnet 39 | 40 | # MxNet importer 41 | # Needed for ONNX -> NNVM -> MxNet conversion 42 | # to validate the results of the export 43 | #import onnx_mxnet 44 | from mxnet.contrib.onnx import import_model 45 | 46 | def check_gpu_id(gpu_id): 47 | try: 48 | result = subprocess.check_output("nvidia-smi --query-gpu=gpu_bus_id --format=csv,noheader", shell=True).strip() 49 | except OSError as e: 50 | return False 51 | if not isinstance(result, str): 52 | result = str(result.decode("ascii")) 53 | gpu_ct = len(result.split("\n")) 54 | # count is zero-based 55 | exists = gpu_id < gpu_ct 56 | print("\nChecked for GPU ID %d. Less than GPU count (%d)? %s\n" % (gpu_id, gpu_ct, exists)) 57 | return exists 58 | 59 | # MxNet LeNet-5 implementation 60 | def lenet5(): 61 | data = mx.sym.var('data') 62 | # first conv layer 63 | conv1 = mx.sym.Convolution(data=data, kernel=(5,5), num_filter=20) 64 | tanh1 = mx.sym.Activation(data=conv1, act_type="tanh") 65 | pool1 = mx.sym.Pooling(data=tanh1, pool_type="max", kernel=(2,2), stride=(2,2)) 66 | # second conv layer 67 | conv2 = mx.sym.Convolution(data=pool1, kernel=(5,5), num_filter=50) 68 | tanh2 = mx.sym.Activation(data=conv2, act_type="tanh") 69 | pool2 = mx.sym.Pooling(data=tanh2, pool_type="max", kernel=(2,2), stride=(2,2)) 70 | # first fullc layer 71 | flatten = mx.sym.flatten(data=pool2) 72 | fc1 = mx.symbol.FullyConnected(data=flatten, num_hidden=500) 73 | tanh3 = mx.sym.Activation(data=fc1, act_type="tanh") 74 | # second fullc 75 | fc2 = mx.sym.FullyConnected(data=tanh3, num_hidden=10) 76 | # softmax loss 77 | lenet = mx.sym.SoftmaxOutput(data=fc2, name='softmax') 78 | return lenet 79 | 80 | # train LeNet-5 model on MNIST data 81 | def train_lenet5(num_epochs, gpu_id, train_iter, val_iter, test_iter, batch_size): 82 | ctx = mx.gpu(gpu_id) if gpu_id is not None else mx.cpu() 83 | print("\nUsing %s to train" % str(ctx)) 84 | lenet_model = lenet5() 85 | lenet_model = mx.mod.Module(lenet_model, context=ctx) 86 | # This is cached so download will only take place if needed 87 | mnist = mx.test_utils.get_mnist() 88 | train_iter = mx.io.NDArrayIter(mnist['train_data'], mnist['train_label'], batch_size, shuffle=True) 89 | val_iter = mx.io.NDArrayIter(mnist['test_data'], mnist['test_label'], batch_size) 90 | 91 | data = mx.sym.var('data') 92 | data = mx.sym.flatten(data=data) 93 | 94 | lenet_model.fit(train_iter, 95 | eval_data=val_iter, 96 | optimizer='sgd', 97 | optimizer_params={'learning_rate': 0.1, 'momentum': 0.9}, 98 | eval_metric='acc', 99 | batch_end_callback = mx.callback.Speedometer(batch_size, 100), 100 | num_epoch=num_epochs) 101 | 102 | test_iter = mx.io.NDArrayIter(mnist['test_data'], mnist['test_label'], batch_size) 103 | 104 | # predict accuracy for lenet 105 | acc = mx.metric.Accuracy() 106 | lenet_model.score(test_iter, acc) 107 | accuracy = acc.get()[1] 108 | print("Training accuracy: %.2f" % accuracy) 109 | assert accuracy > 0.98, "Accuracy was too low" 110 | return lenet_model 111 | 112 | class LeNet5Test(TestCase): 113 | 114 | def __init__(self, *args, **kwargs): 115 | TestCase.__init__(self, *args, **kwargs) 116 | # self.tearDown = lambda: subprocess.call("rm -f *.gz *-symbol.json *.params *.onnx", shell=True) 117 | 118 | def test_convert_and_compare_prediction(self): 119 | # get data iterators and set basic hyperparams 120 | num_epochs = 10 121 | mnist = mx.test_utils.get_mnist() 122 | batch_size = 1000 123 | train_iter = mx.io.NDArrayIter(mnist['train_data'], mnist['train_label'], batch_size, shuffle=True) 124 | val_iter = mx.io.NDArrayIter(mnist['test_data'], mnist['test_label'], batch_size) 125 | test_iter = mx.io.NDArrayIter(mnist['test_data'], mnist['test_label'], batch_size) 126 | model_name = 'lenet5' 127 | model_file = '%s-symbol.json' % model_name 128 | params_file = '%s-%04d.params' % (model_name, num_epochs) 129 | onnx_file = "%s.onnx" % model_name 130 | test_gpu_id = 0 131 | gpu_id = check_gpu_id(test_gpu_id) 132 | if not gpu_id: 133 | print("\nWARNING: GPU id %d is invalid on this machine" % test_gpu_id) 134 | gpu_id = None 135 | 136 | # If trained model exists, re-use cached version. Otherwise, train model. 137 | if not (os.path.exists(model_file) and os.path.exists(params_file)): 138 | print("\n\nTraining LeNet-5 on MNIST data") 139 | trained_lenet = train_lenet5(num_epochs, gpu_id, train_iter, val_iter, test_iter, batch_size) 140 | print("Training finished. Saving model") 141 | trained_lenet.save_checkpoint(model_name, num_epochs) 142 | # delete object so we can verify correct loading of the checkpoint from disk 143 | del trained_lenet 144 | else: 145 | print("\n\nTrained model exists. Skipping training.") 146 | 147 | # Load serialized MxNet model (model-symbol.json + model-epoch.params) 148 | 149 | trained_lenet = mx.mod.Module.load(model_name, num_epochs) 150 | trained_lenet.bind(data_shapes=test_iter.provide_data, label_shapes=None, for_training=False, force_rebind=True) 151 | 152 | # Run inference in MxNet from json/params serialized model 153 | test_iter = mx.io.NDArrayIter(mnist['test_data'], mnist['test_label'], batch_size) 154 | pred_softmax = trained_lenet.predict(test_iter).asnumpy() 155 | pred_classes = np.argmax(pred_softmax, axis=1) 156 | 157 | # Create and save ONNX model 158 | print("\nConverting trained MxNet model to ONNX") 159 | model = from_mxnet(model_file, params_file, [1, 1, 28, 28], np.float32, log=True) 160 | with open(onnx_file, "wb") as f: 161 | serialized = model.SerializeToString() 162 | f.write(serialized) 163 | print("\nONNX file %s serialized to disk" % onnx_file) 164 | 165 | print("\nLoading ONNX file and comparing results to original MxNet output.") 166 | 167 | # ONNX load and inference step 168 | onnx_sym, onnx_arg_params, onnx_aux_params = import_model(onnx_file) 169 | onnx_mod = mx.mod.Module(symbol=onnx_sym, data_names=['data'], context=mx.cpu(), label_names=None) 170 | 171 | # Need to rename data argument from 'data' to 'input_0' because that's how 172 | # the MxNet ONNX importer expects it by default 173 | test_iter = mx.io.NDArrayIter(data={'data': mnist['test_data']}, label=None, batch_size=batch_size) 174 | 175 | onnx_mod.bind(data_shapes=test_iter.provide_data, label_shapes=None, for_training=False, force_rebind=True) 176 | onnx_mod.set_params(arg_params=onnx_arg_params, aux_params=onnx_aux_params, allow_missing=True) 177 | 178 | onnx_pred_softmax = onnx_mod.predict(test_iter).asnumpy() 179 | onnx_pred_classes = np.argmax(pred_softmax, axis=1) 180 | 181 | pred_matches = onnx_pred_classes == pred_classes 182 | pred_match_ct = pred_matches.sum() 183 | pred_total_ct = np.size(pred_matches) 184 | pct_match = 100.0 * pred_match_ct / pred_total_ct 185 | 186 | print("\nOriginal MxNet predictions and ONNX-based predictions after export and re-import:") 187 | print("Total examples tested: %d" % pred_total_ct) 188 | print("Matches: %d" % pred_match_ct) 189 | print("Percent match: %.2f\n" % pct_match) 190 | 191 | assert pred_match_ct == pred_total_ct, "Not all predictions from the ONNX representation match" 192 | --------------------------------------------------------------------------------