├── .github └── workflows │ └── publish-to-test-pypi.yml ├── .gitignore ├── LICENSE ├── MANIFEST.in ├── README.md ├── examples ├── README.md ├── example_01.py ├── example_02.py ├── example_03.py ├── example_04.py ├── example_05.py ├── example_06.py ├── example_merge.py ├── images │ ├── 1-BIG.JPG │ ├── 1-Resized.JPG │ ├── 1.JPG │ ├── 10.JPG │ ├── 11.JPG │ ├── 12.JPG │ ├── 13.JPG │ ├── 14.JPG │ ├── 2.JPG │ ├── 3.JPG │ ├── 4.JPG │ ├── 5.JPG │ ├── 6.JPG │ ├── 7.JPG │ ├── 8.JPG │ ├── 9.JPG │ ├── empty-average.JPG │ └── horse5.png └── onnx │ ├── add-scalars.onnx │ ├── check-container-resize.onnx │ ├── check-container.onnx │ ├── cifar10-resnet20-augmented.onnx │ ├── cifar10-resnet20-clean.onnx │ ├── cifar10-resnet20.onnx │ ├── resize-image-450x600-300x400.onnx │ ├── tf-keras-dynamic.onnx │ └── tf-keras-static.onnx ├── pyproject.toml ├── sclblonnx ├── __init__.py ├── __main__.py ├── _globals.py ├── constant.py ├── input.py ├── main.py ├── merge.py ├── node.py ├── output.py ├── supported_onnx.json ├── utils.py ├── validate.py └── version.py ├── setup.py └── test ├── files ├── add.onnx ├── example01.onnx ├── example02.onnx └── example03.onnx ├── test_constant.py ├── test_input.py ├── test_main.py ├── test_merge.py ├── test_node.py ├── test_output.py ├── test_utils.py └── test_validate.py /.github/workflows/publish-to-test-pypi.yml: -------------------------------------------------------------------------------- 1 | name: PyPI Release 2 | 3 | # https://help.github.com/en/actions/reference/events-that-trigger-workflows 4 | on: 5 | # Trigger the workflow on push or pull request, 6 | # but only for the master branch 7 | push: 8 | branches: 9 | - master 10 | release: 11 | types: 12 | - created 13 | 14 | # based on https://github.com/pypa/gh-action-pypi-publish 15 | 16 | jobs: 17 | build: 18 | runs-on: ubuntu-latest 19 | 20 | steps: 21 | - uses: actions/checkout@master 22 | - name: Set up Python 3.7 23 | uses: actions/setup-python@v1 24 | with: 25 | python-version: 3.7 26 | 27 | - name: Install dependencies 28 | run: >- 29 | python -m pip install --user --upgrade setuptools wheel 30 | - name: Build 31 | run: >- 32 | python setup.py sdist bdist_wheel 33 | 34 | # ---------------- not testing on test.pypi, for now ---------------------- 35 | # - name: Publish to Test PyPI 36 | # if: startsWith(github.event.ref, 'refs/tags') || github.event_name == 'release' 37 | # uses: pypa/gh-action-pypi-publish@master 38 | # with: 39 | # user: __token__ 40 | # password: ${{ secrets.test_pypi_password }} 41 | # repository_url: https://test.pypi.org/legacy/ 42 | 43 | - name: Publish distribution 📦 to PyPI 44 | if: startsWith(github.event.ref, 'refs/tags') || github.event_name == 'release' 45 | uses: pypa/gh-action-pypi-publish@release/v1 46 | with: 47 | user: __token__ 48 | password: ${{ secrets.pypi_password }} 49 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL source 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these source are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed source 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | .idea/ 132 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2018 The Python Packaging Authority 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in all 11 | copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | SOFTWARE. -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include LICENSE 2 | include sclblonnx/supported_onnx.json 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # sclblonnx 2 | 3 | [![PyPI Release](https://github.com/scailable/sclblonnx/workflows/PyPI%20Release/badge.svg)](https://pypi.org/project/sclblonnx/) 4 | 5 | 6 | The `sclblonnx` package provides a high level API to construct and alter ONNX graphs. 7 | 8 | The basic usage is as follows: 9 | ```python 10 | 11 | import sclblonnx as so 12 | 13 | g = so.empty_graph() 14 | n1 = so.node('Add', inputs=['x1', 'x2'], outputs=['sum']) 15 | g = so.add_node(g, n1) 16 | # etc. 17 | 18 | ``` 19 | Please see the `examples/` folder in this repo for examples. 20 | 21 | -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | # SclblONNX examples. 2 | 3 | The following examples are provided: 4 | 5 | * [Example 01](example_01.py) - **Add**: 6 | Create an ONNX graph to add two numbers from scratch and evaluate it using Scailable tools. 7 | This tutorial covers the basic usage of the `sclblonnx` package. 8 | * [Example_02](example_02.py) - **Image**: 9 | Create an ONNX graph that checks whether or not an image is empty. 10 | * [Example_03](example_03.py) - **PyTorch**: 11 | This example imports a resnet model trained using PyTorch and show how to use it with image input. 12 | * [Example_04](example_04.py) - **TensorFlow**: 13 | This example imports a model trained using TensorFlow (including the model training code.) 14 | The example shows how to rename the inputs and outputs of an existing model and fix dynamic inputs. 15 | * [Example_05](example_05.py) - **Post-processing**: 16 | This more elaborate example builds on example 2, but extends the onnx file generated in this example 17 | to include post-processing. It demonstrates how to extend an existing graph and it demonstrates the use 18 | of If statements. 19 | * [Example_06](example_06.py) - **Graph Merge**: 20 | This example shows how to merge two existing ONNX graphs using `so.merge()`. -------------------------------------------------------------------------------- /examples/example_01.py: -------------------------------------------------------------------------------- 1 | import json 2 | import requests 3 | import numpy as np 4 | import sclblpy as sp # Import the sclblpy package 5 | import sclblonnx as so 6 | 7 | """ 8 | EXAMPLE 1: Adding two scalars. 9 | 10 | This example shows the basic usage of the sclblonnx package by creating an onnx graph from scratch that adds two 11 | scalars together. 12 | """ 13 | 14 | # Use the empty_graph() method to create a named xpb2.GraphProto object: 15 | g = so.empty_graph() 16 | 17 | # Add a node to the graph. 18 | # Please note the list of available operators at: https://github.com/onnx/onnx/blob/master/docs/Operators.md 19 | # Run so.list_data_types() and / or so.list_operators() so. to see all Scailable supported data types. 20 | n1 = so.node('Add', inputs=['x1', 'x2'], outputs=['sum']) 21 | g = so.add_node(g, n1) 22 | 23 | # We should explicitly specify the named inputs to the graph -- note that the names determine the graph topology. 24 | # Also, we should specify the data type and dimensions of any input. 25 | # Use so.list_data_types() to see available data types. 26 | g = so.add_input(g, 'x1', "FLOAT", [1]) 27 | g = so.add_input(g, 'x2', "FLOAT", [1]) 28 | 29 | # Similarly, we add the named output with its corresponding type and dimension. 30 | # Note that types will need to "match", as do dimensions. Please see the operator docs for more info. 31 | g = so.add_output(g, 'sum', "FLOAT", [1]) 32 | 33 | # so.check() checks the current graph to see if it matches Scailable's upload criteria for .wasm conversion. 34 | so.check(g) 35 | 36 | # Now, a few tricks to sanitize the graph which are always useful. 37 | # so.clean() provides lossless reduction of the graph. If successful cleaned graph is returned. 38 | g = so.clean(g) 39 | 40 | # so.display() tries to open the graph using Netron to inspect it. This worsk on most systems if Netron is installed. 41 | # Get Netron at https://github.com/onnx/onnx/blob/master/docs/Operators.md 42 | so.display(g) 43 | 44 | # Now, use the default ONNX runtime to do a test run of the graph. 45 | # Note that the inputs dimensions and types need to match the specification of the graph. 46 | # The outputs returns all the outputs named in the list. 47 | example = {"x1": np.array([1.2]).astype(np.float32), "x2": np.array([2.5]).astype(np.float32)} 48 | result = so.run(g, 49 | inputs=example, 50 | outputs=["sum"] 51 | ) 52 | print(result) 53 | 54 | # We can easily store the graph to a file for upload at http://admin.sclbl.net: 55 | so.graph_to_file(g, "onnx/add-scalars.onnx") 56 | 57 | # Or, we can upload it to Scailable using the sclblpy package, 58 | # See the sclblpy package docs for more details. https://pypi.org/project/sclblpy/ 59 | # sp.upload_onnx("onnx/add-scalars.onnx", docs={"name": "Example_01: Add", "documentation": "Empty.."}) 60 | 61 | 62 | # so.sclbl_input(inputs) converts an example input to the input that can be used on the device: 63 | example_input = so.sclbl_input(example, "pb") 64 | print(example_input) 65 | 66 | 67 | # You can use the example to setup your own REST call: 68 | def do_REST_call(cfid: str, data: str): 69 | """ Do rest call calls a REST endpoint on the Scailable cloud""" 70 | # This does work 71 | url = "https://taskmanager.sclbl.net:8080/task/" + cfid 72 | payload = "{\"input\":{\"content-type\":\"json\",\"location\":\"embedded\",\"data\":" \ 73 | + json.dumps(data) + \ 74 | "},\"output\":{\"content-type\":\"json\",\"location\":\"echo\"}," \ 75 | "\"control\":1,\"properties\":{\"language\":\"WASM\"}}" 76 | headers = { 77 | 'Content-Type': 'application/x-www-form-urlencoded' 78 | } 79 | response = requests.request("POST", url, headers=headers, data=payload) 80 | return response.text.encode('utf8') 81 | 82 | 83 | # Do the actual call using requests: 84 | print(do_REST_call("403cd8a0-a10f-11eb-9acc-9600004e79cc", example_input)) 85 | 86 | # Or again use the sclblpy package, this time use the .run() function to execute the graph on the Scailable cloud: 87 | print(sp.run("403cd8a0-a10f-11eb-9acc-9600004e79cc", example_input)) 88 | -------------------------------------------------------------------------------- /examples/example_02.py: -------------------------------------------------------------------------------- 1 | import sclblonnx as so 2 | import numpy as np 3 | from PIL import Image 4 | 5 | """ 6 | EXAMPLE 2: Rudimentary image analysis using ONNX. 7 | 8 | This example is a reworked version of the tutorial presented at: 9 | https://towardsdatascience.com/onnx-for-image-processing-from-scratch-6694f9b141b0 10 | 11 | This example relies on the image "source/images/empty-average.JPG" which provides an average of several 12 | pictures of an empty container. 13 | 14 | The logic that we build is simple: 15 | - Given an image of an object (for example 1.JPG in the source/images/ folder 16 | - Subtract the empty-image (which is encoded as a constant in the ONNX graph) 17 | - Compute absolute values 18 | - Sum all elements of the result into a single scalar 19 | - Compare the scalar to a threshold (another constant) 20 | If the threshold is reached, we conclude that the container is filled. 21 | """ 22 | 23 | # Start with the empty graph: 24 | g = so.empty_graph() 25 | 26 | # Create the constant node encoding the empty image and add it to the graph: 27 | # Note the type encoding as np.int64. 28 | reference_image = np.array(Image.open("images/empty-average.JPG"), dtype=np.int32) 29 | g = so.add_constant(g, "c1", reference_image, "INT32") 30 | 31 | # Add the first input (note, same shape): 32 | g = so.add_input(g, 'in', "INT32", reference_image.shape) 33 | 34 | # Add the Subtract, Absolute, ReduceSum, and Less nodes 35 | # Node how the names again enforce the topology of the graph 36 | n1 = so.node("Sub", inputs=['in', 'c1'], outputs=['sub']) 37 | n2 = so.node("Abs", inputs=['sub'], outputs=['abs']) 38 | n3 = so.node("ReduceSum", inputs=['abs'], outputs=['sum'], keepdims=0) # Note the keepdims additional parameter. 39 | g = so.add_nodes(g, [n1, n2, n3]) 40 | 41 | # And, we need to add the threshold (constant c2): 42 | threshold = np.array([3000000]).astype(np.int32) 43 | g = so.add_constant(g, "c2", threshold, "INT32") 44 | 45 | # Add the less node. Please note that the nodes have to be added in topological order: 46 | n4 = so.node("Less", inputs=['sum', 'c2'], outputs=['result']) 47 | g = so.add_node(g, n4) 48 | 49 | # Check provides an error stating that no outputs have been specified (which is true at this point) 50 | so.check(g) 51 | 52 | # Add output: 53 | g = so.add_output(g, "result", "BOOL", [1]) 54 | 55 | # After which is passes all the checks 56 | so.check(g) 57 | 58 | # Let's inspect: 59 | so.display(g) 60 | 61 | # Let's clean: 62 | g = so.clean(g) 63 | 64 | # Let's try it out for the first image: 65 | img_data = np.array(Image.open("images/1.JPG"), dtype=np.int32) 66 | example = {"in": img_data.astype(np.int32)} 67 | result = so.run(g, 68 | inputs=example, 69 | outputs=['result']) 70 | 71 | # Print the result 72 | if result[0]: 73 | print("The container is empty.") 74 | else: 75 | print("The container is filled.") 76 | 77 | # Store the graph 78 | so.graph_to_file(g, "onnx/check-container.onnx") 79 | 80 | 81 | ''' 82 | Additional usage of sclblpy for upload and evaluation: 83 | 84 | # Import sclblpy 85 | import sclblpy as sp 86 | 87 | # Upload model 88 | sp.upload_onnx("onnx/check-container.onnx", docs={"name": "Example_02: Image", "documentation": "None provided."}) 89 | 90 | # Example input for a Scailable runtime: 91 | input_str = so.sclbl_input(example, _verbose=False) 92 | 93 | # Run 94 | sp.run("5622645a-a10f-11eb-9acc-9600004e79cc", input_str) 95 | ''' 96 | 97 | 98 | 99 | -------------------------------------------------------------------------------- /examples/example_03.py: -------------------------------------------------------------------------------- 1 | import sclblonnx as so 2 | import numpy as np 3 | from PIL import Image 4 | 5 | """ 6 | EXAMPLE 3: Using a previously exported pyTorch model 7 | 8 | Here we open an existing and pre-trained Resnet model (trained on the cifar data). 9 | 10 | For training details see: 11 | https://github.com/scailable/sclbl-tutorials/tree/master/sclbl-pytorch-onnx 12 | 13 | Here we simply evaluate one specific image. 14 | """ 15 | 16 | # Retrieve the graph from the stored .onnx model: 17 | g = so.graph_from_file("onnx/cifar10-resnet20.onnx") 18 | 19 | # Clean, check, and display (this model passes all the checks). 20 | g = so.clean(g) 21 | so.check(g) 22 | so.display(g) 23 | 24 | 25 | # To open an image we write a small utility function using Pillow to transform an image to a numpy array. 26 | def process_image(image_path): 27 | # Load Image 28 | img = Image.open(image_path) 29 | 30 | # Get the dimensions of the image 31 | width, height = img.size 32 | 33 | # Turn image into numpy array 34 | img = np.array(img) 35 | 36 | # Make the color channel dimension first instead of last 37 | img = img.transpose((2, 0, 1)) 38 | 39 | # Make all values between 0 and 1 40 | img = img / 255 41 | 42 | # Normalize based on the preset mean and standard deviation 43 | img[0] = (img[0] - 0.4914) / 0.2023 44 | img[1] = (img[1] - 0.4822) / 0.1994 45 | img[2] = (img[2] - 0.4465) / 0.2010 46 | 47 | # Add a fourth dimension to the beginning to indicate batch size 48 | # img = img[np.newaxis,:].astype(np.float16) 49 | img = img[np.newaxis, :] 50 | 51 | return img 52 | 53 | 54 | # Open the image and execute the graph: 55 | img_data = process_image("images/horse5.png").astype(np.float32) 56 | example = {"input": img_data} 57 | out = so.run(g, 58 | inputs=example, 59 | outputs=['output'] 60 | ) 61 | 62 | # Pretty printing 63 | classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') 64 | print("The ONNX model predicts the image is a", classes[np.argmax(out[0])] + ".") 65 | 66 | # And store the file (since we did clean it): 67 | so.graph_to_file(g, "onnx/cifar10-resnet20-clean.onnx") 68 | 69 | 70 | ''' 71 | Additional usage of sclblpy for upload and evaluation: 72 | 73 | # Import sclblpy 74 | import sclblpy as sp 75 | 76 | # Upload model 77 | sp.upload_onnx("onnx/cifar10-resnet20-clean.onnx", docs={"name": "Example_03: Cifar", "documentation": "None provided."}) 78 | 79 | # Example input for a Scailable runtime: 80 | input_str = so.sclbl_input(example, _verbose=False) 81 | 82 | # Run 83 | sp.run("11928b2a-a110-11eb-9acc-9600004e79cc", input_str) 84 | ''' 85 | 86 | -------------------------------------------------------------------------------- /examples/example_04.py: -------------------------------------------------------------------------------- 1 | #import tensorflow as tf 2 | #import keras2onnx 3 | #from tensorflow.keras import layers 4 | from sklearn import datasets 5 | import sclblonnx as so 6 | import numpy as np 7 | 8 | """ 9 | EXAMPLE 4: Converting a model to ONNX from TensorFlow and fixing the dynamic input & output. 10 | 11 | Here we first train a simple tf model (using keras) using the sklearn diabetes dataset. 12 | We store the model to onnx using the keras2onnx package (please note that this is in flux). 13 | See https://pypi.org/project/keras2onnx/ 14 | 15 | Next, we open and inspect the model graph using Scailable tools. Cleaning the graph gives a warning 16 | regarding the input being dynamic; we show how to fix this and how to evaluate the graph. 17 | 18 | """ 19 | 20 | # Get training data from sklearn 21 | X, y = datasets.load_diabetes(return_X_y=True) 22 | 23 | train = False # Prevent training on every run; the model is stored, /onnx/tf-keras-dynamic.onnx 24 | if train: # Don't retrain everytime, the model is stored. 25 | 26 | # Create the model 27 | dnn_model = tf.keras.Sequential() 28 | dnn_model.add(layers.Dense(64, activation='relu')) 29 | dnn_model.add(layers.Dense(64, activation='relu')) 30 | dnn_model.add(layers.Dense(1)) 31 | 32 | dnn_model.compile(loss='mean_absolute_error', optimizer=tf.keras.optimizers.SGD()) 33 | 34 | # train the model (use .predict for local predictions) 35 | history = dnn_model.fit( 36 | X, y, 37 | validation_split=0.2, 38 | verbose=0, epochs=300) 39 | 40 | # Save model (note, the convert_keras() function is undergoing change in different 41 | # versions of tf / onnx). 42 | # You might need: tf.compat.v1.disable_eager_execution() 43 | # or use the tf2onnx tool at https://github.com/onnx/tensorflow-onnx 44 | onnx_model = keras2onnx.convert_keras(dnn_model, dnn_model.name) 45 | keras2onnx.save_model(onnx_model, "onnx/tf-keras-dynamic.onnx") 46 | 47 | 48 | # load the model using sclblonnx 49 | g = so.graph_from_file("onnx/tf-keras-dynamic.onnx") 50 | # so.display(g) 51 | 52 | # check() and clean() 53 | so.check(g) 54 | g = so.clean(g) # Fails due to dynamic size 55 | 56 | # Note, while this model passes check(), clean() provides a warning message due to the dynamic input (Nx10). 57 | # This occurs because the training data is N long. However, for inference we would like it to be 1x10 58 | # Let's fix this by changing the input to static. 59 | so.list_inputs(g) 60 | g = so.replace_input(g, "input_1", "FLOAT", [1, 10]) 61 | 62 | # And do the same for the output 63 | output = so.replace_output(g, "output_1", "FLOAT", [1, 1]) # Check this one... 64 | 65 | # Now we do pass all checks, and we can look at the graph 66 | so.check(g) 67 | g = so.clean(g) 68 | so.display(g) 69 | 70 | # However, we might not like the tf default input and output names: 71 | g = so.rename_input(g, "input_1", "in") 72 | g = so.rename_output(g, "output_1", "result") 73 | so.display(g) 74 | 75 | # And now we can call it locally: 76 | input_example = np.array([X[1, ]]).astype(np.float32) # Note the extra brackets to create 1x10 77 | example = {"in": input_example} 78 | result = so.run(g, 79 | inputs=example, 80 | outputs=["result"] 81 | ) 82 | print(result) 83 | 84 | # Finally, we can store the changed graph: 85 | so.graph_to_file(g, "onnx/tf-keras-static.onnx") 86 | 87 | 88 | ''' 89 | Additional usage of sclblpy for upload and evaluation: 90 | 91 | # Import sclblpy 92 | import sclblpy as sp 93 | 94 | # Upload model 95 | sp.upload_onnx("onnx/tf-keras-static.onnx", docs={"name": "Example_04: TF-Keras-static", "documentation": "None provided."}) 96 | 97 | # Example input for a Scailable runtime: 98 | input_str = so.sclbl_input(example, _verbose=False) 99 | 100 | # Run 101 | sp.run("0d7db3c7-a111-11eb-9acc-9600004e79cc", input_str) 102 | ''' 103 | -------------------------------------------------------------------------------- /examples/example_05.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import numpy as np 3 | import sclblonnx as so 4 | """ 5 | EXAMPLE 5: 6 | 7 | Editing an existing graph: while the previous examples already showed some graph editing, here we demonstrate more 8 | elaborate editing of the graph; we demonstrate changing inputs, and we demonstrate the use of the If operator for 9 | post-processing. 10 | 11 | We will build on the PyTorch example (see example_03.py) to, instead of outputting the scores from the 12 | cifar model, output 99 if there is a horse on the image and 0 otherwise. 13 | """ 14 | 15 | 16 | # First, we load the the example image: 17 | # To open an image we write a small utility function using Pillow to transform an image to a numpy array. 18 | def process_image(image_path): 19 | # Load Image 20 | img = Image.open(image_path) 21 | 22 | # Get the dimensions of the image 23 | width, height = img.size 24 | 25 | # Turn image into numpy array 26 | img = np.array(img) 27 | 28 | # Make the color channel dimension first instead of last 29 | img = img.transpose((2, 0, 1)) 30 | 31 | # Make all values between 0 and 1 32 | img = img / 255 33 | 34 | # Normalize based on the preset mean and standard deviation 35 | img[0] = (img[0] - 0.4914) / 0.2023 36 | img[1] = (img[1] - 0.4822) / 0.1994 37 | img[2] = (img[2] - 0.4465) / 0.2010 38 | 39 | # Add a fourth dimension to the beginning to indicate batch size 40 | # img = img[np.newaxis,:].astype(np.float16) 41 | img = img[np.newaxis, :] 42 | 43 | return img 44 | 45 | 46 | # Open the image 47 | img_input = process_image("images/horse5.png").astype(np.float32) 48 | 49 | # Load the cifar model: 50 | g = so.graph_from_file("onnx/cifar10-resnet20-clean.onnx") 51 | 52 | # Check its output, we see the name, type, and dimensions 53 | so.list_outputs(g) 54 | 55 | # Run the model to see the outputs: 56 | result = so.run(g, inputs={"input": img_input}, outputs=["output"]) 57 | print(result) 58 | 59 | # Add and arg_max node to find the highest output in the output vector 60 | # Note the keepdims and axis; the output of the Argmax node should align with the defined output. 61 | n1 = so.node('ArgMax', inputs=['output'], outputs=['argmax'], keepdims=0, axis=1) 62 | g = so.add_node(g, n1) # Note, this adds the node, but the output is still "output" 63 | g = so.delete_output(g, "output") # Remove the old output 64 | g = so.add_output(g, 'argmax', "INT64", [1]) # Add the new output (for testing only) 65 | 66 | # Test: 67 | result = so.run(g, inputs={"input": img_input}, outputs=["argmax"]) 68 | print(result) 69 | 70 | # So, this works. Let's remove the output argmax again before we continue: 71 | g = so.delete_output(g, 'argmax') 72 | 73 | # Because the if statement to switch between values of 100 and 0 requires a boolean input condition, we add 74 | # a constant node with the value of 7, and add an equals node: 75 | g = so.add_constant(g, "cut", np.array([7]), "INT64") 76 | n2 = so.node("Equal", inputs=['argmax', 'cut'], outputs=['seven']) 77 | g = so.add_node(g, n2) 78 | 79 | # Lets again test: 80 | g = so.add_output(g, 'seven', "BOOL", [1]) 81 | result = so.run(g, inputs={"input": img_input}, outputs=["seven"]) 82 | print(result) # Prints true... we are getting closer! 83 | g = so.delete_output(g, 'seven') 84 | 85 | 86 | # Here we build an if statement. Note that the if "switches" between two graphs, so let's first create the 87 | # two graphs (which can obviously be much more complex). We start with the if: 88 | then_graph = so.empty_graph("then-graph") 89 | then_graph = so.add_constant(then_graph, "then_value", np.array([100]), "FLOAT") 90 | then_graph = so.add_output(then_graph, "then_value", "FLOAT", [1]) 91 | so.display(then_graph) # See, this is a very small graph, no input, only output 92 | 93 | # Same for else 94 | else_graph = so.empty_graph("else-graph") 95 | else_graph = so.add_constant(else_graph, "iff_value", np.array([0]), "FLOAT") 96 | else_graph = so.add_output(else_graph, "iff_value", "FLOAT", [1]) 97 | 98 | 99 | # Now, the If node which switches between the if and the else graph 100 | n3 = so.node("If", inputs=['seven'], outputs=['horse'], then_branch=then_graph, else_branch=else_graph) 101 | g = so.add_node(g, n3) 102 | 103 | # Add the output 104 | g = so.add_output(g, "horse", "FLOAT", [1]) 105 | result = so.run(g, inputs={"input": img_input}, outputs=["horse"]) 106 | print(result) # Prints 100! 107 | 108 | # Store the augmented graph 109 | g = so.clean(g) 110 | so.check(g) 111 | so.graph_to_file(g, "onnx/cifar10-resnet20-augmented.onnx") 112 | 113 | -------------------------------------------------------------------------------- /examples/example_06.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import sclblonnx as so 4 | import numpy as np 5 | from PIL import Image 6 | """ 7 | EXAMPLE 6: Merging two existing graphs. 8 | 9 | This example combines two (sub) graphs into a single graph describing a longer pipeline. 10 | 11 | The setup builds on example II (example_02.py). We create 2 separate graphs: 12 | 1. Graph that resizes an image from 600x450 to 400x300 13 | 2. The empty container graph (check_container.onnx) which takes a 400x300 image 14 | 15 | Next, we merge the two graphs into one single ONNX file. 16 | """ 17 | 18 | # Let's open the large image and inspect the shape: 19 | large_img = np.array(Image.open("images/1-BIG.JPG"), dtype=np.int32) 20 | print(large_img.shape) # 450x600x3 21 | 22 | # First subgraph for resize: 23 | sg1 = so.empty_graph("resize_graph") 24 | sg1 = so.add_input(sg1, "large_image", "INT32", [450, 600, 3]) # Add the input 25 | 26 | # The resize node: 27 | e1 = so.constant("roi", np.array([]), "FLOAT") # Note the empty fields for roi and scales. 28 | e2 = so.constant("scales", np.array([]), "FLOAT") 29 | c1 = so.constant("size", np.array([300, 400, 3]), "INT64") 30 | n1 = so.node("Resize", inputs=['large_image', 'roi', 'scales', 'size'], outputs=['small_image']) 31 | sg1 = so.add_nodes(sg1, [e1, e2, c1, n1]) 32 | sg1 = so.add_output(sg1, "small_image", "INT32", [300, 400, 3]) 33 | 34 | # Check and clean 35 | sg1 = so.clean(sg1) 36 | so.check(sg1) 37 | 38 | # Test the resize graph: 39 | large_input = {"large_image": large_img.astype(np.int32)} 40 | result = so.run(sg1, inputs=large_input, outputs=['small_image']) 41 | 42 | # Round values in array and cast as 8-bit integer to store back as JPG: 43 | img_arr = np.array(np.round(result[0]), dtype=np.uint8) 44 | out = Image.fromarray(img_arr, mode="RGB") 45 | out.save("images/1-Resized.JPG") # Yes, this works. 46 | 47 | # Store the resize onnx: 48 | so.graph_to_file(sg1, "onnx/resize-image-450x600-300x400.onnx") 49 | 50 | # So, now we have a working (sub)graph that resizes an image (which obviously we can just load next time) 51 | # Now, we open up the original image processing graph 52 | sg2 = so.graph_from_file("onnx/check-container.onnx") 53 | 54 | # The outputs of sg1 and the inputs of sg2 need to match; lets examine them 55 | so.list_outputs(sg1) 56 | so.list_inputs(sg2) 57 | 58 | # Merge the two graphs, the outputs will be merged with the inputs in order of appearance: 59 | g = so.merge(sg1, sg2, outputs=["small_image"], inputs=["in"]) 60 | so.check(g) 61 | so.display(g) 62 | 63 | # And now it works with the large image: 64 | result = so.run(g, inputs=large_input, outputs=['result']) 65 | # Print the result 66 | if result[0]: 67 | print("The container in the large image is empty.") 68 | else: 69 | print("The container in the large image is filled.") 70 | 71 | # Store the merged graph 72 | g = so.graph_to_file(g, "onnx/check-container-resize.onnx") 73 | 74 | -------------------------------------------------------------------------------- /examples/example_merge.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import sclblonnx as so 3 | import numpy as np 4 | """ 5 | EXAMPLE MERGE: a number of examples usages of the merge, join, split, and concat functions. 6 | 7 | Note that merge(), join(), and split() are high level wrappers around concat(), each effectively assuming that the 8 | resulting graph is "complete" (i.e., it is a valid onnx graph including input and output). Concat itself is more 9 | flexible and can be used for intermediate merging/concatenation of partial graphs (i.e., graphs that are not yet 10 | finished). 11 | 12 | Below we provide a number of examples of each of the functions. We recommend using so.display() throughout to visualize 13 | the resulting graphs and truly understand how the graphs are joined together. Examples are all very simple (small graphs, 14 | scalar operations, etc.), but don't underestimate the complexities involved; with larger graphs the behavior of 15 | the concat function can be challenging. 16 | """ 17 | 18 | # # Lets start by creating a few simple (and complete) graphs which we will use throughout: 19 | # # Simple absolute value graph: 20 | g1 = so.empty_graph("g_1") 21 | n1 = so.node('Abs', inputs=['in_1_1'], outputs=['out_1_1'], name="node_1_1") 22 | g1 = so.add_input(g1, 'in_1_1', "FLOAT", [1]) 23 | g1 = so.add_output(g1, 'out_1_1', "FLOAT", [1]) 24 | g1 = so.add_node(g1, n1) 25 | # so.display(g1) 26 | # data = {"in_1_1": np.array([2]).astype(np.float32)} 27 | # print(so.run(g1, inputs=data, outputs=["out_1_1"])) 28 | 29 | # # Simple max value graph: 30 | g2= so.empty_graph("g_2") 31 | n2 = so.node('Max', inputs=['in_2_1', 'in_2_2'], outputs=['out_2_1'], name="node_2_1") 32 | g2 = so.add_input(g2, 'in_2_1', "FLOAT", [1]) 33 | g2 = so.add_constant(g2, "in_2_2", np.array([10]), "FLOAT") 34 | g2 = so.add_output(g2, 'out_2_1', "FLOAT", [1]) 35 | g2 = so.add_node(g2, n2) 36 | # so.display(g2) 37 | # data = {"in_2_1": np.array([2]).astype(np.float32)} 38 | # print(so.run(g2, inputs=data, outputs=["out_2_1"])) 39 | 40 | # # Simple add two values graph: 41 | g3 = so.empty_graph("g_3") 42 | n3 = so.node('Add', inputs=['in_3_1', 'in_3_2'], outputs=['out_3_1'], name="node_3_1") 43 | g3 = so.add_input(g3, 'in_3_1', "FLOAT", [1]) 44 | g3 = so.add_input(g3, 'in_3_2', "FLOAT", [1]) 45 | g3 = so.add_output(g3, 'out_3_1', "FLOAT", [1]) 46 | g3 = so.add_node(g3, n3) 47 | # so.display(g3) 48 | # data = { 49 | # "in_3_1": np.array([2]).astype(np.float32), 50 | # "in_3_2": np.array([5]).astype(np.float32)} 51 | # print(so.run(g3, inputs=data, outputs=["out_3_1"])) 52 | 53 | 54 | # # MERGE: 55 | # # Merge takes two complete graphs and links the output of the parent to the inputs of the child. 56 | # # Merge assumes the result is complete. 57 | g_merge = so.merge(sg1=g1, sg2=g2, io_match=[("out_1_1", "in_2_1")]) 58 | # so.display(g_merge) 59 | # data = {"in_1_1": np.array([2]).astype(np.float32)} 60 | # print(so.run(g_merge, inputs=data, outputs=["out_2_1"])) 61 | 62 | 63 | # # JOIN: 64 | # # Join takes two parents and links their outputs to one child 65 | # # Join assumes the result is complete. 66 | g_join = so.join(pg1=g1, pg2=g2, cg=g3, pg1_match=[("out_1_1", "in_3_1")], pg2_match=[("out_2_1", "in_3_2")]) 67 | # so.display(g_join) 68 | # data = { 69 | # "in_1_1": np.array([2]).astype(np.float32), 70 | # "in_2_1": np.array([2]).astype(np.float32)} 71 | # print(so.run(g_join, inputs=data, outputs=["out_3_1"])) 72 | 73 | 74 | # # SPLIT: 75 | # # Split takes a single parent and links its output to the inputs of two children. 76 | # # Split assumes the result is complete. 77 | g_split = so.split(pg=g3, cg1=g1, cg2=g2, cg1_match=[("out_3_1", "in_1_1")], cg2_match=[("out_3_1", "in_2_1")]) 78 | # so.display(g_split) 79 | # data = { 80 | # "in_3_1": np.array([2]).astype(np.float32), 81 | # "in_3_2": np.array([5]).astype(np.float32)} 82 | # print(so.run(g_split, inputs=data, outputs=["out_1_1", "out_2_1"])) 83 | 84 | 85 | # # CONCAT 86 | # # Here we provide a number of uses of concat, please inspect the resulting graphs 87 | # # Note, these result are by default not checked for completeness. Hence, the returned graph need not contain 88 | # # valid inputs and outputs. 89 | g_c1 = so.concat(g1, g2) # Note, these are just the two graphs "side-by-side" 90 | g_c2 = so.concat(g1, g2, io_match=[("out_1_1", "in_2_1")]) # Merge 91 | g_c3 = so.concat(g1, g2, io_match=[("out_2_1", "in_1_1")]) # No merge 92 | g_c4 = so.concat(g2, g1, io_match=[("out_2_1", "in_1_1")]) # Merge flipped, the order matters 93 | g_c5 = so.concat(g1, g2, io_match=[("out_1_1", "in_2_1")], rename_nodes=False) # Akin g_c2, but without the node names changed 94 | 95 | g4 = copy.deepcopy(g1) # an exact copy of g1 96 | g_c6 = so.concat(g1, g4) # Ugly... 97 | g_c7 = so.concat(g1, g4, rename_edges=True, rename_io=True) # Side by side 98 | 99 | g5 = copy.deepcopy(g4) # Another exact copy, 100 | g5 = so.delete_input(g5, "in_1_1") # Removing input and output 101 | g5 = so.delete_output(g5, "out_1_1") 102 | g_c8 = so.concat(g1, g5) # Edge created, but unable to link a single output to two named edges 103 | 104 | g6 = so.empty_graph("g_6") 105 | n4 = so.node('Add', inputs=['in_1_1', 'in_6_2'], outputs=['out_6_1'], name="node_6_1") 106 | g6 = so.add_node(g6, n4) 107 | g_c9 = so.concat(g1, g6) # Similarly named edges are also linked 108 | g_c10 = so.concat(g1, g6, rename_edges=True) # All edges renamed, but not i/o broken 109 | g_c11 = so.concat(g1, g6, rename_edges=True, rename_io=True) # g6 did not have inputs and outputs 110 | g_c12 = so.concat(g1, g6, edge_match=[("out_1_1", "in_6_2")]) # Explicit edge matching (akin io_match but for internal edges) 111 | 112 | # # Again, please use so.display(g..) to see the results of the above uses of concat. 113 | 114 | -------------------------------------------------------------------------------- /examples/images/1-BIG.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scailable/sclblonnx/6146f3b486833d4f094e73278891caafe2216baa/examples/images/1-BIG.JPG -------------------------------------------------------------------------------- /examples/images/1-Resized.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scailable/sclblonnx/6146f3b486833d4f094e73278891caafe2216baa/examples/images/1-Resized.JPG -------------------------------------------------------------------------------- /examples/images/1.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scailable/sclblonnx/6146f3b486833d4f094e73278891caafe2216baa/examples/images/1.JPG -------------------------------------------------------------------------------- /examples/images/10.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scailable/sclblonnx/6146f3b486833d4f094e73278891caafe2216baa/examples/images/10.JPG -------------------------------------------------------------------------------- /examples/images/11.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scailable/sclblonnx/6146f3b486833d4f094e73278891caafe2216baa/examples/images/11.JPG -------------------------------------------------------------------------------- /examples/images/12.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scailable/sclblonnx/6146f3b486833d4f094e73278891caafe2216baa/examples/images/12.JPG -------------------------------------------------------------------------------- /examples/images/13.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scailable/sclblonnx/6146f3b486833d4f094e73278891caafe2216baa/examples/images/13.JPG -------------------------------------------------------------------------------- /examples/images/14.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scailable/sclblonnx/6146f3b486833d4f094e73278891caafe2216baa/examples/images/14.JPG -------------------------------------------------------------------------------- /examples/images/2.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scailable/sclblonnx/6146f3b486833d4f094e73278891caafe2216baa/examples/images/2.JPG -------------------------------------------------------------------------------- /examples/images/3.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scailable/sclblonnx/6146f3b486833d4f094e73278891caafe2216baa/examples/images/3.JPG -------------------------------------------------------------------------------- /examples/images/4.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scailable/sclblonnx/6146f3b486833d4f094e73278891caafe2216baa/examples/images/4.JPG -------------------------------------------------------------------------------- /examples/images/5.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scailable/sclblonnx/6146f3b486833d4f094e73278891caafe2216baa/examples/images/5.JPG -------------------------------------------------------------------------------- /examples/images/6.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scailable/sclblonnx/6146f3b486833d4f094e73278891caafe2216baa/examples/images/6.JPG -------------------------------------------------------------------------------- /examples/images/7.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scailable/sclblonnx/6146f3b486833d4f094e73278891caafe2216baa/examples/images/7.JPG -------------------------------------------------------------------------------- /examples/images/8.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scailable/sclblonnx/6146f3b486833d4f094e73278891caafe2216baa/examples/images/8.JPG -------------------------------------------------------------------------------- /examples/images/9.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scailable/sclblonnx/6146f3b486833d4f094e73278891caafe2216baa/examples/images/9.JPG -------------------------------------------------------------------------------- /examples/images/empty-average.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scailable/sclblonnx/6146f3b486833d4f094e73278891caafe2216baa/examples/images/empty-average.JPG -------------------------------------------------------------------------------- /examples/images/horse5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scailable/sclblonnx/6146f3b486833d4f094e73278891caafe2216baa/examples/images/horse5.png -------------------------------------------------------------------------------- /examples/onnx/add-scalars.onnx: -------------------------------------------------------------------------------- 1 |  sclblonnx:i 2 | $ 3 | x1 4 | x2sumsclbl-onnx-node1"Add 5 | sclblgraphZ 6 | x1 7 | 8 |  9 | Z 10 | x2 11 | 12 |  13 | b 14 | sum 15 | 16 |  17 | B -------------------------------------------------------------------------------- /examples/onnx/check-container-resize.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scailable/sclblonnx/6146f3b486833d4f094e73278891caafe2216baa/examples/onnx/check-container-resize.onnx -------------------------------------------------------------------------------- /examples/onnx/check-container.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scailable/sclblonnx/6146f3b486833d4f094e73278891caafe2216baa/examples/onnx/check-container.onnx -------------------------------------------------------------------------------- /examples/onnx/cifar10-resnet20-augmented.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scailable/sclblonnx/6146f3b486833d4f094e73278891caafe2216baa/examples/onnx/cifar10-resnet20-augmented.onnx -------------------------------------------------------------------------------- /examples/onnx/cifar10-resnet20-clean.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scailable/sclblonnx/6146f3b486833d4f094e73278891caafe2216baa/examples/onnx/cifar10-resnet20-clean.onnx -------------------------------------------------------------------------------- /examples/onnx/cifar10-resnet20.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scailable/sclblonnx/6146f3b486833d4f094e73278891caafe2216baa/examples/onnx/cifar10-resnet20.onnx -------------------------------------------------------------------------------- /examples/onnx/resize-image-450x600-300x400.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scailable/sclblonnx/6146f3b486833d4f094e73278891caafe2216baa/examples/onnx/resize-image-450x600-300x400.onnx -------------------------------------------------------------------------------- /examples/onnx/tf-keras-dynamic.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scailable/sclblonnx/6146f3b486833d4f094e73278891caafe2216baa/examples/onnx/tf-keras-dynamic.onnx -------------------------------------------------------------------------------- /examples/onnx/tf-keras-static.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scailable/sclblonnx/6146f3b486833d4f094e73278891caafe2216baa/examples/onnx/tf-keras-static.onnx -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools >= 40.6.0", "wheel"] 3 | build-backend = "setuptools.build_meta" -------------------------------------------------------------------------------- /sclblonnx/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | The sclblonnx package provides onnx tools 3 | """ 4 | # Ran on import of the package. Check version: 5 | import sys 6 | 7 | if sys.version_info < (3, 0): 8 | print('Sclblonnx requires Python 3, while Python ' + str(sys.version[0] + ' was detected. Terminating... ')) 9 | sys.exit(1) 10 | 11 | from .version import __version__ 12 | 13 | from .main import \ 14 | empty_graph, \ 15 | graph_from_file, \ 16 | graph_to_file, \ 17 | run, \ 18 | display, \ 19 | sclbl_input, \ 20 | list_data_types, \ 21 | list_operators 22 | 23 | from .validate import \ 24 | clean, \ 25 | check 26 | 27 | from .node import \ 28 | node, \ 29 | add_node, \ 30 | add_nodes, \ 31 | delete_node 32 | 33 | from .constant import \ 34 | constant, \ 35 | add_constant 36 | 37 | from .input import \ 38 | list_inputs, \ 39 | add_input, \ 40 | rename_input, \ 41 | replace_input, \ 42 | delete_input 43 | 44 | from .output import \ 45 | list_outputs, \ 46 | add_output, \ 47 | rename_output, \ 48 | replace_output, \ 49 | delete_output 50 | 51 | from .merge import \ 52 | merge, \ 53 | join, \ 54 | split, \ 55 | concat, \ 56 | postfix_names 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | -------------------------------------------------------------------------------- /sclblonnx/__main__.py: -------------------------------------------------------------------------------- 1 | # Executed when running 2 | # > python sclblonnx 3 | # from terminal. 4 | if __name__ == '__main__': 5 | print("The sclblonnx package currently does not have any command line options.") 6 | print("Please import the package into your python project and work from there.") 7 | -------------------------------------------------------------------------------- /sclblonnx/_globals.py: -------------------------------------------------------------------------------- 1 | # Global variables for the sclblonnx package. 2 | import os 3 | 4 | # Dictionary containing details to check support 5 | VERSION_INFO_LOCATION: str = os.path.dirname(os.path.realpath(__file__)) + "/supported_onnx.json" 6 | ONNX_VERSION_INFO: dict = {} 7 | 8 | # Node counter: 9 | NODE_COUNT = 1 10 | 11 | # Optimizer passes: 12 | OPTIMIZER_PASSES = ['eliminate_deadend', 13 | 'eliminate_duplicate_initializer', 14 | 'eliminate_identity', 15 | 'eliminate_if_with_const_cond', 16 | 'eliminate_nop_cast', 17 | 'eliminate_nop_dropout', 18 | 'eliminate_nop_flatten', 19 | 'eliminate_nop_monotone_argmax', 20 | 'eliminate_nop_pad', 21 | 'eliminate_nop_transpose', 22 | 'eliminate_unused_initializer', 23 | 'extract_constant_to_initializer', 24 | 'fuse_add_bias_into_conv', 25 | 'fuse_bn_into_conv', 26 | 'fuse_consecutive_concats', 27 | 'fuse_consecutive_log_softmax', 28 | 'fuse_consecutive_reduce_unsqueeze', 29 | 'fuse_consecutive_squeezes', 30 | 'fuse_consecutive_transposes', 31 | 'fuse_matmul_add_bias_into_gemm', 32 | 'fuse_pad_into_conv', 33 | 'fuse_transpose_into_gemm', 34 | 'lift_lexical_references'] 35 | 36 | # Data types, see https://deeplearning4j.org/api/latest/onnx/Onnx.TensorProto.DataType.html 37 | DATA_TYPES = { 38 | "FLOAT": 1, 39 | "UINT8": 2, 40 | "INT8": 3, 41 | "UINT16": 4, 42 | "INT16": 5, 43 | "INT32": 6, 44 | "INT64": 7, 45 | # "STRING" : 8, 46 | "BOOL": 9, 47 | "FLOAT16": 10, 48 | "DOUBLE": 11, 49 | "UINT32": 12, 50 | "UINT64": 13, 51 | "COMPLEX64": 14, 52 | "COMPLEX128": 15 53 | } 54 | -------------------------------------------------------------------------------- /sclblonnx/constant.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from onnx import helper as xhelp 3 | from onnx import onnx_ml_pb2 as xpb2 4 | 5 | from sclblonnx.utils import _data_type, _print 6 | from sclblonnx.node import add_node 7 | 8 | 9 | # constant creates a constant node. 10 | def constant(name: str, 11 | value: np.array, 12 | data_type: str, 13 | **kwargs): 14 | """ Create a constant node 15 | 16 | Args: 17 | name: Name of the (output value of the) constant node to determine the graph topology 18 | value: Values of the node (as a np.array) 19 | data_type: Data type of the node 20 | **kwargs 21 | 22 | Returns: 23 | A constant node. 24 | """ 25 | if not name: 26 | _print("Unable to create unnamed constant.") 27 | return False 28 | 29 | dtype = _data_type(data_type) 30 | if not dtype: 31 | return False 32 | 33 | try: 34 | constant_node = xhelp.make_node('Constant', inputs=[], outputs=[name], name=name + "-constant", 35 | value=xhelp.make_tensor(name=name + "-values", data_type=dtype, 36 | dims=value.shape, vals=value.flatten()), **kwargs) 37 | except Exception as e: 38 | _print("Unable to create the constant node: " + str(e)) 39 | return False 40 | 41 | return constant_node 42 | 43 | 44 | # add_constant adds a constant node to a graph 45 | def add_constant( 46 | graph: xpb2.GraphProto, 47 | name: str, 48 | value: np.array, 49 | data_type: str, 50 | **kwargs): 51 | """ Create and add a constant node to an existing graph. 52 | 53 | Note: use add_node() if you want to add an existing constant node to an existing graph 54 | 55 | Args: 56 | graph: A graph, onnx.onnx_ml_pb2.GraphProto. 57 | name: Name of the (output value of the) constant node to determine the graph topology 58 | value: Values of the node (as a np.array) 59 | data_type: Data type of the node 60 | 61 | Returns: 62 | The extended graph. 63 | """ 64 | if type(graph) is not xpb2.GraphProto: 65 | print("graph is not a valid ONNX graph.") 66 | return False 67 | 68 | dtype = _data_type(data_type) 69 | if not dtype: 70 | return False 71 | 72 | try: 73 | constant_node = xhelp.make_node('Constant', inputs=[], outputs=[name], name=name + "-constant", 74 | value=xhelp.make_tensor(name=name + "-values", data_type=dtype, 75 | dims=value.shape, vals=value.flatten()), **kwargs) 76 | except Exception as e: 77 | _print("Unable to create the constant node: " + str(e)) 78 | return False 79 | 80 | try: 81 | graph = add_node(graph, constant_node, **kwargs) 82 | except Exception as e: 83 | _print("Unable to add the constant node to the graph: " + str(e)) 84 | return False 85 | 86 | if not graph: 87 | _print("Unable to add constant node to graph.") 88 | return False 89 | return graph 90 | 91 | -------------------------------------------------------------------------------- /sclblonnx/input.py: -------------------------------------------------------------------------------- 1 | from onnx import helper as xhelp 2 | from onnx import onnx_ml_pb2 as xpb2 3 | 4 | from sclblonnx.utils import _value, _data_type, _parse_element, _print 5 | import onnx 6 | 7 | # list_inputs lists all inputs of a graph 8 | def list_inputs(graph: xpb2.GraphProto): 9 | """ Tries to list the inputs of a given graph. 10 | 11 | Args: 12 | graph the ONNX graph 13 | """ 14 | if type(graph) is not xpb2.GraphProto: 15 | _print("graph is not a valid ONNX graph.") 16 | return False 17 | 18 | i = 1 19 | for elem in graph.input: 20 | name, dtype, shape = _parse_element(elem) 21 | print("Input {}: Name: '{}', Type: {}, Dimension: {}".format(i, name, dtype, shape)) 22 | i += 1 23 | 24 | if i == 1: 25 | print("No inputs found.") 26 | 27 | return True 28 | 29 | 30 | # add_input adds an input to a graph 31 | def add_input( 32 | graph: xpb2.GraphProto, 33 | name: str, 34 | data_type: str, 35 | dimensions: [], 36 | **kwargs): 37 | """ Add an input to a graph 38 | 39 | Args: 40 | graph: A graph, onnx.onnx_ml_pb2.GraphProto. 41 | name: String, the name of the input as used to determine the graph topology. 42 | data_type: String, the data type of the input. Run list_data_types() for an overview. 43 | dimensions: List[] specifying the dimensions of the input. 44 | **kwargs 45 | 46 | Returns: 47 | The extended graph. 48 | 49 | """ 50 | if type(graph) is not xpb2.GraphProto: 51 | _print("graph is not a valid ONNX graph.") 52 | return False 53 | 54 | dtype = _data_type(data_type) 55 | if not dtype: 56 | return False 57 | 58 | try: 59 | graph.input.append(xhelp.make_tensor_value_info(name, dtype, dimensions, **kwargs), *kwargs) 60 | except Exception as e: 61 | _print("Unable to add the input: " + str(e)) 62 | return False 63 | return graph 64 | 65 | 66 | # rename_input renames an input 67 | def rename_input(graph, current_name, new_name): 68 | """ Rename an input to a graph 69 | 70 | Args: 71 | graph: A graph, onnx.onnx_ml_pb2.GraphProto. 72 | current_name: String, the current input name. 73 | new_name: String, the name desired input name. 74 | 75 | Returns: 76 | The changed graph. 77 | """ 78 | if type(graph) is not xpb2.GraphProto: 79 | _print("graph is not a valid ONNX graph.") 80 | return False 81 | 82 | found = False 83 | for input in graph.input: 84 | if input.name == current_name: 85 | input.name = new_name 86 | found = True 87 | if not found: 88 | _print("Unable to find the input to rename.") 89 | return False 90 | 91 | # And rename it in every nodes that takes this as input: 92 | for node in graph.node: 93 | for index, name in enumerate(node.input): 94 | if name == current_name: 95 | node.input[index] = new_name 96 | 97 | return graph 98 | 99 | 100 | # rename_input_image renames an image input 101 | def rename_input_image(graph, image_input_name): 102 | """ Rename an image input 103 | 104 | Args: 105 | graph: A graph, onnx.onnx_ml_pb2.GraphProto. 106 | image_input_name: String, the current image input name. 107 | 108 | Returns: 109 | The changed graph. 110 | """ 111 | if type(graph) is not xpb2.GraphProto: 112 | _print("graph is not a valid ONNX graph.") 113 | return False 114 | 115 | found = False 116 | for input in graph.input: 117 | if input.name == image_input_name: 118 | input.name = "image-" 119 | found = True 120 | if not found: 121 | _print("Unable to find the input to rename.") 122 | return False 123 | 124 | # And rename it in every nodes that takes this as input: 125 | for node in graph.node: 126 | for index, name in enumerate(node.input): 127 | if name == image_input_name: 128 | node.input[index] = "image-" 129 | return graph 130 | 131 | # rename_input_image renames a binary mask input 132 | def rename_input_mask(graph, mask_input_name): 133 | """ Rename a binary mask input 134 | 135 | Args: 136 | graph: A graph, onnx.onnx_ml_pb2.GraphProto. 137 | mask_input_name: String, the current mask input name. 138 | 139 | Returns: 140 | The changed graph. 141 | """ 142 | if type(graph) is not xpb2.GraphProto: 143 | _print("graph is not a valid ONNX graph.") 144 | return False 145 | 146 | found = False 147 | for input in graph.input: 148 | if input.name == mask_input_name: 149 | input.name = "mask-" 150 | found = True 151 | if not found: 152 | _print("Unable to find the input to rename.") 153 | return False 154 | 155 | # And rename it in every nodes that takes this as input: 156 | for node in graph.node: 157 | for index, name in enumerate(node.input): 158 | if name == mask_input_name: 159 | node.input[index] = "mask-" 160 | return graph 161 | 162 | 163 | # rename_input_image renames a threshold input 164 | def rename_input_threshold(graph, threshold_input_name, class_list): 165 | """ Enable a model to raise an alarm when number of 166 | occurrences of an object is above the threshold 167 | 168 | Args: 169 | graph: A graph, onnx.onnx_ml_pb2.GraphProto. 170 | threshold_input_name: String, the current input name of threshold. 171 | class_list: List of classes. 172 | 173 | Returns: 174 | The changed graph. 175 | """ 176 | if type(graph) is not xpb2.GraphProto: 177 | _print("graph is not a valid ONNX graph.") 178 | return False 179 | 180 | found = False 181 | new_name = "thresholds-" 182 | for index, name in enumerate(class_list): 183 | new_name = new_name + str(index) + ':' + name + ";" 184 | new_name = new_name[0:-1] 185 | 186 | for input in graph.input: 187 | if input.name == threshold_input_name: 188 | input.name = new_name 189 | found = True 190 | if not found: 191 | _print("Unable to find the input to rename.") 192 | return False 193 | 194 | # And rename it in every nodes that takes this as input: 195 | for node in graph.node: 196 | for index, name in enumerate(node.input): 197 | if name == threshold_input_name: 198 | node.input[index] = new_name 199 | return graph 200 | 201 | def rename_input_sensor(graph, sensor_input_name): 202 | """ Rename a sensor input 203 | 204 | Args: 205 | graph: A graph, onnx.onnx_ml_pb2.GraphProto. 206 | current_name: String, the current input name. 207 | 208 | Returns: 209 | The changed graph. 210 | """ 211 | if type(graph) is not xpb2.GraphProto: 212 | _print("graph is not a valid ONNX graph.") 213 | return False 214 | 215 | found = False 216 | for input in graph.input: 217 | if input.name == sensor_input_name: 218 | input.name = "sensor-" 219 | found = True 220 | if not found: 221 | _print("Unable to find the input to rename.") 222 | return False 223 | 224 | # And rename it in every nodes that takes this as input: 225 | for node in graph.node: 226 | for index, name in enumerate(node.input): 227 | if name == sensor_input_name: 228 | node.input[index] = "sensor-" 229 | return graph 230 | 231 | 232 | # replace input replaces and existing input 233 | def replace_input( 234 | graph: xpb2.GraphProto, 235 | name: str, 236 | data_type: str, 237 | dimensions: [], 238 | **kwargs): 239 | """ Changes an existing input in a graph 240 | 241 | Args: 242 | graph: A graph, onnx.onnx_ml_pb2.GraphProto. 243 | name: String, the name of the input as used to determine the graph topology. 244 | data_type: String, the data type of the input. Run list_data_types() for an overview. 245 | dimensions: List[] specifying the dimensions of the input., 246 | **kwargs 247 | 248 | Returns: 249 | The extended graph. 250 | 251 | """ 252 | if type(graph) is not xpb2.GraphProto: 253 | _print("graph is not a valid ONNX graph.") 254 | return graph 255 | 256 | # Remove the named input 257 | found = False 258 | try: 259 | for elem in graph.input: 260 | if elem.name == name: 261 | graph.input.remove(elem) 262 | found = True 263 | except Exception as e: 264 | _print("Unable to iterate the inputs. " + str(e)) 265 | return False 266 | if not found: 267 | _print("Unable to find the input by name.") 268 | 269 | # Create the new value 270 | try: 271 | val = _value(name, data_type, dimensions, **kwargs) 272 | except Exception as e: 273 | _print("Unable to create value. " + str(e)) 274 | return False 275 | 276 | # Add the value to the input 277 | try: 278 | graph.input.append(val, *kwargs) 279 | except Exception as e: 280 | _print("Unable to add the input: " + str(e)) 281 | return False 282 | 283 | return graph 284 | 285 | 286 | # delete_input deletes an existing input 287 | def delete_input( 288 | graph: xpb2.GraphProto, 289 | name: str): 290 | """ Removes an existing input of a graph by name 291 | 292 | Args: 293 | graph: A graph, onnx.onnx_ml_pb2.GraphProto. 294 | name: String, the name of the input as used to determine the graph topology. 295 | 296 | Returns: 297 | The extended graph. 298 | 299 | """ 300 | if type(graph) is not xpb2.GraphProto: 301 | return graph 302 | 303 | # Remove the named output 304 | found = False 305 | try: 306 | for elem in graph.input: 307 | if elem.name == name: 308 | graph.input.remove(elem) 309 | found = True 310 | except Exception as e: 311 | _print("Unable to iterate the inputs. " + str(e)) 312 | return False 313 | if not found: 314 | _print("Unable to find the input by name.") 315 | return False 316 | 317 | return graph 318 | -------------------------------------------------------------------------------- /sclblonnx/main.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import json 3 | import os 4 | import subprocess 5 | import onnxruntime as xrt 6 | from onnx import ModelProto as xmp 7 | from onnx import helper as xhelp 8 | from onnx import onnx_ml_pb2 as xpb2 9 | from onnx import save as xsave 10 | from onnx import numpy_helper as xnp 11 | import onnx 12 | import sclblonnx._globals as glob 13 | from sclblonnx.utils import _print 14 | 15 | 16 | # empty_graph creates an empty graph 17 | def empty_graph( 18 | _default_name: str = "sclblgraph"): 19 | """ empty_graph returns an empty graph 20 | 21 | Note, an empty graph does not pass the check() as it does not contain input and output. 22 | 23 | Args: 24 | _default_name: Graph name, default sclblgraph 25 | 26 | Returns: 27 | An empty graph 28 | """ 29 | try: 30 | graph = xpb2.GraphProto(name=_default_name) 31 | except Exception as e: 32 | _print("Unable to create graph: " + str(e)) 33 | return False 34 | return graph 35 | 36 | 37 | # graph_from_file opens an existing onnx file and returns the graph 38 | def graph_from_file( 39 | filename: str): 40 | """ Retrieve a graph object from an onnx file 41 | 42 | Function attempts to open a .onnx file and returns its graph. 43 | 44 | Args: 45 | filename: String indicating the filename / relative location. 46 | 47 | Returns: 48 | An ONNX graph or False if unable to open. 49 | 50 | """ 51 | mod_temp = xmp() 52 | try: 53 | with open(filename, 'rb') as fid: 54 | content = fid.read() 55 | mod_temp.ParseFromString(content) 56 | graph = mod_temp.graph 57 | except Exception as e: 58 | _print("Unable to open the file: " + str(e)) 59 | return False 60 | return graph 61 | 62 | 63 | # graph_to_file saves a graph to a file 64 | def graph_to_file( 65 | graph: xpb2.GraphProto, 66 | filename: str, 67 | _producer: str = "sclblonnx", 68 | onnx_opset_version = 12, 69 | **kwargs): 70 | """ graph_to_file stores an onnx graph to a .onnx file 71 | 72 | Stores a graph to a file 73 | 74 | Args: 75 | graph: An onnx graph 76 | filename: The filename of the resulting file 77 | _producer: Optional string with producer name. Default 'sclblonnx' 78 | onnx_opset_version: Optional version number for ONNX opset. Default 12 79 | Returns: 80 | True if successful, False otherwise. 81 | """ 82 | if not filename: 83 | _print("Unable to save: Please specify a filename.") 84 | return False 85 | 86 | if type(graph) is not xpb2.GraphProto: 87 | _print("Unable to save: Graph is not an ONNX graph") 88 | 89 | try: 90 | if not 'opset_imports' in kwargs: 91 | op = onnx.OperatorSetIdProto() 92 | op.version = onnx_opset_version 93 | mod = xhelp.make_model(graph, producer_name=_producer, opset_imports=[op], **kwargs) 94 | else: 95 | mod = xhelp.make_model(graph, producer_name=_producer, **kwargs) 96 | except Exception as e: 97 | print("Unable to convert graph to model: " + str(e)) 98 | return False 99 | 100 | try: 101 | xsave(mod, filename, **kwargs) 102 | except Exception as e: 103 | print("Unable to save the model: " + str(e)) 104 | return False 105 | 106 | return True 107 | 108 | 109 | # run executes a given graph and returns its result 110 | def run( 111 | graph: xpb2.GraphProto, 112 | inputs: {}, 113 | outputs: [], 114 | _tmpfile: str = ".tmp.onnx", 115 | onnx_opset_version = 12, 116 | **kwargs): 117 | """ run executes a give graph with the given input and returns the output 118 | 119 | Args: 120 | graph: The onnx graph 121 | inputs: an object with the named inputs; please check the data types 122 | outputs: list of named outputs 123 | _tmpfile: String the temporary filename for the onnx file to run. 124 | onnx_opset_version: Optional version number for ONNX opset. Default 12 125 | 126 | Returns: 127 | The result (or False if it fails somewhere) 128 | """ 129 | 130 | store = graph_to_file(graph, _tmpfile, onnx_opset_version=onnx_opset_version) 131 | if not store: 132 | _print("Unable to store model for evaluation.") 133 | return False 134 | 135 | try: 136 | sess = xrt.InferenceSession(_tmpfile, **kwargs) 137 | out = sess.run(outputs, inputs) 138 | except Exception as e: 139 | _print("Failed to run the model: " + str(e)) 140 | return False 141 | 142 | try: 143 | os.remove(_tmpfile) 144 | except Exception as e: 145 | print("We were unable to delete the file " + _tmpfile, "MSG") 146 | 147 | return out 148 | 149 | 150 | # display uses Netron to display a graph 151 | def display( 152 | graph: xpb2.GraphProto, 153 | _tmpfile: str = '.tmp.onnx'): 154 | """ display a onnx graph using netron. 155 | 156 | Pass a graph to the display function to open it in Netron. 157 | Note: Due to the complexities of cross platform opening of source and the potential lack of 158 | a Netron installation this function might not always behave properly. 159 | Note2: This function might leave a file called .temp.onnx if it fails to remove the file. 160 | 161 | Args: 162 | graph: an ONNX graph 163 | _tmpfile: an optional string with the temporary file name. Default .tmp.onnx 164 | 165 | Returns: 166 | True if one of the 3 methods to open the file did not raise any warnings. 167 | 168 | Raises: 169 | SclblONNXError 170 | """ 171 | if type(graph) is not xpb2.GraphProto: 172 | _print("graph is not a valid ONNX graph.") 173 | return False 174 | 175 | # store as tmpfile 176 | graph_to_file(graph, _tmpfile) 177 | 178 | file_open = False 179 | # try open on unix: 180 | if not file_open: 181 | try: 182 | subprocess.run(['xdg-open', _tmpfile]) 183 | file_open = True 184 | except Exception: 185 | file_open = False 186 | 187 | # try open on mac: 188 | if not file_open: 189 | try: 190 | subprocess.run(['open', _tmpfile]) 191 | file_open = True 192 | except Exception: 193 | file_open = False 194 | 195 | # try open on windows: 196 | if not file_open: 197 | try: 198 | os.startfile(_tmpfile) 199 | file_open = True 200 | except Exception: 201 | file_open = False 202 | 203 | # Result: 204 | return file_open 205 | 206 | 207 | # sclbl_input generates the example input for a Scailable runtime 208 | def sclbl_input( 209 | inputs: {}, 210 | example_type: str = "pb", 211 | _verbose: bool = True): 212 | """ input_str returns an example input for a Scailable runtime 213 | 214 | The method takes a valid input object to an onnx graph (i.e., one used for the "inputs" argument 215 | in the run() function, and returns and prints an example input to a Scailable runtime / REST endpoint 216 | 217 | Args: 218 | inputs: The input object as supplied to the run() function to test an ONNX grph 219 | example_type: The type of example string ("raw" for base64 encoded, or "pb" for protobuf, default pb) 220 | _verbose: Print user feedback; default True (note, errors are always printed). 221 | 222 | Returns: 223 | An example input to a Scailable runtime. 224 | """ 225 | if not inputs: 226 | _print("No input provided.") 227 | 228 | if example_type == "raw": 229 | if len(inputs) == 1: 230 | for val in inputs.values(): 231 | bytes = val.tobytes() 232 | encoded = base64.b64encode(bytes) 233 | value_str = '"'+encoded.decode('ascii')+'"' 234 | else: 235 | value_str = '["' 236 | for val in inputs.values(): 237 | bytes = val.tobytes() 238 | encoded = base64.b64encode(bytes) 239 | value_str += (encoded.decode('ascii') + '","') 240 | value_str = value_str.rstrip(',"') 241 | value_str += '"]' 242 | 243 | input_json = '{"input": ' + value_str + ', "type":"raw"}' 244 | if _verbose: 245 | _print("The following input string can be used for the Scailable runtime:", "MSG") 246 | _print(input_json, "LIT") 247 | return input_json 248 | 249 | elif example_type == "pb" or "protobuf": 250 | 251 | if len(inputs) == 1: 252 | for val in inputs.values(): 253 | tensor = xnp.from_array(val) 254 | serialized = tensor.SerializeToString() 255 | encoded = base64.b64encode(serialized) 256 | value_str = '"'+encoded.decode('ascii')+'"' 257 | else: 258 | value_str = '["' 259 | for val in inputs.values(): 260 | tensor = xnp.from_array(val) 261 | serialized = tensor.SerializeToString() 262 | encoded = base64.b64encode(serialized) 263 | value_str += (encoded.decode('ascii') + '","') 264 | value_str = value_str.rstrip(',"') 265 | value_str += '"]' 266 | 267 | input_json = '{"input": ' + value_str + ', "type":"pb"}' 268 | if _verbose: 269 | _print("The following input string can be used for the Scailable runtime:", "MSG") 270 | _print(input_json, "LIT") 271 | _print("The following input string can be used for the web front-end:", "MSG") 272 | _print(value_str, "LIT") 273 | return input_json 274 | 275 | 276 | # list_data_types prints all available data types 277 | def list_data_types(): 278 | """ List all available data types. """ 279 | _print(json.dumps(glob.DATA_TYPES, indent=2), "MSG") 280 | _print("Note: STRINGS are not supported at this time.", "LIT") 281 | return True 282 | 283 | 284 | # list_operators prints all operators available within Scailable 285 | def list_operators(): 286 | """ List all available Scailable ONNX operators. """ 287 | try: 288 | with open(glob.VERSION_INFO_LOCATION, "r") as f: 289 | glob.ONNX_VERSION_INFO = json.load(f) 290 | except FileNotFoundError: 291 | print("Unable to locate the ONNX_VERSION INFO.") 292 | return False 293 | _print(json.dumps(glob.ONNX_VERSION_INFO['operators'], indent=2), "MSG") 294 | return True 295 | 296 | 297 | # No command line options for this script: 298 | if __name__ == '__main__': 299 | print("No command line options available for main.py.") 300 | -------------------------------------------------------------------------------- /sclblonnx/merge.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from onnx import onnx_ml_pb2 as xpb2 3 | from sclblonnx import check, delete_output, delete_input 4 | from sclblonnx.utils import _print 5 | """ 6 | merge.py contains a number of utilities to merge / combine existing graphs. The functions merge(), join(), and split() 7 | provide easy to use wrappers around the actual workhorse concat(). The concat function is versatile; please 8 | see the example_merge.py script in /examples to 9 | """ 10 | 11 | 12 | def merge( 13 | sg1: xpb2.GraphProto, 14 | sg2: xpb2.GraphProto, 15 | outputs: [] = None, 16 | inputs: [] = None, 17 | io_match: [] = None, 18 | complete: bool = True, 19 | _verbose: bool = True, 20 | **kwargs): 21 | """ 22 | merge merges two graphs. 23 | 24 | Given subgraph sg1 and subgraph sg2 merge attempts to link the identified outputs of sg1 to the 25 | inputs of sg2 resulting in a graph in which sg1 is the parent of sg2. 26 | 27 | Merge expects two complete graphs (i.e., it expects sg1 and sg2 to pass check(). If you would like more 28 | flexible merge options or partial merge please see the concat function (merge is merely a constrained wrapper 29 | around concat). 30 | 31 | Note: The args inputs and outputs are present for legacy reasons, we recommend using io_match directly. 32 | 33 | Args: 34 | sg1: Subgraph 1, the parent. 35 | sg2: Subgraph 2, the child. 36 | outputs: (Optional) A list of strings containing the names of the outputs of sg1 that are matched to inputs (in order of the desired match). 37 | inputs: (Optional) A list of strings containing the names of the inputs of sg2 to which the outputs of sg1 are matched. 38 | io_match: (Optional) A list of names pairs [("out1","in1"), ("out2","in2"),...]. This is an alternative for the inputs/outputs arguments. 39 | complete: (Optional) Boolean indicating whether the resulting graph should be complete (i.e., should pass check). Default True. 40 | _verbose: (Optional) Boolean indicating whether or not verbose user feedback should be provided. Default True. 41 | Returns: 42 | The merged graph g, or False (with a printed error message) if something is wrong. 43 | """ 44 | # immutable defaults: 45 | if inputs is None: 46 | inputs = [] 47 | if outputs is None: 48 | outputs = [] 49 | if io_match is None: 50 | io_match = [] 51 | 52 | # prevent changes to original 53 | sg1 = copy.deepcopy(sg1) 54 | sg2 = copy.deepcopy(sg2) 55 | 56 | # Check the inputs: 57 | if type(sg1) is not xpb2.GraphProto: 58 | _print("Graph sg1 is not an ONNX graph.") 59 | return False 60 | if type(sg2) is not xpb2.GraphProto: 61 | _print("Graph sg2 is not an ONNX graph.") 62 | return False 63 | if len(outputs) != len(inputs): 64 | _print("The number of outputs and inputs do not match.") 65 | return False 66 | if len(inputs) > 0 and len(io_match) > 0: 67 | _print("Please use either the inputs/outputs arguments OR the io_match argument (not both).") 68 | return False 69 | 70 | # Construct IO pairs 71 | if len(inputs) > 0: 72 | _print("Constructing the io_match list from your input and output.", "MSG", (not _verbose)) 73 | io_match = [] 74 | for idx, val in enumerate(outputs): 75 | io_match.append((val, inputs[idx])) 76 | 77 | # Use concat to do the merge 78 | g = concat(sg1, sg2, io_match=io_match, complete=complete, **kwargs) 79 | if not g: 80 | _print("Graph merge failed. Please checkout concat for additional options.", "MSG", (not _verbose)) 81 | 82 | return g 83 | 84 | 85 | def join( 86 | pg1: xpb2.GraphProto, 87 | pg2: xpb2.GraphProto, 88 | cg: xpb2.GraphProto, 89 | pg1_match: [] = None, 90 | pg2_match: [] = None, 91 | complete: bool = True, 92 | _verbose: bool = True, 93 | **kwargs): 94 | """ 95 | join takes two parent graphs (pg1 & pg2) and merges them with a child graph cg. 96 | 97 | Join matches the outputs of pg1 to the inputs of cg specified in pg1_match, and similarly for pg2 and pg2_match. 98 | Desired matches are specified in pairs: [("out1","in1"), ("out2","in2"),...]. 99 | 100 | Join by default assumes the resulting joined graph to be complete. Join is merely a wrapper around concat (used 101 | twice). For more flexible combinations of graphs please see concat(). 102 | 103 | Note: ONNX concat operations might give unexpected results if names of elements collide, please use postfix_names() 104 | to prevent this (and always critically inspect the resulting graph). 105 | 106 | Args: 107 | pg1: Parent graph 1. 108 | pg2: Parent graph 2. 109 | cg: Child graph, the graph that will join together pg1 and pg2. 110 | pg1_match: (Optional) List of pairs matching outputs of pg1 to inputs of cg. Default []. 111 | pg2_match: (Optional) List of pairs matching outputs of pg2 to inputs of cg. Default []. 112 | complete: (Optional) Boolean indicating whether the resulting graph should be complete (i.e., should pass check). Default True. 113 | _verbose: (Optional) Boolean indicating whether or not verbose user feedback should be provided. Default True. 114 | Returns: 115 | The joined graph g (of False is something fails along the way). 116 | """ 117 | # immutable defaults: 118 | if pg1_match is None: 119 | pg1_match = [] 120 | if pg2_match is None: 121 | pg2_match = [] 122 | 123 | # prevent changes to original 124 | pg1 = copy.deepcopy(pg1) 125 | pg2 = copy.deepcopy(pg2) 126 | cg = copy.deepcopy(cg) 127 | 128 | if type(pg1) is not xpb2.GraphProto: 129 | _print("Graph pg1 is not an ONNX graph.") 130 | return False 131 | if type(pg2) is not xpb2.GraphProto: 132 | _print("Graph pg2 is not an ONNX graph.") 133 | return False 134 | if type(cg) is not xpb2.GraphProto: 135 | _print("Graph cg is not an ONNX graph.") 136 | return False 137 | 138 | # Construct the match list 139 | io_match = pg1_match 140 | io_match.extend(pg2_match) 141 | 142 | # Do the joint (2x concat) 143 | g1 = concat(pg1, pg2, rename_nodes=True, complete=False, _verbose=False, **kwargs) 144 | g = concat(g1, cg, rename_nodes=True, io_match=io_match, complete=complete, _verbose=False, **kwargs) 145 | if not g: 146 | _print("Graph merge failed. Please checkout concat for additional options.", "MSG", (not _verbose)) 147 | 148 | return g 149 | 150 | 151 | def split( 152 | pg: xpb2.GraphProto, 153 | cg1: xpb2.GraphProto, 154 | cg2: xpb2.GraphProto, 155 | cg1_match: [] = None, 156 | cg2_match: [] = None, 157 | complete: bool = True, 158 | _verbose: bool = True, 159 | **kwargs): 160 | """ 161 | split takes takes a single parent and matches the outputs to the inputs of two childs (cg1 & cg2) 162 | 163 | Split matches the outputs of pg to the inputs of cg1 and cg2 as specified in cg1_match and cg2_match. 164 | Desired matches are specified in pairs: [("out1","in1"), ("out2","in2"),...]. 165 | 166 | Split by default assumes the resulting joined graph to be complete. Split is merely a wrapper around concat (used 167 | twice). For more flexible combinations of graphs please see concat(). 168 | 169 | Note: ONNX concat operations might give unexpected results if names of elements collide, please use postfix_names() 170 | to prevent this (and always critically inspect the resulting graph). 171 | 172 | Args: 173 | pg: The parent graph 174 | cg1: The left child. 175 | cg2: The right child. 176 | cg1_match: (Optional) List of pairs matching outputs of pg to inputs of cg1. Default []. 177 | cg2_match: (Optional) List of pairs matching outputs of pg to inputs of cg2. Default []. 178 | complete: (Optional) Boolean indicating whether the resulting graph should be complete (i.e., should pass check). Default True. 179 | _verbose: (Optional) Boolean indicating whether or not verbose user feedback should be provided. Default True. 180 | Returns: 181 | The joined graph g (of False is something fails along the way). 182 | """ 183 | # immutable defaults: 184 | if cg1_match is None: 185 | cg1_match = [] 186 | if cg2_match is None: 187 | cg2_match = [] 188 | 189 | # prevent changes to original 190 | pg = copy.deepcopy(pg) 191 | cg1 = copy.deepcopy(cg1) 192 | cg2 = copy.deepcopy(cg2) 193 | 194 | if type(pg) is not xpb2.GraphProto: 195 | _print("Graph pg is not an ONNX graph.") 196 | return False 197 | if type(cg1) is not xpb2.GraphProto: 198 | _print("Graph cg1 is not an ONNX graph.") 199 | return False 200 | if type(cg2) is not xpb2.GraphProto: 201 | _print("Graph cg2 is not an ONNX graph.") 202 | return False 203 | 204 | # Create the split (using concat 2x) 205 | g1 = concat(pg, cg1, rename_nodes=True, io_match=cg1_match, complete=False, _verbose=False, **kwargs) 206 | g = concat(g1, cg2, rename_nodes=True, io_match=cg2_match, complete=complete, _verbose=False, **kwargs) 207 | if not g: 208 | _print("Graph merge failed. Please checkout concat() for additional options.", "MSG", (not _verbose)) 209 | 210 | return g 211 | 212 | 213 | def concat( 214 | sg1: xpb2.GraphProto, 215 | sg2: xpb2.GraphProto, 216 | complete: bool = False, 217 | rename_nodes: bool = True, 218 | io_match: [] = None, 219 | rename_io: bool = False, 220 | edge_match: [] = None, 221 | rename_edges: bool = False, 222 | rename_init: bool = False, 223 | _verbose: bool = True, 224 | **kwargs): 225 | """ 226 | concat concatenates two graphs. 227 | 228 | Concat is the flexible (but also rather complex) workhorse for the merge, join, and split functions and 229 | can be used to quite flexibly paste together two (sub)graphs. Contrary to merge, join, and split, concat 230 | does not by default assume the resulting onnx graph to be complete (i.e., to contain inputs and outputs and to 231 | pass check()), and it can thus be used as an intermediate function when constructing larger graphs. 232 | 233 | Concat is flexible and versatile, but it takes time to master. See example_merge.py in the examples folder 234 | for a number of examples. 235 | 236 | Args: 237 | sg1: Subgraph 1, the parent. 238 | sg2: Subgraph 2, the child. 239 | complete: (Optional) Boolean indicating whether the resulting graph should be checked using so.check(). Default False. 240 | rename_nodes: (Optional) Boolean indicating whether the names of the nodes in the graph should be made unique. Default True. 241 | io_match: (Optional) Dict containing pairs of outputs of sg1 that should be matched to inputs of sg2. Default []. 242 | rename_io: (Optional) Boolean indicating whether the inputs and outputs of the graph should be renamed. Default False. 243 | edge_match: (Optional) Dict containing pairs edge names of sg1 (i.e., node outputs) that should be matched to edges of sg2 (i.e., node inputs). Default []. 244 | rename_edges: (Optional) Boolean indicating whether the edges should be renamed (default False) 245 | _verbose: (Optional) Boolean indicating whether verbose output should be printed (default False) 246 | Returns: 247 | The concatenated graph g, or False if something goes wrong along the way. 248 | """ 249 | # immutable defaults: 250 | if io_match is None: 251 | io_match = [] 252 | if edge_match is None: 253 | edge_match = [] 254 | 255 | # prevent changes to original 256 | sg1 = copy.deepcopy(sg1) 257 | sg2 = copy.deepcopy(sg2) 258 | 259 | # Check input types: 260 | if type(sg1) is not xpb2.GraphProto: 261 | _print("Graph sg1 is not an ONNX graph. Abort.") 262 | return False 263 | if type(sg2) is not xpb2.GraphProto: 264 | _print("Graph sg2 is not an ONNX graph. Abort.") 265 | return False 266 | 267 | # Rename node names if requested (default True) 268 | if rename_nodes: 269 | _print("Renaming node names in graph.", "MSG", (not _verbose)) 270 | sg1 = postfix_names(sg1, "_sg1", "node") 271 | sg2 = postfix_names(sg2, "_sg2", "node") 272 | 273 | if io_match: 274 | _print("Matching specified inputs and outputs..", "MSG", (not _verbose)) 275 | for io_pair in io_match: 276 | for outputs in sg1.output: 277 | if outputs.name == io_pair[0]: 278 | sg1 = delete_output(sg1, io_pair[0]) 279 | for inputs in sg2.input: 280 | if inputs.name == io_pair[1]: 281 | sg2 = delete_input(sg2, io_pair[1]) 282 | for item in sg2.node: 283 | for index, name in enumerate(item.input): 284 | if name == io_pair[1]: 285 | item.input[index] = io_pair[0] 286 | 287 | if rename_io: 288 | _print("Renaming inputs and outputs.", "MSG", (not _verbose)) 289 | sg1 = postfix_names(sg1, "_sg1", "io") 290 | sg2 = postfix_names(sg2, "_sg2", "io") 291 | 292 | if edge_match: 293 | _print("Matching edges.", "MSG", (not _verbose)) 294 | for edge_pair in edge_match: 295 | for item in sg2.node: 296 | for index, name in enumerate(item.input): 297 | if name == edge_pair[1]: 298 | item.input[index] = edge_pair[0] 299 | 300 | if rename_edges: 301 | _print("Renaming edges.", "MSG", (not _verbose)) 302 | sg1 = postfix_names(sg1, "_sg1", "edge") 303 | sg2 = postfix_names(sg2, "_sg2", "edge") 304 | 305 | if rename_init: 306 | _print("Renaming init.", "MSG", (not _verbose)) 307 | sg1 = postfix_names(sg1, "_sg1", "init") 308 | sg2 = postfix_names(sg2, "_sg2", "init") 309 | 310 | # Paste graphs together: 311 | _print("Pasting graphs.", "MSG", (not _verbose)) 312 | g = _paste_graphs(sg1, sg2) 313 | 314 | if complete: 315 | if not check(g, _verbose=_verbose, **kwargs): 316 | _print("The end result does not pass check(). Are you sure you want a complete result? Set complete=False " 317 | "to continue concat without checking.") 318 | return False 319 | 320 | return g 321 | 322 | 323 | def postfix_names( 324 | g: xpb2.GraphProto, 325 | postfix: str = "_g1", 326 | elem: str = "node"): 327 | """ 328 | postfix_names is a utility function used by concat() to rename parts of an onnx graph. 329 | 330 | When merging (or otherwise manipulating) onnx graphs it is often useful to create unique names of the 331 | various elements of the graph. This function postfixes each name in supplied graph g of elements of type elem 332 | by the supplied postfix. 333 | 334 | Args: 335 | g: The graph 336 | postfix: (Optional) The postfix for the names of the elements. Default "_g1". 337 | elem: (Optional) The type of element. Options are "node", "init", "edge", "input", "output", "io", and "all". Default "node". 338 | """ 339 | if elem == 'node': 340 | for item in g.node: 341 | item.name = item.name + postfix 342 | return g 343 | elif elem == 'init': 344 | for init in g.initializer: 345 | init.name = init.name + postfix 346 | return g 347 | elif elem == 'edge': 348 | for init in g.node: 349 | for index, name in enumerate(init.input): 350 | init.input[index] = init.input[index] + postfix 351 | for index, name in enumerate(init.output): 352 | init.output[index] = init.output[index] + postfix 353 | return g 354 | elif elem == 'input': 355 | for item in g.input: 356 | item.name = item.name + postfix 357 | return g 358 | elif elem == 'output': 359 | for item in g.output: 360 | item.name = item.name + postfix 361 | return g 362 | elif elem == 'io': 363 | cg = postfix_names(g, postfix, "input") 364 | cg = postfix_names(cg, postfix, "output") 365 | return cg 366 | elif elem == 'all': 367 | cg = postfix_names(g, postfix, "node") 368 | cg = postfix_names(cg, postfix, "init") 369 | cg = postfix_names(cg, postfix, "edge") 370 | cg = postfix_names(cg, postfix, "input") 371 | cg = postfix_names(cg, postfix, "output") 372 | return cg 373 | else: 374 | _print("No names have been changed; did you select the right element?", "MSG") 375 | 376 | return g 377 | 378 | 379 | def _paste_graphs( 380 | sg1: xpb2.GraphProto, 381 | sg2: xpb2.GraphProto): 382 | """ 383 | _paste_graphs takes two subgraphs and pastes all of their objects into a single graph. 384 | 385 | Note, _paste_graphs does not conduct any checks, it just blindly copies. It is used internally 386 | by the concat() function. 387 | 388 | Args: 389 | sg1: The first subgraph 390 | sg2: The second subgraph 391 | 392 | Returns: 393 | g: the joined graph 394 | """ 395 | g = copy.deepcopy(sg1) 396 | 397 | # Copy initializers from sg2 398 | for init in sg2.initializer: 399 | g.initializer.append(init) 400 | 401 | # Copy nodes from sg2 402 | for node in sg2.node: 403 | g.node.append(node) 404 | 405 | # Copy inputs and outputs from sg2 406 | for item in sg2.input: 407 | g.input.append(item) 408 | for item in sg2.output: 409 | g.output.append(item) 410 | 411 | return g 412 | -------------------------------------------------------------------------------- /sclblonnx/node.py: -------------------------------------------------------------------------------- 1 | from onnx import helper as xhelp 2 | from onnx import onnx_ml_pb2 as xpb2 3 | import sclblonnx._globals as glob 4 | from sclblonnx.utils import _print 5 | 6 | 7 | # Node creates a new node 8 | def node( 9 | op_type: str, 10 | inputs: [], 11 | outputs: [], 12 | name: str = "", 13 | **kwargs): 14 | """ Create a new node 15 | 16 | Args: 17 | op_type: Operator type, see https://github.com/onnx/onnx/blob/master/docs/Operators.md 18 | inputs: [] list of inputs (names to determine the graph topology) 19 | outputs: [] list of outputs (names to determine the graph topology) 20 | name: The name of this node (Optional) 21 | **kwargs 22 | """ 23 | if not name: 24 | name = "sclbl-onnx-node" + str(glob.NODE_COUNT) 25 | glob.NODE_COUNT += 1 26 | 27 | try: 28 | node = xhelp.make_node(op_type, inputs, outputs, name, **kwargs) 29 | except Exception as e: 30 | _print("Unable to create node: " + str(e)) 31 | return False 32 | return node 33 | 34 | 35 | # add_node adds a node to a graph 36 | def add_node( 37 | graph: xpb2.GraphProto, 38 | node: xpb2.NodeProto, 39 | **kwargs): 40 | """ Add node appends a node to graph g and returns the extended graph 41 | 42 | Prints a message and returns False if fails. 43 | 44 | Args: 45 | graph: A graph, onnx.onnx_ml_pb2.GraphProto. 46 | node: A node, onnx.onnx_ml_pb2.NodeProto. 47 | **kwargs 48 | 49 | Returns: 50 | The extended graph. 51 | """ 52 | if type(graph) is not xpb2.GraphProto: 53 | _print("The graph is not a valid ONNX graph.") 54 | return False 55 | 56 | if type(node) is not xpb2.NodeProto: 57 | _print("The node is not a valid ONNX node.") 58 | return False 59 | 60 | try: 61 | graph.node.append(node, **kwargs) 62 | except Exception as e: 63 | _print("Unable to extend graph: " + str(e)) 64 | return False 65 | return graph 66 | 67 | 68 | def add_nodes( 69 | graph: xpb2.GraphProto, 70 | nodes: [xpb2.NodeProto], 71 | **kwargs): 72 | """ Add a list of nodes appends a node to graph g and returns the extended graph 73 | 74 | Prints a message and returns False if fails. 75 | 76 | Args: 77 | graph: A graph, onnx.onnx_ml_pb2.GraphProto. 78 | nodes: A list of nodes, [onnx.onnx_ml_pb2.NodeProto]. 79 | **kwargs 80 | 81 | Returns: 82 | The extended graph. 83 | """ 84 | if type(graph) is not xpb2.GraphProto: 85 | print("graph is not a valid ONNX graph.") 86 | return False 87 | 88 | for node in nodes: # error handling in add_node 89 | graph = add_node(graph, node, **kwargs) 90 | if not graph: 91 | return False 92 | 93 | return graph 94 | 95 | 96 | # delete_node deletes a node from a graph 97 | def delete_node( 98 | graph: xpb2.GraphProto, 99 | node_name: str = "", 100 | **kwargs): 101 | """ Add node appends a node to graph g and returns the extended graph 102 | 103 | Prints a message and returns False if fails. 104 | 105 | Args: 106 | graph: A graph, onnx.onnx_ml_pb2.GraphProto. 107 | node_name: Name of the node to remove. 108 | **kwargs 109 | 110 | Returns: 111 | The extended graph. 112 | """ 113 | if type(graph) is not xpb2.GraphProto: 114 | _print("The graph is not a valid ONNX graph.") 115 | return False 116 | 117 | if not node_name: 118 | _print("Please specify a node name.") 119 | return False 120 | 121 | found = False 122 | try: 123 | for elem in graph.node: 124 | if elem.name == node_name: 125 | graph.node.remove(elem) 126 | found = True 127 | except Exception as e: 128 | _print("Unable to iterate the nodes. " + str(e)) 129 | return False 130 | if not found: 131 | _print("Unable to find the node by name.") 132 | return False 133 | 134 | return graph 135 | -------------------------------------------------------------------------------- /sclblonnx/output.py: -------------------------------------------------------------------------------- 1 | from onnx import helper as xhelp 2 | from onnx import onnx_ml_pb2 as xpb2 3 | from sclblonnx.utils import _parse_element, _value, _data_type, _print 4 | 5 | 6 | # list_outputs list all outputs in a graph 7 | def list_outputs(graph: xpb2.GraphProto): 8 | """ Tries to list the outputs of a given graph. 9 | 10 | Args: 11 | graph the ONNX graph 12 | """ 13 | if type(graph) is not xpb2.GraphProto: 14 | _print("graph is not a valid ONNX graph.") 15 | return False 16 | 17 | i = 1 18 | for elem in graph.output: 19 | name, dtype, shape = _parse_element(elem) 20 | print("Output {}: Name: '{}', Type: {}, Dimension: {}".format(i, name, dtype, shape)) 21 | i += 1 22 | 23 | if i == 1: 24 | print("No outputs found.") 25 | 26 | return True 27 | 28 | 29 | # add_output adds an output to a graph 30 | def add_output( 31 | graph: xpb2.GraphProto, 32 | name: str, 33 | data_type: str, 34 | dimensions: [], 35 | **kwargs): 36 | """ Add an output to a graph 37 | 38 | Args: 39 | graph: A graph, onnx.onnx_ml_pb2.GraphProto. 40 | name: String, the name of the input as used to determine the graph topology. 41 | data_type: String, the data type of the input. Run list_data_types() for an overview. 42 | dimensions: List[] specifying the dimensions of the input. 43 | **kwargs 44 | 45 | Returns: 46 | The extended graph. 47 | 48 | """ 49 | if type(graph) is not xpb2.GraphProto: 50 | _print("graph is not a valid ONNX graph.") 51 | return False 52 | 53 | dtype = _data_type(data_type) 54 | if not dtype: 55 | return False 56 | 57 | try: 58 | graph.output.append(xhelp.make_tensor_value_info(name, dtype, dimensions, **kwargs), **kwargs) 59 | except Exception as e: 60 | _print("Unable to add the input: " + str(e)) 61 | return False 62 | return graph 63 | 64 | 65 | # rename_output renames an existing output 66 | def rename_output(graph, current_name, new_name): 67 | """ Rename an output to a graph 68 | 69 | Args: 70 | graph: A graph, onnx.onnx_ml_pb2.GraphProto. 71 | current_name: String, the current output name. 72 | new_name: String, the name desired output name. 73 | 74 | Returns: 75 | The changed graph. 76 | """ 77 | if type(graph) is not xpb2.GraphProto: 78 | _print("graph is not a valid ONNX graph.") 79 | return False 80 | 81 | found = False 82 | for output in graph.output: 83 | if output.name == current_name: 84 | output.name = new_name 85 | found = True 86 | if not found: 87 | _print("Unable to found the output by name.") 88 | return False 89 | 90 | for node in graph.node: 91 | for index, name in enumerate(node.output): 92 | if name == current_name: 93 | node.output[index] = new_name 94 | 95 | # Handle the case when the output is fed to another node 96 | for index, name in enumerate(node.input): 97 | if name == current_name: 98 | node.input[index] = new_name 99 | 100 | return graph 101 | 102 | 103 | def rename_bbox_output(graph, bboxes_output_name, format, class_list): 104 | """ Rename a bbox output of a graph 105 | 106 | Args: 107 | graph: A graph, onnx.onnx_ml_pb2.GraphProto. 108 | bboxes_output_name: String, the current output name of bounding boxes. 109 | format: Format of output, choose among : 110 | "xy" if (x1, y1, x2, y2) 111 | "xyc" if (x1, y1, x2, y2, class) 112 | "xysc" if (x1, y1, x2, y2, score, class) 113 | class_list: List of classes. 114 | Returns: 115 | The changed graph. 116 | """ 117 | 118 | if type(graph) is not xpb2.GraphProto: 119 | _print("graph is not a valid ONNX graph.") 120 | return False 121 | 122 | found = False 123 | new_name = "" 124 | if format == "xy": 125 | new_name = "bboxes-format:xyxy;" 126 | elif format == "xyc": 127 | new_name = "bboxes-format:xyxyc;" 128 | elif format == "xysc": 129 | new_name = "bboxes-format:xyxysc;" 130 | else: 131 | print("Format input is incorrect, it must be 'xy', 'xyc' or 'xysc'") 132 | return False 133 | for index, name in enumerate(class_list): 134 | new_name = new_name + str(index) + ':' + name + ";" 135 | new_name = new_name[0:-1] 136 | 137 | for output in graph.output: 138 | if output.name == bboxes_output_name: 139 | output.name = new_name 140 | found = True 141 | if not found: 142 | _print("Unable to found the output by name.") 143 | return False 144 | 145 | for node in graph.node: 146 | for index, name in enumerate(node.output): 147 | if name == bboxes_output_name: 148 | node.output[index] = new_name 149 | 150 | # Handle the case when the output is fed to another node 151 | for index, name in enumerate(node.input): 152 | if name == bboxes_output_name: 153 | node.input[index] = new_name 154 | return graph 155 | 156 | 157 | def rename_barcode_output(graph, barcode_output_name): 158 | """ Rename a barcode bbox output of a graph 159 | 160 | Args: 161 | graph: A graph, onnx.onnx_ml_pb2.GraphProto. 162 | barcode_output_name: String, the current name of bounding-boxes output for barcodes 163 | 164 | Returns: 165 | The changed graph. 166 | """ 167 | 168 | if type(graph) is not xpb2.GraphProto: 169 | _print("graph is not a valid ONNX graph.") 170 | return False 171 | 172 | found = False 173 | for output in graph.output: 174 | if output.name == barcode_output_name: 175 | output.name = "barcode_bboxes-format:xyxy" 176 | found = True 177 | if not found: 178 | _print("Unable to found the output by name.") 179 | return False 180 | 181 | for node in graph.node: 182 | for index, name in enumerate(node.output): 183 | if name == barcode_output_name: 184 | node.output[index] = "barcode_bboxes-format:xyxy" 185 | 186 | # Handle the case when the output is fed to another node 187 | for index, name in enumerate(node.input): 188 | if name == barcode_output_name: 189 | node.input[index] = "barcode_bboxes-format:xyxy" 190 | return graph 191 | 192 | 193 | def rename_licenseplate_output(graph, licenseplate_output_name, format): 194 | """ Rename a licenseplate bbox output of a graph 195 | 196 | Args: 197 | graph: A graph, onnx.onnx_ml_pb2.GraphProto. 198 | licenseplate_output_name: String, the current output name of licenseplate bounding boxes. 199 | format: Format of output, choose among : 200 | "xy" if (x1, y1, x2, y2) 201 | "xyxyxsxyxyxyxy" if (x1, y1, x2, y2, score, ... landmark coordinates) 202 | Returns: 203 | The changed graph. 204 | """ 205 | 206 | if type(graph) is not xpb2.GraphProto: 207 | _print("graph is not a valid ONNX graph.") 208 | return False 209 | 210 | found = False 211 | new_name = "" 212 | if format == "xy": 213 | new_name = "licenseplate_bboxes-format:xyxy" 214 | elif format == "xys": 215 | new_name = "licenseplate_bboxes-format:xyxyxsxyxyxyxy" 216 | else: 217 | print("Format input is incorrect, it must be 'xy' or 'xys'") 218 | return False 219 | 220 | for output in graph.output: 221 | if output.name == licenseplate_output_name: 222 | output.name = new_name 223 | found = True 224 | if not found: 225 | _print("Unable to found the output by name.") 226 | return False 227 | 228 | for node in graph.node: 229 | for index, name in enumerate(node.output): 230 | if name == licenseplate_output_name: 231 | node.output[index] = new_name 232 | 233 | # Handle the case when the output is fed to another node 234 | for index, name in enumerate(node.input): 235 | if name == licenseplate_output_name: 236 | node.input[index] = new_name 237 | return graph 238 | 239 | 240 | def rename_class_probabilities_output(graph, output_name, class_list): 241 | """ Rename the output of a model that generates probabilities per class 242 | 243 | Args: 244 | graph: A graph, onnx.onnx_ml_pb2.GraphProto. 245 | output_name: String, the current output name of the graph 246 | class_list: List of classes. 247 | Returns: 248 | The changed graph. 249 | """ 250 | 251 | if type(graph) is not xpb2.GraphProto: 252 | _print("graph is not a valid ONNX graph.") 253 | return False 254 | 255 | found = False 256 | new_name = "scores-" 257 | for index, name in enumerate(class_list): 258 | new_name = new_name + str(index) + ':' + name + ";" 259 | new_name = new_name[0:-1] 260 | 261 | for output in graph.output: 262 | if output.name == output_name: 263 | output.name = new_name 264 | found = True 265 | if not found: 266 | _print("Unable to found the output by name.") 267 | return False 268 | 269 | for node in graph.node: 270 | for index, name in enumerate(node.output): 271 | if name == output_name: 272 | node.output[index] = new_name 273 | 274 | # Handle the case when the output is fed to another node 275 | for index, name in enumerate(node.input): 276 | if name == output_name: 277 | node.input[index] = new_name 278 | return graph 279 | 280 | 281 | def rename_object_count_output(graph, output_name, class_list): 282 | """ Rename the output of a model that generates number of objects per class 283 | 284 | Args: 285 | graph: A graph, onnx.onnx_ml_pb2.GraphProto. 286 | output_name: String, the current output name of the graph 287 | class_list: List of classes. 288 | Returns: 289 | The changed graph. 290 | """ 291 | 292 | if type(graph) is not xpb2.GraphProto: 293 | _print("graph is not a valid ONNX graph.") 294 | return False 295 | 296 | found = False 297 | new_name = "counts-" 298 | for index, name in enumerate(class_list): 299 | new_name = new_name + str(index) + ':' + name + ";" 300 | new_name = new_name[0:-1] 301 | 302 | for output in graph.output: 303 | if output.name == output_name: 304 | output.name = new_name 305 | found = True 306 | if not found: 307 | _print("Unable to found the output by name.") 308 | return False 309 | 310 | for node in graph.node: 311 | for index, name in enumerate(node.output): 312 | if name == output_name: 313 | node.output[index] = new_name 314 | 315 | # Handle the case when the output is fed to another node 316 | for index, name in enumerate(node.input): 317 | if name == output_name: 318 | node.input[index] = new_name 319 | return graph 320 | 321 | 322 | def rename_alarm_output(graph, output_name, class_list): 323 | """ Rename the output of a model that generates an alarm based 324 | on number of objects per class 325 | 326 | Args: 327 | graph: A graph, onnx.onnx_ml_pb2.GraphProto. 328 | output_name: String, the current output name of the graph 329 | class_list: List of classes. 330 | Returns: 331 | The changed graph. 332 | """ 333 | 334 | if type(graph) is not xpb2.GraphProto: 335 | _print("graph is not a valid ONNX graph.") 336 | return False 337 | 338 | found = False 339 | new_name = "alarm-" 340 | for index, name in enumerate(class_list): 341 | new_name = new_name + str(index) + ':' + name + ";" 342 | new_name = new_name[0:-1] 343 | 344 | for output in graph.output: 345 | if output.name == output_name: 346 | output.name = new_name 347 | found = True 348 | if not found: 349 | _print("Unable to found the output by name.") 350 | return False 351 | 352 | for node in graph.node: 353 | for index, name in enumerate(node.output): 354 | if name == output_name: 355 | node.output[index] = new_name 356 | 357 | # Handle the case when the output is fed to another node 358 | for index, name in enumerate(node.input): 359 | if name == output_name: 360 | node.input[index] = new_name 361 | return graph 362 | 363 | 364 | def rename_linecrossing_bboxes_output(graph, output_name): 365 | """ Rename the output of a model to be post-processed 366 | using the line-crossing counter 367 | 368 | Args: 369 | graph: A graph, onnx.onnx_ml_pb2.GraphProto. 370 | output_name: String, the current output name of the graph 371 | Returns: 372 | The changed graph. 373 | """ 374 | 375 | if type(graph) is not xpb2.GraphProto: 376 | _print("graph is not a valid ONNX graph.") 377 | return False 378 | 379 | found = False 380 | new_name = "linecrossing_bboxes-format:xyxysc" 381 | 382 | for output in graph.output: 383 | if output.name == output_name: 384 | output.name = new_name 385 | found = True 386 | if not found: 387 | _print("Unable to found the output by name.") 388 | return False 389 | 390 | for node in graph.node: 391 | for index, name in enumerate(node.output): 392 | if name == output_name: 393 | node.output[index] = new_name 394 | 395 | # Handle the case when the output is fed to another node 396 | for index, name in enumerate(node.input): 397 | if name == output_name: 398 | node.input[index] = new_name 399 | return graph 400 | 401 | 402 | # replace_output replaces an existing output 403 | def replace_output( 404 | graph: xpb2.GraphProto, 405 | name: str, 406 | data_type: str, 407 | dimensions: [], 408 | **kwargs): 409 | """ Changes an existing output of a graph 410 | 411 | Args: 412 | graph: A graph, onnx.onnx_ml_pb2.GraphProto. 413 | name: String, the name of the output as used to determine the graph topology. 414 | data_type: String, the data type of the output. Run list_data_types() for an overview. 415 | dimensions: List[] specifying the dimensions of the input. 416 | **kwargs 417 | 418 | Returns: 419 | The extended graph. 420 | 421 | """ 422 | if type(graph) is not xpb2.GraphProto: 423 | _print("graph is not a valid ONNX graph.") 424 | return graph 425 | 426 | # Remove the named output 427 | found = False 428 | try: 429 | for elem in graph.output: 430 | if elem.name == name: 431 | graph.output.remove(elem) 432 | found = True 433 | except Exception as e: 434 | _print("Unable to iterate the outputs. " + str(e)) 435 | return False 436 | if not found: 437 | _print("Unable to find the output by name.") 438 | 439 | # Create the new value 440 | try: 441 | val = _value(name, data_type, dimensions, **kwargs) 442 | except Exception as e: 443 | _print("Unable to create value. " + str(e)) 444 | return False 445 | 446 | # Add the value to the output 447 | try: 448 | graph.output.append(val, **kwargs) 449 | except Exception as e: 450 | _print("Unable to add the output: " + str(e)) 451 | return False 452 | 453 | return graph 454 | 455 | 456 | # delete_output deletes an existing output 457 | def delete_output( 458 | graph: xpb2.GraphProto, 459 | name: str): 460 | """ Removes an existing output of a graph by name 461 | 462 | Args: 463 | graph: A graph, onnx.onnx_ml_pb2.GraphProto. 464 | name: String, the name of the output as used to determine the graph topology. 465 | 466 | 467 | Returns: 468 | The extended graph. 469 | 470 | """ 471 | if type(graph) is not xpb2.GraphProto: 472 | _print("graph is not a valid ONNX graph.") 473 | return graph 474 | 475 | # Remove the named output 476 | found = False 477 | try: 478 | for elem in graph.output: 479 | if elem.name == name: 480 | graph.output.remove(elem) 481 | found = True 482 | except Exception as e: 483 | _print("Unable to iterate the outputs. " + str(e)) 484 | return False 485 | if not found: 486 | _print("Unable to find the output by name.") 487 | return False 488 | 489 | return graph 490 | -------------------------------------------------------------------------------- /sclblonnx/supported_onnx.json: -------------------------------------------------------------------------------- 1 | { 2 | "onnx_version" : { 3 | "version_min" : "1.7.0", 4 | "version_max" : "1.17.0", 5 | "ir_version_min" : 7, 6 | "ir_version_max" : 10, 7 | "opset_min" : 12, 8 | "opset_max" : 18 9 | }, 10 | "operators" : [ 11 | "Abs", 12 | "Acos", 13 | "Acosh", 14 | "Add", 15 | "ai.onnx.ml.FeatureVectorizer", 16 | "ai.onnx.ml.LinearRegressor", 17 | "ai.onnx.ml.Scaler", 18 | "And", 19 | "ArgMax", 20 | "ArgMin", 21 | "Asin", 22 | "Asinh", 23 | "Atan", 24 | "Atanh", 25 | "AveragePool", 26 | "BatchNormalization", 27 | "Cast", 28 | "Ceil", 29 | "Celu", 30 | "Clip", 31 | "Compress", 32 | "Concat", 33 | "Constant", 34 | "ConstantOfShape", 35 | "Conv", 36 | "ConvInteger", 37 | "ConvTranspose", 38 | "Cos", 39 | "Cosh", 40 | "CumSum", 41 | "Div", 42 | "Dropout", 43 | "DynamicQuantizeLinear", 44 | "Elu", 45 | "Equal", 46 | "Exp", 47 | "Expand", 48 | "Flatten", 49 | "Floor", 50 | "Gather", 51 | "Gemm", 52 | "GlobalAveragePool", 53 | "Greater", 54 | "GreaterOrEqual", 55 | "Identity", 56 | "If", 57 | "IsNaN", 58 | "IsInf", 59 | "LeakyRelu", 60 | "Less", 61 | "LessOrEqual", 62 | "Log", 63 | "LogSoftmax", 64 | "LRN", 65 | "LSTM", 66 | "MatMul", 67 | "MatMulInteger", 68 | "Max", 69 | "MaxPool", 70 | "Min", 71 | "Mul", 72 | "Neg", 73 | "Not", 74 | "NonMaxSuppression", 75 | "Or", 76 | "Pad", 77 | "Pow", 78 | "PRelu", 79 | "QuantizeLinear", 80 | "Range", 81 | "Reciprocal", 82 | "ReduceL1", 83 | "ReduceL2", 84 | "ReduceLogSum", 85 | "ReduceLogSumExp", 86 | "ReduceMax", 87 | "ReduceMean", 88 | "ReduceMin", 89 | "ReduceProd", 90 | "ReduceSum", 91 | "ReduceSumSquare", 92 | "Relu", 93 | "Reshape", 94 | "Resize", 95 | "Round", 96 | "Scatter", 97 | "ScatterElements", 98 | "ScatterND", 99 | "Selu", 100 | "Shape", 101 | "Sigmoid", 102 | "Sign", 103 | "Sin", 104 | "Sinh", 105 | "Size", 106 | "Slice", 107 | "Softmax", 108 | "Softplus", 109 | "Softsign", 110 | "Split", 111 | "Sqrt", 112 | "Squeeze", 113 | "Sub", 114 | "Sum", 115 | "Tan", 116 | "Tanh", 117 | "ThresholdedRelu", 118 | "Transpose", 119 | "Unsqueeze", 120 | "Where", 121 | "Xor" 122 | ], 123 | "types" : [ 124 | "float", 125 | "int" 126 | ] 127 | } 128 | -------------------------------------------------------------------------------- /sclblonnx/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | from onnx import helper as xhelp 3 | from onnx import onnx_ml_pb2 as xpb2 4 | import sclblonnx._globals as glob 5 | 6 | 7 | # _parse_element parses a graph input or output element and return its properties for printing. 8 | def _parse_element(elem: xpb2.ValueInfoProto): 9 | """ Parse a graph input or output element and return its contents. 10 | 11 | Utility. 12 | 13 | Args: 14 | elem, a ValueInfoProto. 15 | 16 | Returns: 17 | name The name of the element 18 | data_type The data type of the element 19 | shape_str The dimensions of the element 20 | """ 21 | name = getattr(elem, 'name', "None") 22 | data_type = "NA" 23 | shape_str = "NA" 24 | etype = getattr(elem, 'type', False) 25 | if etype: 26 | ttype = getattr(etype, 'tensor_type', False) 27 | if ttype: 28 | data_type = _data_string(getattr(ttype, 'elem_type', 0)) 29 | shape = getattr(elem.type.tensor_type, "shape", False) 30 | if shape: 31 | shape_str = "[" 32 | dims = getattr(shape, 'dim', []) 33 | for dim in dims: 34 | vals = getattr(dim, 'dim_value', "?") 35 | shape_str += (str(vals) + ",") 36 | shape_str = shape_str.rstrip(",") 37 | shape_str += "]" 38 | return name, data_type, shape_str 39 | 40 | 41 | # _value creates a new value description 42 | def _value(name: str, 43 | data_type: str, 44 | dimensions: [], 45 | **kwargs): 46 | """_value creates a tensor value 47 | 48 | Args: 49 | name: TensorValue name. 50 | data_type: String data type. 51 | dimensions: List with dimensions. 52 | **kwargs 53 | 54 | 55 | """ 56 | dtype = _data_type(data_type) 57 | if not dtype: # Error printed by _data_type() 58 | return False 59 | 60 | try: 61 | val = xhelp.make_tensor_value_info(name, dtype, dimensions, **kwargs) 62 | except Exception as e: 63 | _print("Unable to create tensor value: " + str(e)) 64 | return False 65 | 66 | return val 67 | 68 | 69 | def _input_details(graph: xpb2.GraphProto): 70 | """ List details of the inputs of the graph by name. 71 | 72 | Args: 73 | The graph object. 74 | 75 | Returns: 76 | A dict containing the input details. 77 | 78 | """ 79 | if type(graph) is not xpb2.GraphProto: 80 | _print("graph is not a valid ONNX graph.") 81 | return False 82 | 83 | names = {} 84 | for elem in graph.input: 85 | name, dtype, shape = _parse_element(elem) 86 | desc = {"data_type": dtype, "shape": shape} 87 | names[name] = desc 88 | 89 | return names 90 | 91 | 92 | def _output_details(graph: xpb2.GraphProto): 93 | """ List details of the outputs of the graph by name. 94 | 95 | Args: 96 | The graph object. 97 | 98 | Returns: 99 | A dict containing the output details. 100 | 101 | """ 102 | if type(graph) is not xpb2.GraphProto: 103 | _print("graph is not a valid ONNX graph.") 104 | return False 105 | 106 | names = {} 107 | for elem in graph.output: 108 | name, dtype, shape = _parse_element(elem) 109 | desc = {"data_type": dtype, "shape": shape} 110 | names[name] = desc 111 | 112 | return names 113 | 114 | 115 | # bcolors, used for printing. 116 | class bcolors: 117 | HEADER = '\033[95m' 118 | OKBLUE = '\033[94m' 119 | OKGREEN = '\033[92m' 120 | WARNING = '\033[93m' 121 | FAIL = '\033[91m' 122 | ENDC = '\033[0m' 123 | BOLD = '\033[1m' 124 | UNDERLINE = '\033[4m' 125 | 126 | 127 | # _print controls the printing of user feedback throughout the package 128 | def _print( 129 | msg: str, 130 | print_type: str = "ERR", 131 | silent: bool = False): 132 | """ print controls the printing throughout the package. 133 | 134 | Args: 135 | msg: The message to print 136 | print_type: "MSG", "LIT" or "ERR", default "ERR" 137 | silent: Suppress printing, default False 138 | """ 139 | if not silent: 140 | if print_type == "MSG": 141 | print(msg) 142 | elif print_type == "ERR": 143 | print(f"{bcolors.FAIL}ERROR: "+msg+f"{bcolors.ENDC}") 144 | elif print_type == "LIT": 145 | print(f"{bcolors.OKGREEN}" + msg + f"{bcolors.ENDC}") 146 | else: 147 | print(msg) 148 | 149 | 150 | # _load_version_info loads info of supported ONNX versions (see _globals.py) 151 | def _load_version_info() -> bool: 152 | """Loads the version info json. 153 | Function opens and parses the file supported_onnx.json in the current 154 | package folder to check current Scailable toolchain requirements. 155 | 156 | Note: the supported models are loaded into the glob.ONNX_VERSION_INFO 157 | dictionary to make them available to the whole package. 158 | 159 | Args: 160 | 161 | Returns: 162 | True if the supported version info is successfully loaded. 163 | """ 164 | try: 165 | with open(glob.VERSION_INFO_LOCATION, "r") as f: 166 | glob.ONNX_VERSION_INFO = json.load(f) 167 | except FileNotFoundError: 168 | _print("Unable to locate the ONNX_VERSION INFO.") 169 | return False 170 | return True 171 | 172 | 173 | # _data_type converts a data_string to the data_type int (see _globals.py) 174 | def _data_type(data_string: str): 175 | """ convert the data type string (i.e., FLOAT, INT16, etc.) to the appropriate int. 176 | 177 | See: https://deeplearning4j.org/api/latest/onnx/Onnx.TensorProto.DataType.html 178 | """ 179 | for key, val in glob.DATA_TYPES.items(): 180 | if key == data_string: 181 | return val 182 | _print("Data string not found. Use `list_data_types()` to list all supported data strings.") 183 | return False 184 | 185 | 186 | # _data_string converts a data_type int to a data string 187 | def _data_string(data_type: int): 188 | """ convert the data type number to the appropriate string 189 | 190 | See: https://deeplearning4j.org/api/latest/onnx/Onnx.TensorProto.DataType.html 191 | """ 192 | for key, val in glob.DATA_TYPES.items(): 193 | if val == data_type: 194 | return key 195 | 196 | _print("Data type not found. Use `list_data_types()` to list all supported data types.") 197 | return False 198 | -------------------------------------------------------------------------------- /sclblonnx/validate.py: -------------------------------------------------------------------------------- 1 | import onnxoptimizer 2 | from onnx import __version__ as xversion 3 | from onnx import checker 4 | import onnx 5 | from onnx import helper as xhelp 6 | from onnx import onnx_ml_pb2 as xpb2 7 | from onnxsim import simplify 8 | from packaging import version 9 | 10 | import sclblonnx._globals as glob 11 | from sclblonnx.utils import _load_version_info, _print 12 | 13 | 14 | # clean cleans a graph if possible (but also provides a stringent check) 15 | def clean( 16 | graph: xpb2.GraphProto, 17 | _optimize: bool = True, 18 | _simplify: bool = True, 19 | _remove_initializer: bool = True, 20 | _producer: str = "sclblonnx", 21 | _verbose: bool = True, 22 | **kwargs): 23 | """ clean cleans an ONNX graph using onnx tooling 24 | 25 | This method will attempt to clean the supplied graph by 26 | a. Removing initializers from input 27 | b. Optimizing it using onnxoptimizer.optimize 28 | c. Simplifying it using onnxsim.simplify 29 | 30 | If one of these fails the method will print an error message and return the unaltered graph. 31 | 32 | Args: 33 | graph: An ONNX graph 34 | _optimize: Boolean, default True. Optimize the model using onnxoptimizer. 35 | _simplify: Boolean, default True. Simplify the model using simplify. 36 | _remove_initializer: Boolean, default True. Remove initializers from input. 37 | _producer: Optional string with producer name. Default 'sclblonnx' (used for internal conversion) 38 | _verbose: Print user feedback; default True (note, errors are always printed). 39 | **kwargs 40 | 41 | Returns: 42 | The cleaned ONNX graph, or the old graph if an error occurs. 43 | """ 44 | try: 45 | if not 'opset_imports' in kwargs: 46 | op = onnx.OperatorSetIdProto() 47 | op.version = 12 48 | mod = xhelp.make_model(graph, producer_name=_producer, opset_imports=[op], **kwargs) 49 | else: 50 | mod = xhelp.make_model(graph, producer_name=_producer, **kwargs) 51 | except Exception as e: 52 | _print("Unable to create the model: " + str(e)) 53 | return graph 54 | 55 | if _optimize: 56 | try: 57 | mod = onnxoptimizer.optimize(mod, glob.OPTIMIZER_PASSES, **kwargs) 58 | except Exception as e: 59 | _print("Unable to optimize your model: " + str(e)) 60 | return graph 61 | 62 | if _simplify: 63 | try: 64 | mod, _ = simplify(mod, **kwargs) 65 | except Exception as e: 66 | _print("Unable to simplify your model: " + str(e)) 67 | return graph 68 | 69 | # From: onnxruntime/tools/python/remove_initializer_from_input.py 70 | graph = mod.graph 71 | if _remove_initializer: 72 | inputs = graph.input 73 | name_to_input = {} 74 | for input in inputs: 75 | name_to_input[input.name] = input 76 | for initializer in graph.initializer: 77 | if initializer.name in name_to_input: 78 | inputs.remove(name_to_input[initializer.name]) 79 | 80 | _print("The graph was successfully cleaned.", "MSG", (not _verbose)) 81 | return graph 82 | 83 | 84 | # check checks the graph and inspects whether it is valid. 85 | def check( 86 | graph: xpb2.GraphProto, 87 | _producer: str = "sclblonnx", 88 | _onnx_check: bool = True, 89 | _sclbl_check: bool = True, 90 | _verbose: bool = True, 91 | **kwargs): 92 | """ check whether or not an existing graph can be converted using the Scailable platform 93 | 94 | We assume that a user will use graph_to_file() in this package to store the model. This 95 | 96 | Args: 97 | graph: an ONNX graph 98 | _producer: String optional 99 | _onnx_check: Bool, default True. Run ONNX checker.check(). 100 | _sclbl_check: Bool, default True. Run Scailable checks. 101 | _verbose: Print user feedback; default True (note, errors are always printed). 102 | **kwargs 103 | 104 | Returns: 105 | True if the graph passes all the test. False otherwise. 106 | """ 107 | # Check if this is a valid graph: 108 | if type(graph) is not xpb2.GraphProto: 109 | _print("Graph is not a valid ONNX graph.") 110 | return False 111 | 112 | # Convert to model: 113 | try: 114 | if not 'opset_imports' in kwargs: 115 | op = onnx.OperatorSetIdProto() 116 | op.version = 12 117 | mod = xhelp.make_model(graph, producer_name=_producer, opset_imports=[op], **kwargs) 118 | else: 119 | mod = xhelp.make_model(graph, producer_name=_producer, **kwargs) 120 | except Exception as e: 121 | _print("Unable to create the model: " + str(e)) 122 | return False 123 | 124 | # Standard ONNX checking: 125 | if _onnx_check and False: 126 | try: 127 | checker.check_model(mod, **kwargs) 128 | except Exception as e: 129 | _print("Model fails on standard ONNX checker: " + str(e)) 130 | return False 131 | 132 | if _sclbl_check: 133 | 134 | # User feedback 135 | _print("Running Scailable specific checks for WASM conversion. \nUse _sclbl_check=False to turn off", "MSG", (not _verbose)) 136 | 137 | # input / output checking: 138 | if not graph.input: 139 | _print("This graph does not contain any inputs.") 140 | return False 141 | 142 | if not graph.output: 143 | _print("This graph does not contain any outputs.") 144 | return False 145 | 146 | # Sclbl checking: 147 | if not glob.ONNX_VERSION_INFO: 148 | if not _load_version_info(): 149 | _print("Unable to load the ONNX_VERSION INFO.") 150 | 151 | # Check general ONNX version: 152 | if version.parse(xversion) < version.parse(glob.ONNX_VERSION_INFO['onnx_version']['version_min']): 153 | _print("Your current onnx version is lower then our support minimum. Please update your ONNX to {}".format( 154 | glob.ONNX_VERSION_INFO['onnx_version']['version_min'])) 155 | return False 156 | 157 | if version.parse(xversion) > version.parse(glob.ONNX_VERSION_INFO['onnx_version']['version_max']): 158 | _print( 159 | "Your current onnx version is higher then our support max. Please downgrade your ONNX version to {}".format( 160 | glob.ONNX_VERSION_INFO['onnx_version']['version_max'])) 161 | return False 162 | 163 | if mod.ir_version < glob.ONNX_VERSION_INFO['onnx_version']['ir_version_min']: 164 | _print("Your current IR version is lower then our support minimum. Please update to {}".format( 165 | glob.ONNX_VERSION_INFO['onnx_version']['ir_version_min'])) 166 | return False 167 | 168 | if mod.ir_version > glob.ONNX_VERSION_INFO['onnx_version']['ir_version_max']: 169 | _print( 170 | "Your current IR version is higher then our support max. Please downgrade to {}".format( 171 | glob.ONNX_VERSION_INFO['onnx_version']['ir_version_max'])) 172 | return False 173 | 174 | # Interate through opset and check: 175 | for key in mod.opset_import: 176 | v = key.version 177 | if v < glob.ONNX_VERSION_INFO['onnx_version']['opset_min']: 178 | _print("One or more operators use an opset version that is too low. Please update to {}".format( 179 | glob.ONNX_VERSION_INFO['onnx_version']['opset_min'])) 180 | return False 181 | 182 | if v > glob.ONNX_VERSION_INFO['onnx_version']['opset_max']: 183 | _print( 184 | "One or more operators use an opset version that is too high. Please downgrade to {}".format( 185 | glob.ONNX_VERSION_INFO['onnx_version']['opset_max'])) 186 | return False 187 | 188 | # Check individual nodes: 189 | not_supported = [] 190 | for n in graph.node: 191 | op = n.op_type 192 | if op not in glob.ONNX_VERSION_INFO['operators']: 193 | not_supported.append(op) 194 | if not_supported: 195 | _print("The operator(s) {} are currently not supported.".format(not_supported)) 196 | return False 197 | 198 | # Check dynamic 199 | for inputs in graph.input: 200 | if not inputs.type.tensor_type.shape.dim: 201 | _print("Your graph contains dynamically sized inputs, this is currently not supported.") 202 | return False 203 | for elem in inputs.type.tensor_type.shape.dim: 204 | if elem.dim_value == 0 or elem.dim_value == "": 205 | _print("Your graph contains dynamically size inputs, this is currently not supported.") 206 | 207 | if not _sclbl_check and not _onnx_check: 208 | _print("Set _sclbl_check or _onnx_check to True to run any checks.") 209 | 210 | _print("Your graph was successfully checked.", "MSG", (not _verbose)) 211 | return True 212 | -------------------------------------------------------------------------------- /sclblonnx/version.py: -------------------------------------------------------------------------------- 1 | # Central place for version numbering: 2 | __version__ = "0.3.0" 3 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | # get long description 4 | with open("README.md", "r") as fh: 5 | long_description = fh.read() 6 | 7 | exec(open('sclblonnx/version.py').read()) 8 | 9 | setuptools.setup( 10 | name="sclblonnx", 11 | version=__version__, 12 | author="Maurits Kaptein", 13 | author_email="maurits.kaptein@scailable.net", 14 | description="Python package containing Scailable ONNX tools.", 15 | long_description=long_description, 16 | long_description_content_type="text/markdown", 17 | include_package_data=True, 18 | url="https://github.com/scailable/sclblonnx/", 19 | packages=setuptools.find_packages(), 20 | classifiers=[ 21 | "Programming Language :: Python :: 3", 22 | "License :: OSI Approved :: MIT License", 23 | "Operating System :: OS Independent", 24 | ], 25 | install_requires=[ 26 | 'numpy', 27 | 'onnxruntime', 28 | 'onnx>=1.7.0', 29 | 'requests', 30 | 'onnxoptimizer', 31 | 'onnxsim', 32 | 'packaging' 33 | ], 34 | python_requires='>=3.7', 35 | ) 36 | -------------------------------------------------------------------------------- /test/files/add.onnx: -------------------------------------------------------------------------------- 1 |  sclblonnx:i 2 | $ 3 | x1 4 | x2sumsclbl-onnx-node1"Add 5 | sclblgraphZ 6 | x1 7 | 8 |  9 | Z 10 | x2 11 | 12 |  13 | b 14 | sum 15 | 16 |  17 | B -------------------------------------------------------------------------------- /test/files/example01.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scailable/sclblonnx/6146f3b486833d4f094e73278891caafe2216baa/test/files/example01.onnx -------------------------------------------------------------------------------- /test/files/example02.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scailable/sclblonnx/6146f3b486833d4f094e73278891caafe2216baa/test/files/example02.onnx -------------------------------------------------------------------------------- /test/files/example03.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scailable/sclblonnx/6146f3b486833d4f094e73278891caafe2216baa/test/files/example03.onnx -------------------------------------------------------------------------------- /test/test_constant.py: -------------------------------------------------------------------------------- 1 | from sclblonnx import constant, empty_graph, node, add_node, add_constant, add_output, run, display, clean 2 | import numpy as np 3 | 4 | 5 | def test_constant(): 6 | c = constant("", np.array([1,2,]), "FLOAT") 7 | assert not c, "Constant creation should have failed without a name." 8 | c = constant("constant", np.array([1,2,]), "NONE") 9 | assert not c, "Constant creation should have failed without a valid data type." 10 | c = constant("constant", np.array([1,2,]), "FLOAT") 11 | check = getattr(c, "output", False) 12 | assert check[0] == "constant", "Constant creation should have worked." 13 | 14 | 15 | def test_add_constant(): 16 | 17 | # Simple add graph 18 | g = empty_graph() 19 | n1 = node('Add', inputs=['x1', 'x2'], outputs=['sum']) 20 | g = add_node(g, n1) 21 | 22 | # Add input and constant 23 | g = add_constant(g, 'x1', np.array([1]), "INT64") 24 | g = add_constant(g, 'x2', np.array([5]), "INT64") 25 | 26 | # Output: 27 | g = add_output(g, 'sum', "INT64", [1]) 28 | 29 | # This works, but seems to fail for other data types... 30 | result = run(g, inputs={}, outputs=["sum"]) 31 | assert result[0] == 6, "Add constant failed." 32 | # todo(McK): Does not work for INT16 / INT8, check? -------------------------------------------------------------------------------- /test/test_input.py: -------------------------------------------------------------------------------- 1 | from sclblonnx import empty_graph, list_inputs, add_input, rename_input, replace_input, delete_input 2 | 3 | 4 | def test_list_inputs(): 5 | g = empty_graph() 6 | assert list_inputs(g), "No inputs listed." 7 | g = add_input(g, "test", "FLOAT", [0]) 8 | list_inputs(g) 9 | assert not list_inputs(False), "List inputs should be false." 10 | 11 | 12 | def test_add_input(): 13 | g = empty_graph() 14 | g = add_input(g, "test", "FLOAT", [0]) 15 | name = getattr(g.input[0], "name", False) # get the first input name: 16 | assert name == "test", "'test' should be in list of inputs." 17 | 18 | 19 | def test_rename_input(): 20 | g = empty_graph() 21 | g = add_input(g, "test", "FLOAT", [0]) 22 | g = rename_input(g, "test", "new_name") 23 | name = getattr(g.input[0], "name", False) # get the first input name: 24 | assert name == "new_name", "New name should be in list of inputs." 25 | 26 | 27 | def test_replace_input(): 28 | g = empty_graph() 29 | g = add_input(g, "test", "FLOAT", [0]) 30 | g = replace_input(g, "test", "FLOAT", [10,10]) 31 | 32 | type = getattr(g.input[0], "type", False) # get the input type 33 | ttype = getattr(type, "tensor_type", False) 34 | shape = getattr(ttype, "shape", False) 35 | dim = getattr(shape, "dim", False) 36 | dim_val = getattr(dim[0], "dim_value", False) 37 | 38 | assert dim_val == 10, "New dimension should be 10" 39 | 40 | 41 | def test_delete_input(): 42 | g = empty_graph() 43 | g = add_input(g, "test", "FLOAT", [0]) 44 | g = delete_input(g, "test") 45 | assert len(g.input) == 0, "There should not be any inputs after delete." -------------------------------------------------------------------------------- /test/test_main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from onnx import onnx_ml_pb2 as xpb2 4 | from sclblonnx import empty_graph, graph_from_file, graph_to_file, run, list_data_types, list_operators, sclbl_input 5 | 6 | 7 | def test_empty_graph(): 8 | g = empty_graph() 9 | assert type(g) is xpb2.GraphProto, "Failed to create empty graph." 10 | 11 | 12 | def test_graph_from_file(): 13 | g = graph_from_file("files/non-existing-file.onnx") 14 | assert not g, "Graph from file failed to check emtpy file." 15 | g = graph_from_file("files/example01.onnx") 16 | assert type(g) is xpb2.GraphProto, "Graph from file failed to open file." 17 | 18 | 19 | def test_graph_to_file(): 20 | g = empty_graph() 21 | check1 = graph_to_file(g, "") 22 | assert not check1, "Graph to file failed should have failed." 23 | check2 = graph_to_file(g, "files/test_graph_to_file.onnx") 24 | assert check2, "Graph to file failed to write file." 25 | os.remove("files/test_graph_to_file.onnx") 26 | 27 | 28 | def test_run(): 29 | g = graph_from_file("files/add.onnx") 30 | example = {"x1": np.array([2]).astype(np.float32), "x2": np.array([5]).astype(np.float32)} 31 | result = run(g, 32 | inputs=example, 33 | outputs=["sum"] 34 | ) 35 | assert result[0] == 7, "Add output not correct." 36 | result = run(g, inputs="", outputs="sum") 37 | assert not result, "Model with this input should not run." 38 | 39 | 40 | def test_display(): 41 | from onnx import TensorProto 42 | print(TensorProto.DOUBLE) 43 | 44 | return True # No test for display 45 | 46 | 47 | def test_scblbl_input(): 48 | example = {"in": np.array([1,2,3,4]).astype(np.int32)} 49 | result = sclbl_input(example, _verbose=False) 50 | assert result == '{"input": "CAQQBkoQAQAAAAIAAAADAAAABAAAAA==", "type":"pb"}', "PB output not correct." 51 | 52 | example = {"x1": np.array([1,2,3,4]).astype(np.int32), "x2": np.array([1,2,3,4]).astype(np.int32)} 53 | result = sclbl_input(example, _verbose=False) 54 | assert result == '{"input": ["CAQQBkoQAQAAAAIAAAADAAAABAAAAA==","CAQQBkoQAQAAAAIAAAADAAAABAAAAA=="], "type":"pb"}',\ 55 | "PB output 2 not correct. " 56 | 57 | example = {"in": np.array([1,2,3,4]).astype(np.int32)} 58 | result = sclbl_input(example, "raw", _verbose=False) 59 | assert result == '{"input": "AQAAAAIAAAADAAAABAAAAA==", "type":"raw"}', "Raw output not correct." 60 | 61 | example = {"x1": np.array([1,2,3,4]).astype(np.int32), "x2": np.array([1,2,3,4]).astype(np.int32)} 62 | result = sclbl_input(example, "raw", _verbose=False) 63 | assert result == '{"input": ["AQAAAAIAAAADAAAABAAAAA==","AQAAAAIAAAADAAAABAAAAA=="], "type":"raw"}',\ 64 | "Raw output 2 not correct. " 65 | 66 | example = {"x1": np.array([1.2]).astype(np.float32), "x2": np.array([2.5]).astype(np.float32)} 67 | input = sclbl_input(example, _verbose=False) 68 | print(input) 69 | 70 | 71 | def test_list_data_types(): 72 | test = list_data_types() 73 | assert test, "Data types should be listed." 74 | 75 | 76 | def test_list_operators(): 77 | test = list_operators() 78 | assert test, "Operators should be listed." -------------------------------------------------------------------------------- /test/test_merge.py: -------------------------------------------------------------------------------- 1 | from sclblonnx import add_output, add_input, add_node, node, empty_graph, add_constant, run, merge, split, display, \ 2 | join, concat 3 | import numpy as np 4 | """ 5 | Some rudimentary tests of the functions in merge.py; should be extended. 6 | 7 | For example usage see example_merge.py in /examples. 8 | """ 9 | 10 | 11 | def test_merge(): 12 | """ 13 | Functional test of merge(). 14 | """ 15 | # Subgraph 1 16 | sg1 = empty_graph("Graph 1") 17 | n1 = node('Add', inputs=['x1', 'x2'], outputs=['sum']) 18 | sg1 = add_node(sg1, n1) 19 | sg1 = add_input(sg1, 'x1', "FLOAT", [1]) 20 | sg1 = add_input(sg1, 'x2', "FLOAT", [1]) 21 | sg1 = add_output(sg1, 'sum', "FLOAT", [1]) 22 | 23 | # Subgraph 2 24 | sg2 = empty_graph("Graph 2") 25 | sg2 = add_constant(sg2, "const", np.array([7]), "FLOAT") 26 | n2 = node("Equal", inputs=['sum', 'const'], outputs=['equal']) 27 | sg2 = add_node(sg2, n2) 28 | 29 | sg2 = add_input(sg2, 'sum', "FLOAT", [1]) 30 | sg2 = add_output(sg2, 'equal', "BOOL", [1]) 31 | 32 | g = merge(sg1, sg2, outputs=["sum"], inputs=["sum"]) 33 | 34 | data = {"x1": np.array([2]).astype(np.float32), "x2": np.array([5]).astype(np.float32)} 35 | result = run(g, inputs=data, outputs=["equal"]) 36 | assert result[0], "Sum of 2 and 5 should be equal to constant 7. Merged failed." 37 | 38 | 39 | def test_join(): 40 | """ 41 | Functional test for join: 42 | """ 43 | g1 = empty_graph("G1") 44 | n1 = node('Add', inputs=['x1_1', 'x1_2'], outputs=['sum_1'], name="n1") 45 | g1 = add_input(g1, 'x1_1', "FLOAT", [1]) 46 | g1 = add_input(g1, 'x1_2', "FLOAT", [1]) 47 | g1 = add_output(g1, 'sum_1', "FLOAT", [1]) 48 | g1 = add_node(g1, n1) 49 | 50 | g2 = empty_graph("G2") 51 | n2 = node('Add', inputs=['x2_1', 'x2_2'], outputs=['sum_2'], name="n2") 52 | g2 = add_input(g2, 'x2_1', "FLOAT", [1]) 53 | g2 = add_input(g2, 'x2_2', "FLOAT", [1]) 54 | g2 = add_output(g2, 'sum_2', "FLOAT", [1]) 55 | g2 = add_node(g2, n2) 56 | 57 | g3 = empty_graph("G3") 58 | n3 = node('Add', inputs=['x3_1', 'x3_2'], outputs=['sum_3'], name="n3") 59 | g3 = add_input(g3, 'x3_1', "FLOAT", [1]) 60 | g3 = add_input(g3, 'x3_2', "FLOAT", [1]) 61 | g3 = add_output(g3, 'sum_3', "FLOAT", [1]) 62 | g3 = add_node(g3, n3) 63 | 64 | g = join(g1, g2, g3, [("sum_1", "x3_1")], [("sum_2", "x3_2")]) 65 | 66 | data = { 67 | "x1_1": np.array([1]).astype(np.float32), 68 | "x1_2": np.array([2]).astype(np.float32), 69 | "x2_1": np.array([3]).astype(np.float32), 70 | "x2_2": np.array([4]).astype(np.float32), 71 | } 72 | result = run(g, inputs=data, outputs=["sum_3"]) 73 | assert result[0], "Sum of 1,2, 3, and 4 should be equal to constant 10. Join failed." 74 | 75 | 76 | def test_split(): 77 | """ 78 | Functional test for split 79 | """ 80 | g1 = empty_graph("G1") 81 | n1 = node('Add', inputs=['x1_1', 'x1_2'], outputs=['sum_1'], name="n1") 82 | g1 = add_input(g1, 'x1_1', "FLOAT", [1]) 83 | g1 = add_input(g1, 'x1_2', "FLOAT", [1]) 84 | g1 = add_output(g1, 'sum_1', "FLOAT", [1]) 85 | g1 = add_node(g1, n1) 86 | 87 | g2 = empty_graph("G2") 88 | n2 = node('Add', inputs=['x2_1', 'x2_2'], outputs=['sum_2'], name="n2") 89 | g2 = add_input(g2, 'x2_1', "FLOAT", [1]) 90 | g2 = add_input(g2, 'x2_2', "FLOAT", [1]) 91 | g2 = add_output(g2, 'sum_2', "FLOAT", [1]) 92 | g2 = add_node(g2, n2) 93 | 94 | g3 = empty_graph("G3") 95 | n3 = node('Add', inputs=['x3_1', 'x3_2'], outputs=['sum_3'], name="n3") 96 | g3 = add_input(g3, 'x3_1', "FLOAT", [1]) 97 | g3 = add_input(g3, 'x3_2', "FLOAT", [1]) 98 | g3 = add_output(g3, 'sum_3', "FLOAT", [1]) 99 | g3 = add_node(g3, n3) 100 | 101 | g = split(g1, g2, g3, cg1_match=[("sum_1", "x2_2")], cg2_match=[("sum_1", "x3_1")]) 102 | 103 | data = { 104 | "x1_1": np.array([1]).astype(np.float32), 105 | "x1_2": np.array([2]).astype(np.float32), 106 | "x2_1": np.array([3]).astype(np.float32), 107 | "x3_2": np.array([4]).astype(np.float32), 108 | } 109 | result = run(g, inputs=data, outputs=["sum_2", "sum_3"]) 110 | assert result[0], "Sum of 1,2, and 3 should be equal to constant 6. Split failed." 111 | 112 | 113 | def test_concat(): 114 | """ 115 | Functional test for concat 116 | """ 117 | g1 = empty_graph("G1") 118 | n1 = node('Add', inputs=['x1_1', 'x1_2'], outputs=['sum_1'], name="node_name") 119 | g1 = add_input(g1, 'x1_1', "FLOAT", [1]) 120 | g1 = add_input(g1, 'x1_2', "FLOAT", [1]) 121 | g1 = add_output(g1, 'sum_1', "FLOAT", [1]) 122 | g1 = add_node(g1, n1) 123 | 124 | g2 = empty_graph("G2") 125 | n2 = node('Add', inputs=['x2_1', 'x2_2'], outputs=['sum_2'], name="node_name") 126 | g2 = add_input(g2, 'x2_2', "FLOAT", [1]) 127 | g2 = add_output(g2, 'sum_2', "FLOAT", [1]) 128 | g2 = add_node(g2, n2) 129 | 130 | g = concat(g1, g2, False, True, edge_match=[("x1_2", "x2_1")]) 131 | 132 | data = { 133 | "x1_1": np.array([2]).astype(np.float32), 134 | "x1_2": np.array([5]).astype(np.float32), 135 | "x2_2": np.array([5]).astype(np.float32)} 136 | result = run(g, inputs=data, outputs=["sum_1", "sum_2"]) 137 | assert result[0], "Sum of 2 and 5 should be equal to constant 7. Concat failed." 138 | 139 | 140 | # Run tests, all passes: 141 | test_merge() 142 | test_join() 143 | test_split() 144 | test_concat() 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | -------------------------------------------------------------------------------- /test/test_node.py: -------------------------------------------------------------------------------- 1 | from onnx import onnx_ml_pb2 as xpb2 2 | from sclblonnx import node, empty_graph, add_node, delete_node 3 | 4 | 5 | def test_node(): 6 | n1 = node('Add', inputs=['x1', 'x2'], outputs=['sum']) 7 | assert type(n1) is xpb2.NodeProto, "Error creating node." 8 | n2 = node('Add', inputs=['x1', 'x2'], outputs=['sum'], name="node_name") 9 | name = getattr(n2, "name", False) 10 | assert name == "node_name", "Node name should be node_name." 11 | 12 | 13 | def test_add_node(): 14 | g = empty_graph() 15 | n = node('Add', inputs=['x1', 'x2'], outputs=['sum']) 16 | g = add_node(g, n) 17 | assert len(g.node) == 1, "Node not properly added." 18 | 19 | 20 | def test_add_nodes(): 21 | g = empty_graph() 22 | for i in range(10): 23 | n = node('Add', inputs=['x1', 'x2'], outputs=['sum']) 24 | g = add_node(g, n) 25 | assert len(g.node) == 10, "Nodes not properly added." 26 | check = add_node(False, True) 27 | assert not check, "Incorrectly able to add to non-graph." 28 | check = add_node(g, False) 29 | assert not check, "Incorrectly able to add a non-node." 30 | 31 | 32 | def test_delete_node(): 33 | g = empty_graph() 34 | n = node('Add', inputs=['x1', 'x2'], outputs=['sum'], name="node_name") 35 | g = add_node(g, n) 36 | assert len(g.node) == 1, "Node not properly added." 37 | g = delete_node(g, "node_name") 38 | assert len(g.node) == 0, "Node not properly deleted." 39 | -------------------------------------------------------------------------------- /test/test_output.py: -------------------------------------------------------------------------------- 1 | from sclblonnx import empty_graph, list_outputs, add_output, rename_output, replace_output, delete_output 2 | 3 | 4 | def test_list_outputs(): 5 | g = empty_graph() 6 | assert list_outputs(g), "No outputs listed." 7 | g = add_output(g, "test", "FLOAT", [0]) 8 | list_outputs(g) 9 | assert not list_outputs(False), "List outputs should be false." 10 | 11 | 12 | def test_add_output(): 13 | g = empty_graph() 14 | g = add_output(g, "test", "FLOAT", [0]) 15 | name = getattr(g.output[0], "name", False) # get the first output name: 16 | assert name == "test", "'test' should be in list of outputs." 17 | 18 | 19 | def test_rename_output(): 20 | g = empty_graph() 21 | g = add_output(g, "test", "FLOAT", [0]) 22 | g = rename_output(g, "test", "new_name") 23 | name = getattr(g.output[0], "name", False) # get the first output name: 24 | assert name == "new_name", "New name should be in list of outputs." 25 | 26 | 27 | def test_replace_output(): 28 | g = empty_graph() 29 | g = add_output(g, "test", "FLOAT", [0]) 30 | g = replace_output(g, "test", "FLOAT", [10,10]) 31 | 32 | type = getattr(g.output[0], "type", False) # get the output type 33 | ttype = getattr(type, "tensor_type", False) 34 | shape = getattr(ttype, "shape", False) 35 | dim = getattr(shape, "dim", False) 36 | dim_val = getattr(dim[0], "dim_value", False) 37 | 38 | assert dim_val == 10, "New dimension should be 10" 39 | 40 | 41 | def test_delete_output(): 42 | g = empty_graph() 43 | g = add_output(g, "test", "FLOAT", [0]) 44 | g = delete_output(g, "test") 45 | assert len(g.output) == 0, "There should not be any outputs after delete." -------------------------------------------------------------------------------- /test/test_utils.py: -------------------------------------------------------------------------------- 1 | from sclblonnx import empty_graph, add_output, add_input 2 | from sclblonnx.utils import _parse_element, _value, _input_details, _output_details, _print, _load_version_info, \ 3 | _data_type, _data_string 4 | from sclblonnx._globals import ONNX_VERSION_INFO 5 | 6 | def test__parse_element(): 7 | g = empty_graph() 8 | dims = [4, 3, 7] 9 | g = add_output(g, 'sum', "FLOAT", dims) 10 | for elem in g.output: 11 | _, elem_type, _ = _parse_element(elem) 12 | assert elem_type == "FLOAT", "Element not properly parsed." 13 | 14 | 15 | def test__value(): 16 | v = _value("test_value", "FLOAT", [1,2,3]) 17 | v_name = getattr(v, "name", False) 18 | assert v_name == "test_value", "Wrong value created." 19 | 20 | 21 | def test__input_details(): 22 | g = empty_graph() 23 | g = add_input(g, 'sum', "FLOAT", [4, 3, 7]) 24 | in_det = _input_details(g) 25 | assert in_det['sum']['data_type'] == "FLOAT", "Input details not correct." 26 | 27 | 28 | def test__output_details(): 29 | g = empty_graph() 30 | g = add_output(g, 'sum', "FLOAT", [4, 3, 7]) 31 | out_det = _output_details(g) 32 | assert out_det['sum']['data_type'] == "FLOAT", "Output details not correct." 33 | 34 | 35 | def test__print(): 36 | print("\n") 37 | _print("Red warning.") 38 | _print("Normal feedback", "MSG") 39 | _print("Green literal", "LIT") 40 | pass 41 | 42 | 43 | def test__load_version_info(): 44 | assert not ONNX_VERSION_INFO, "Should not be loaded." 45 | _load_version_info() 46 | assert not ONNX_VERSION_INFO, "Should be loaded." 47 | 48 | 49 | def test__data_type(): 50 | assert _data_type("FLOAT") == 1, "Float should be 1." 51 | assert not _data_type("BLA"), "Bla should not be a data type." 52 | 53 | 54 | def test__data_string(): 55 | assert _data_string(1) == "FLOAT", "Float should be 1." 56 | assert not _data_string(99), "99 should not be a data string." -------------------------------------------------------------------------------- /test/test_validate.py: -------------------------------------------------------------------------------- 1 | from sclblonnx import empty_graph, node, add_node, add_input, add_output, check, clean 2 | from onnx import onnx_ml_pb2 as xpb2 3 | 4 | 5 | def test_clean(): 6 | g = empty_graph() 7 | n1 = node('Add', inputs=['x1', 'x2'], outputs=['sum']) 8 | g = add_node(g, n1) 9 | g = add_input(g, 'x1', "FLOAT", [1]) 10 | g = add_input(g, 'x2', "FLOAT", [1]) 11 | g = add_output(g, 'sum', "FLOAT", [1]) 12 | g = clean(g) 13 | assert type(g) == xpb2.GraphProto, "Clean failed." 14 | 15 | 16 | def test_check(): 17 | 18 | # Invalid, no input/output: 19 | g = empty_graph() 20 | n1 = node('Add', inputs=['x1', 'x2'], outputs=['sum']) 21 | g = add_node(g, n1) 22 | assert not check(g), "Graph is not complete." 23 | 24 | # Valid graph 25 | g = add_input(g, 'x1', "FLOAT", [1]) 26 | g = add_input(g, 'x2', "FLOAT", [1]) 27 | g = add_output(g, 'sum', "FLOAT", [1]) 28 | assert check(g), "Graph should pass checks." 29 | 30 | # Invalid: None operator: 31 | g = empty_graph() 32 | n1 = node('None', inputs=['x1', 'x2'], outputs=['sum']) 33 | g = add_node(g, n1) 34 | g = add_input(g, 'x1', "FLOAT", [1]) 35 | g = add_input(g, 'x2', "FLOAT", [1]) 36 | g = add_output(g, 'sum', "FLOAT", [1]) 37 | assert not check(g), "Graph should not pass checks." 38 | 39 | # Invalid: Dynamic size input 40 | g = empty_graph() 41 | n1 = node('Add', inputs=['x1', 'x2'], outputs=['sum']) 42 | g = add_node(g, n1) 43 | g = add_input(g, 'x1', "FLOAT", []) 44 | g = add_input(g, 'x2', "FLOAT", [1]) 45 | g = add_output(g, 'sum', "FLOAT", [1]) 46 | assert not check(g), "Graph should not pass checks." 47 | 48 | check(g, _sclbl_check=False, _onnx_check=False) 49 | check(g, _onnx_check=False) # Operator check. --------------------------------------------------------------------------------