├── stubs ├── .gitkeep ├── numpy │ ├── random.pyi │ └── __init__.pyi └── click │ ├── globals.pyi │ ├── formatting.pyi │ ├── exceptions.pyi │ ├── utils.pyi │ ├── parser.pyi │ ├── termui.pyi │ ├── __init__.pyi │ ├── decorators.pyi │ ├── types.pyi │ └── core.pyi ├── setup.pyi ├── travis ├── after_failure.sh ├── after_success.sh ├── install.sh ├── setup.sh ├── script.sh └── before_install.sh ├── tests ├── test_data │ ├── test_bn │ │ ├── input.npy │ │ └── output.npy │ ├── test_conv │ │ ├── input.npy │ │ └── output.npy │ ├── test_gemm │ │ ├── input.npy │ │ └── output.npy │ ├── test_lrn │ │ ├── input.npy │ │ └── output.npy │ ├── test_concat │ │ ├── input.npy │ │ └── output.npy │ ├── test_gather │ │ ├── input.npy │ │ └── output.npy │ ├── test_shape │ │ ├── input.npy │ │ └── output.npy │ ├── test_squeeze │ │ ├── input.npy │ │ └── output.npy │ ├── test_constant │ │ ├── input.npy │ │ └── output.npy │ ├── test_unsqueeze │ │ ├── input.npy │ │ └── output.npy │ ├── test_avg_pool_1 │ │ ├── input.npy │ │ └── output.npy │ ├── test_avg_pool_2 │ │ ├── input.npy │ │ └── output.npy │ ├── test_max_pool_1 │ │ ├── input.npy │ │ └── output.npy │ ├── test_max_pool_2 │ │ ├── input.npy │ │ └── output.npy │ ├── test_conv_transpose │ │ ├── input.npy │ │ └── output.npy │ ├── test_pixel_shuffle │ │ ├── input.npy │ │ └── output.npy │ ├── test_conv_without_pads │ │ ├── input.npy │ │ └── output.npy │ ├── test_gemm_transB_off │ │ ├── input.npy │ │ └── output.npy │ ├── test_reshape_dynamic │ │ ├── input.npy │ │ └── output.npy │ ├── test_reshape_same_rank │ │ ├── input.npy │ │ └── output.npy │ ├── test_transpose_default │ │ ├── input.npy │ │ └── output.npy │ ├── test_transpose_permute │ │ ├── input.npy │ │ └── output.npy │ ├── test_slice_axis_0_rank_2 │ │ ├── input.npy │ │ ├── output.npy │ │ └── test_cd.npy │ ├── test_slice_axis_3_rank_4 │ │ ├── input.npy │ │ └── output.npy │ ├── test_split_axis_0_rank_3 │ │ ├── input.npy │ │ └── output.npy │ └── test_reshape_same_rank_infer_shape │ │ ├── input.npy │ │ └── output.npy ├── __init__.py ├── test_mlmodel_passes.py ├── onnx_backend_models_test.py ├── graph_test.py ├── convert_test.py ├── custom_layers_test.py ├── transformers_test.py ├── _test_utils.py └── operators_test.py ├── .gitmodules ├── examples ├── README.md └── utils.py ├── onnx_coreml ├── bin │ ├── __init__.py │ └── convert.py ├── __init__.py ├── _error_utils.py ├── graph_viz.py ├── _backend_rep.py ├── _backend.py └── _graph.py ├── .github └── ISSUE_TEMPLATE │ ├── -question---help.md │ ├── ---feature-request.md │ └── ---bug-report.md ├── .travis.yml ├── setup.cfg ├── LICENSE.txt ├── README.md ├── install.sh ├── install-develop.sh ├── .gitignore ├── setup.py └── contributing.md /stubs/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /setup.pyi: -------------------------------------------------------------------------------- 1 | # This file exists to make mypy ignore setup.py 2 | -------------------------------------------------------------------------------- /travis/after_failure.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | source "${0%/*}/setup.sh" 4 | -------------------------------------------------------------------------------- /travis/after_success.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | source "${0%/*}/setup.sh" 4 | -------------------------------------------------------------------------------- /tests/test_data/test_bn/input.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/onnx/onnx-coreml/HEAD/tests/test_data/test_bn/input.npy -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "third_party/onnx"] 2 | path = third_party/onnx 3 | url = https://github.com/onnx/onnx.git 4 | -------------------------------------------------------------------------------- /tests/test_data/test_bn/output.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/onnx/onnx-coreml/HEAD/tests/test_data/test_bn/output.npy -------------------------------------------------------------------------------- /tests/test_data/test_conv/input.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/onnx/onnx-coreml/HEAD/tests/test_data/test_conv/input.npy -------------------------------------------------------------------------------- /tests/test_data/test_gemm/input.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/onnx/onnx-coreml/HEAD/tests/test_data/test_gemm/input.npy -------------------------------------------------------------------------------- /tests/test_data/test_lrn/input.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/onnx/onnx-coreml/HEAD/tests/test_data/test_lrn/input.npy -------------------------------------------------------------------------------- /tests/test_data/test_lrn/output.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/onnx/onnx-coreml/HEAD/tests/test_data/test_lrn/output.npy -------------------------------------------------------------------------------- /tests/test_data/test_concat/input.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/onnx/onnx-coreml/HEAD/tests/test_data/test_concat/input.npy -------------------------------------------------------------------------------- /tests/test_data/test_concat/output.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/onnx/onnx-coreml/HEAD/tests/test_data/test_concat/output.npy -------------------------------------------------------------------------------- /tests/test_data/test_conv/output.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/onnx/onnx-coreml/HEAD/tests/test_data/test_conv/output.npy -------------------------------------------------------------------------------- /tests/test_data/test_gather/input.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/onnx/onnx-coreml/HEAD/tests/test_data/test_gather/input.npy -------------------------------------------------------------------------------- /tests/test_data/test_gather/output.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/onnx/onnx-coreml/HEAD/tests/test_data/test_gather/output.npy -------------------------------------------------------------------------------- /tests/test_data/test_gemm/output.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/onnx/onnx-coreml/HEAD/tests/test_data/test_gemm/output.npy -------------------------------------------------------------------------------- /tests/test_data/test_shape/input.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/onnx/onnx-coreml/HEAD/tests/test_data/test_shape/input.npy -------------------------------------------------------------------------------- /tests/test_data/test_shape/output.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/onnx/onnx-coreml/HEAD/tests/test_data/test_shape/output.npy -------------------------------------------------------------------------------- /tests/test_data/test_squeeze/input.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/onnx/onnx-coreml/HEAD/tests/test_data/test_squeeze/input.npy -------------------------------------------------------------------------------- /tests/test_data/test_constant/input.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/onnx/onnx-coreml/HEAD/tests/test_data/test_constant/input.npy -------------------------------------------------------------------------------- /tests/test_data/test_constant/output.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/onnx/onnx-coreml/HEAD/tests/test_data/test_constant/output.npy -------------------------------------------------------------------------------- /tests/test_data/test_squeeze/output.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/onnx/onnx-coreml/HEAD/tests/test_data/test_squeeze/output.npy -------------------------------------------------------------------------------- /tests/test_data/test_unsqueeze/input.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/onnx/onnx-coreml/HEAD/tests/test_data/test_unsqueeze/input.npy -------------------------------------------------------------------------------- /tests/test_data/test_avg_pool_1/input.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/onnx/onnx-coreml/HEAD/tests/test_data/test_avg_pool_1/input.npy -------------------------------------------------------------------------------- /tests/test_data/test_avg_pool_1/output.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/onnx/onnx-coreml/HEAD/tests/test_data/test_avg_pool_1/output.npy -------------------------------------------------------------------------------- /tests/test_data/test_avg_pool_2/input.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/onnx/onnx-coreml/HEAD/tests/test_data/test_avg_pool_2/input.npy -------------------------------------------------------------------------------- /tests/test_data/test_avg_pool_2/output.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/onnx/onnx-coreml/HEAD/tests/test_data/test_avg_pool_2/output.npy -------------------------------------------------------------------------------- /tests/test_data/test_max_pool_1/input.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/onnx/onnx-coreml/HEAD/tests/test_data/test_max_pool_1/input.npy -------------------------------------------------------------------------------- /tests/test_data/test_max_pool_1/output.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/onnx/onnx-coreml/HEAD/tests/test_data/test_max_pool_1/output.npy -------------------------------------------------------------------------------- /tests/test_data/test_max_pool_2/input.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/onnx/onnx-coreml/HEAD/tests/test_data/test_max_pool_2/input.npy -------------------------------------------------------------------------------- /tests/test_data/test_max_pool_2/output.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/onnx/onnx-coreml/HEAD/tests/test_data/test_max_pool_2/output.npy -------------------------------------------------------------------------------- /tests/test_data/test_unsqueeze/output.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/onnx/onnx-coreml/HEAD/tests/test_data/test_unsqueeze/output.npy -------------------------------------------------------------------------------- /tests/test_data/test_conv_transpose/input.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/onnx/onnx-coreml/HEAD/tests/test_data/test_conv_transpose/input.npy -------------------------------------------------------------------------------- /tests/test_data/test_pixel_shuffle/input.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/onnx/onnx-coreml/HEAD/tests/test_data/test_pixel_shuffle/input.npy -------------------------------------------------------------------------------- /tests/test_data/test_pixel_shuffle/output.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/onnx/onnx-coreml/HEAD/tests/test_data/test_pixel_shuffle/output.npy -------------------------------------------------------------------------------- /tests/test_data/test_conv_transpose/output.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/onnx/onnx-coreml/HEAD/tests/test_data/test_conv_transpose/output.npy -------------------------------------------------------------------------------- /tests/test_data/test_conv_without_pads/input.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/onnx/onnx-coreml/HEAD/tests/test_data/test_conv_without_pads/input.npy -------------------------------------------------------------------------------- /tests/test_data/test_gemm_transB_off/input.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/onnx/onnx-coreml/HEAD/tests/test_data/test_gemm_transB_off/input.npy -------------------------------------------------------------------------------- /tests/test_data/test_gemm_transB_off/output.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/onnx/onnx-coreml/HEAD/tests/test_data/test_gemm_transB_off/output.npy -------------------------------------------------------------------------------- /tests/test_data/test_reshape_dynamic/input.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/onnx/onnx-coreml/HEAD/tests/test_data/test_reshape_dynamic/input.npy -------------------------------------------------------------------------------- /tests/test_data/test_reshape_dynamic/output.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/onnx/onnx-coreml/HEAD/tests/test_data/test_reshape_dynamic/output.npy -------------------------------------------------------------------------------- /tests/test_data/test_reshape_same_rank/input.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/onnx/onnx-coreml/HEAD/tests/test_data/test_reshape_same_rank/input.npy -------------------------------------------------------------------------------- /tests/test_data/test_transpose_default/input.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/onnx/onnx-coreml/HEAD/tests/test_data/test_transpose_default/input.npy -------------------------------------------------------------------------------- /tests/test_data/test_transpose_permute/input.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/onnx/onnx-coreml/HEAD/tests/test_data/test_transpose_permute/input.npy -------------------------------------------------------------------------------- /tests/test_data/test_conv_without_pads/output.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/onnx/onnx-coreml/HEAD/tests/test_data/test_conv_without_pads/output.npy -------------------------------------------------------------------------------- /tests/test_data/test_reshape_same_rank/output.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/onnx/onnx-coreml/HEAD/tests/test_data/test_reshape_same_rank/output.npy -------------------------------------------------------------------------------- /tests/test_data/test_slice_axis_0_rank_2/input.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/onnx/onnx-coreml/HEAD/tests/test_data/test_slice_axis_0_rank_2/input.npy -------------------------------------------------------------------------------- /tests/test_data/test_slice_axis_3_rank_4/input.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/onnx/onnx-coreml/HEAD/tests/test_data/test_slice_axis_3_rank_4/input.npy -------------------------------------------------------------------------------- /tests/test_data/test_split_axis_0_rank_3/input.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/onnx/onnx-coreml/HEAD/tests/test_data/test_split_axis_0_rank_3/input.npy -------------------------------------------------------------------------------- /tests/test_data/test_transpose_default/output.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/onnx/onnx-coreml/HEAD/tests/test_data/test_transpose_default/output.npy -------------------------------------------------------------------------------- /tests/test_data/test_transpose_permute/output.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/onnx/onnx-coreml/HEAD/tests/test_data/test_transpose_permute/output.npy -------------------------------------------------------------------------------- /tests/test_data/test_slice_axis_0_rank_2/output.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/onnx/onnx-coreml/HEAD/tests/test_data/test_slice_axis_0_rank_2/output.npy -------------------------------------------------------------------------------- /tests/test_data/test_slice_axis_0_rank_2/test_cd.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/onnx/onnx-coreml/HEAD/tests/test_data/test_slice_axis_0_rank_2/test_cd.npy -------------------------------------------------------------------------------- /tests/test_data/test_slice_axis_3_rank_4/output.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/onnx/onnx-coreml/HEAD/tests/test_data/test_slice_axis_3_rank_4/output.npy -------------------------------------------------------------------------------- /tests/test_data/test_split_axis_0_rank_3/output.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/onnx/onnx-coreml/HEAD/tests/test_data/test_split_axis_0_rank_3/output.npy -------------------------------------------------------------------------------- /travis/install.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | source "${0%/*}/setup.sh" 4 | 5 | time CMAKE_ARGS='-DUSE_ATEN=OFF -DUSE_OPENMP=OFF' "$top_dir/install-develop.sh" 6 | -------------------------------------------------------------------------------- /tests/test_data/test_reshape_same_rank_infer_shape/input.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/onnx/onnx-coreml/HEAD/tests/test_data/test_reshape_same_rank_infer_shape/input.npy -------------------------------------------------------------------------------- /tests/test_data/test_reshape_same_rank_infer_shape/output.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/onnx/onnx-coreml/HEAD/tests/test_data/test_reshape_same_rank_infer_shape/output.npy -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | The examples have been moved from here to the coremltools repo ([link](https://github.com/apple/coremltools/tree/master/examples/neural_network_inference)) 2 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | -------------------------------------------------------------------------------- /onnx_coreml/bin/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | -------------------------------------------------------------------------------- /stubs/numpy/random.pyi: -------------------------------------------------------------------------------- 1 | # This stub is incomplete. It only contains the functions we use. 2 | 3 | from . import ndarray 4 | from typing import Optional, Union, Tuple 5 | 6 | 7 | def ranf(size: Optional[Union[int, Tuple[int, ...]]]) -> ndarray[float]: ... 8 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/-question---help.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: "❓Question / Help" 3 | about: Any issue that is not a bug/feature request 4 | title: '' 5 | labels: question 6 | assignees: '' 7 | 8 | --- 9 | 10 | ## ❓Question 11 | 12 | ## System Information 13 | - If applicable 14 | -------------------------------------------------------------------------------- /onnx_coreml/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | from .converter import convert 7 | 8 | # onnx-coreml version 9 | __version__ = '1.3' 10 | 11 | __all__ = ['convert'] 12 | -------------------------------------------------------------------------------- /stubs/click/globals.pyi: -------------------------------------------------------------------------------- 1 | from click.core import Context 2 | from typing import Optional 3 | 4 | 5 | def get_current_context(silent: bool = ...) -> Context: 6 | ... 7 | 8 | 9 | def push_context(ctx: Context) -> None: 10 | ... 11 | 12 | 13 | def pop_context() -> None: 14 | ... 15 | 16 | 17 | def resolve_color_default(color: Optional[bool] = ...) -> Optional[bool]: 18 | ... 19 | -------------------------------------------------------------------------------- /examples/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def _compute_SNR(x,y, message=''): 4 | noise = x - y 5 | noise_var = np.sum(noise ** 2)/len(noise) + 1e-7 6 | signal_energy = np.sum(y ** 2)/len(y) 7 | max_signal_energy = np.amax(y ** 2) 8 | SNR = 10 * np.log10(signal_energy/noise_var) 9 | PSNR = 10 * np.log10(max_signal_energy/noise_var) 10 | print('{} SNR: {} PSNR: {}'.format(message, SNR, PSNR)) 11 | -------------------------------------------------------------------------------- /travis/setup.sh: -------------------------------------------------------------------------------- 1 | set -exv 2 | 3 | export top_dir=$(dirname ${0%/*}) 4 | 5 | source "${HOME}/virtualenv/bin/activate" 6 | python --version 7 | 8 | # setup ccache 9 | if [ "$TRAVIS_OS_NAME" == "linux" ]; then 10 | export PATH="/usr/lib/ccache:$PATH" 11 | elif [ "$TRAVIS_OS_NAME" == "osx" ]; then 12 | export PATH="/usr/local/opt/ccache/libexec:$PATH" 13 | else 14 | echo Unknown OS: $TRAVIS_OS_NAME 15 | exit 1 16 | fi 17 | ccache --max-size 1G 18 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: generic 2 | 3 | matrix: 4 | include: 5 | - os: linux 6 | sudo: required 7 | env: PYTHON_VERSION=python3 8 | - os: osx 9 | osx_image: xcode11 10 | env: PYTHON_VERSION=python3 11 | 12 | env: 13 | global: 14 | - PB_VERSION=2.6.1 15 | 16 | before_install: 17 | - ./travis/before_install.sh 18 | 19 | install: 20 | - ./travis/install.sh 21 | 22 | script: 23 | - ./travis/script.sh 24 | 25 | after_success: 26 | - ./travis/after_success.sh 27 | 28 | after_failure: 29 | - ./travis/after_failure.sh 30 | 31 | cache: 32 | - timeout: 300 33 | - directories: 34 | - $HOME/.ccache 35 | - $HOME/.cache/pb 36 | -------------------------------------------------------------------------------- /travis/script.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | source "${0%/*}/setup.sh" 4 | 5 | # Mypy only works with Python 3 6 | # if [ "${PYTHON_VERSION}" != "python2" ]; then 7 | # time mypy . 8 | # # Also test in python2 mode (but this is still in the python 3 CI 9 | # # instance, because mypy itself needs python 3) 10 | # time mypy --py2 . 11 | # fi 12 | 13 | if [[ $TRAVIS_OS_NAME == 'osx' ]]; then 14 | time python setup.py test 15 | else 16 | # Test cases that need to run CoreML models won't work on Linux, 17 | # only run test cases that don't need it. 18 | time python setup.py test --addopts tests/graph_test.py 19 | time python setup.py test --addopts tests/custom_layers_test.py 20 | fi 21 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/---feature-request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: "\U0001F331 Feature request" 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: feature request 6 | assignees: '' 7 | 8 | --- 9 | 10 | ## 🌱 Describe your Feature Request 11 | - A clear and concise description of what the problem is. 12 | - CoreML / iOS version you are using? 13 | - Are you interested in contributing? 14 | 15 | ## Use cases 16 | - Please describe the use cases 17 | - Please provide examples 18 | 19 | ## Describe alternatives you've considered 20 | A clear and concise description of any alternative solutions or features you've considered. 21 | 22 | ## Additional context 23 | Add any other context or screenshots about the feature request here. 24 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | description-file = README.md 3 | 4 | [aliases] 5 | test=pytest 6 | 7 | [tool:pytest] 8 | addopts = --cov=onnx_coreml --cov-report term-missing 9 | testpaths = tests 10 | 11 | [mypy] 12 | follow-imports = silent # TODO Remove this 13 | mypy_path = stubs:third_party/onnx 14 | strict_optional = True 15 | warn_return_any = True 16 | warn_no_return = True 17 | warn_unused_ignores = True 18 | warn_redundant_casts = True 19 | warn_incomplete_stub = True 20 | disallow_untyped_calls = True 21 | check_untyped_defs = True 22 | disallow_any_generics = True 23 | no_implicit_optional = True 24 | # TODO Add disallow_untyped_defs = True 25 | # TODO Add disallow_incomplete_defs = True 26 | # TODO Add disallow_subclassing_any = True 27 | disallow_untyped_decorators = True 28 | warn_unused_configs = True 29 | [mypy-onnx.*] 30 | ignore_errors = True 31 | ignore-missing-imports = True 32 | -------------------------------------------------------------------------------- /onnx_coreml/bin/convert.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | import click 7 | from onnx import onnx_pb 8 | from onnx_coreml import convert 9 | from typing import Text, IO 10 | 11 | 12 | @click.command( 13 | help='convert ONNX model to CoreML model', 14 | context_settings={ 15 | 'help_option_names': ['-h', '--help'] 16 | } 17 | ) 18 | @click.argument('onnx_model', type=click.File('rb')) 19 | @click.option('-o', '--output', required=True, 20 | type=str, 21 | help='Output path for the CoreML *.mlmodel file') 22 | def onnx_to_coreml(onnx_model, output): # type: (IO[str], str) -> None 23 | onnx_model_proto = onnx_pb.ModelProto() 24 | onnx_model_proto.ParseFromString(onnx_model.read()) 25 | coreml_model = convert(onnx_model_proto) 26 | coreml_model.save(output) 27 | 28 | 29 | if __name__ == '__main__': 30 | onnx_to_coreml() 31 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Copyright 2017 Prisma Labs, Inc. 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 8 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/---bug-report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: "\U0001F41E Bug report" 3 | about: Submit a bug report 4 | title: '' 5 | labels: bug 6 | assignees: '' 7 | 8 | --- 9 | 10 | ## 🐞Describe the bug 11 | A clear and brief description of what the bug is. 12 | 13 | ## Trace 14 | If applicable, please paste the error trace. 15 | 16 | ## To Reproduce 17 | - If a python script can reproduce the error, please paste the code snippet 18 | ``` 19 | from onnx_coreml import convert 20 | # Paste code snippet here 21 | ``` 22 | - If applicable, please attach ONNX model 23 | - If model cannot be shared publicly, please attach it via filing a bug report at https://developer.apple.com/bug-reporting/ 24 | - If model conversion succeeds, however, there is numerical mismatch between the original and the coreml model, please paste python script used for comparison (pytorch code, onnx runtime code etc.) 25 | 26 | ## System environment (please complete the following information): 27 | - coremltools version (e.g., 3.0b5): 28 | - onnx-coreml version (e.g. 1.0b2): 29 | - OS (e.g., MacOS, Linux): 30 | - macOS version (if applicable): 31 | - How you install python (anaconda, virtualenv, system): 32 | - python version (e.g. 3.7): 33 | - any other relevant information: 34 | 35 | ## Additional context 36 | Add any other context about the problem here. 37 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Convert ONNX models into Apple Core ML format. 2 | 3 | `onnx-coreml` package is **no longer supported** and will **not be maintained**. 4 | 5 | ## PyTorch Models 6 | 7 | For converting PyTorch models to CoreML format, the recommended approach is to use **new** PyTorch to Core ML converter, introduced in the [`coremltools 4.0`](https://github.com/apple/coremltools) python package. 8 | Please read the coremltools documentation on [PyTorch conversion](https://coremltools.readme.io/docs/pytorch-conversion) for example usage. 9 | 10 | ## ONNX Models 11 | 12 | Code for ONNX to Core ML conversion is now available through `coremltools` python package and `coremltools.converters.onnx.convert` is the only supported API for conversion. To read more about exporting ONNX models to Core ML format, please visit coremltools documentation on [ONNX conversion.](https://coremltools.readme.io/docs/onnx-conversion) 13 | 14 | Note: ONNX converter is not under any active feature development. For access to bug fixes, community support and requests, please use [coremltools](https://github.com/apple/coremltools) github repository. 15 | 16 | ## Installation 17 | To install coremltools package, please follow [these instructions](https://coremltools.readme.io/docs/installation) in the coremltools documentation. 18 | 19 | ## License 20 | Copyright © 2018 by Apple Inc., Facebook Inc., and Prisma Labs Inc. 21 | 22 | Use of this source code is governed by the [MIT License](https://opensource.org/licenses/MIT) that can be found in the LICENSE.txt file. 23 | -------------------------------------------------------------------------------- /install.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -ex 4 | 5 | # realpath might not be available on MacOS 6 | script_path=$(python -c "import os; import sys; print(os.path.realpath(sys.argv[1]))" "${BASH_SOURCE[0]}") 7 | top_dir=$(dirname "$script_path") 8 | REPOS_DIR="$top_dir/third_party" 9 | BUILD_DIR="$top_dir/build" 10 | 11 | _check_submodule_present() { 12 | if [ ! -f "$REPOS_DIR/$@/setup.py" ]; then 13 | echo Didn\'t find $@ submodule. Please run: git submodule update --recursive --init 14 | exit 1 15 | fi 16 | } 17 | 18 | _check_submodule_present onnx 19 | 20 | _check_compilers_use_ccache() { 21 | COMPILERS_WITHOUT_CCACHE="" 22 | for compiler in gcc g++ cc c++; do 23 | if ! readlink $(which $compiler) | grep ccache; then 24 | COMPILERS_WITHOUT_CCACHE="$COMPILERS_WITHOUT_CCACHE $compiler" 25 | fi 26 | done 27 | 28 | if [ "$COMPILERS_WITHOUT_CCACHE" != "" ]; then 29 | echo Warning: Compilers not set up for ccache: $COMPILERS_WITHOUT_CCACHE. Incremental builds will be slow. 30 | fi 31 | } 32 | _check_compilers_use_ccache 33 | 34 | 35 | mkdir -p "$BUILD_DIR" 36 | 37 | _pip_install() { 38 | if [[ -n "$CI" ]]; then 39 | ccache -z 40 | fi 41 | if [[ -n "$CI" ]]; then 42 | time pip install "$@" 43 | else 44 | pip install "$@" 45 | fi 46 | if [[ -n "$CI" ]]; then 47 | ccache -s 48 | fi 49 | } 50 | 51 | # Install onnx 52 | _pip_install -b "$BUILD_DIR/onnx" "file://$REPOS_DIR/onnx#egg=onnx" 53 | 54 | # Install onnx-coreml 55 | _pip_install . 56 | -------------------------------------------------------------------------------- /tests/test_mlmodel_passes.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import unittest 3 | import coremltools.models.datatypes as datatypes 4 | from coremltools.models import neural_network as neural_network 5 | from coremltools.converters.nnssa.coreml.graph_pass.mlmodel_passes import remove_disconnected_layers 6 | 7 | 8 | class MLModelPassesTest(unittest.TestCase): 9 | 10 | def test_load_constant_remove(self): 11 | input_features = [('data', datatypes.Array(*(3, 4)))] 12 | output_features = [('out', None)] 13 | builder = neural_network.NeuralNetworkBuilder(input_features, output_features, disable_rank5_shape_mapping=True) 14 | builder.add_activation('relu1', 'RELU', 'data', 'relu1') 15 | builder.add_load_constant_nd('const1', 'c1', constant_value=np.ones((5,)), shape=(5,)) 16 | builder.add_activation('relu2', 'RELU', 'relu1', 'out') 17 | builder.add_load_constant_nd('const2', 'c2', constant_value=np.ones((5,)), shape=(5,)) 18 | builder.add_load_constant_nd('const3', 'c3', constant_value=np.ones((5,)), shape=(5,)) 19 | spec = builder.spec 20 | np.testing.assert_equal(5, len(spec.neuralNetwork.layers)) 21 | remove_disconnected_layers(spec) 22 | np.testing.assert_equal(2, len(spec.neuralNetwork.layers)) 23 | 24 | 25 | if __name__ == '__main__': 26 | RUN_ALL_TESTS = True 27 | if RUN_ALL_TESTS: 28 | unittest.main() 29 | else: 30 | suite = unittest.TestSuite() 31 | suite.addTest(MLModelPassesTest('test_load_constant_remove')) 32 | unittest.TextTestRunner().run(suite) 33 | -------------------------------------------------------------------------------- /install-develop.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -ex 4 | 5 | # realpath might not be available on MacOS 6 | script_path=$(python -c "import os; import sys; print(os.path.realpath(sys.argv[1]))" "${BASH_SOURCE[0]}") 7 | top_dir=$(dirname "$script_path") 8 | REPOS_DIR="$top_dir/third_party" 9 | BUILD_DIR="$top_dir/build" 10 | 11 | _check_submodule_present() { 12 | if [ ! -f "$REPOS_DIR/$@/setup.py" ]; then 13 | echo Didn\'t find $@ submodule. Please run: git submodule update --recursive --init 14 | exit 1 15 | fi 16 | } 17 | 18 | _check_submodule_present onnx 19 | 20 | _check_compilers_use_ccache() { 21 | COMPILERS_WITHOUT_CCACHE="" 22 | for compiler in gcc g++ cc c++; do 23 | if ! readlink $(which $compiler) | grep ccache; then 24 | COMPILERS_WITHOUT_CCACHE="$COMPILERS_WITHOUT_CCACHE $compiler" 25 | fi 26 | done 27 | 28 | if [ "$COMPILERS_WITHOUT_CCACHE" != "" ]; then 29 | echo Warning: Compilers not set up for ccache: $COMPILERS_WITHOUT_CCACHE. Incremental builds will be slow. 30 | read -p "Press enter to continue" 31 | fi 32 | } 33 | _check_compilers_use_ccache 34 | 35 | 36 | mkdir -p "$BUILD_DIR" 37 | 38 | _pip_install() { 39 | if [[ -n "$CI" ]]; then 40 | ccache -z 41 | fi 42 | if [[ -n "$CI" ]]; then 43 | time pip install "$@" 44 | else 45 | pip install "$@" 46 | fi 47 | if [[ -n "$CI" ]]; then 48 | ccache -s 49 | fi 50 | } 51 | 52 | # Install onnx 53 | # _pip_install -e "$REPOS_DIR/onnx" 54 | 55 | cd "$REPOS_DIR/onnx" 56 | python setup.py install 57 | cd - 58 | 59 | # Install onnx-coreml 60 | _pip_install -e .[mypy] 61 | -------------------------------------------------------------------------------- /tests/onnx_backend_models_test.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | import unittest 7 | 8 | import onnx 9 | 10 | import onnx.backend.test 11 | 12 | from onnx_coreml._backend import CoreMLBackend, CoreMLBackendND 13 | from onnx_coreml.converter import SupportedVersion 14 | 15 | from coremltools.models.utils import macos_version 16 | 17 | # Default target iOS 18 | MINIMUM_IOS_DEPLOYMENT_TARGET = '13' 19 | 20 | MIN_MACOS_VERSION_10_15 = (10, 15) 21 | # If MACOS version is less than 10.15 22 | # Then force testing on CoreML 2.0 23 | if macos_version() < MIN_MACOS_VERSION_10_15: 24 | MINIMUM_IOS_DEPLOYMENT_TARGET = '12' 25 | 26 | if not SupportedVersion.ios_support_check(MINIMUM_IOS_DEPLOYMENT_TARGET): 27 | raise ValueError( 28 | "Invalid Target iOS version provided. Valid target iOS: {}".format(supported_ios_version) 29 | ) 30 | 31 | # import all test cases at global scope to make them visible to python.unittest 32 | backend_test = onnx.backend.test.BackendTest(CoreMLBackendND if SupportedVersion.is_nd_array_supported(MINIMUM_IOS_DEPLOYMENT_TARGET) else CoreMLBackend, __name__) 33 | 34 | # Only include the big models tests 35 | backend_test.include('test_resnet50') 36 | backend_test.include('test_inception_v1') 37 | backend_test.include('test_inception_v2') 38 | backend_test.include('test_densenet121') 39 | backend_test.include('test_shufflenet') 40 | backend_test.include('test_squeezenet') 41 | backend_test.include('test_bvlc_alexnet') 42 | backend_test.include('test_zfnet512') 43 | backend_test.include('test_vgg19') 44 | 45 | globals().update(backend_test 46 | .enable_report() 47 | .test_cases) 48 | 49 | 50 | if __name__ == '__main__': 51 | unittest.main() 52 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *.cover 46 | .hypothesis/ 47 | 48 | # Translations 49 | *.mo 50 | *.pot 51 | 52 | # Django stuff: 53 | *.log 54 | local_settings.py 55 | 56 | # Flask stuff: 57 | instance/ 58 | .webassets-cache 59 | 60 | # Scrapy stuff: 61 | .scrapy 62 | 63 | # Sphinx documentation 64 | docs/_build/ 65 | 66 | # PyBuilder 67 | target/ 68 | 69 | # Jupyter Notebook 70 | .ipynb_checkpoints 71 | 72 | # pyenv 73 | .python-version 74 | 75 | # celery beat schedule file 76 | celerybeat-schedule 77 | 78 | # SageMath parsed files 79 | *.sage.py 80 | 81 | # Environments 82 | .env 83 | .venv 84 | env/ 85 | venv/ 86 | ENV/ 87 | 88 | # Spyder project settings 89 | .spyderproject 90 | .spyproject 91 | 92 | # Rope project settings 93 | .ropeproject 94 | 95 | # mkdocs documentation 96 | /site 97 | 98 | # mypy 99 | .mypy_cache/ 100 | 101 | # onnx models 102 | *.onnx 103 | *.pb 104 | 105 | # test data 106 | *.npz 107 | 108 | dist/ 109 | build/ 110 | example/fast-neural-style/fast-neural-style/ 111 | 112 | .DS_Store 113 | 114 | # CoreML models 115 | *.mlmodel 116 | 117 | virtualenv 118 | .pytest_cache 119 | 120 | .idea/ 121 | -------------------------------------------------------------------------------- /travis/before_install.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Don't source setup.sh here, because the virtualenv might not be set up yet 4 | 5 | set -ex 6 | 7 | export NUMCORES=`grep -c ^processor /proc/cpuinfo` 8 | if [ ! -n "$NUMCORES" ]; then 9 | export NUMCORES=`sysctl -n hw.ncpu` 10 | fi 11 | echo Using $NUMCORES cores 12 | 13 | # Install dependencies 14 | if [ "$TRAVIS_OS_NAME" == "linux" ]; then 15 | # Install protobuf 16 | pb_dir="~/.cache/pb" 17 | mkdir -p "$pb_dir" 18 | wget -qO- "https://github.com/google/protobuf/releases/download/v${PB_VERSION}/protobuf-${PB_VERSION}.tar.gz" | tar -xz -C "$pb_dir" --strip-components 1 19 | ccache -z 20 | cd "$pb_dir" && ./configure && make -j${NUMCORES} && make check && sudo make install && sudo ldconfig 21 | ccache -s 22 | 23 | # Setup Python. 24 | if [ "${PYTHON_VERSION}" == "python3" ]; then 25 | export PYTHON_DIR="$(ls -d /opt/python/3.7.1/bin)" 26 | else 27 | echo Unknown Python Version: ${PYTHON_VERSION} 28 | exit 1 29 | fi 30 | elif [ "$TRAVIS_OS_NAME" == "osx" ]; then 31 | # Setup Python. 32 | export PYTHON_DIR="/usr/local/bin" 33 | brew unlink python 34 | brew install openssl 35 | if [ "${PYTHON_VERSION}" == "python3" ]; then 36 | brew install ccache 37 | brew unlink python 38 | brew install --ignore-dependencies https://raw.githubusercontent.com/Homebrew/homebrew-core/f2a764ef944b1080be64bd88dca9a1d80130c558/Formula/python.rb 39 | else 40 | echo Unknown Python Version: ${PYTHON_VERSION} 41 | exit 1 42 | fi 43 | else 44 | echo Unknown OS: $TRAVIS_OS_NAME 45 | exit 1 46 | fi 47 | 48 | pip install virtualenv 49 | virtualenv -p "${PYTHON_DIR}/${PYTHON_VERSION}" "${HOME}/virtualenv" 50 | source "${HOME}/virtualenv/bin/activate" 51 | python --version 52 | 53 | # Update all existing python packages 54 | for package in $(pip list --outdated --format=freeze | grep -v '^\-e' | cut -d = -f 1); do 55 | pip install -U $package 56 | done 57 | 58 | if [[ $TRAVIS_OS_NAME == 'osx' ]]; then 59 | pip install torch 60 | fi 61 | -------------------------------------------------------------------------------- /stubs/click/formatting.pyi: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | from typing import Generator, Iterable, List, Optional, Text, Tuple 3 | 4 | 5 | FORCED_WIDTH: Optional[int] 6 | 7 | 8 | def measure_table(rows: Iterable[Iterable[Text]]) -> Tuple[int, ...]: 9 | ... 10 | 11 | 12 | def iter_rows( 13 | rows: Iterable[Iterable[Text]], col_count: int 14 | ) -> Generator[Tuple[Text, ...], None, None]: 15 | ... 16 | 17 | 18 | def wrap_text( 19 | text: Text, 20 | width: int = ..., 21 | initial_indent: Text = ..., 22 | subsequent_indent: Text = ..., 23 | preserve_paragraphs: bool = ... 24 | ) -> Text: 25 | ... 26 | 27 | 28 | class HelpFormatter: 29 | indent_increment: int 30 | width: Optional[int] 31 | current_indent: int 32 | buffer: List[Text] 33 | 34 | def __init__( 35 | self, 36 | indent_increment: int = ..., 37 | width: Optional[int] = ..., 38 | max_width: Optional[int] = ..., 39 | ) -> None: 40 | ... 41 | 42 | def write(self, string: Text) -> None: 43 | ... 44 | 45 | def indent(self) -> None: 46 | ... 47 | 48 | def dedent(self) -> None: 49 | ... 50 | 51 | def write_usage( 52 | self, 53 | prog: Text, 54 | args: Text = ..., 55 | prefix: Text = ..., 56 | ): 57 | ... 58 | 59 | def write_heading(self, heading: Text) -> None: 60 | ... 61 | 62 | def write_paragraph(self) -> None: 63 | ... 64 | 65 | def write_text(self, text: Text) -> None: 66 | ... 67 | 68 | def write_dl( 69 | self, 70 | rows: Iterable[Iterable[Text]], 71 | col_max: int = ..., 72 | col_spacing: int = ..., 73 | ) -> None: 74 | ... 75 | 76 | @contextmanager 77 | def section(self, name) -> Generator[None, None, None]: 78 | ... 79 | 80 | @contextmanager 81 | def indentation(self) -> Generator[None, None, None]: 82 | ... 83 | 84 | def getvalue(self) -> Text: 85 | ... 86 | 87 | 88 | def join_options(options: List[Text]) -> Tuple[Text, bool]: 89 | ... 90 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | from os import path 3 | import sys 4 | from onnx_coreml import __version__ 5 | 6 | VERSION = __version__ 7 | 8 | here = path.abspath(path.dirname(__file__)) 9 | 10 | try: 11 | import pypandoc 12 | long_description = pypandoc.convert('README.md', 'rst') 13 | except(IOError, ImportError): 14 | with open('README.md', "rb") as f: 15 | long_description = f.read().decode("UTF-8") 16 | 17 | 18 | if sys.version_info[0] == 2: 19 | # Mypy doesn't work with Python 2 20 | mypy = [] 21 | elif sys.version_info[0] == 3: 22 | mypy = ['mypy==0.560'] 23 | 24 | 25 | setup( 26 | name='onnx-coreml', 27 | version=VERSION, 28 | packages=find_packages(exclude=['contrib', 'docs', 'test', 'example']), 29 | description='Convert ONNX (Open Neural Network Exchange)' 30 | 'models into Apple CoreML format.', 31 | long_description=long_description, 32 | long_description_content_type='text/markdown', 33 | url='https://github.com/onnx/onnx-coreml/', 34 | author='ONNX-CoreML Team', 35 | author_email='onnx-coreml@apple.com', 36 | license='MIT', 37 | classifiers=[ 38 | 'Development Status :: 5 - Production/Stable', 39 | 'License :: OSI Approved :: MIT License', 40 | 'Intended Audience :: Developers', 41 | 'Intended Audience :: End Users/Desktop', 42 | 'Operating System :: MacOS :: MacOS X', 43 | 'Programming Language :: Python', 44 | 'Topic :: Scientific/Engineering', 45 | 'Topic :: Software Development' 46 | ], 47 | keywords='onnx coreml machinelearning ml coremltools converter neural', 48 | install_requires=[ 49 | 'click', 50 | 'numpy', 51 | 'sympy', 52 | 'onnx>=1.5.0', 53 | 'typing>=3.6.4', 54 | 'typing-extensions>=3.6.2.1', 55 | 'coremltools>=3.2', 56 | ], 57 | setup_requires=['pytest-runner'], 58 | tests_require=[ 59 | 'pytest', 60 | 'pytest-cov', 61 | 'Pillow' 62 | ], 63 | extras_require={ 64 | 'mypy': mypy, 65 | }, 66 | entry_points={ 67 | 'console_scripts': [ 68 | 'convert-onnx-to-coreml = onnx_coreml.bin.convert:onnx_to_coreml' 69 | ] 70 | }, 71 | ) 72 | -------------------------------------------------------------------------------- /stubs/click/exceptions.pyi: -------------------------------------------------------------------------------- 1 | from typing import IO, List, Text, Optional, Any 2 | 3 | from click.core import Context, Parameter 4 | 5 | 6 | class ClickException(Exception): 7 | exit_code: int 8 | message: str 9 | 10 | def __init__(self, message: Text) -> None: 11 | ... 12 | 13 | def format_message(self) -> Text: 14 | ... 15 | 16 | def show(self, file=None) -> None: 17 | ... 18 | 19 | 20 | class UsageError(ClickException): 21 | ctx: Optional[Context] 22 | 23 | def __init__(self, message: Text, ctx: Optional[Context] = ...) -> None: 24 | ... 25 | 26 | def show(self, file: Optional[IO[Any]] = ...) -> None: 27 | ... 28 | 29 | 30 | class BadParameter(UsageError): 31 | param: Optional[Parameter] 32 | param_hint: Optional[Text] 33 | 34 | def __init__( 35 | self, 36 | message: Text, 37 | ctx: Optional[Context] = ..., 38 | param: Optional[Parameter] = ..., 39 | param_hint: Optional[Text] = ... 40 | ) -> None: 41 | ... 42 | 43 | 44 | class MissingParameter(BadParameter): 45 | param_type: Text # valid values: 'parameter', 'option', 'argument' 46 | 47 | def __init__( 48 | self, 49 | message: Optional[Text] = ..., 50 | ctx: Optional[Context] = ..., 51 | param: Optional[Parameter] = ..., 52 | param_hint: Optional[Text] = ..., 53 | param_type: Optional[Text] = ... 54 | ) -> None: 55 | ... 56 | 57 | 58 | class NoSuchOption(UsageError): 59 | option_name: Text 60 | possibilities: Optional[List[Text]] 61 | 62 | def __init__( 63 | self, 64 | option_name: Text, 65 | message: Optional[Text] = ..., 66 | possibilities: Optional[List[Text]] = ..., 67 | ctx: Optional[Context] = ... 68 | ) -> None: 69 | ... 70 | 71 | 72 | class BadOptionUsage(UsageError): 73 | def __init__(self, message: Text, ctx: Optional[Context] = ...) -> None: 74 | ... 75 | 76 | 77 | class BadArgumentUsage(UsageError): 78 | def __init__(self, message: Text, ctx: Optional[Context] = ...) -> None: 79 | ... 80 | 81 | 82 | class FileError(ClickException): 83 | ui_filename: Text 84 | filename: Text 85 | 86 | def __init__(self, filename: Text, hint: Optional[Text] = ...) -> None: 87 | ... 88 | 89 | 90 | class Abort(RuntimeError): 91 | ... 92 | -------------------------------------------------------------------------------- /stubs/click/utils.pyi: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Iterator, IO, List, Optional, Text, TypeVar, Union 2 | 3 | _T = TypeVar('_T') 4 | _Decorator = Callable[[_T], _T] 5 | 6 | 7 | def _posixify(name: Text) -> Text: 8 | ... 9 | 10 | 11 | def safecall(func: _T) -> _T: 12 | ... 13 | 14 | 15 | def make_str(value: Any) -> Text: 16 | ... 17 | 18 | 19 | def make_default_short_help(help: Text, max_length: int = ...): 20 | ... 21 | 22 | 23 | class LazyFile: 24 | name: Text 25 | mode: Text 26 | encoding: Optional[Text] 27 | errors: Text 28 | atomic: bool 29 | 30 | def __init__( 31 | self, 32 | filename: Text, 33 | mode: Text = ..., 34 | encoding: Optional[Text] = ..., 35 | errors: Text = ..., 36 | atomic: bool = ... 37 | ) -> None: 38 | ... 39 | 40 | def open(self) -> IO[Any]: 41 | ... 42 | 43 | def close(self) -> None: 44 | ... 45 | 46 | def close_intelligently(self) -> None: 47 | ... 48 | 49 | def __enter__(self) -> 'LazyFile': 50 | ... 51 | 52 | def __exit__(self, exc_type, exc_value, tb): 53 | ... 54 | 55 | def __iter__(self) -> Iterator[Any]: 56 | ... 57 | 58 | 59 | class KeepOpenFile: 60 | _file: IO[Any] 61 | 62 | def __init__(self, file: IO[Any]) -> None: 63 | ... 64 | 65 | def __enter__(self) -> 'KeepOpenFile': 66 | ... 67 | 68 | def __exit__(self, exc_type, exc_value, tb): 69 | ... 70 | 71 | def __iter__(self) -> Iterator[Any]: 72 | ... 73 | 74 | 75 | def echo( 76 | message: Optional[Union[bytes, Text]] = ..., 77 | file: Optional[IO[Any]] = ..., 78 | nl: bool = ..., 79 | err: bool = ..., 80 | color: Optional[bool] = ..., 81 | ) -> None: 82 | ... 83 | 84 | 85 | def get_binary_stream(name: Text) -> IO[bytes]: 86 | ... 87 | 88 | 89 | def get_text_stream( 90 | name: Text, encoding: Optional[Text] = ..., errors: Text = ... 91 | ) -> IO[Text]: 92 | ... 93 | 94 | 95 | def open_file( 96 | filename: Text, 97 | mode: Text = ..., 98 | encoding: Optional[Text] = ..., 99 | errors: Text = ..., 100 | lazy: bool = ..., 101 | atomic: bool = ... 102 | ) -> Union[IO[Any], LazyFile, KeepOpenFile]: 103 | ... 104 | 105 | 106 | def get_os_args() -> List[Text]: 107 | ... 108 | 109 | 110 | def format_filename(filename: Text, shorten: bool = ...) -> Text: 111 | ... 112 | 113 | 114 | def get_app_dir( 115 | app_name: Text, roaming: bool = ..., force_posix: bool = ... 116 | ) -> Text: 117 | ... 118 | -------------------------------------------------------------------------------- /stubs/click/parser.pyi: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Iterable, List, Optional, Set, Text, Tuple 2 | 3 | from click.core import Context 4 | 5 | 6 | def _unpack_args( 7 | args: Iterable[Text], nargs_spec: Iterable[int] 8 | ) -> Tuple[Tuple[Optional[Tuple[Text, ...]], ...], List[Text]]: 9 | ... 10 | 11 | 12 | def split_opt(opt: Text) -> Tuple[Text, Text]: 13 | ... 14 | 15 | 16 | def normalize_opt(opt: Text, ctx: Context) -> Text: 17 | ... 18 | 19 | 20 | def split_arg_Texting(Texting: Text) -> List[Text]: 21 | ... 22 | 23 | 24 | class Option: 25 | dest: Text 26 | action: Text 27 | nargs: int 28 | const: Any 29 | obj: Any 30 | prefixes: Set[Text] 31 | _short_opts: List[Text] 32 | _long_opts: List[Text] 33 | # properties 34 | takes_value: bool 35 | 36 | def __init__( 37 | self, 38 | opts: Iterable[Text], 39 | dest: Text, 40 | action: Optional[Text] = ..., 41 | nargs: int = ..., 42 | const: Optional[Any] = ..., 43 | obj: Optional[Any] = ... 44 | ) -> None: 45 | ... 46 | 47 | def process(self, value: Any, state: 'ParsingState') -> None: 48 | ... 49 | 50 | 51 | class Argument: 52 | dest: Text 53 | nargs: int 54 | obj: Any 55 | 56 | def __init__(self, dest: Text, nargs: int = ..., obj: Optional[Any] = ...) -> None: 57 | ... 58 | 59 | def process(self, value: Any, state: 'ParsingState') -> None: 60 | ... 61 | 62 | 63 | class ParsingState: 64 | opts: Dict[Text, Any] 65 | largs: List[Text] 66 | rargs: List[Text] 67 | order: List[Any] 68 | 69 | def __init__(self, rargs: List[Text]) -> None: 70 | ... 71 | 72 | 73 | class OptionParser: 74 | ctx: Optional[Context] 75 | allow_interspersed_args: bool 76 | ignore_unknown_options: bool 77 | _short_opt: Dict[Text, Option] 78 | _long_opt: Dict[Text, Option] 79 | _opt_prefixes: Set[Text] 80 | _args: List[Argument] 81 | 82 | def __init__(self, ctx: Optional[Context] = ...) -> None: 83 | ... 84 | 85 | def add_option( 86 | self, 87 | opts: Iterable[Text], 88 | dest: Text, 89 | action: Optional[Text] = ..., 90 | nargs: int = ..., 91 | const: Optional[Any] = ..., 92 | obj: Optional[Any] = ... 93 | ) -> None: 94 | ... 95 | 96 | def add_argument(self, dest: Text, nargs: int = ..., obj: Optional[Any] = ...) -> None: 97 | ... 98 | 99 | def parse_args( 100 | self, args: List[Text] 101 | ) -> Tuple[Dict[Text, Any], List[Text], List[Any]]: 102 | ... 103 | -------------------------------------------------------------------------------- /tests/graph_test.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | import unittest 7 | 8 | from onnx import helper, numpy_helper, TensorProto 9 | 10 | from tests._test_utils import _onnx_create_single_node_model, \ 11 | _onnx_create_model, _conv_pool_output_size, _random_array 12 | 13 | from onnx_coreml._graph import Node, Graph 14 | 15 | 16 | class NodeTest(unittest.TestCase): 17 | def test_create_node(self): # type: () -> None 18 | model = _onnx_create_single_node_model( 19 | "Elu", 20 | [(1, 3, 224, 224)], 21 | [(1, 3, 224, 224)], 22 | alpha=0.5 23 | ) 24 | graph = model.graph 25 | node = graph.node[0] 26 | node_ = Node.from_onnx(node) 27 | self.assertTrue(len(node_.inputs) == 1) 28 | self.assertTrue(len(node_.outputs) == 1) 29 | self.assertTrue(len(node_.attrs) == 1) 30 | self.assertTrue(node_.attrs["alpha"] == 0.5) 31 | 32 | 33 | class GraphTest(unittest.TestCase): 34 | def test_create_graph(self): # type: () -> None 35 | kernel_shape = (3, 2) 36 | strides = (2, 3) 37 | pads = (4, 2, 4, 2) 38 | dilations = (1, 2) 39 | group = 1 40 | weight = numpy_helper.from_array( 41 | _random_array((16, 3, 3, 2)), name="weight" 42 | ) 43 | 44 | input_shape = (1, 3, 224, 224) 45 | output_size = _conv_pool_output_size(input_shape, dilations, 46 | kernel_shape, pads, strides) 47 | 48 | output_shape = (1, int(weight.dims[0]), output_size[0], output_size[1]) 49 | 50 | inputs = [('input0', input_shape)] 51 | outputs = [('output0', output_shape, TensorProto.FLOAT)] 52 | 53 | conv = helper.make_node( 54 | "Conv", 55 | inputs=[inputs[0][0], "weight"], 56 | outputs=["conv_output"], 57 | dilations=dilations, 58 | group=group, 59 | kernel_shape=kernel_shape, 60 | pads=pads, 61 | strides=strides 62 | ) 63 | 64 | relu = helper.make_node( 65 | "Relu", 66 | inputs=[conv.output[0]], 67 | outputs=[outputs[0][0]] 68 | ) 69 | 70 | model = _onnx_create_model([conv, relu], inputs, outputs, [weight]) 71 | graph_ = Graph.from_onnx(model.graph, onnx_ir_version=5) 72 | self.assertTrue(len(graph_.inputs) == 1) 73 | self.assertEqual(graph_.inputs[0][2], input_shape) 74 | self.assertTrue(len(graph_.outputs) == 1) 75 | self.assertEqual(graph_.outputs[0][2], output_shape) 76 | self.assertTrue(len(graph_.nodes) == 2) 77 | self.assertEqual(len(graph_.nodes[0].parents), 0) 78 | self.assertEqual(len(graph_.nodes[1].parents), 1) 79 | self.assertEqual(len(graph_.nodes[0].children), 1) 80 | self.assertEqual(len(graph_.nodes[1].children), 0) 81 | 82 | 83 | if __name__ == '__main__': 84 | unittest.main() 85 | -------------------------------------------------------------------------------- /tests/convert_test.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | import unittest 7 | import numpy as np 8 | import numpy.testing as npt # type: ignore 9 | import numpy.random as npr 10 | 11 | from PIL import Image # type: ignore 12 | 13 | from onnx_coreml import convert 14 | from tests._test_utils import _onnx_create_single_node_model 15 | 16 | 17 | class ConvertTest(unittest.TestCase): 18 | def setUp(self): # type: () -> None 19 | self.img_arr = np.uint8(npr.rand(224, 224, 3) * 255) # type: ignore 20 | self.img = Image.fromarray(np.uint8(self.img_arr)) # type: ignore 21 | self.img_arr = np.float32(self.img_arr) # type: ignore 22 | self.onnx_model = _onnx_create_single_node_model( 23 | "Relu", 24 | [(3, 224, 224)], 25 | [(3, 224, 224)] 26 | ) 27 | self.input_names = [i.name for i in self.onnx_model.graph.input] 28 | self.output_names = [o.name for o in self.onnx_model.graph.output] 29 | 30 | def test_convert_image_input(self): # type: () -> None 31 | coreml_model = convert( 32 | self.onnx_model, 33 | image_input_names=self.input_names 34 | ) 35 | spec = coreml_model.get_spec() 36 | for input_ in spec.description.input: 37 | self.assertEqual(input_.type.WhichOneof('Type'), 'imageType') 38 | 39 | def test_convert_image_output(self): # type: () -> None 40 | coreml_model = convert( 41 | self.onnx_model, 42 | image_output_names=self.output_names 43 | ) 44 | spec = coreml_model.get_spec() 45 | for output in spec.description.output: 46 | self.assertEqual(output.type.WhichOneof('Type'), 'imageType') 47 | 48 | def test_convert_image_input_preprocess(self): # type: () -> None 49 | bias = np.array([100, 90, 80]) 50 | coreml_model = convert( 51 | self.onnx_model, 52 | image_input_names=self.input_names, 53 | preprocessing_args={ 54 | 'is_bgr': True, 55 | 'blue_bias': bias[0], 56 | 'green_bias': bias[1], 57 | 'red_bias': bias[2] 58 | } 59 | ) 60 | output = coreml_model.predict( 61 | { 62 | self.input_names[0]: self.img 63 | } 64 | )[self.output_names[0]] 65 | 66 | expected_output = self.img_arr[:, :, ::-1].transpose((2, 0, 1)) 67 | expected_output[0] = expected_output[0] + bias[0] 68 | expected_output[1] = expected_output[1] + bias[1] 69 | expected_output[2] = expected_output[2] + bias[2] 70 | npt.assert_equal(output.flatten(), expected_output.flatten()) 71 | 72 | def test_convert_image_output_bgr(self): # type: () -> None 73 | coreml_model = convert( 74 | self.onnx_model, 75 | image_input_names=self.input_names, 76 | image_output_names=self.output_names, 77 | deprocessing_args={ 78 | 'is_bgr': True 79 | } 80 | ) 81 | output = coreml_model.predict( 82 | { 83 | self.input_names[0]: self.img 84 | } 85 | )[self.output_names[0]] 86 | output = np.array(output)[:, :, :3].transpose((2, 0, 1)) 87 | expected_output = self.img_arr[:, :, ::-1].transpose((2, 0, 1)) 88 | npt.assert_equal(output, expected_output) 89 | 90 | 91 | if __name__ == '__main__': 92 | unittest.main() 93 | -------------------------------------------------------------------------------- /stubs/click/termui.pyi: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | from typing import ( 3 | Any, 4 | Callable, 5 | Generator, 6 | Iterable, 7 | IO, 8 | List, 9 | Optional, 10 | Text, 11 | Tuple, 12 | TypeVar, 13 | ) 14 | 15 | 16 | def hidden_prompt_func(prompt: Text) -> Text: 17 | ... 18 | 19 | 20 | def _build_prompt( 21 | text: Text, 22 | suffix: Text, 23 | show_default: bool = ..., 24 | default: Optional[Text] = ..., 25 | ) -> Text: 26 | ... 27 | 28 | 29 | def prompt( 30 | text: Text, 31 | default: Optional[Text] = ..., 32 | hide_input: bool = ..., 33 | confirmation_prompt: bool = ..., 34 | type: Optional[Any] = ..., 35 | value_proc: Optional[Callable[[Optional[Text]], Any]] = ..., 36 | prompt_suffix: Text = ..., 37 | show_default: bool = ..., 38 | err: bool = ..., 39 | ) -> Any: 40 | ... 41 | 42 | 43 | def confirm( 44 | text: Text, 45 | default: bool = ..., 46 | abort: bool = ..., 47 | prompt_suffix: Text = ..., 48 | show_default: bool = ..., 49 | err: bool = ..., 50 | ) -> bool: 51 | ... 52 | 53 | 54 | def get_terminal_size() -> Tuple[int, int]: 55 | ... 56 | 57 | 58 | def echo_via_pager(text: Text, color: Optional[bool] = ...) -> None: 59 | ... 60 | 61 | 62 | _T = TypeVar('_T') 63 | 64 | 65 | @contextmanager 66 | def progressbar( 67 | iterable: Optional[Iterable[_T]] = ..., 68 | length: Optional[int] = ..., 69 | label: Optional[Text] = ..., 70 | show_eta: bool = ..., 71 | show_percent: Optional[bool] = ..., 72 | show_pos: bool = ..., 73 | item_show_func: Optional[Callable[[_T], Text]] = ..., 74 | fill_char: Text = ..., 75 | empty_char: Text = ..., 76 | bar_template: Text = ..., 77 | info_sep: Text = ..., 78 | width: int = ..., 79 | file: Optional[IO[Any]] = ..., 80 | color: Optional[bool] = ..., 81 | ) -> Generator[_T, None, None]: 82 | ... 83 | 84 | 85 | def clear() -> None: 86 | ... 87 | 88 | 89 | def style( 90 | text: Text, 91 | fg: Optional[Text] = ..., 92 | bg: Optional[Text] = ..., 93 | bold: Optional[bool] = ..., 94 | dim: Optional[bool] = ..., 95 | underline: Optional[bool] = ..., 96 | blink: Optional[bool] = ..., 97 | reverse: Optional[bool] = ..., 98 | reset: bool = ..., 99 | ): 100 | ... 101 | 102 | 103 | def unstyle(text: Text) -> Text: 104 | ... 105 | 106 | 107 | # Styling options copied from style() for nicer type checking. 108 | def secho( 109 | text: Text, 110 | file: Optional[IO[Any]] = ..., 111 | nl: bool = ..., 112 | err: bool = ..., 113 | color: Optional[bool] = ..., 114 | fg: Optional[Text] = ..., 115 | bg: Optional[Text] = ..., 116 | bold: Optional[bool] = ..., 117 | dim: Optional[bool] = ..., 118 | underline: Optional[bool] = ..., 119 | blink: Optional[bool] = ..., 120 | reverse: Optional[bool] = ..., 121 | reset: bool = ..., 122 | ): 123 | ... 124 | 125 | 126 | def edit( 127 | text: Optional[Text] = ..., 128 | editor: Optional[Text] = ..., 129 | env: Optional[Text] = ..., 130 | require_save: bool = ..., 131 | extension: Text = ..., 132 | filename: Optional[Text] = ..., 133 | ) -> Text: 134 | ... 135 | 136 | 137 | def launch(url: Text, wait: bool = ..., locate: bool = ...) -> int: 138 | ... 139 | 140 | 141 | def getchar(echo: bool = ...) -> Text: 142 | ... 143 | 144 | 145 | def pause( 146 | info: Text = ..., err: bool = ... 147 | ) -> None: 148 | ... 149 | -------------------------------------------------------------------------------- /onnx_coreml/_error_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from typing import Dict, Text, Any, Callable 6 | from coremltools.models.neural_network import NeuralNetworkBuilder #type: ignore 7 | from ._graph import Node, Graph 8 | 9 | class ErrorHandling(object): 10 | ''' 11 | To handle errors and addition of custom layers 12 | ''' 13 | 14 | def __init__(self, 15 | add_custom_layers = False, # type: bool 16 | custom_conversion_functions = dict(), # type: Dict[Text, Any] 17 | custom_layer_nodes = [], # type : List[Node] 18 | ): 19 | # type: (...) -> None 20 | self.add_custom_layers = add_custom_layers 21 | self.custom_conversion_functions = custom_conversion_functions 22 | self.custom_layer_nodes = custom_layer_nodes 23 | 24 | self.rerun_suggestion = '\n Please try converting with higher minimum_ios_deployment_target.\n' \ 25 | 'You can also provide custom function/layer to convert the model.' 26 | 27 | 28 | def unsupported_op(self, 29 | node, # type: Node 30 | ): 31 | # type: (...) -> Callable[[Any, Node, Graph, ErrorHandling], None] 32 | ''' 33 | Either raise an error for an unsupported op type or return custom layer add function 34 | ''' 35 | if self.add_custom_layers: 36 | from ._operators import _convert_custom 37 | return _convert_custom 38 | else: 39 | raise TypeError( 40 | "ONNX node of type {} is not supported. {}\n".format(node.op_type, self.rerun_suggestion) 41 | ) 42 | 43 | 44 | def unsupported_op_configuration(self, 45 | builder, # type: NeuralNetworkBuilder 46 | node, # type: Node 47 | graph, # type: Graph 48 | err_message, # type: Text 49 | ): 50 | # type: (...) -> None 51 | ''' 52 | Either raise an error for an unsupported attribute or add a custom layer. 53 | ''' 54 | if self.add_custom_layers: 55 | from ._operators import _convert_custom 56 | _convert_custom(builder, node, graph, self) 57 | else: 58 | raise TypeError( 59 | "Error while converting op of type: {}. Error message: {} {}\n".format(node.op_type, err_message, 60 | self.rerun_suggestion) 61 | ) 62 | 63 | 64 | def missing_initializer(self, 65 | node, # type: Node 66 | err_message, # type: Text 67 | ): 68 | # type: (...) -> None 69 | ''' 70 | Missing initializer error 71 | ''' 72 | raise ValueError( 73 | "Missing initializer error in op of type {}, with input name = {}, " 74 | "output name = {}. Error message: {} {}\n". 75 | format(node.op_type, node.inputs[0], node.outputs[0], err_message, self.rerun_suggestion) 76 | ) 77 | 78 | def unsupported_feature_warning(self, 79 | node, # type: Node 80 | warn_message, # type: Text 81 | ): 82 | # type: (...) -> None 83 | ''' 84 | Unsupported feature warning 85 | ''' 86 | print( 87 | "Warning: Unsupported Feature in op of type {}, with input name = {}, " 88 | "output name = {}. Warning message: {}\n". 89 | format(node.op_type, node.inputs[0], node.outputs[0], err_message) 90 | ) 91 | 92 | 93 | 94 | -------------------------------------------------------------------------------- /onnx_coreml/graph_viz.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | def _shape_notation(int_shape): 5 | X = ['S','B','C','H','W'] 6 | return [X[i] for i in int_shape] 7 | 8 | def plot_graph(graph, graph_img_path='graph.png', show_coreml_mapped_shapes=False): 9 | """ 10 | Plot graph using pydot 11 | 12 | It works in two steps: 13 | 1. Add nodes to pydot 14 | 2. connect nodes added in pydot 15 | 16 | :param graph 17 | :return: writes down a png/pdf file using dot 18 | """ 19 | 20 | try: 21 | # pydot-ng is a fork of pydot that is better maintained. 22 | import pydot_ng as pydot # type: ignore 23 | except: 24 | # pydotplus is an improved version of pydot 25 | try: 26 | import pydotplus as pydot # type: ignore 27 | except: 28 | # Fall back on pydot if necessary. 29 | try: 30 | import pydot # type: ignore 31 | except: 32 | return None 33 | 34 | dot = pydot.Dot() 35 | dot.set('rankdir', 'TB') 36 | dot.set('concentrate', True) 37 | dot.set_node_defaults(shape='record') 38 | 39 | # Add nodes corresponding to graph inputs 40 | graph_inputs = [] 41 | for input_ in graph.inputs: 42 | if show_coreml_mapped_shapes: 43 | if input_[0] in graph.onnx_coreml_shape_mapping: 44 | shape = tuple(_shape_notation(graph.onnx_coreml_shape_mapping[input_[0]])) 45 | else: 46 | shape = 'NA, ' 47 | else: 48 | shape = tuple(input_[2]) 49 | label = '%s\n|{|%s}|{{%s}|{%s}}' % ('Input', 50 | input_[0], 51 | '', 52 | str(shape)) 53 | pydot_node = pydot.Node(input_[0], label=label) 54 | dot.add_node(pydot_node) 55 | graph_inputs.append(input_[0]) 56 | 57 | # Traverse graph and add nodes to pydot 58 | for node in graph.nodes: 59 | inputlabels = '' 60 | for input_ in node.inputs: 61 | if show_coreml_mapped_shapes: 62 | if input_ in graph.onnx_coreml_shape_mapping: 63 | inputlabels += str(tuple(_shape_notation(graph.onnx_coreml_shape_mapping[input_]))) + ', ' 64 | else: 65 | inputlabels += 'NA, ' 66 | else: 67 | if input_ in graph.shape_dict: 68 | inputlabels += str(tuple(graph.shape_dict[input_])) + ', ' 69 | else: 70 | inputlabels += 'NA, ' 71 | outputlabels = '' 72 | for output_ in node.outputs: 73 | if show_coreml_mapped_shapes: 74 | if output_ in graph.onnx_coreml_shape_mapping: 75 | outputlabels += str(tuple(_shape_notation(graph.onnx_coreml_shape_mapping[output_]))) + ', ' 76 | else: 77 | outputlabels += 'NA, ' 78 | else: 79 | if output_ in graph.shape_dict: 80 | outputlabels += str(tuple(graph.shape_dict[output_])) + ', ' 81 | else: 82 | outputlabels += 'NA, ' 83 | output_names = ', '.join([output_ for output_ in node.outputs]) 84 | input_names = ', '.join([input_ for input_ in node.inputs]) 85 | label = '%s\n|{{%s}|{%s}}|{{%s}|{%s}}' % (node.op_type, 86 | input_names, 87 | output_names, 88 | inputlabels, 89 | outputlabels) 90 | pydot_node = pydot.Node(node.name, label=label) 91 | dot.add_node(pydot_node) 92 | 93 | # add edges 94 | for node in graph.nodes: 95 | for child in node.children: 96 | # add edge in pydot 97 | dot.add_edge(pydot.Edge(node.name, child.name)) 98 | for input_ in node.inputs: 99 | if input_ in graph_inputs: 100 | dot.add_edge(pydot.Edge(input_, node.name)) 101 | 102 | 103 | # write out the image file 104 | _, extension = os.path.splitext(graph_img_path) 105 | if not extension: 106 | extension = 'pdf' 107 | else: 108 | extension = extension[1:] 109 | dot.write(graph_img_path, format=extension) 110 | 111 | 112 | 113 | -------------------------------------------------------------------------------- /stubs/click/__init__.pyi: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Note: The files in this directory are taken from https://github.com/python/typeshed with the pull request 4 | # https://github.com/python/typeshed/pull/1931 applied. 5 | # TODO This whole stubs/click directory can be deleted once that pull request is incorporated into a new mypy version. 6 | 7 | """ 8 | click 9 | ~~~~~ 10 | 11 | Click is a simple Python module that wraps the stdlib's optparse to make 12 | writing command line scripts fun. Unlike other modules, it's based around 13 | a simple API that does not come with too much magic and is composable. 14 | 15 | In case optparse ever gets removed from the stdlib, it will be shipped by 16 | this module. 17 | 18 | :copyright: (c) 2014 by Armin Ronacher. 19 | :license: BSD, see LICENSE for more details. 20 | """ 21 | 22 | # Core classes 23 | from .core import ( 24 | Context as Context, 25 | BaseCommand as BaseCommand, 26 | Command as Command, 27 | MultiCommand as MultiCommand, 28 | Group as Group, 29 | CommandCollection as CommandCollection, 30 | Parameter as Parameter, 31 | Option as Option, 32 | Argument as Argument, 33 | ) 34 | 35 | # Globals 36 | from .globals import get_current_context as get_current_context 37 | 38 | # Decorators 39 | from .decorators import ( 40 | pass_context as pass_context, 41 | pass_obj as pass_obj, 42 | make_pass_decorator as make_pass_decorator, 43 | command as command, 44 | group as group, 45 | argument as argument, 46 | option as option, 47 | confirmation_option as confirmation_option, 48 | password_option as password_option, 49 | version_option as version_option, 50 | help_option as help_option, 51 | ) 52 | 53 | # Types 54 | from .types import ( 55 | ParamType as ParamType, 56 | File as File, 57 | Path as Path, 58 | Choice as Choice, 59 | IntRange as IntRange, 60 | Tuple as Tuple, 61 | STRING as STRING, 62 | INT as INT, 63 | FLOAT as FLOAT, 64 | BOOL as BOOL, 65 | UUID as UUID, 66 | UNPROCESSED as UNPROCESSED, 67 | ) 68 | 69 | # Utilities 70 | from .utils import ( 71 | echo as echo, 72 | get_binary_stream as get_binary_stream, 73 | get_text_stream as get_text_stream, 74 | open_file as open_file, 75 | format_filename as format_filename, 76 | get_app_dir as get_app_dir, 77 | get_os_args as get_os_args, 78 | ) 79 | 80 | # Terminal functions 81 | from .termui import ( 82 | prompt as prompt, 83 | confirm as confirm, 84 | get_terminal_size as get_terminal_size, 85 | echo_via_pager as echo_via_pager, 86 | progressbar as progressbar, 87 | clear as clear, 88 | style as style, 89 | unstyle as unstyle, 90 | secho as secho, 91 | edit as edit, 92 | launch as launch, 93 | getchar as getchar, 94 | pause as pause, 95 | ) 96 | 97 | # Exceptions 98 | from .exceptions import ( 99 | ClickException as ClickException, 100 | UsageError as UsageError, 101 | BadParameter as BadParameter, 102 | FileError as FileError, 103 | Abort as Abort, 104 | NoSuchOption as NoSuchOption, 105 | BadOptionUsage as BadOptionUsage, 106 | BadArgumentUsage as BadArgumentUsage, 107 | MissingParameter as MissingParameter, 108 | ) 109 | 110 | # Formatting 111 | from .formatting import HelpFormatter as HelpFormatter, wrap_text as wrap_text 112 | 113 | # Parsing 114 | from .parser import OptionParser as OptionParser 115 | 116 | 117 | __all__ = [ 118 | # Core classes 119 | 'Context', 'BaseCommand', 'Command', 'MultiCommand', 'Group', 120 | 'CommandCollection', 'Parameter', 'Option', 'Argument', 121 | 122 | # Globals 123 | 'get_current_context', 124 | 125 | # Decorators 126 | 'pass_context', 'pass_obj', 'make_pass_decorator', 'command', 'group', 127 | 'argument', 'option', 'confirmation_option', 'password_option', 128 | 'version_option', 'help_option', 129 | 130 | # Types 131 | 'ParamType', 'File', 'Path', 'Choice', 'IntRange', 'Tuple', 'STRING', 132 | 'INT', 'FLOAT', 'BOOL', 'UUID', 'UNPROCESSED', 133 | 134 | # Utilities 135 | 'echo', 'get_binary_stream', 'get_text_stream', 'open_file', 136 | 'format_filename', 'get_app_dir', 'get_os_args', 137 | 138 | # Terminal functions 139 | 'prompt', 'confirm', 'get_terminal_size', 'echo_via_pager', 140 | 'progressbar', 'clear', 'style', 'unstyle', 'secho', 'edit', 'launch', 141 | 'getchar', 'pause', 142 | 143 | # Exceptions 144 | 'ClickException', 'UsageError', 'BadParameter', 'FileError', 145 | 'Abort', 'NoSuchOption', 'BadOptionUsage', 'BadArgumentUsage', 146 | 'MissingParameter', 147 | 148 | # Formatting 149 | 'HelpFormatter', 'wrap_text', 150 | 151 | # Parsing 152 | 'OptionParser', 153 | ] 154 | 155 | 156 | # Controls if click should emit the warning about the use of unicode 157 | # literals. 158 | disable_unicode_literals_warning = False 159 | 160 | 161 | __version__ = '6.6' 162 | -------------------------------------------------------------------------------- /contributing.md: -------------------------------------------------------------------------------- 1 | 2 | Contribution Guidelines 3 | ======================= 4 | 5 | How to be a Contributor 6 | --- 7 | 8 | To contribute to Core ML via onnx-coreml repository, a Contributor License Agreement (CLA) must be signed. This can be found [here](https://cla-assistant.io/onnx/onnx-coreml) or will be attached as a comment when your first pull request is created. 9 | 10 | **Core ML Open Source Community** 11 | 12 | The Core ML open source community welcomes all contributions and ideas to grow the product. This can occur within this repo as well as [coremltools](https://github.com/apple/coremltools) or [tf-coreml](https://github.com/tf-coreml/tf-coreml). 13 | 14 | This could be provided in a couple of ways: 15 | 16 | * Discovering and logging a **bug**, submitting an idea for a **feature request** (or **enhancement** to an existing feature) or asking a **question** through the use of the templates: [onnx-coreml issue](https://github.com/onnx/onnx-coreml/issues/new/choose) 17 | 18 | * Submit a **pull request** for additional functionality that you have completed: [onnx-coreml PR](https://github.com/onnx/onnx-coreml/pulls) 19 | 20 | * Resolve an existing **issue** found in any of the repositories: [onnx-coreml open issues](https://github.com/onnx/onnx-coreml/issues) 21 | 22 | Expectations of the Community 23 | --- 24 | 25 | The contributing guidelines and code of conduct are similar to most open source communities. This includes participating in the community through developing, receiving help and answering questions as well as engaging in a highly motivate, positive environment. 26 | 27 | Additionally, this includes but is not limited to: 28 | 29 | * Providing comments that are helpful, motivating and constructive 30 | * Being respectful of others within the community 31 | * Collaborate with others to produce new, useful contributions to the community 32 | 33 | ## Github Best Practices for Core ML Community 34 | 35 | While participating in the Core ML community, to ensure that issues and pull requests are able to be addressed quickly, please ensure that the following is being done: 36 | 37 | * Checking to see if your issue already exists 38 | * Following pre-existing templates 39 | * Promptly replying to any requests or questions posed by others within the community on your issue / PR 40 | 41 | Code of Conduct 42 | --- 43 | Our goal is to house an inclusive, welcoming open source community. This includes treating all members of the community equally 44 | and with respect. We as contributors and maintainers pledge to making participation in our project and our community a harassment-free 45 | experience for everyone, regardless of age, body size, disability, ethnicity, sex characteristics, gender identity and expression, 46 | level of experience, education, socio-economic status, nationality, personal appearance, race, religion, or sexual identity and orientation. 47 | 48 | As project maintainers, we will monitor community behaviour to ensure acceptable behaviour. If instances of abusive, harrasement or otherwise 49 | unacceptable behaviour ocur, please contact coreml-conduct@group.apple.com. 50 | 51 | The Code of Conduct is adapted from the Contributor Covenant, found [here](https://www.contributor-covenant.org) 52 | 53 | ## What to expect as a Contributor to Core ML 54 | 55 | ### Lifecycle of Issues 56 | 57 | Once an issue has been submitted it will be triaged and appropriate labels will be added. The issue will then either be slotted in to an upcoming release, commented on for additional information or placed in the backlog for a future release. 58 | 59 | #### Use of templates 60 | There will be a provided template when submitting your issue. Please ensure that this is used and filled in as much as possible to help others in the community in understanding the issue so that they are able to provide a response easily. 61 | 62 | If applicable, please provide the model that was being used when logging the issue, so that the issue is able to be reproduced. 63 | 64 | #### Labels 65 | Please check the labels page under each repository for a further description of each label. These labels will be added by project maintainers depending on need. Typically an issue can be given any of the following labels: 66 | 67 | Status → What stage of the process the issue is in (turquoise) 68 | 69 | * Triaged, awaiting response, duplicate, repro needed, investigation or needs discussion 70 | 71 | Type → Issue is classified based on the category it belongs to (red) 72 | 73 | * bug, clarification, docs, enhancement, feature request, perf, question 74 | 75 | Release → If the issue is scheduled to be resolved in a specific release, it will be added (yellow) 76 | 77 | Other → this may vary depending on the repository 78 | 79 | * good first issue, help wanted 80 | 81 | ### Lifecycle of Pull Requests 82 | 83 | Any pull request submitted to the repositories will be reviewed by a member of the community and upon approval by a Core ML team member can be merged to the master branch. If you are new to GitHub, please find more information regarding creating pull requests here (https://help.github.com/en/articles/creating-a-pull-request). 84 | 85 | Developing in onnx-coreml 86 | --- 87 | 88 | Additional information regarding APIs, installation and dependencies and more can be found in the readme.md file, [here](https://github.com/onnx/onnx-coreml). 89 | 90 | -------------------------------------------------------------------------------- /onnx_coreml/_backend_rep.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | # from __future__ import unicode_literals 5 | 6 | import numpy as np 7 | from typing import Any, Sequence, List 8 | from onnx.backend.base import BackendRep, namedtupledict 9 | from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE 10 | from coremltools.proto import FeatureTypes_pb2 as ft #type: ignore 11 | from coremltools.models import MLModel #type: ignore 12 | from typing import Dict, Any, Text, Tuple 13 | from onnx import TensorProto 14 | from ._graph import EdgeInfo 15 | from .converter import SupportedVersion 16 | 17 | 18 | def _set_dtypes(input_dict, #type: Dict[Text, np._ArrayLike[Any]] 19 | model, #type: MLModel 20 | ): 21 | # type: (...) -> None 22 | spec = model.get_spec() 23 | for input_ in spec.description.input: 24 | if input_.type.HasField('multiArrayType') and input_.name in input_dict: 25 | if input_.type.multiArrayType.dataType == ft.ArrayFeatureType.INT32: 26 | input_dict[input_.name] = input_dict[input_.name].astype(np.int32) 27 | if input_.type.multiArrayType.dataType == ft.ArrayFeatureType.FLOAT32: 28 | input_dict[input_.name] = input_dict[input_.name].astype(np.float32) 29 | if input_.type.multiArrayType.dataType == ft.ArrayFeatureType.DOUBLE: 30 | input_dict[input_.name] = input_dict[input_.name].astype(np.float64) 31 | 32 | 33 | class CoreMLRep(BackendRep): 34 | def __init__(self, 35 | coreml_model, # type: MLModel 36 | onnx_outputs_info, # type: Dict[Text, EdgeInfo] 37 | useCPUOnly=False, # type: bool 38 | minimum_ios_deployment_target='12' # type: str 39 | ): 40 | # type: (...) -> None 41 | super(CoreMLRep, self).__init__() 42 | self.model = coreml_model 43 | self.useCPUOnly = useCPUOnly 44 | self.minimum_ios_deployment_target = minimum_ios_deployment_target 45 | 46 | spec = coreml_model.get_spec() 47 | self.input_names = [str(i.name) for i in spec.description.input] 48 | self.output_names = [str(o.name) for o in spec.description.output] 49 | self.onnx_outputs_info = onnx_outputs_info # type: Dict[Text, EdgeInfo] 50 | 51 | def run(self, 52 | inputs, # type: Any 53 | **kwargs # type: Any 54 | ): 55 | # type: (...) -> Tuple[Any, ...] 56 | super(CoreMLRep, self).run(inputs, **kwargs) 57 | inputs_ = inputs 58 | _reshaped = False 59 | if not SupportedVersion.is_nd_array_supported(self.minimum_ios_deployment_target): 60 | for i, input_ in enumerate(inputs_): 61 | shape = input_.shape 62 | if len(shape) == 4 or len(shape) == 2: 63 | inputs_[i] = input_[np.newaxis, :] 64 | _reshaped = True 65 | elif len(shape) == 3: 66 | spec = self.model.get_spec() 67 | spec_shape = [int(k) for k in spec.description.input[i].type.multiArrayType.shape] 68 | prod = spec_shape[0] * spec_shape[1] * spec_shape[2] 69 | onnx_shape = list(shape) 70 | if onnx_shape != spec_shape: 71 | if onnx_shape[2] == prod: 72 | inputs_[i] = np.reshape(inputs_[i], [onnx_shape[0], onnx_shape[1]] + spec_shape) 73 | elif onnx_shape[1] * onnx_shape[2] == prod: 74 | inputs_[i] = np.reshape(inputs_[i], [1, onnx_shape[0]] + spec_shape) 75 | input_dict = dict( 76 | zip(self.input_names, 77 | map(np.array, inputs_))) 78 | _set_dtypes(input_dict, self.model) #type: ignore 79 | 80 | prediction = self.model.predict(input_dict, self.useCPUOnly) 81 | output_values = [prediction[name] for name in self.output_names] 82 | 83 | if not SupportedVersion.is_nd_array_supported(self.minimum_ios_deployment_target): 84 | for i, output_ in enumerate(output_values): 85 | shape = output_.shape 86 | #reshape the CoreML output to match Onnx's output shape 87 | try: 88 | output_values[i] = np.reshape(output_, self.onnx_outputs_info[self.output_names[i]][2]) # type: ignore 89 | except RuntimeError: 90 | print("Output '%s' shape incompatible between CoreML (%s) and onnx (%s)" 91 | %(self.output_names[i], output_.shape, 92 | self.onnx_outputs_info[self.output_names[i]])) 93 | 94 | ## Type Cast to ONNX expected output types 95 | for i, output_ in enumerate(output_values): 96 | output_type = self.onnx_outputs_info[self.output_names[i]][1] 97 | if TENSOR_TYPE_TO_NP_TYPE[output_type] != output_values[i].dtype: 98 | output_values[i] = output_values[i].astype(TENSOR_TYPE_TO_NP_TYPE[output_type]) 99 | 100 | result = namedtupledict('Outputs', 101 | self.output_names)(*output_values) # type: Tuple[Any, ...] 102 | return result -------------------------------------------------------------------------------- /onnx_coreml/_backend.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | from typing import Any, Text, Dict, Tuple 7 | from onnx import ModelProto 8 | from onnx.backend.base import Backend 9 | from onnx_coreml._backend_rep import CoreMLRep 10 | from onnx_coreml import convert 11 | import onnx 12 | from ._graph import _input_from_onnx_input, EdgeInfo 13 | 14 | DEBUG = False 15 | 16 | def _get_onnx_outputs_info(model): # type: (...) -> Dict[Text, EdgeInfo] 17 | """ 18 | Takes in an onnx model and returns a dictionary 19 | of onnx output names mapped to a tuple that is (output_name, type, shape) 20 | """ 21 | if isinstance(model, str): 22 | onnx_model = onnx.load(model) 23 | elif isinstance(model, onnx.ModelProto): 24 | onnx_model = model 25 | 26 | graph = onnx_model.graph 27 | onnx_output_dict = {} 28 | for o in graph.output: 29 | out = _input_from_onnx_input(o) 30 | onnx_output_dict[out[0]] = out 31 | return onnx_output_dict 32 | 33 | 34 | class CoreMLBackend(Backend): 35 | @classmethod 36 | def prepare(cls, 37 | model, # type: ModelProto 38 | device='CPU', # type: Text 39 | minimum_ios_deployment_target='12', # type: str 40 | **kwargs # type: Any 41 | ): 42 | # type: (...) -> CoreMLRep 43 | super(CoreMLBackend, cls).prepare(model, device, **kwargs) 44 | if DEBUG: 45 | with open('/tmp/node_model.onnx', 'wb') as f: 46 | s = model.SerializeToString() 47 | f.write(s) 48 | coreml_model = convert(model, minimum_ios_deployment_target=minimum_ios_deployment_target) 49 | if DEBUG: 50 | coreml_model.save('/tmp/node_model.mlmodel') 51 | onnx_outputs_info = _get_onnx_outputs_info(model) 52 | return CoreMLRep(coreml_model, onnx_outputs_info, device == 'CPU', minimum_ios_deployment_target=minimum_ios_deployment_target) 53 | 54 | @classmethod 55 | def is_compatible(cls, 56 | model, # type: ModelProto 57 | device='CPU', # type: Text 58 | **kwargs # type: Any 59 | ): # type: (...) -> bool 60 | # Return whether the model is compatible with CoreML. 61 | ''' 62 | This function will gradually grow to cover more cases. 63 | Need to be careful of false negatives. There are some cases that seemingly 64 | are not supported on CoreML, which the graph transformer optimizes and converts to 65 | a graph that can be converted to CoreML. 66 | 67 | 1. Check whether the layers for which CoreML expects constant weights are in 68 | the list of initializers in the onnx graph 69 | 2. unsupported ops like "And", "Or" etc 70 | 71 | ''' 72 | 73 | node_set = set() 74 | initializer_set = set() 75 | graph = model.graph 76 | for t in graph.initializer: 77 | initializer_set.add(t.name) 78 | for node in graph.node: 79 | if node.op_type in ['ConvTranspose', 80 | 'Conv', 81 | 'BatchNormalization', 82 | 'InstanceNormalization', 83 | 'PRelu']: 84 | if len(node.input) > 1 and node.input[1] not in initializer_set: 85 | return False 86 | node_set.add(node.op_type) 87 | 88 | # unsupported ops remove 89 | for node in graph.node: 90 | if node.op_type in ['Cast', 91 | 'And', 92 | 'Or', 93 | 'Xor', 94 | 'Not', 95 | 'Less', 96 | 'Greater', 97 | 'Equal', 98 | 'Ceil', 99 | 'Floor']: 100 | return False 101 | 102 | return True 103 | 104 | @classmethod 105 | def supports_device(cls, 106 | device, # type: Text 107 | ): 108 | # type: (...) -> bool 109 | return device == 'CPU' 110 | 111 | 112 | class CoreMLBackendND(Backend): 113 | @classmethod 114 | def prepare(cls, 115 | model, # type: ModelProto 116 | device='CPU', # type: Text 117 | minimum_ios_deployment_target='13', # type: str 118 | **kwargs # type: Any 119 | ): 120 | # type: (...) -> CoreMLRep 121 | super(CoreMLBackendND, cls).prepare(model, device, **kwargs) 122 | if DEBUG: 123 | with open('/tmp/node_model.onnx', 'wb') as f: 124 | s = model.SerializeToString() 125 | f.write(s) 126 | coreml_model = convert(model, minimum_ios_deployment_target=minimum_ios_deployment_target) 127 | if DEBUG: 128 | coreml_model.save('/tmp/node_model.mlmodel') 129 | onnx_outputs_info = _get_onnx_outputs_info(model) 130 | return CoreMLRep(coreml_model, onnx_outputs_info, device == 'CPU', minimum_ios_deployment_target=minimum_ios_deployment_target) 131 | 132 | @classmethod 133 | def is_compatible(cls, 134 | model, # type: ModelProto 135 | device='CPU', # type: Text 136 | **kwargs # type: Any 137 | ): # type: (...) -> bool 138 | # Return whether the model is compatible with CoreML. 139 | ''' 140 | This function will gradually grow to cover more cases. 141 | Need to be careful of false negatives. There are some cases that seemingly 142 | are not supported on CoreML, which the graph transformer optimizes and converts to 143 | a graph that can be converted to CoreML. 144 | 145 | 2. Unsupported ops: If graph has one of unsupported op, exit 146 | 147 | ''' 148 | ## TODO: Add un-supported ops 149 | unsupported_ops = [] 150 | graph = model.graph 151 | for node in graph.node: 152 | if node.op_type in unsupported_ops: 153 | return False 154 | return True 155 | 156 | @classmethod 157 | def supports_device(cls, 158 | device, # type: Text 159 | ): 160 | # type: (...) -> bool 161 | return device == 'CPU' 162 | -------------------------------------------------------------------------------- /stubs/click/decorators.pyi: -------------------------------------------------------------------------------- 1 | from distutils.version import Version 2 | from typing import Any, Callable, Dict, List, Optional, Text, Type, TypeVar, Union 3 | 4 | from click.core import Command, Group, Argument, Option, Parameter, Context 5 | from click.types import ParamType 6 | 7 | _T = TypeVar('_T') 8 | _Decorator = Callable[[_T], _T] 9 | 10 | _Callback = Callable[ 11 | [Context, Union[Option, Parameter], Union[bool, int, Text]], 12 | Any 13 | ] 14 | 15 | def pass_context(_T) -> _T: 16 | ... 17 | 18 | 19 | def pass_obj(_T) -> _T: 20 | ... 21 | 22 | 23 | def make_pass_decorator( 24 | object_type: type, ensure: bool = ... 25 | ) -> Callable[[_T], _T]: 26 | ... 27 | 28 | 29 | # NOTE: Decorators below have **attrs converted to concrete constructor 30 | # arguments from core.pyi to help with type checking. 31 | 32 | def command( 33 | name: Optional[Text] = ..., 34 | cls: Optional[Type[Command]] = ..., 35 | # Command 36 | context_settings: Optional[Dict[Text, Any]] = ..., 37 | help: Optional[Text] = ..., 38 | epilog: Optional[Text] = ..., 39 | short_help: Optional[Text] = ..., 40 | options_metavar: Text = ..., 41 | add_help_option: bool = ..., 42 | ) -> _Decorator[Any]: 43 | ... 44 | 45 | 46 | # This inherits attrs from Group, MultiCommand and Command. 47 | 48 | def group( 49 | name: Optional[Text] = ..., 50 | cls: Type[Command] = ..., 51 | # Group 52 | commands: Optional[Dict[Text, Command]] = ..., 53 | # MultiCommand 54 | invoke_without_command: bool = ..., 55 | no_args_is_help: Optional[bool] = ..., 56 | subcommand_metavar: Optional[Text] = ..., 57 | chain: bool = ..., 58 | result_callback: Optional[Callable[..., Any]] = ..., 59 | # Command 60 | help: Optional[Text] = ..., 61 | epilog: Optional[Text] = ..., 62 | short_help: Optional[Text] = ..., 63 | options_metavar: Text = ..., 64 | add_help_option: bool = ..., 65 | # User-defined 66 | **kwargs: Any, 67 | ) -> _Decorator[Any]: 68 | ... 69 | 70 | 71 | def argument( 72 | *param_decls: Text, 73 | cls: Type[Argument] = ..., 74 | # Argument 75 | required: Optional[bool] = ..., 76 | # Parameter 77 | type: Optional[Union[type, ParamType]] = ..., 78 | default: Optional[Any] = ..., 79 | callback: Optional[_Callback] = ..., 80 | nargs: Optional[int] = ..., 81 | metavar: Optional[Text] = ..., 82 | expose_value: bool = ..., 83 | is_eager: bool = ..., 84 | envvar: Optional[Union[Text, List[Text]]] = ... 85 | ) -> _Decorator[Any]: 86 | ... 87 | 88 | 89 | def option( 90 | *param_decls: Text, 91 | cls: Type[Option] = ..., 92 | # Option 93 | show_default: bool = ..., 94 | prompt: Union[bool, Text] = ..., 95 | confirmation_prompt: bool = ..., 96 | hide_input: bool = ..., 97 | is_flag: Optional[bool] = ..., 98 | flag_value: Optional[Any] = ..., 99 | multiple: bool = ..., 100 | count: bool = ..., 101 | allow_from_autoenv: bool = ..., 102 | type: Optional[Union[type, ParamType]] = ..., 103 | help: Optional[Text] = ..., 104 | # Parameter 105 | default: Optional[Any] = ..., 106 | required: bool = ..., 107 | callback: Optional[_Callback] = ..., 108 | nargs: Optional[int] = ..., 109 | metavar: Optional[Text] = ..., 110 | expose_value: bool = ..., 111 | is_eager: bool = ..., 112 | envvar: Optional[Union[Text, List[Text]]] = ... 113 | ) -> _Decorator[Any]: 114 | ... 115 | 116 | 117 | def confirmation_option( 118 | *param_decls: Text, 119 | cls: Type[Option] = ..., 120 | # Option 121 | show_default: bool = ..., 122 | prompt: Union[bool, Text] = ..., 123 | confirmation_prompt: bool = ..., 124 | hide_input: bool = ..., 125 | is_flag: bool = ..., 126 | flag_value: Optional[Any] = ..., 127 | multiple: bool = ..., 128 | count: bool = ..., 129 | allow_from_autoenv: bool = ..., 130 | type: Optional[Union[type, ParamType]] = ..., 131 | help: Text = ..., 132 | # Parameter 133 | default: Optional[Any] = ..., 134 | callback: Optional[_Callback] = ..., 135 | nargs: Optional[int] = ..., 136 | metavar: Optional[Text] = ..., 137 | expose_value: bool = ..., 138 | is_eager: bool = ..., 139 | envvar: Optional[Union[Text, List[Text]]] = ... 140 | ) -> _Decorator[Any]: 141 | ... 142 | 143 | 144 | def password_option( 145 | *param_decls: Text, 146 | cls: Type[Option] = ..., 147 | # Option 148 | show_default: bool = ..., 149 | prompt: Union[bool, Text] = ..., 150 | confirmation_prompt: bool = ..., 151 | hide_input: bool = ..., 152 | is_flag: Optional[bool] = ..., 153 | flag_value: Optional[Any] = ..., 154 | multiple: bool = ..., 155 | count: bool = ..., 156 | allow_from_autoenv: bool = ..., 157 | type: Optional[Union[type, ParamType]] = ..., 158 | help: Optional[Text] = ..., 159 | # Parameter 160 | default: Optional[Any] = ..., 161 | callback: Optional[_Callback] = ..., 162 | nargs: Optional[int] = ..., 163 | metavar: Optional[Text] = ..., 164 | expose_value: bool = ..., 165 | is_eager: bool = ..., 166 | envvar: Optional[Union[Text, List[Text]]] = ... 167 | ) -> _Decorator[Any]: 168 | ... 169 | 170 | 171 | def version_option( 172 | version: Optional[Union[Text, Version]] = ..., 173 | *param_decls: Text, 174 | cls: Type[Option] = ..., 175 | # Option 176 | prog_name: Optional[Text] = ..., 177 | message: Optional[Text] = ..., 178 | show_default: bool = ..., 179 | prompt: Union[bool, Text] = ..., 180 | confirmation_prompt: bool = ..., 181 | hide_input: bool = ..., 182 | is_flag: bool = ..., 183 | flag_value: Optional[Any] = ..., 184 | multiple: bool = ..., 185 | count: bool = ..., 186 | allow_from_autoenv: bool = ..., 187 | type: Optional[Union[type, ParamType]] = ..., 188 | help: Text = ..., 189 | # Parameter 190 | default: Optional[Any] = ..., 191 | callback: Optional[_Callback] = ..., 192 | nargs: Optional[int] = ..., 193 | metavar: Optional[Text] = ..., 194 | expose_value: bool = ..., 195 | is_eager: bool = ..., 196 | envvar: Optional[Union[Text, List[Text]]] = ... 197 | ) -> _Decorator[Any]: 198 | ... 199 | 200 | 201 | def help_option( 202 | *param_decls: Text, 203 | cls: Type[Option] = ..., 204 | # Option 205 | show_default: bool = ..., 206 | prompt: Union[bool, Text] = ..., 207 | confirmation_prompt: bool = ..., 208 | hide_input: bool = ..., 209 | is_flag: bool = ..., 210 | flag_value: Optional[Any] = ..., 211 | multiple: bool = ..., 212 | count: bool = ..., 213 | allow_from_autoenv: bool = ..., 214 | type: Optional[Union[type, ParamType]] = ..., 215 | help: Text = ..., 216 | # Parameter 217 | default: Optional[Any] = ..., 218 | callback: Optional[_Callback] = ..., 219 | nargs: Optional[int] = ..., 220 | metavar: Optional[Text] = ..., 221 | expose_value: bool = ..., 222 | is_eager: bool = ..., 223 | envvar: Optional[Union[Text, List[Text]]] = ... 224 | ) -> _Decorator[Any]: 225 | ... 226 | -------------------------------------------------------------------------------- /stubs/click/types.pyi: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, IO, Iterable, List, Optional, Text, TypeVar, Union 2 | import uuid 3 | 4 | from click.core import Context, Parameter 5 | 6 | 7 | class ParamType: 8 | name: Text 9 | is_composite: bool 10 | envvar_list_splitter: Optional[Text] 11 | 12 | def __call__( 13 | self, 14 | value: Optional[Text], 15 | param: Optional[Parameter] = ..., 16 | ctx: Optional[Context] = ..., 17 | ) -> Any: 18 | ... 19 | 20 | def get_metavar(self, param: Parameter) -> Text: 21 | ... 22 | 23 | def get_missing_message(self, param: Parameter) -> Text: 24 | ... 25 | 26 | def convert( 27 | self, 28 | value: Text, 29 | param: Optional[Parameter], 30 | ctx: Optional[Context], 31 | ) -> Any: 32 | ... 33 | 34 | def split_envvar_value(self, rv: Text) -> List[Text]: 35 | ... 36 | 37 | def fail(self, message: Text, param: Optional[Parameter] = ..., ctx: Optional[Context] = ...) -> None: 38 | ... 39 | 40 | 41 | class BoolParamType(ParamType): 42 | def __call__( 43 | self, 44 | value: Optional[Text], 45 | param: Optional[Parameter] = ..., 46 | ctx: Optional[Context] = ..., 47 | ) -> bool: 48 | ... 49 | 50 | def convert( 51 | self, 52 | value: Text, 53 | param: Optional[Parameter], 54 | ctx: Optional[Context], 55 | ) -> bool: 56 | ... 57 | 58 | 59 | class CompositeParamType(ParamType): 60 | arity: int 61 | 62 | 63 | class Choice(ParamType): 64 | choices: Iterable[Text] 65 | def __init__(self, choices: Iterable[Text]) -> None: 66 | ... 67 | 68 | 69 | class FloatParamType(ParamType): 70 | def __call__( 71 | self, 72 | value: Optional[Text], 73 | param: Optional[Parameter] = ..., 74 | ctx: Optional[Context] = ..., 75 | ) -> float: 76 | ... 77 | 78 | def convert( 79 | self, 80 | value: Text, 81 | param: Optional[Parameter], 82 | ctx: Optional[Context], 83 | ) -> float: 84 | ... 85 | 86 | 87 | class FloatRange(FloatParamType): 88 | ... 89 | 90 | 91 | class File(ParamType): 92 | def __init__( 93 | self, 94 | mode: Text = ..., 95 | encoding: Optional[Text] = ..., 96 | errors: Optional[Text] = ..., 97 | lazy: Optional[bool] = ..., 98 | atomic: Optional[bool] = ..., 99 | ) -> None: 100 | ... 101 | 102 | def __call__( 103 | self, 104 | value: Optional[Text], 105 | param: Optional[Parameter] = ..., 106 | ctx: Optional[Context] = ..., 107 | ) -> IO[Any]: 108 | ... 109 | 110 | def convert( 111 | self, 112 | value: Text, 113 | param: Optional[Parameter], 114 | ctx: Optional[Context], 115 | ) -> IO[Any]: 116 | ... 117 | 118 | def resolve_lazy_flag(self, value: Text) -> bool: 119 | ... 120 | 121 | 122 | _F = TypeVar('_F') # result of the function 123 | _Func = Callable[[Optional[Text]], _F] 124 | 125 | 126 | class FuncParamType(ParamType): 127 | func: _Func[Any] 128 | 129 | def __init__(self, func: _Func[Any]) -> None: 130 | ... 131 | 132 | def __call__( 133 | self, 134 | value: Optional[Text], 135 | param: Optional[Parameter] = ..., 136 | ctx: Optional[Context] = ..., 137 | ) -> _F: 138 | ... 139 | 140 | def convert( 141 | self, 142 | value: Text, 143 | param: Optional[Parameter], 144 | ctx: Optional[Context], 145 | ) -> _F: 146 | ... 147 | 148 | 149 | class IntParamType(ParamType): 150 | def __call__( 151 | self, 152 | value: Optional[Text], 153 | param: Optional[Parameter] = ..., 154 | ctx: Optional[Context] = ..., 155 | ) -> int: 156 | ... 157 | 158 | def convert( 159 | self, 160 | value: Text, 161 | param: Optional[Parameter], 162 | ctx: Optional[Context], 163 | ) -> int: 164 | ... 165 | 166 | 167 | class IntRange(IntParamType): 168 | def __init__( 169 | self, min: Optional[int] = ..., max: Optional[int] = ..., clamp: bool = ... 170 | ) -> None: 171 | ... 172 | 173 | 174 | _PathType = TypeVar('_PathType', Text, bytes) 175 | 176 | 177 | class Path(ParamType): 178 | def __init__( 179 | self, 180 | exists: bool = ..., 181 | file_okay: bool = ..., 182 | dir_okay: bool = ..., 183 | writable: bool = ..., 184 | readable: bool = ..., 185 | resolve_path: bool = ..., 186 | allow_dash: bool = ..., 187 | path_type: Optional[_PathType] = ..., 188 | ) -> None: 189 | ... 190 | 191 | def coerce_path_result(self, rv: Union[Text, bytes]) -> _PathType: 192 | ... 193 | 194 | def __call__( 195 | self, 196 | value: Optional[Text], 197 | param: Optional[Parameter] = ..., 198 | ctx: Optional[Context] = ..., 199 | ) -> _PathType: 200 | ... 201 | 202 | def convert( 203 | self, 204 | value: Text, 205 | param: Optional[Parameter], 206 | ctx: Optional[Context], 207 | ) -> _PathType: 208 | ... 209 | 210 | class StringParamType(ParamType): 211 | def __call__( 212 | self, 213 | value: Optional[Text], 214 | param: Optional[Parameter] = ..., 215 | ctx: Optional[Context] = ..., 216 | ) -> Text: 217 | ... 218 | 219 | def convert( 220 | self, 221 | value: Text, 222 | param: Optional[Parameter], 223 | ctx: Optional[Context], 224 | ) -> Text: 225 | ... 226 | 227 | 228 | class Tuple(CompositeParamType): 229 | types: List[ParamType] 230 | 231 | def __init__(self, types: Iterable[Any]) -> None: 232 | ... 233 | 234 | def __call__( 235 | self, 236 | value: Optional[Text], 237 | param: Optional[Parameter] = ..., 238 | ctx: Optional[Context] = ..., 239 | ) -> Tuple: 240 | ... 241 | 242 | def convert( 243 | self, 244 | value: Text, 245 | param: Optional[Parameter], 246 | ctx: Optional[Context], 247 | ) -> Tuple: 248 | ... 249 | 250 | 251 | class UnprocessedParamType(ParamType): 252 | ... 253 | 254 | 255 | class UUIDParameterType(ParamType): 256 | def __call__( 257 | self, 258 | value: Optional[Text], 259 | param: Optional[Parameter] = ..., 260 | ctx: Optional[Context] = ..., 261 | ) -> uuid.UUID: 262 | ... 263 | 264 | def convert( 265 | self, 266 | value: Text, 267 | param: Optional[Parameter], 268 | ctx: Optional[Context], 269 | ) -> uuid.UUID: 270 | ... 271 | 272 | 273 | def convert_type(ty: Any, default: Optional[Any] = ...) -> ParamType: 274 | ... 275 | 276 | # parameter type shortcuts 277 | 278 | BOOL = BoolParamType() 279 | FLOAT = FloatParamType() 280 | INT = IntParamType() 281 | STRING = StringParamType() 282 | UNPROCESSED = UnprocessedParamType() 283 | UUID = UUIDParameterType() 284 | -------------------------------------------------------------------------------- /tests/custom_layers_test.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import onnx 6 | import unittest 7 | 8 | from tests._test_utils import _onnx_create_model 9 | from onnx import helper, numpy_helper, ModelProto, TensorProto 10 | from onnx_coreml import convert 11 | from coremltools.proto import NeuralNetwork_pb2 #type: ignore 12 | 13 | def _make_model_acos_exp_topk(): # type: (...) -> ModelProto 14 | ''' 15 | make a very simple model for testing: input->clip->exp->topk->2 outputs 16 | ''' 17 | inputs = [('input0', (10,), TensorProto.FLOAT), ('K', (1,), TensorProto.INT64)] 18 | outputs = [('output_values', (3,), TensorProto.FLOAT), ('output_indices', (3,), TensorProto.INT64)] 19 | acos = helper.make_node("Acos", 20 | inputs=[inputs[0][0]], 21 | outputs=['acos_out']) 22 | exp = helper.make_node("Exp", 23 | inputs=[acos.output[0]], 24 | outputs=['exp_out']) 25 | topk = helper.make_node("TopK", 26 | inputs=[exp.output[0], inputs[1][0]], 27 | outputs=[outputs[0][0], outputs[1][0]], 28 | axis=0) 29 | return _onnx_create_model([acos, exp, topk], inputs, outputs) 30 | 31 | def _make_model_flatten_axis3(): # type: (...) -> ModelProto 32 | ''' 33 | make a simple model: 4-D input -> flatten (axis=3)-> output 34 | ''' 35 | inputs = [('input', (1,3,10,20), TensorProto.FLOAT)] 36 | outputs = [('output', (30,20), TensorProto.FLOAT)] 37 | flatten = helper.make_node("Flatten", 38 | inputs=[inputs[0][0]], 39 | outputs=[outputs[0][0]], 40 | axis=3) 41 | return _onnx_create_model([flatten], inputs, outputs) 42 | 43 | 44 | class CustomLayerTest(unittest.TestCase): 45 | 46 | def test_unsupported_ops(self): # type: () -> None 47 | 48 | onnx_model = _make_model_acos_exp_topk() 49 | coreml_model = convert(onnx_model, add_custom_layers=True) 50 | 51 | spec = coreml_model.get_spec() 52 | layers = spec.neuralNetwork.layers 53 | self.assertIsNotNone(layers[0].custom) 54 | self.assertIsNotNone(layers[2].custom) 55 | self.assertEqual('Acos', layers[0].custom.className) 56 | self.assertEqual('TopK', layers[2].custom.className) 57 | 58 | def test_unsupported_ops_provide_functions(self): # type: () -> None 59 | 60 | def convert_acos(builder, node, graph, err): 61 | params = NeuralNetwork_pb2.CustomLayerParams() 62 | params.className = node.op_type 63 | params.description = "Custom layer that corresponds to the ONNX op {}".format(node.op_type, ) 64 | 65 | builder.add_custom( 66 | name=node.name, 67 | input_names=node.inputs, 68 | output_names=node.outputs, 69 | custom_proto_spec=params 70 | ) 71 | 72 | def convert_topk(builder, node, graph, err): 73 | params = NeuralNetwork_pb2.CustomLayerParams() 74 | params.className = node.op_type 75 | params.description = "Custom layer that corresponds to the ONNX op {}".format(node.op_type, ) 76 | params.parameters["axis"].intValue = node.attrs.get('axis', -1) 77 | 78 | builder.add_custom( 79 | name=node.name, 80 | input_names=node.inputs, 81 | output_names=node.outputs, 82 | custom_proto_spec=params 83 | ) 84 | 85 | onnx_model = _make_model_acos_exp_topk() 86 | coreml_model = convert(model=onnx_model, 87 | add_custom_layers=True, 88 | custom_conversion_functions={'Acos':convert_acos, 'TopK':convert_topk}) 89 | 90 | spec = coreml_model.get_spec() 91 | layers = spec.neuralNetwork.layers 92 | self.assertIsNotNone(layers[0].custom) 93 | self.assertIsNotNone(layers[2].custom) 94 | self.assertEqual('Acos', layers[0].custom.className) 95 | self.assertEqual('TopK', layers[2].custom.className) 96 | self.assertEqual(0, layers[2].custom.parameters['axis'].intValue) 97 | 98 | def test_node_name_type_custom_functions(self): # type: () -> None 99 | def convert_acos(builder, node, graph, err): 100 | params = NeuralNetwork_pb2.CustomLayerParams() 101 | params.className = node.op_type 102 | params.description = "Custom layer that corresponds to the ONNX op {}".format(node.op_type, ) 103 | 104 | builder.add_custom( 105 | name=node.name, 106 | input_names=node.inputs, 107 | output_names=node.outputs, 108 | custom_proto_spec=params 109 | ) 110 | 111 | def convert_topk_generic(builder, node, graph, err): 112 | params = NeuralNetwork_pb2.CustomLayerParams() 113 | params.className = node.op_type 114 | params.description = "Custom layer that corresponds to the ONNX op {}".format(node.op_type, ) 115 | params.parameters["axis"].intValue = node.attrs.get('axis', -1) 116 | params.parameters["k"].intValue = node.attrs['k'] 117 | 118 | builder.add_custom( 119 | name=node.name, 120 | input_names=node.inputs, 121 | output_names=node.outputs, 122 | custom_proto_spec=params 123 | ) 124 | 125 | def convert_topk_node_specific(builder, node, graph, err): 126 | params = NeuralNetwork_pb2.CustomLayerParams() 127 | params.className = node.op_type 128 | params.description = "Custom layer that corresponds to the ONNX op {}".format(node.op_type, ) 129 | params.parameters["axis"].intValue = node.attrs.get('axis', -1) 130 | 131 | builder.add_custom( 132 | name=node.name, 133 | input_names=node.inputs, 134 | output_names=node.outputs, 135 | custom_proto_spec=params 136 | ) 137 | 138 | onnx_model = _make_model_acos_exp_topk() 139 | onnx.save_model(onnx_model, 'acos.onnx') 140 | coreml_model = convert(model=onnx_model, 141 | add_custom_layers=True, 142 | custom_conversion_functions={'Acos':convert_acos, 'TopK':convert_topk_generic, 143 | 'output_values_output_indices':convert_topk_node_specific}) 144 | 145 | spec = coreml_model.get_spec() 146 | layers = spec.neuralNetwork.layers 147 | self.assertIsNotNone(layers[0].custom) 148 | self.assertIsNotNone(layers[2].custom) 149 | self.assertEqual('Acos', layers[0].custom.className) 150 | self.assertEqual('TopK', layers[2].custom.className) 151 | self.assertEqual(0, layers[2].custom.parameters['axis'].intValue) 152 | 153 | def test_unsupported_op_attribute(self): # type: () -> None 154 | onnx_model = _make_model_flatten_axis3() 155 | coreml_model = convert(onnx_model, add_custom_layers=True) 156 | 157 | spec = coreml_model.get_spec() 158 | layers = spec.neuralNetwork.layers 159 | self.assertIsNotNone(layers[0].custom) 160 | self.assertEqual('Flatten', layers[0].custom.className) 161 | 162 | def test_unsupported_op_attribute_provide_functions(self): # type: () -> None 163 | 164 | def convert_flatten(builder, node, graph, err): 165 | params = NeuralNetwork_pb2.CustomLayerParams() 166 | params.className = node.op_type 167 | params.description = "Custom layer that corresponds to the ONNX op {}".format(node.op_type, ) 168 | params.parameters["axis"].intValue = node.attrs['axis'] 169 | 170 | builder.add_custom( 171 | name=node.name, 172 | input_names=node.inputs, 173 | output_names=node.outputs, 174 | custom_proto_spec=params 175 | ) 176 | 177 | def test_conversion(onnx_model, add_custom_layers=False): 178 | coreml_model = convert(onnx_model, add_custom_layers=add_custom_layers, 179 | custom_conversion_functions={'Flatten': convert_flatten}) 180 | 181 | spec = coreml_model.get_spec() 182 | layers = spec.neuralNetwork.layers 183 | self.assertIsNotNone(layers[0].custom) 184 | self.assertEqual('Flatten', layers[0].custom.className) 185 | self.assertEqual(3, layers[0].custom.parameters['axis'].intValue) 186 | 187 | onnx_model = _make_model_flatten_axis3() 188 | # Test with add_custom_layers True 189 | convert(onnx_model, add_custom_layers=True, 190 | custom_conversion_functions={'Flatten': convert_flatten}) 191 | 192 | # Test with add_custom_layers False 193 | convert(onnx_model, add_custom_layers=False, 194 | custom_conversion_functions={'Flatten': convert_flatten}) 195 | 196 | if __name__ == '__main__': 197 | unittest.main() 198 | -------------------------------------------------------------------------------- /tests/transformers_test.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | import unittest 7 | import numpy as np 8 | import numpy.testing as npt # type: ignore 9 | 10 | from onnx import helper, numpy_helper, TensorProto 11 | 12 | from onnx_coreml import convert 13 | from onnx_coreml._graph import Graph 14 | from onnx_coreml._transformers import ConvAddFuser, DropoutRemover, ImageScalerRemover 15 | from tests._test_utils import _onnx_create_model, _test_onnx_model, \ 16 | _conv_pool_output_size, _random_array 17 | 18 | 19 | class ConvAddFuserTest(unittest.TestCase): 20 | def test_fuse_conv_without_bias(self): # type: () -> None 21 | kernel_shape = (3, 2) 22 | strides = (2, 3) 23 | pads = (4, 2, 4, 2) 24 | dilations = (1, 2) 25 | group = 1 26 | weight = numpy_helper.from_array( 27 | _random_array((16, 3, 3, 2)), name="weight" 28 | ) 29 | 30 | input_shape = (1, 3, 224, 224) 31 | output_size = _conv_pool_output_size(input_shape, dilations, 32 | kernel_shape, pads, strides) 33 | 34 | output_shape = (1, int(weight.dims[0]), output_size[0], output_size[1]) 35 | 36 | inputs = [('input0', input_shape)] 37 | outputs = [('output0', output_shape, TensorProto.FLOAT)] 38 | 39 | conv = helper.make_node( 40 | "Conv", 41 | inputs=[inputs[0][0], "weight"], 42 | outputs=["conv_output"], 43 | dilations=dilations, 44 | group=group, 45 | kernel_shape=kernel_shape, 46 | pads=pads, 47 | strides=strides 48 | ) 49 | 50 | b = _random_array((int(weight.dims[0]),)) 51 | bias = numpy_helper.from_array( 52 | b, name="bias" 53 | ) 54 | 55 | add = helper.make_node( 56 | "Add", 57 | inputs=[conv.output[0], "bias"], 58 | outputs=[outputs[0][0]], 59 | broadcast=1, 60 | axis=1 61 | ) 62 | 63 | model = _onnx_create_model( 64 | [conv, add], inputs, outputs, [weight, bias] 65 | ) 66 | graph_ = Graph.from_onnx(model.graph, onnx_ir_version=5) 67 | fused_graph = graph_.transformed([ConvAddFuser()]) 68 | 69 | self.assertEqual(len(fused_graph.nodes), 1) 70 | node = fused_graph.nodes[0] 71 | self.assertEqual(len(node.inputs), 3) 72 | npt.assert_equal(node.input_tensors[node.inputs[2]], b) 73 | self.assertEqual(fused_graph.nodes[0].outputs[0], outputs[0][0]) 74 | 75 | def test_fuse_conv_with_bias(self): # type: () -> None 76 | kernel_shape = (3, 2) 77 | strides = (2, 3) 78 | pads = (4, 2, 4, 2) 79 | dilations = (1, 2) 80 | group = 1 81 | weight = numpy_helper.from_array( 82 | _random_array((16, 3, 3, 2)), name="weight" 83 | ) 84 | b = _random_array((int(weight.dims[0]),)) 85 | bias = numpy_helper.from_array( 86 | b, name="bias" 87 | ) 88 | 89 | input_shape = (1, 3, 224, 224) 90 | output_size = _conv_pool_output_size(input_shape, dilations, 91 | kernel_shape, pads, strides) 92 | 93 | output_shape = (1, int(weight.dims[0]), output_size[0], output_size[1]) 94 | 95 | inputs = [('input0', input_shape)] 96 | outputs = [('output0', output_shape, TensorProto.FLOAT)] 97 | 98 | conv = helper.make_node( 99 | "Conv", 100 | inputs=[inputs[0][0], "weight", "bias"], 101 | outputs=["conv_output"], 102 | dilations=dilations, 103 | group=group, 104 | kernel_shape=kernel_shape, 105 | pads=pads, 106 | strides=strides 107 | ) 108 | 109 | add = helper.make_node( 110 | "Add", 111 | inputs=[conv.output[0], "bias"], 112 | outputs=[outputs[0][0]], 113 | broadcast=1, 114 | axis=1 115 | ) 116 | 117 | model = _onnx_create_model( 118 | [conv, add], inputs, outputs, [weight, bias] 119 | ) 120 | graph_ = Graph.from_onnx(model.graph, onnx_ir_version=5) 121 | fused_graph = graph_.transformed([ConvAddFuser()]) 122 | 123 | self.assertEqual(len(fused_graph.nodes), 1) 124 | node = fused_graph.nodes[0] 125 | self.assertEqual(len(node.inputs), 3) 126 | npt.assert_equal(node.input_tensors[node.inputs[2]], b * 2) 127 | self.assertEqual(fused_graph.nodes[0].outputs[0], outputs[0][0]) 128 | 129 | 130 | class NodeRemoverTests(unittest.TestCase): 131 | 132 | def test_dropout_remover(self): # type: () -> None 133 | inputs = [('input', (1,3,50,50))] 134 | outputs = [('out', (1,5,50,50), TensorProto.FLOAT)] 135 | weight = numpy_helper.from_array(_random_array((5, 3, 1, 1)), name="weight") 136 | conv = helper.make_node( 137 | "Conv", 138 | inputs=["input", "weight"], 139 | outputs=["conv_output"], 140 | kernel_shape=(1,1), 141 | strides=(1,1) 142 | ) 143 | drop = helper.make_node("Dropout", 144 | inputs = ["conv_output"], 145 | outputs = ["drop_output"], 146 | ) 147 | exp = helper.make_node("Exp", 148 | inputs=["drop_output"], 149 | outputs=['out']) 150 | 151 | onnx_model = _onnx_create_model([conv, drop, exp], inputs, outputs) 152 | 153 | graph = Graph.from_onnx(onnx_model.graph, onnx_ir_version=5) 154 | new_graph = graph.transformed([DropoutRemover()]) 155 | self.assertEqual(len(graph.nodes), 3) 156 | self.assertEqual(len(new_graph.nodes), 2) 157 | self.assertEqual(new_graph.nodes[0].inputs[0], 'input') 158 | self.assertEqual(new_graph.nodes[1].inputs[0], new_graph.nodes[0].outputs[0]) 159 | self.assertEqual(new_graph.nodes[1].outputs[0], 'out') 160 | 161 | def test_image_scaler_remover(self): # type: () -> None 162 | inputs = [('input', (1,3,50,50))] 163 | outputs = [('out', (1,3,50,50), TensorProto.FLOAT)] 164 | 165 | im_scaler = helper.make_node("ImageScaler", 166 | inputs = ['input'], 167 | outputs = ['scaler_out'], 168 | bias = [10,-6,20], scale=3.0) 169 | 170 | exp = helper.make_node("Exp", 171 | inputs=["scaler_out"], 172 | outputs=['out']) 173 | 174 | onnx_model = _onnx_create_model([im_scaler, exp], inputs, outputs) 175 | 176 | graph = Graph.from_onnx(onnx_model.graph, onnx_ir_version=5) 177 | new_graph = graph.transformed([ImageScalerRemover()]) 178 | self.assertEqual(len(graph.nodes), 2) 179 | self.assertEqual(len(new_graph.nodes), 1) 180 | self.assertEqual(new_graph.nodes[0].inputs[0], 'input') 181 | self.assertEqual(new_graph.nodes[0].outputs[0], 'out') 182 | 183 | coreml_model = convert(onnx_model) 184 | spec = coreml_model.get_spec() 185 | 186 | self.assertEqual(spec.neuralNetwork.preprocessing[0].scaler.channelScale, 3.0) 187 | self.assertEqual(spec.neuralNetwork.preprocessing[0].scaler.blueBias, 20.0) 188 | self.assertEqual(spec.neuralNetwork.preprocessing[0].scaler.greenBias, -6.0) 189 | self.assertEqual(spec.neuralNetwork.preprocessing[0].scaler.redBias, 10.0) 190 | 191 | def test_multiple_image_scaler(self): # type : () -> None 192 | inputs = [('input_color', (1,3,10,10)), ('input_gray', (1,1,10,10))] 193 | outputs = [('out', (1,4,10,10), TensorProto.FLOAT)] 194 | 195 | im_scaler1 = helper.make_node("ImageScaler", 196 | inputs = ['input_color'], 197 | outputs = ['scaler_out_1'], 198 | bias = [10,-6,20], scale=3.0) 199 | 200 | im_scaler2 = helper.make_node("ImageScaler", 201 | inputs = ['input_gray'], 202 | outputs = ['scaler_out_2'], 203 | bias = [-13], scale=5.0) 204 | 205 | concat = helper.make_node("Concat", 206 | inputs=['scaler_out_1', 'scaler_out_2'], 207 | outputs=['out'], 208 | axis = 1) 209 | 210 | onnx_model = _onnx_create_model([im_scaler1, im_scaler2, concat], inputs, outputs) 211 | 212 | spec = convert(onnx_model).get_spec() 213 | self.assertEqual(len(spec.neuralNetwork.layers), 1) 214 | self.assertEqual(len(spec.neuralNetwork.preprocessing), 2) 215 | self.assertEqual(spec.neuralNetwork.preprocessing[0].scaler.channelScale, 3.0) 216 | self.assertEqual(spec.neuralNetwork.preprocessing[0].scaler.blueBias, 20.0) 217 | self.assertEqual(spec.neuralNetwork.preprocessing[0].scaler.greenBias, -6.0) 218 | self.assertEqual(spec.neuralNetwork.preprocessing[0].scaler.redBias, 10.0) 219 | self.assertEqual(spec.neuralNetwork.preprocessing[1].scaler.channelScale, 5.0) 220 | self.assertEqual(spec.neuralNetwork.preprocessing[1].scaler.grayBias, -13.0) 221 | 222 | 223 | 224 | class PixelShuffleFuserTest(unittest.TestCase): 225 | def test_pixel_shuffle(self): # type: () -> None 226 | scale_factor = 2 227 | input_shape = (1, 8, 2, 2) 228 | output_shape = ( 229 | input_shape[0], 230 | int(input_shape[1] / (scale_factor ** 2)), 231 | input_shape[2] * scale_factor, 232 | input_shape[3] * scale_factor 233 | ) 234 | 235 | inputs = [('input0', input_shape)] 236 | outputs = [('output0', output_shape, TensorProto.FLOAT)] 237 | 238 | shape1 = [ 239 | output_shape[0], 240 | output_shape[1], 241 | scale_factor, 242 | scale_factor, 243 | input_shape[2], 244 | input_shape[3] 245 | ] 246 | 247 | shape1 = numpy_helper.from_array(np.asarray(shape1), name="shape1") 248 | shape2 = numpy_helper.from_array(np.asarray(list(output_shape)), name="shape2") 249 | 250 | node_0 = helper.make_node( 251 | "Reshape", 252 | inputs=[inputs[0][0], 'shape1'], 253 | outputs=['node0'], 254 | ) 255 | node_1 = helper.make_node( 256 | "Transpose", 257 | inputs=['node0'], 258 | outputs=['node1'], 259 | perm=[0, 1, 4, 2, 5, 3] 260 | ) 261 | node_2 = helper.make_node( 262 | "Reshape", 263 | inputs=['node1','shape2'], 264 | outputs=[outputs[0][0]], 265 | ) 266 | model = _onnx_create_model( 267 | [node_0, node_1, node_2], inputs, outputs, initializer=[shape1, shape2] 268 | ) 269 | _test_onnx_model(model, decimal=7) 270 | 271 | 272 | if __name__ == '__main__': 273 | unittest.main() 274 | -------------------------------------------------------------------------------- /tests/_test_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | import numpy as np 7 | import numpy.testing as npt # type: ignore 8 | import numpy.random as npr 9 | from onnx import helper, TensorProto, ModelProto, ValueInfoProto, TensorProto, NodeProto 10 | from typing import Any, Sequence, Text, Tuple, Optional, Dict, List, TypeVar 11 | from onnx_coreml import convert 12 | from onnx_coreml.converter import SupportedVersion 13 | from onnx_coreml._graph import Node 14 | import sys 15 | import shutil 16 | import os 17 | 18 | ''' 19 | 0: dynamically generate random inputs, 20 | use caffe2 backend for onnx and 21 | also save out the generated input and output dicts for future use 22 | 1: use the already saved out input and output dicts for testing 23 | ''' 24 | TEST_MODE = 1 25 | 26 | def _forward_onnx_model(model, # type: ModelProto 27 | input_dict, # type: Dict[Text, np._ArrayLike[Any]] 28 | test_name = '' # type: Text 29 | ): 30 | # type: (...) -> np.ndarray[Any] 31 | if TEST_MODE: 32 | current_dir_path = os.path.dirname(os.path.realpath(__file__)) 33 | loaded_obj = np.load(current_dir_path + '/test_data/' + test_name + '/output.npy', encoding='bytes', allow_pickle=True) #type: ignore 34 | out = loaded_obj.item() 35 | else: 36 | import caffe2.python.onnx.backend # type: ignore 37 | prepared_backend = caffe2.python.onnx.backend.prepare(model) 38 | out = prepared_backend.run(input_dict) 39 | out_dict = {} 40 | out_names = [v.name for v in model.graph.output] 41 | for out_name in out_names: 42 | out_dict[out_name] = out[out_name] 43 | dir = os.path.dirname(os.path.realpath(__file__)) + '/test_data/' + test_name + '/' 44 | np.save(dir + 'output.npy', out_dict) #type: ignore 45 | 46 | result = [out[v.name] for v in model.graph.output] 47 | output_shapes = [ 48 | _shape_from_onnx_value_info(o) for o in model.graph.output 49 | ] 50 | for i, output in enumerate(result): 51 | result[i] = output.reshape(output_shapes[i]) 52 | return np.array(result) 53 | 54 | 55 | def _onnx_create_model(nodes, # type: Sequence[NodeProto] 56 | inputs, # type: Sequence[Tuple[Text,Tuple[int, ...]]] 57 | outputs, # type: Sequence[Tuple[Text,Tuple[int, ...], int]] 58 | initializer=[], # type: Sequence[TensorProto] 59 | ): 60 | # type: (...) -> ModelProto 61 | initializer_inputs = [ 62 | helper.make_tensor_value_info( 63 | t.name, 64 | TensorProto.FLOAT, 65 | t.dims 66 | ) for t in initializer 67 | ] 68 | 69 | graph = helper.make_graph( 70 | nodes=nodes, 71 | name="test", 72 | inputs=initializer_inputs + [ 73 | helper.make_tensor_value_info( 74 | input_[0], 75 | TensorProto.FLOAT, 76 | input_[1] 77 | ) for input_ in inputs 78 | ], 79 | outputs=[ 80 | helper.make_tensor_value_info( 81 | output_[0], 82 | output_[2], 83 | output_[1] 84 | ) for output_ in outputs 85 | ], 86 | initializer=initializer 87 | ) 88 | onnx_model = helper.make_model(graph) 89 | return onnx_model 90 | 91 | 92 | def _onnx_create_single_node_model(op_type, # type: Text 93 | input_shapes, # type: Sequence[Tuple[int, ...]] 94 | output_shapes, # type: Sequence[Tuple[int, ...]] 95 | initializer=[], # type: Sequence[TensorProto] 96 | **kwargs # type: Any 97 | ): 98 | # type: (...) -> ModelProto 99 | inputs = [ 100 | ("input{}".format(i,), input_shapes[i]) 101 | for i in range(len(input_shapes)) 102 | ] 103 | outputs = [ 104 | ("output{}".format(i,), output_shapes[i], TensorProto.FLOAT) 105 | for i in range(len(output_shapes)) 106 | ] 107 | 108 | node = helper.make_node( 109 | op_type, 110 | inputs=[i[0] for i in inputs] + [t.name for t in initializer], 111 | outputs=[o[0] for o in outputs], 112 | **kwargs 113 | ) 114 | return _onnx_create_model([node], inputs, outputs, initializer) 115 | 116 | 117 | def _shape_from_onnx_value_info(v): # type: (ValueInfoProto) -> Sequence[Tuple[int, ...]] 118 | return tuple([d.dim_value for d in v.type.tensor_type.shape.dim]) 119 | 120 | def _coreml_forward_model(model, # type: ModelProto 121 | input_dict, # type: Dict[Text, np._ArrayLike[Any]] 122 | output_names, # type: Sequence[Text] 123 | minimum_ios_deployment_target='12' 124 | ): 125 | # type: (...) -> np.ndarray[Any] 126 | if not SupportedVersion.is_nd_array_supported(minimum_ios_deployment_target): 127 | for k, arr in input_dict.items(): 128 | if len(arr.shape) == 4: 129 | input_dict[k] = arr[0] 130 | for k,v in input_dict.items(): 131 | if len(v.shape) == 2 and v.shape[0] == 1: 132 | input_dict[k] = v.flatten() 133 | coreml_out = model.predict(input_dict, useCPUOnly=True) 134 | return np.array([coreml_out[name] for name in output_names]) 135 | 136 | 137 | def _coreml_forward_onnx_model(model, # type: ModelProto 138 | input_dict, # type: Dict[Text, np._ArrayLike[Any]] 139 | onnx_coreml_input_shape_map = {}, # type: Dict[Text, List[int,...]] 140 | minimum_ios_deployment_target='12' 141 | ): 142 | # type: (...) -> np.ndarray[Any] 143 | coreml_model = convert(model, onnx_coreml_input_shape_map=onnx_coreml_input_shape_map, minimum_ios_deployment_target=minimum_ios_deployment_target) 144 | output_names = [o.name for o in model.graph.output] 145 | return _coreml_forward_model(coreml_model, input_dict, output_names, minimum_ios_deployment_target=minimum_ios_deployment_target) 146 | 147 | 148 | def _random_array(shape, random_seed=10): # type: (Tuple[int, ...], Any) -> np._ArrayLike[float] 149 | if random_seed: 150 | npr.seed(random_seed) # type: ignore 151 | return npr.ranf(shape).astype("float32") 152 | 153 | 154 | def _conv_pool_output_size(input_shape, # type: Sequence[int] 155 | dilations, # type: Sequence[int] 156 | kernel_shape, # type: Tuple[int, int] 157 | pads, # type: Sequence[int] 158 | strides, # type: Tuple[int, int] 159 | ): 160 | # type: (...) -> Tuple[int, int] 161 | output_height = ( 162 | input_shape[2] + pads[0] + pads[2] - 163 | (dilations[0] * (kernel_shape[0] - 1) + 1) 164 | ) / strides[0] + 1 165 | output_width = ( 166 | input_shape[3] + pads[1] + pads[3] - 167 | (dilations[1] * (kernel_shape[1] - 1) + 1) 168 | ) / strides[1] + 1 169 | 170 | return (int(output_height), int(output_width)) 171 | 172 | 173 | _T = TypeVar('_T') 174 | 175 | 176 | def _assert_outputs(output1, # type: np.ndarray[_T] 177 | output2, # type: np.ndarray[_T] 178 | decimal=7, # type: int 179 | ): 180 | # type: (...) -> None 181 | npt.assert_equal(len(output1), len(output2)) 182 | for o1, o2 in zip(output1, output2): 183 | npt.assert_almost_equal( 184 | o2.flatten(), 185 | o1.flatten(), 186 | decimal=decimal 187 | ) 188 | 189 | 190 | def _prepare_inputs_for_onnx(model, # type: ModelProto 191 | test_name = '', # type: Text 192 | values=None, # type: Optional[List[np._ArrayLike[Any]]] 193 | ): 194 | # type: (...) -> Dict[Text, np._ArrayLike[Any]] 195 | graph = model.graph 196 | initializer_names = {t.name for t in graph.initializer} 197 | input_names = [ 198 | i.name for i in graph.input if i.name not in initializer_names 199 | ] 200 | input_shapes = [ 201 | tuple([d.dim_value for d in i.type.tensor_type.shape.dim]) 202 | for i in graph.input if i.name not in initializer_names 203 | ] 204 | 205 | if TEST_MODE: 206 | dir_path = os.path.dirname(os.path.realpath(__file__)) 207 | loaded_obj = np.load(dir_path + '/test_data/' + test_name + '/input.npy', encoding='bytes', allow_pickle=True) # type: ignore 208 | return loaded_obj.item() # type: ignore 209 | else: 210 | if values is None: 211 | inputs = [_random_array(shape) for shape in input_shapes] 212 | else: 213 | inputs = values 214 | input_dict = dict(zip(input_names, inputs)) 215 | dir = os.path.dirname(os.path.realpath(__file__)) + '/test_data/' + test_name + '/' 216 | if os.path.exists(dir): 217 | shutil.rmtree(dir) 218 | os.makedirs(dir) 219 | np.save(dir + 'input.npy', input_dict) # type: ignore 220 | return input_dict 221 | 222 | 223 | def _test_onnx_model(model, # type: ModelProto 224 | test_name='', # type: Text 225 | decimal=5, # type: int 226 | onnx_coreml_input_shape_map = {}, # type: Dict[Text, List[int,...]] 227 | coreml_input_shape = {}, # type: Dict[Text, List[int,...]] 228 | minimum_ios_deployment_target='12' 229 | ): 230 | # type: (...) -> None 231 | if not test_name: 232 | test_name = sys._getframe(1).f_code.co_name 233 | W = _prepare_inputs_for_onnx(model, test_name=test_name) 234 | c2_outputs = _forward_onnx_model(model, W, test_name=test_name) 235 | coreml_input_dict = dict() 236 | # Supported iOS Version 237 | # New OS Version must be added at the end to maintain backward version index 238 | supported_ios_version = ['11.2', '12', '13'] 239 | IOS_13_VERSION = supported_ios_version.index('13') 240 | for key, value in W.items(): 241 | if supported_ios_version.index(minimum_ios_deployment_target) < IOS_13_VERSION and key in coreml_input_shape: 242 | coreml_input_dict[key] = np.reshape(value, coreml_input_shape[key]) 243 | else: 244 | coreml_input_dict[key] = value 245 | coreml_outputs = _coreml_forward_onnx_model(model, coreml_input_dict, onnx_coreml_input_shape_map=onnx_coreml_input_shape_map, 246 | minimum_ios_deployment_target=minimum_ios_deployment_target) 247 | _assert_outputs(c2_outputs, coreml_outputs, decimal=decimal) 248 | 249 | 250 | def _test_single_node(op_type, # type: Text 251 | input_shapes, # type: Sequence[Tuple[int, ...]] 252 | output_shapes, # type: Sequence[Tuple[int, ...]] 253 | initializer=[], # type: Sequence[TensorProto] 254 | decimal=5, # type: int 255 | test_name = '', # type: Text 256 | onnx_coreml_input_shape_map = {}, # type: Dict[Text, List[int,...]] 257 | coreml_input_shape = {}, # type: Dict[Text, List[int,...]] 258 | minimum_ios_deployment_target='12', 259 | **kwargs # type: Any 260 | ): 261 | # type: (...) -> None 262 | model = _onnx_create_single_node_model( 263 | op_type, input_shapes, output_shapes, initializer, **kwargs 264 | ) 265 | if not test_name: 266 | test_name = sys._getframe(1).f_code.co_name 267 | _test_onnx_model(model, test_name=test_name, decimal=decimal, 268 | onnx_coreml_input_shape_map=onnx_coreml_input_shape_map, 269 | coreml_input_shape = coreml_input_shape, 270 | minimum_ios_deployment_target=minimum_ios_deployment_target) 271 | -------------------------------------------------------------------------------- /onnx_coreml/_graph.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | from onnx import numpy_helper, ValueInfoProto, AttributeProto, GraphProto, NodeProto, TensorProto, TensorShapeProto 7 | from typing import Any, Text, Iterable, List, Dict, Sequence, Optional, Tuple, Union 8 | from typing_extensions import Protocol 9 | import numpy as np 10 | 11 | 12 | class Transformer(Protocol): 13 | def __call__(self, graph): # type: (Graph) -> Graph 14 | pass 15 | 16 | 17 | EdgeInfo = Tuple[Text, Any, TensorShapeProto] 18 | AttributeValue = Any # TODO Union[Sequence[float], Sequence[int], Sequence[Text], Sequence[TensorProto], Sequence[GraphProto]] 19 | 20 | def _input_from_onnx_input(input): # type: (ValueInfoProto) -> EdgeInfo 21 | name = input.name 22 | type = input.type.tensor_type.elem_type 23 | shape = tuple([d.dim_value for d in input.type.tensor_type.shape.dim]) 24 | return (name, type, shape) 25 | 26 | 27 | def _convertAttributeProto(onnx_arg): # type: (AttributeProto) -> AttributeValue 28 | """ 29 | Convert an ONNX AttributeProto into an appropriate Python object 30 | for the type. 31 | NB: Tensor attribute gets returned as numpy array 32 | """ 33 | if onnx_arg.HasField('f'): 34 | return onnx_arg.f 35 | elif onnx_arg.HasField('i'): 36 | return onnx_arg.i 37 | elif onnx_arg.HasField('s'): 38 | return onnx_arg.s 39 | elif onnx_arg.HasField('t'): 40 | return numpy_helper.to_array(onnx_arg.t) 41 | elif len(onnx_arg.floats): 42 | return list(onnx_arg.floats) 43 | elif len(onnx_arg.ints): 44 | return list(onnx_arg.ints) 45 | elif len(onnx_arg.strings): 46 | return list(onnx_arg.strings) 47 | else: 48 | return None 49 | 50 | def _extract_node_names(graph): # type : (Graph) -> List[Text] 51 | node_names = [] 52 | for node in graph.nodes: 53 | node_names.append(node.name) 54 | return node_names 55 | 56 | def _apply_graph_transformations(graph, transformers): # (Graph, Iterable[Transformer]) -> Graph 57 | old_node_names = _extract_node_names(graph) # type: ignore 58 | while True: 59 | for transformer in transformers: 60 | graph = transformer(graph) 61 | new_node_names = _extract_node_names(graph) # type: ignore 62 | if new_node_names == old_node_names: 63 | break 64 | old_node_names = new_node_names 65 | return graph 66 | 67 | class Attributes(Dict[Text, Any]): 68 | @staticmethod 69 | def from_onnx(args): # type: (Iterable[AttributeProto]) -> Attributes 70 | d = Attributes() 71 | for arg in args: 72 | val = _convertAttributeProto(arg) 73 | if val is not None: 74 | d[arg.name] = val 75 | return d 76 | 77 | 78 | class Node(object): 79 | def __init__(self, 80 | name, # type: Optional[Text] 81 | op_type, # type: Text 82 | attrs, # type: Dict[Text, AttributeValue] 83 | inputs, # type: List[Text] 84 | outputs, # type: List[Text] 85 | ): 86 | # type: (...) -> None 87 | self.name = name 88 | self.op_type = op_type 89 | self.attrs = attrs 90 | self.inputs = inputs 91 | self.outputs = outputs 92 | self.input_tensors = {} # type: Dict[Text, np._ArrayLike[Any]] 93 | self.parents = [] # type: List[Node] 94 | self.children = [] # type: List[Node] 95 | self.metadata = {} # type: Dict[Any, Any] 96 | 97 | def add_parent(self, parent_node): # type: (Node) -> None 98 | assert parent_node not in self.parents 99 | self.parents.append(parent_node) 100 | if self not in parent_node.children: 101 | parent_node.children.append(self) 102 | 103 | def add_child(self, child_node): # type: (Node) -> None 104 | assert child_node not in self.children 105 | self.children.append(child_node) 106 | if self not in child_node.parents: 107 | child_node.parents.append(self) 108 | 109 | def get_only_parent(self): # type: () -> Node 110 | if len(self.parents) != 1: 111 | raise ValueError('Node ({}) expected to have 1 parent. Found {}.' 112 | .format(self, len(self.parents))) 113 | return self.parents[0] 114 | 115 | @staticmethod 116 | def from_onnx(node): # type: (NodeProto) -> Node 117 | attrs = Attributes.from_onnx(node.attribute) 118 | name = Text(node.name) 119 | if len(name) == 0: 120 | name = "_".join(node.output) 121 | return Node( 122 | name, node.op_type, attrs, list(node.input), list(node.output) 123 | ) 124 | 125 | 126 | class Graph(object): 127 | def __init__(self, 128 | nodes, # type: List[Node] 129 | inputs, # type: List[EdgeInfo] 130 | outputs, # type: List[EdgeInfo] 131 | shape_dict, # type: Dict[Text,Tuple[int,...]] 132 | onnx_ir_version, # type: int 133 | ): 134 | # type: (...) -> None 135 | self.nodes = nodes 136 | self.inputs = inputs 137 | self.outputs = outputs 138 | self.shape_dict = shape_dict # data blob name to its shape 139 | self.constants_loaded = set() # set of constants present in graph as node 140 | self.onnx_ir_version = onnx_ir_version # ONNX IR Version for current graph 141 | 142 | self.optional_inputs = [] # list of tuple(str, tuple(int)), use with recurrent layers 143 | self.optional_outputs = [] # list of tuple(str,tuple(int)), use with recurrent layers 144 | 145 | ''' 146 | All axes in CoreML Tensor shapes are annotated. That is, 147 | 0: Sequence 148 | 1: Batch 149 | 2: Channel 150 | 3: Height 151 | 4: Width 152 | This dictionary "onnx_coreml_shape_mapping" records onnx shape to coreml shape mapping for 153 | every tensor (including intermediate tensors) in the onnx graph. 154 | The requirement is to only know the "rank" (i.e. number of dimensions) of the onnx tensor, not its actual shape, during conversion time. 155 | 156 | The Dict is "str" -> List of ints 157 | 158 | e.g. "x" -> [1,3] carries the following information: 159 | - "x" is rank 2 160 | - "x" in Coreml will have the shape [Seq=1, B=x.shape[0], C=1, H=x.shape[1], W=1] 161 | 162 | e.g. "x" -> [1,3,2] carries the following information: 163 | - "x" is rank 3 164 | - "x" in Coreml will have the shape [Seq=1, B=x.shape[0], C=x.shape[2], H=x.shape[1], W=1] 165 | 166 | The dictionary "onnx_coreml_shape_mapping" is progressively built as the onnx graph is converted to CoreML graph. 167 | The op to layer conversion functions use the information in this dict to correctly set the parameters of the CoreML layer 168 | to be added and at the end they update the dict with that layer's output(s). 169 | ''' 170 | self.onnx_coreml_shape_mapping = {} # type: Dict[Text, List[int,...]] 171 | 172 | # data blob name to the list of op types it feeds into 173 | self.blob_to_op_type = {} # type: Dict[Text, List[Text]] 174 | # data blob name to the op_type that generates it 175 | self.blob_from_op_type = {} # type: Dict[Text, Text] 176 | 177 | self.constant_layers_added = {} # type: Dict[Text, bool] 178 | 179 | for node_ in nodes: 180 | for input_ in node_.inputs: 181 | if input_ in self.blob_to_op_type: 182 | self.blob_to_op_type[input_].append(node_.op_type) 183 | else: 184 | self.blob_to_op_type[input_] = [node_.op_type] 185 | for output_ in node_.outputs: 186 | if output_ in self.blob_from_op_type: 187 | raise ValueError("Data blob: %s, is generated by more than 1 op" %(output_)) 188 | self.blob_from_op_type[output_] = node_.op_type 189 | 190 | 191 | def create_graph(self, nodes=None, inputs=None, outputs=None, shape_dict=None, onnx_ir_version=None): 192 | node = self.nodes if nodes is None else nodes 193 | inputs = self.inputs if inputs is None else inputs 194 | outputs = self.outputs if outputs is None else outputs 195 | shape_dict = self.shape_dict if shape_dict is None else shape_dict 196 | onnx_ir_version = self.onnx_ir_version if onnx_ir_version is None else onnx_ir_version 197 | return Graph(nodes, inputs, outputs, shape_dict, onnx_ir_version) 198 | 199 | def transformed(self, transformers): # type: (Iterable[Transformer]) -> Graph 200 | graph = self 201 | return _apply_graph_transformations(graph, transformers) # type: ignore 202 | 203 | 204 | def has_edge_name(self, name): # type: (Text) -> bool 205 | ''' 206 | Check if name is already used for graph inputs/outputs or for nodes 207 | inputs/outputs 208 | ''' 209 | names = set() 210 | for input in self.inputs: 211 | names.add(input[0]) 212 | for output in self.outputs: 213 | names.add(output[0]) 214 | for node in self.nodes: 215 | names.update(node.inputs) 216 | names.update(node.outputs) 217 | return name in names 218 | 219 | def get_unique_edge_name(self, name): # type: (Text) -> Text 220 | n_ = name 221 | i = 0 222 | while self.has_edge_name(n_): 223 | n_ = "{}_{}".format(name, i) 224 | i += 1 225 | return n_ 226 | 227 | @staticmethod 228 | def from_onnx(graph, onnx_ir_version): # type: (GraphProto) -> Graph 229 | input_tensors = { 230 | t.name: numpy_helper.to_array(t) for t in graph.initializer 231 | } 232 | nodes_ = [] 233 | nodes_by_input = {} # type: Dict[Text, List[Node]] 234 | nodes_by_output = {} 235 | for node in graph.node: 236 | node_ = Node.from_onnx(node) 237 | for input_ in node_.inputs: 238 | if input_ in input_tensors: 239 | node_.input_tensors[input_] = input_tensors[input_] 240 | else: 241 | if input_ in nodes_by_input: 242 | input_nodes = nodes_by_input[input_] 243 | else: 244 | input_nodes = [] 245 | nodes_by_input[input_] = input_nodes 246 | input_nodes.append(node_) 247 | for output_ in node_.outputs: 248 | nodes_by_output[output_] = node_ 249 | nodes_.append(node_) 250 | 251 | inputs = [] 252 | for i in graph.input: 253 | if i.name not in input_tensors: 254 | inputs.append(_input_from_onnx_input(i)) 255 | 256 | outputs = [] 257 | for o in graph.output: 258 | outputs.append(_input_from_onnx_input(o)) 259 | 260 | for node_ in nodes_: 261 | for input_ in node_.inputs: 262 | if input_ in nodes_by_output: 263 | node_.parents.append(nodes_by_output[input_]) 264 | for output_ in node_.outputs: 265 | if output_ in nodes_by_input: 266 | node_.children.extend(nodes_by_input[output_]) 267 | 268 | # Dictionary to hold the "value_info" field from ONNX graph 269 | shape_dict = {} # type: Dict[Text,Tuple[int,...]] 270 | 271 | def extract_value_info(shape_dict, # type: Dict[Text,Tuple[int,...]] 272 | value_info, # type: ValueInfoProto[...] 273 | ): 274 | # type: (...) -> None 275 | t = tuple([int(dim.dim_value) for dim in value_info.type.tensor_type.shape.dim]) 276 | if t: 277 | shape_dict[value_info.name] = t 278 | 279 | for value_info in graph.value_info: 280 | extract_value_info(shape_dict, value_info) 281 | for value_info in graph.input: 282 | extract_value_info(shape_dict, value_info) 283 | for value_info in graph.output: 284 | extract_value_info(shape_dict, value_info) 285 | 286 | 287 | return Graph(nodes_, inputs, outputs, shape_dict, onnx_ir_version) 288 | -------------------------------------------------------------------------------- /stubs/click/core.pyi: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | from typing import ( 3 | Any, 4 | Callable, 5 | Dict, 6 | Generator, 7 | Iterable, 8 | List, 9 | Mapping, 10 | Optional, 11 | Sequence, 12 | Set, 13 | Text, 14 | Tuple, 15 | TypeVar, 16 | Union, 17 | ) 18 | 19 | from click.formatting import HelpFormatter 20 | from click.parser import OptionParser 21 | 22 | 23 | def invoke_param_callback( 24 | callback: Callable[['Context', 'Parameter', Optional[Text]], Any], 25 | ctx: 'Context', 26 | param: 'Parameter', 27 | value: Optional[Text] 28 | ) -> Any: 29 | ... 30 | 31 | 32 | @contextmanager 33 | def augment_usage_errors( 34 | ctx: 'Context', param: Optional['Parameter'] = ... 35 | ) -> Generator[None, None, None]: 36 | ... 37 | 38 | 39 | def iter_params_for_processing( 40 | invocation_order: Sequence['Parameter'], 41 | declaration_order: Iterable['Parameter'], 42 | ) -> Iterable['Parameter']: 43 | ... 44 | 45 | 46 | class Context: 47 | parent: Optional['Context'] 48 | command: 'Command' 49 | info_name: Optional[Text] 50 | params: Dict[Text, Any] 51 | args: List[Text] 52 | protected_args: List[Text] 53 | obj: Any 54 | default_map: Mapping[Text, Any] 55 | invoked_subcommand: Optional[Text] 56 | terminal_width: Optional[int] 57 | max_content_width: Optional[int] 58 | allow_extra_args: bool 59 | allow_interspersed_args: bool 60 | ignore_unknown_options: bool 61 | help_option_names: List[Text] 62 | token_normalize_func: Optional[Callable[[Text], Text]] 63 | resilient_parsing: bool 64 | auto_envvar_prefix: Optional[Text] 65 | color: Optional[bool] 66 | _meta: Dict[Text, Any] 67 | _close_callbacks: List[Callable[..., Any]] 68 | _depth: int 69 | 70 | # properties 71 | meta: Dict[Text, Any] 72 | command_path: Text 73 | 74 | def __init__( 75 | self, 76 | command: 'Command', 77 | parent: Optional['Context'] = ..., 78 | info_name: Optional[Text] = ..., 79 | obj: Optional[Any] = ..., 80 | auto_envvar_prefix: Optional[Text] = ..., 81 | default_map: Optional[Mapping[Text, Any]] = ..., 82 | terminal_width: Optional[int] = ..., 83 | max_content_width: Optional[int] = ..., 84 | resilient_parsing: bool = ..., 85 | allow_extra_args: Optional[bool] = ..., 86 | allow_interspersed_args: Optional[bool] = ..., 87 | ignore_unknown_options: Optional[bool] = ..., 88 | help_option_names: Optional[List[Text]] = ..., 89 | token_normalize_func: Optional[Callable[[Text], Text]] = ..., 90 | color: Optional[bool] = ... 91 | ) -> None: 92 | ... 93 | 94 | @contextmanager 95 | def scope(self, cleanup: bool = ...) -> Generator['Context', None, None]: 96 | ... 97 | 98 | def make_formatter(self) -> HelpFormatter: 99 | ... 100 | 101 | def call_on_close(self, f: Callable[..., Any]) -> Callable[..., Any]: 102 | ... 103 | 104 | def close(self) -> None: 105 | ... 106 | 107 | def find_root(self) -> 'Context': 108 | ... 109 | 110 | def find_object(self, object_type: type) -> Any: 111 | ... 112 | 113 | def ensure_object(self, object_type: type) -> Any: 114 | ... 115 | 116 | def lookup_default(self, name: Text) -> Any: 117 | ... 118 | 119 | def fail(self, message: Text) -> None: 120 | ... 121 | 122 | def abort(self) -> None: 123 | ... 124 | 125 | def exit(self, code: Union[int, Text] = ...) -> None: 126 | ... 127 | 128 | def get_usage(self) -> Text: 129 | ... 130 | 131 | def get_help(self) -> Text: 132 | ... 133 | 134 | def invoke( 135 | self, callback: Union['Command', Callable[..., Any]], *args, **kwargs 136 | ) -> Any: 137 | ... 138 | 139 | def forward( 140 | self, callback: Union['Command', Callable[..., Any]], *args, **kwargs 141 | ) -> Any: 142 | ... 143 | 144 | class BaseCommand: 145 | allow_extra_args: bool 146 | allow_interspersed_args: bool 147 | ignore_unknown_options: bool 148 | name: Text 149 | context_settings: Dict[Text, Any] 150 | 151 | def __init__(self, name: Text, context_settings: Optional[Dict[Text, Any]] = ...) -> None: 152 | ... 153 | 154 | def get_usage(self, ctx: Context) -> Text: 155 | ... 156 | 157 | def get_help(self, ctx: Context) -> Text: 158 | ... 159 | 160 | def make_context( 161 | self, info_name: Text, args: List[Text], parent: Optional[Context] = ..., **extra 162 | ) -> Context: 163 | ... 164 | 165 | def parse_args(self, ctx: Context, args: List[Text]) -> List[Text]: 166 | ... 167 | 168 | def invoke(self, ctx: Context) -> Any: 169 | ... 170 | 171 | def main( 172 | self, 173 | args: Optional[List[Text]] = ..., 174 | prog_name: Optional[Text] = ..., 175 | complete_var: Optional[Text] = ..., 176 | standalone_mode: bool = ..., 177 | **extra 178 | ) -> Any: 179 | ... 180 | 181 | def __call__(self, *args, **kwargs) -> Any: 182 | ... 183 | 184 | 185 | class Command(BaseCommand): 186 | callback: Optional[Callable[..., Any]] 187 | params: List['Parameter'] 188 | help: Optional[Text] 189 | epilog: Optional[Text] 190 | short_help: Optional[Text] 191 | options_metavar: Text 192 | add_help_option: bool 193 | 194 | def __init__( 195 | self, 196 | name: Text, 197 | context_settings: Optional[Dict[Text, Any]] = ..., 198 | callback: Optional[Callable[..., Any]] = ..., 199 | params: Optional[List['Parameter']] = ..., 200 | help: Optional[Text] = ..., 201 | epilog: Optional[Text] = ..., 202 | short_help: Optional[Text] = ..., 203 | options_metavar: Text = ..., 204 | add_help_option: bool = ... 205 | ) -> None: 206 | ... 207 | 208 | def get_params(self, ctx: Context) -> List['Parameter']: 209 | ... 210 | 211 | def format_usage( 212 | self, 213 | ctx: Context, 214 | formatter: HelpFormatter 215 | ) -> None: 216 | ... 217 | 218 | def collect_usage_pieces(self, ctx: Context) -> List[Text]: 219 | ... 220 | 221 | def get_help_option_names(self, ctx: Context) -> Set[Text]: 222 | ... 223 | 224 | def get_help_option(self, ctx: Context) -> Optional['Option']: 225 | ... 226 | 227 | def make_parser(self, ctx: Context) -> OptionParser: 228 | ... 229 | 230 | def format_help(self, ctx: Context, formatter: HelpFormatter) -> None: 231 | ... 232 | 233 | def format_help_text(self, ctx: Context, formatter: HelpFormatter) -> None: 234 | ... 235 | 236 | def format_options(self, ctx: Context, formatter: HelpFormatter) -> None: 237 | ... 238 | 239 | def format_epilog(self, ctx: Context, formatter: HelpFormatter) -> None: 240 | ... 241 | 242 | 243 | _T = TypeVar('_T') 244 | _Decorator = Callable[[_T], _T] 245 | 246 | 247 | class MultiCommand(Command): 248 | no_args_is_help: bool 249 | invoke_without_command: bool 250 | subcommand_metavar: Text 251 | chain: bool 252 | result_callback: Callable[..., Any] 253 | 254 | def __init__( 255 | self, 256 | name: Optional[Text] = ..., 257 | invoke_without_command: bool = ..., 258 | no_args_is_help: Optional[bool] = ..., 259 | subcommand_metavar: Optional[Text] = ..., 260 | chain: bool = ..., 261 | result_callback: Optional[Callable[..., Any]] = ..., 262 | **attrs 263 | ) -> None: 264 | ... 265 | 266 | def resultcallback( 267 | self, replace: bool = ... 268 | ) -> _Decorator[Any]: 269 | ... 270 | 271 | def format_commands(self, ctx: Context, formatter: HelpFormatter) -> None: 272 | ... 273 | 274 | def resolve_command( 275 | self, ctx: Context, args: List[Text] 276 | ) -> Tuple[Text, Command, List[Text]]: 277 | ... 278 | 279 | def get_command(self, ctx: Context, cmd_name: Text) -> Optional[Command]: 280 | ... 281 | 282 | def list_commands(self, ctx: Context) -> Iterable[Command]: 283 | ... 284 | 285 | 286 | class Group(MultiCommand): 287 | commands: Dict[Text, Command] 288 | 289 | def __init__( 290 | self, name: Optional[Text] = ..., commands: Optional[Dict[Text, Command]] = ..., **attrs 291 | ) -> None: 292 | ... 293 | 294 | def add_command(self, cmd: Command, name: Optional[Text] = ...): 295 | ... 296 | 297 | def command(self, *args, **kwargs) -> _Decorator[Any]: 298 | ... 299 | 300 | def group(self, *args, **kwargs) -> _Decorator[Any]: 301 | ... 302 | 303 | 304 | class CommandCollection(MultiCommand): 305 | sources: List[MultiCommand] 306 | 307 | def __init__( 308 | self, name: Optional[Text] = ..., sources: Optional[List[MultiCommand]] = ..., **attrs 309 | ) -> None: 310 | ... 311 | 312 | def add_source(self, multi_cmd: MultiCommand) -> None: 313 | ... 314 | 315 | 316 | class Parameter: 317 | param_type_name: Text 318 | name: Text 319 | opts: List[Text] 320 | secondary_opts: List[Text] 321 | type: 'ParamType' 322 | required: bool 323 | callback: Optional[Callable[[Context, 'Parameter', Text], Any]] 324 | nargs: int 325 | multiple: bool 326 | expose_value: bool 327 | default: Any 328 | is_eager: bool 329 | metavar: Optional[Text] 330 | envvar: Union[Text, List[Text], None] 331 | # properties 332 | human_readable_name: Text 333 | 334 | def __init__( 335 | self, 336 | param_decls: Optional[List[Text]] = ..., 337 | type: Optional[Union[type, 'ParamType']] = ..., 338 | required: bool = ..., 339 | default: Optional[Any] = ..., 340 | callback: Optional[Callable[[Context, 'Parameter', Text], Any]] = ..., 341 | nargs: Optional[int] = ..., 342 | metavar: Optional[Text] = ..., 343 | expose_value: bool = ..., 344 | is_eager: bool = ..., 345 | envvar: Optional[Union[Text, List[Text]]] = ... 346 | ) -> None: 347 | ... 348 | 349 | def make_metavar(self) -> Text: 350 | ... 351 | 352 | def get_default(self, ctx: Context) -> Any: 353 | ... 354 | 355 | def add_to_parser(self, parser: OptionParser, ctx: Context) -> None: 356 | ... 357 | 358 | def consume_value(self, ctx: Context, opts: Dict[Text, Any]) -> Any: 359 | ... 360 | 361 | def type_cast_value(self, ctx: Context, value: Any) -> Any: 362 | ... 363 | 364 | def process_value(self, ctx: Context, value: Any) -> Any: 365 | ... 366 | 367 | def value_is_missing(self, value: Any) -> bool: 368 | ... 369 | 370 | def full_process_value(self, ctx: Context, value: Any) -> Any: 371 | ... 372 | 373 | def resolve_envvar_value(self, ctx: Context) -> Text: 374 | ... 375 | 376 | def value_from_envvar(self, ctx: Context) -> Union[Text, List[Text]]: 377 | ... 378 | 379 | def handle_parse_result( 380 | self, ctx: Context, opts: Dict[Text, Any], args: List[Text] 381 | ) -> Tuple[Any, List[Text]]: 382 | ... 383 | 384 | def get_help_record(self, ctx: Context) -> Tuple[Text, Text]: 385 | ... 386 | 387 | def get_usage_pieces(self, ctx: Context) -> List[Text]: 388 | ... 389 | 390 | 391 | class Option(Parameter): 392 | prompt: Text # sic 393 | confirmation_prompt: bool 394 | hide_input: bool 395 | is_flag: bool 396 | flag_value: Any 397 | is_bool_flag: bool 398 | count: bool 399 | multiple: bool 400 | allow_from_autoenv: bool 401 | help: Optional[Text] 402 | show_default: bool 403 | 404 | def __init__( 405 | self, 406 | param_decls: Optional[List[Text]] = ..., 407 | show_default: bool = ..., 408 | prompt: Union[bool, Text] = ..., 409 | confirmation_prompt: bool = ..., 410 | hide_input: bool = ..., 411 | is_flag: Optional[bool] = ..., 412 | flag_value: Optional[Any] = ..., 413 | multiple: bool = ..., 414 | count: bool = ..., 415 | allow_from_autoenv: bool = ..., 416 | type: Optional[Union[type, 'ParamType']] = ..., 417 | help: Optional[Text] = ..., 418 | **attrs 419 | ) -> None: 420 | ... 421 | 422 | def prompt_for_value(self, ctx: Context) -> Any: 423 | ... 424 | 425 | 426 | class Argument(Parameter): 427 | def __init__( 428 | self, 429 | param_decls: Optional[List[Text]] = ..., 430 | required: Optional[bool] = ..., 431 | **attrs 432 | ) -> None: 433 | ... 434 | 435 | # cyclic dependency 436 | from click.types import ParamType # noqa: E402 437 | -------------------------------------------------------------------------------- /tests/operators_test.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | import unittest 7 | import numpy as np 8 | from onnx.numpy_helper import from_array 9 | import onnx 10 | from onnx_coreml import convert 11 | 12 | from typing import Text 13 | 14 | from tests._test_utils import _test_single_node, \ 15 | _random_array, _conv_pool_output_size, \ 16 | _onnx_create_single_node_model, _assert_outputs 17 | 18 | from coremltools.models.utils import macos_version 19 | 20 | MIN_MACOS_VERSION_10_15 = (10, 15) 21 | 22 | ONNX_SHAPE_INFERENCE_FAILS = True 23 | 24 | class SingleOperatorTest(unittest.TestCase): 25 | 26 | def test_conv(self): # type: () -> None 27 | kernel_shape = (3, 2) 28 | strides = (2, 3) 29 | pads = (4, 2, 4, 2) 30 | dilations = (1, 2) 31 | group = 1 32 | weight = from_array(_random_array((16, 3, 3, 2)), name="weight") 33 | 34 | input_shape = (1, 3, 224, 224) 35 | output_size = _conv_pool_output_size(input_shape, dilations, 36 | kernel_shape, pads, strides) 37 | 38 | output_shape = (1, int(weight.dims[0]), output_size[0], output_size[1]) 39 | 40 | _test_single_node( 41 | "Conv", 42 | [input_shape], 43 | [output_shape], 44 | initializer=[weight], 45 | dilations=dilations, 46 | group=group, 47 | kernel_shape=kernel_shape, 48 | pads=pads, 49 | strides=strides 50 | ) 51 | 52 | def test_conv_transpose(self): # type: () -> None 53 | kernel_shape = (3, 3) 54 | pads = (0, 0, 0, 0) 55 | C_in = 3 56 | C_out = 12 57 | H_in, W_in = 30, 30 58 | strides = (2, 2) 59 | 60 | input_shape = (1, C_in, H_in, W_in) 61 | weight = from_array(_random_array((C_in, C_out, kernel_shape[0], kernel_shape[1])), 62 | name="weight") 63 | 64 | H_out = (H_in-1) * strides[0] + kernel_shape[0] - pads[0] - pads[2] 65 | W_out = (W_in-1) * strides[1] + kernel_shape[1] - pads[1] - pads[3] 66 | output_shape = (1, C_out, H_out, W_out) 67 | 68 | _test_single_node( 69 | "ConvTranspose", 70 | [input_shape], 71 | [output_shape], 72 | initializer=[weight], 73 | # Default values for other attributes: dilations=[1, 1], group=1 74 | strides = strides, 75 | kernel_shape=kernel_shape, 76 | pads=pads, 77 | output_padding=(0, 0) 78 | ) 79 | 80 | def test_conv_without_pads(self): # type: () -> None 81 | kernel_shape = (3, 2) 82 | strides = (2, 3) 83 | dilations = (1, 2) 84 | group = 1 85 | weight = from_array(_random_array((16, 3, 3, 2)), name="weight") 86 | 87 | input_shape = (1, 3, 224, 224) 88 | output_size = _conv_pool_output_size(input_shape, dilations, 89 | kernel_shape, [0, 0, 0, 0], 90 | strides) 91 | 92 | output_shape = (1, int(weight.dims[0]), output_size[0], output_size[1]) 93 | _test_single_node( 94 | "Conv", 95 | [input_shape], 96 | [output_shape], 97 | initializer=[weight], 98 | dilations=dilations, 99 | group=group, 100 | kernel_shape=kernel_shape, 101 | strides=strides 102 | ) 103 | 104 | def test_max_pool(self): # type: () -> None 105 | kernel_shape = (5, 3) 106 | pads = (2, 1, 2, 1) 107 | strides = (1, 2) 108 | 109 | input_shape = (1, 3, 224, 224) 110 | 111 | output_size = _conv_pool_output_size(input_shape, [1, 1], 112 | kernel_shape, pads, strides) 113 | 114 | output_shape = (1, 3, output_size[0], output_size[1]) 115 | 116 | _test_single_node( 117 | "MaxPool", 118 | [input_shape], 119 | [output_shape], 120 | test_name='test_max_pool_1', 121 | kernel_shape=kernel_shape, 122 | pads=pads, 123 | strides=strides 124 | ) 125 | 126 | output_size = _conv_pool_output_size(input_shape, [1, 1], 127 | kernel_shape, [0, 0, 0, 0], 128 | strides) 129 | output_shape = (1, 3, output_size[0], output_size[1]) 130 | _test_single_node( 131 | "MaxPool", 132 | [input_shape], 133 | [output_shape], 134 | test_name='test_max_pool_2', 135 | kernel_shape=kernel_shape, 136 | strides=strides 137 | ) 138 | @unittest.skip('Skip due to internal CoreML CPU backend issue') 139 | def test_avg_pool(self): # type: () -> None 140 | kernel_shape = (5, 3) 141 | pads = (2, 1, 2, 1) 142 | strides = (1, 2) 143 | 144 | input_shape = (1, 3, 224, 224) 145 | output_size = _conv_pool_output_size(input_shape, (1, 1), 146 | kernel_shape, pads, strides) 147 | output_shape = (1, 3, output_size[0], output_size[1]) 148 | _test_single_node( 149 | "AveragePool", 150 | [input_shape], 151 | [output_shape], 152 | test_name='test_avg_pool_1', 153 | kernel_shape=kernel_shape, 154 | pads=pads, 155 | strides=strides 156 | ) 157 | 158 | output_size = _conv_pool_output_size(input_shape, (1, 1), 159 | kernel_shape, [0, 0, 0, 0], 160 | strides) 161 | output_shape = (1, 3, output_size[0], output_size[1]) 162 | _test_single_node( 163 | "AveragePool", 164 | [input_shape], 165 | [output_shape], 166 | test_name='test_avg_pool_2', 167 | kernel_shape=kernel_shape, 168 | strides=strides 169 | ) 170 | 171 | def test_bn(self): # type: () -> None 172 | scale = from_array(_random_array((3,)), name="scale") 173 | bias = from_array(_random_array((3,)), name="bias") 174 | mean = from_array(_random_array((3,)), name="mean") 175 | var = from_array(_random_array((3,)), name="var") 176 | 177 | epsilon = 1e-5 178 | momentum = 0.001 179 | 180 | op_types = ["BatchNormalization", "SpatialBN"] 181 | for op_type in op_types: 182 | _test_single_node( 183 | "BatchNormalization", 184 | [(1, 3, 224, 224)], 185 | [(1, 3, 224, 224)], 186 | initializer=[scale, bias, mean, var], 187 | epsilon=epsilon, 188 | momentum=momentum 189 | ) 190 | 191 | # epsilon by default 192 | _test_single_node( 193 | "BatchNormalization", 194 | [(1, 3, 224, 224)], 195 | [(1, 3, 224, 224)], 196 | initializer=[scale, bias, mean, var], 197 | # epsilon=epsilon, 198 | momentum=momentum 199 | ) 200 | 201 | def test_gemm(self, minimum_ios_deployment_target='12'): # type: () -> None 202 | input_shape = (1, 2048) 203 | output_shape = (1, 5) 204 | W = from_array( 205 | _random_array((output_shape[1], input_shape[1])), name="weight" 206 | ) 207 | b = from_array( 208 | _random_array((output_shape[1],)), name="bias" 209 | ) 210 | _test_single_node( 211 | "Gemm", 212 | [input_shape], 213 | [output_shape], 214 | initializer=[W, b], 215 | decimal=3, 216 | transB=1, 217 | minimum_ios_deployment_target=minimum_ios_deployment_target 218 | ) 219 | 220 | @unittest.skipIf(macos_version() < MIN_MACOS_VERSION_10_15, 221 | 'macOS 10.15+ required. Skipping test.') 222 | def test_gemm_ios13(self): 223 | self.test_gemm(minimum_ios_deployment_target='13') 224 | 225 | def test_gemm_transB_off(self, minimum_ios_deployment_target='12'): # type: () -> None 226 | input_shape = (1, 2048) 227 | output_shape = (1, 5) 228 | W = from_array( 229 | _random_array((input_shape[1], output_shape[1])), name="weight" 230 | ) 231 | b = from_array( 232 | _random_array((output_shape[1],)), name="bias" 233 | ) 234 | _test_single_node( 235 | "Gemm", 236 | [input_shape], 237 | [output_shape], 238 | initializer=[W, b], 239 | decimal=3, 240 | transB=0, 241 | minimum_ios_deployment_target=minimum_ios_deployment_target 242 | ) 243 | 244 | @unittest.skipIf(macos_version() < MIN_MACOS_VERSION_10_15, 245 | 'macOS 10.15+ required. Skipping test.') 246 | def test_gemm_transB_off_ios13(self): 247 | self.test_gemm_transB_off(minimum_ios_deployment_target='13') 248 | 249 | def test_lrn(self): # type: () -> None 250 | _test_single_node( 251 | "LRN", 252 | [(1, 3, 224, 224)], 253 | [(1, 3, 224, 224)], 254 | alpha=9.99e-5, 255 | beta=0.75, 256 | bias=1.0, 257 | size=5 258 | ) 259 | 260 | @unittest.skipIf(macos_version() < MIN_MACOS_VERSION_10_15, 261 | 'macOS 10.15+ required. Skipping test.') 262 | def test_split_axis_0_rank_3(self, minimum_ios_deployment_target='12'): # type: () -> None 263 | _test_single_node( 264 | "Split", 265 | [(2, 1, 200)], 266 | [(1, 1, 200), (1, 1, 200)], 267 | axes=0, 268 | minimum_ios_deployment_target=minimum_ios_deployment_target 269 | ) 270 | 271 | @unittest.skipIf(macos_version() < MIN_MACOS_VERSION_10_15, 272 | 'macOS 10.15+ required. Skipping test.') 273 | def test_concat(self, minimum_ios_deployment_target='13'): # type: () -> None 274 | _test_single_node( 275 | "Concat", 276 | [(1, 2, 200), (1, 2, 200)], 277 | [(2, 2, 200)], 278 | axis=0, 279 | minimum_ios_deployment_target=minimum_ios_deployment_target 280 | ) 281 | 282 | @unittest.skipIf(macos_version() < MIN_MACOS_VERSION_10_15, 283 | 'macOS 10.15+ required. Skipping test.') 284 | def test_gather(self, minimum_ios_deployment_target='13'): # type: () -> None 285 | _test_single_node( 286 | "Gather", 287 | [(5, 4, 3), (3,)], 288 | [(3, 4, 3)], 289 | axis=0, 290 | minimum_ios_deployment_target=minimum_ios_deployment_target 291 | ) 292 | 293 | @unittest.skipIf(macos_version() < MIN_MACOS_VERSION_10_15, 294 | 'macOS 10.15+ required. Skipping test.') 295 | def test_reshape_same_rank(self, minimum_ios_deployment_target='13'): # type: () -> None 296 | _test_single_node( 297 | "Reshape", 298 | [(5, 4, 3), (3,)], 299 | [(4, 5, 3)], 300 | minimum_ios_deployment_target=minimum_ios_deployment_target 301 | ) 302 | 303 | @unittest.skipIf(macos_version() < MIN_MACOS_VERSION_10_15, 304 | 'macOS 10.15+ required. Skipping test.') 305 | def test_reshape_same_rank_infer_shape(self, minimum_ios_deployment_target='13'): # type: () -> None 306 | _test_single_node( 307 | "Reshape", 308 | [(5, 4, 3), (3,)], 309 | [(5, 2, 6)], 310 | minimum_ios_deployment_target=minimum_ios_deployment_target 311 | ) 312 | 313 | # TODO: add test_reshape_diff_rank_infer_shape where shape is Constant and known 314 | # to test rank-4 into rank-3 reshape with shape inferencing 315 | @unittest.skipIf(macos_version() < MIN_MACOS_VERSION_10_15, 316 | 'macOS 10.15+ required. Skipping test.') 317 | def test_reshape_dynamic(self, minimum_ios_deployment_target='13'): # type: () -> None 318 | _test_single_node( 319 | "Reshape", 320 | [(5, 4, 3, 2), (3,)], 321 | [(2, 3, 20)], 322 | minimum_ios_deployment_target=minimum_ios_deployment_target 323 | ) 324 | 325 | @unittest.skipIf(macos_version() < MIN_MACOS_VERSION_10_15, 326 | 'macOS 10.15+ required. Skipping test.') 327 | def test_squeeze(self, minimum_ios_deployment_target='13'): # type: () -> None 328 | _test_single_node( 329 | "Squeeze", 330 | [(5, 1, 3, 1, 1)], 331 | [(5, 3)], 332 | axes=[1, 3, 4], 333 | minimum_ios_deployment_target=minimum_ios_deployment_target 334 | ) 335 | 336 | @unittest.skipIf(macos_version() < MIN_MACOS_VERSION_10_15, 337 | 'macOS 10.15+ required. Skipping test.') 338 | def test_transpose_default(self, minimum_ios_deployment_target='13'): # type: () -> None 339 | _test_single_node( 340 | "Transpose", 341 | [(5, 3, 4, 6, 2)], 342 | [(2, 6, 4, 3, 5)], 343 | minimum_ios_deployment_target=minimum_ios_deployment_target 344 | ) 345 | 346 | @unittest.skipIf(ONNX_SHAPE_INFERENCE_FAILS, 347 | 'ONNX Shape inference fails to recongnize correct shape') 348 | @unittest.skipIf(macos_version() < MIN_MACOS_VERSION_10_15, 349 | 'macOS 10.15+ required. Skipping test.') 350 | def test_transpose_permute(self, minimum_ios_deployment_target='13'): # type: () -> None 351 | _test_single_node( 352 | "Transpose", 353 | [(5, 3, 4, 6, 2)], 354 | [(2, 3, 4, 6, 5)], 355 | axes=[4, 1, 2, 3, 0], 356 | minimum_ios_deployment_target=minimum_ios_deployment_target 357 | ) 358 | @unittest.skipIf(ONNX_SHAPE_INFERENCE_FAILS, 359 | 'ONNX Shape inference fails to recongnize correct shape') 360 | @unittest.skipIf(macos_version() < MIN_MACOS_VERSION_10_15, 361 | 'macOS 10.15+ required. Skipping test.') 362 | def test_unsqueeze(self, minimum_ios_deployment_target='13'): # type: () -> None 363 | _test_single_node( 364 | "Unsqueeze", 365 | [(5, 3, 4)], 366 | [(1, 5, 1, 3, 4)], 367 | axes=[0, 1], 368 | minimum_ios_deployment_target=minimum_ios_deployment_target 369 | ) 370 | 371 | 372 | # @unittest.skip("Error while preparing Caffe2 backend. Maybe something is incorrect in ONNX model definition") 373 | # def skip_test_lstm(self): # type: () -> None 374 | # x = 4 375 | # h = 2 376 | # seq_length = 3 377 | # W = from_array(_random_array((4*h, x)), name="gate_weights") 378 | # R = from_array(_random_array((4*h, h)), name="recursion_weights") 379 | # B = from_array(_random_array((8*h,)), name="biases") 380 | # seq_lens_input = from_array(np.array([seq_length]).astype(np.int32), name='seq_lens_input') 381 | # initial_h = from_array(np.zeros((1, 1, h)).astype(np.float32), name='initial_h') 382 | # initial_c = from_array(np.zeros((1, 1, h)).astype(np.float32), name='initial_c') 383 | # 384 | # input_shape = (seq_length, 1, x) 385 | # output_shape_all = (seq_length, 1, h) 386 | # output_shape_last = (1, 1, h) 387 | # 388 | # onnx_model = _onnx_create_single_node_model( 389 | # "LSTM", 390 | # [input_shape], 391 | # [output_shape_all, output_shape_last], 392 | # initializer=[W, R, B, seq_lens_input, initial_h, initial_c], 393 | # hidden_size=h 394 | # ) 395 | # X = np.random.rand(*input_shape).astype("float32") #type: ignore 396 | # import caffe2.python.onnx.backend 397 | # prepared_backend = caffe2.python.onnx.backend.prepare(onnx_model) 398 | # out = prepared_backend.run({'input0': X}) 399 | # caffe2_out_all = out['output0'] 400 | # caffe2_out_last = out['output1'] 401 | # 402 | # coreml_model = convert(onnx_model) 403 | # inputdict = {} 404 | # inputdict['input0'] = X 405 | # inputdict['initial_h'] = np.zeros((h), dtype=np.float32) 406 | # inputdict['initial_c'] = np.zeros((h), dtype=np.float32) 407 | # coreml_out_dict = coreml_model.predict(inputdict, useCPUOnly=True) 408 | # coreml_out_all = coreml_out_dict['output0'] 409 | # coreml_out_last = coreml_out_dict['output1'] 410 | # 411 | # _assert_outputs(caffe2_out_all.flatten(), coreml_out_all.flatten(), decimal=5) 412 | # _assert_outputs(caffe2_out_last.flatten(), coreml_out_last.flatten(), decimal=5) 413 | 414 | 415 | if __name__ == '__main__': 416 | unittest.main() 417 | #suite = unittest.TestSuite() 418 | #suite.addTest(SingleOperatorTest("test_gemm_transB_off")) 419 | #unittest.TextTestRunner().run(suite) 420 | -------------------------------------------------------------------------------- /stubs/numpy/__init__.pyi: -------------------------------------------------------------------------------- 1 | """ 2 | Numpy's mypy stub. Only type declarations for ndarray, the scalar hierarchy and array creation 3 | methods are provided. 4 | """ 5 | 6 | # This file is taken from https://github.com/machinalis/mypy-data/tree/master/numpy-mypy/numpy 7 | # under a BSD-style license (see below). This file can probably be deleted once 8 | # https://github.com/numpy/numpy/issues/7370 is resolved. 9 | # 10 | # Copyright (c) 2016, Machinalis 11 | # 12 | # All rights reserved. 13 | # 14 | # Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following 15 | # conditions are met: 16 | # 17 | # Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 18 | # Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer 19 | # in the documentation and/or other materials provided with the distribution. 20 | # Neither the name of Machinalis nor the names of any contributors may be used to endorse or promote products derived from this 21 | # software without specific prior written permission. 22 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, 23 | # BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT 24 | # SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 26 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 27 | # NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | # 29 | 30 | from typing import (Any, Callable, Dict, Generic, Iterator, List, Optional, Sequence, Tuple, Type, 31 | TypeVar, Union) 32 | 33 | class dtype: ... 34 | _dtype = dtype 35 | 36 | class newaxis: ... 37 | 38 | 39 | class flagsobj: 40 | """numpy.flagsobj""" 41 | aligned = None # type: bool 42 | behaved = None # type: bool 43 | c_contiguous = None # type: bool 44 | carray = None # type: bool 45 | contiguous = None # type: bool 46 | f_contiguous = None # type: bool 47 | farray = None # type: bool 48 | fnc = None # type: bool 49 | forc = None # type: bool 50 | fortran = None # type: bool 51 | owndata = None # type: bool 52 | updateifcopy = None # type: bool 53 | writeable = None # type: bool 54 | def __getitem__(self, item: str) -> bool: ... 55 | def __setitem__(self, item: str, value: bool) -> None: ... 56 | 57 | # 58 | # Type variables. _T wasn't used to avoid confusions with ndarray's "T" attribute. 59 | # 60 | 61 | _S = TypeVar('_S') 62 | _U = TypeVar('_U') 63 | _V = TypeVar('_V') 64 | 65 | # 66 | # Auxiliary types 67 | # 68 | 69 | ShapeType = Union[int, Tuple[int, ...]] 70 | AxesType = Union[int, Tuple[int, ...]] 71 | OrderType = Union[str, Sequence[str]] 72 | DtypeType = Union[dtype, type] 73 | 74 | class flatiter(Generic[_S], Iterator[_S]): 75 | coords = ... # type: ShapeType 76 | def copy(self) -> flatiter[_S]: ... 77 | 78 | class _ArrayLike(Generic[_S]): 79 | """ 80 | "array-like" interface that both numpy.ndarray and all scalars (descendants of numpy.generic) 81 | implement this interface. 82 | """ 83 | # 84 | # Array-like structures attributes 85 | # 86 | T = None # type: _ArrayLike[_S] 87 | data = None # type: Any 88 | dtype = None # type: _dtype 89 | flags = None # type: flagsobj 90 | flat = None # type: flatiter[_ArrayLike[_S]] 91 | imag = None # type: _ArrayLike[_S] 92 | real = None # type: _ArrayLike[_S] 93 | size = None # type: int 94 | itemsize = None # type: int 95 | nbytes = None # type: int 96 | ndim = None # type: int 97 | shape = None # type: Tuple[int, ...] 98 | strides = None # type: Tuple[int, ...] 99 | base = None # type: _ArrayLike[_S] 100 | 101 | # 102 | # Array-like methods 103 | # 104 | 105 | # Once this issue https://github.com/python/mypy/issues/1907 is resolved, most methods that 106 | # have an 'out' argument, will be implemented using overload instead of with a Union 107 | # result. mypy is smart enough to assign the proper type (_ArrayLike[_U]) when out is present 108 | # but it falls back to the union when it's not. 109 | def all(self, axis: Optional[AxesType]=None, out: 'Optional[_ArrayLike[_U]]'=None, 110 | keepdims: bool=False) -> Union['_ArrayLike[_U]', '_ArrayLike[bool]']: ... 111 | 112 | def any(self, axis: Optional[AxesType]=None, out: 'Optional[_ArrayLike[_U]]'=None, 113 | keepdims: bool=False) -> Union['_ArrayLike[_U]', '_ArrayLike[bool]']: ... 114 | 115 | def argmax(self, axis: Optional[int]=None, 116 | out: 'Optional[_ArrayLike[_U]]'=None) -> Union['_ArrayLike[_U]', '_ArrayLike[int]']: ... 117 | 118 | def argmin(self, axis: Optional[int]=None, 119 | out: 'Optional[_ArrayLike[_U]]'=None) -> Union['_ArrayLike[_U]', '_ArrayLike[int]']: ... 120 | 121 | def argpartition(self, kth: Union[int, Sequence[int]], axis: Optional[int]=-1, 122 | kind: str='introselect', order: Optional[OrderType]=None) -> '_ArrayLike[int]': ... 123 | 124 | def argsort(self, axis: Optional[int]=None, kind: str='quicksort', 125 | order: Optional[OrderType]=None) -> '_ArrayLike[int]': ... 126 | 127 | def astype(self, dtype: Any, order: str='K', casting: str='unsafe', subok: bool=True, 128 | copy: bool=False) -> '_ArrayLike[Any]': ... 129 | 130 | def byteswap(self, inplace: bool=False) -> '_ArrayLike[_S]': ... 131 | 132 | def choose(self, choices:Sequence['_ArrayLike[_V]'], out: 'Optional[_ArrayLike[_U]]'=None, 133 | mode: str='raise') -> Union['_ArrayLike[_U]', '_ArrayLike[_V]']: ... 134 | 135 | def clip(self, a_min: Any, a_max: Any, 136 | out: 'Optional[_ArrayLike[_U]]'=None) -> Union['_ArrayLike[_S]', '_ArrayLike[_U]']: ... 137 | 138 | def compress(self, condition: Sequence[bool], axis: Optional[int]=None, 139 | out: 'Optional[_ArrayLike[_U]]'=None) -> Union['_ArrayLike[_S]', '_ArrayLike[_U]']: ... 140 | 141 | def conj(self) -> '_ArrayLike[_S]': ... 142 | 143 | def conjugate(self) -> '_ArrayLike[_S]': ... 144 | 145 | def copy(self, order: str='C') -> '_ArrayLike[_S]': ... 146 | 147 | def cumprod(self, axis: Optional[int]=None, dtype: Optional[Any]=None, 148 | out: 'Optional[_ArrayLike[Any]]'=None) -> '_ArrayLike[Any]': ... 149 | 150 | def cumsum(self, axis: Optional[int]=None, dtype: Optional[DtypeType]=None, 151 | out: 'Optional[_ArrayLike[Any]]'=None) -> '_ArrayLike[Any]': ... 152 | 153 | def diagonal(self, offset: int=0, axis1: int=0, axis2: int=1) -> '_ArrayLike[_S]': ... 154 | 155 | def dot(self, b: '_ArrayLike[Any]', out: 'Optional[_ArrayLike[Any]]'=None) -> '_ArrayLike[Any]': ... 156 | 157 | def dump(self, file: str) -> None: ... 158 | 159 | def dumps(self) -> str: ... 160 | 161 | def fill(self, value: _S) -> None: ... 162 | 163 | def flatten(self, order: str='C') -> '_ArrayLike[_S]': ... 164 | 165 | def getfield(self, dtype: DtypeType, offset: int=0) -> '_ArrayLike[Any]': ... 166 | 167 | def item(self, args: AxesType) -> generic[_S]: ... 168 | 169 | def itemset(self, arg0: Union[int, Tuple[int, ...]], arg1: Optional[Any]=None) -> None: ... 170 | 171 | def max(self, axis: Optional[AxesType]=None, 172 | out: 'Optional[_ArrayLike[_U]]'=None) -> Union['_ArrayLike[_S]', '_ArrayLike[_U]']: ... 173 | 174 | def mean(self, axis: Optional[AxesType]=None, dtype: Optional[Any]=None, 175 | out: 'Optional[_ArrayLike[_U]]'=None, keepdims: bool=False) -> '_ArrayLike[floating]': ... 176 | 177 | def min(self, axis: Optional[AxesType]=None, 178 | out: 'Optional[_ArrayLike[_U]]'=None) -> Union['_ArrayLike[_S]', '_ArrayLike[_U]']: ... 179 | 180 | def newbyteorder(self, new_order: str='S') -> '_ArrayLike[_S]': ... 181 | 182 | def nonzero(self) -> '_ArrayLike[int]': ... 183 | 184 | def partition(self, kth: AxesType, axis: int=-1, kind: str='introselect', 185 | order: Optional[OrderType]=None) -> None: ... 186 | 187 | def prod(self, axis: Optional[AxesType]=None, dtype: Optional[DtypeType]=None, 188 | out: 'Optional[_ArrayLike[_U]]'=None, keepdims: bool=False) -> '_ArrayLike[Any]': ... 189 | 190 | def ptp(self, axis: Optional[int]=None, 191 | out: 'Optional[_ArrayLike[_U]]'=None) -> Union['_ArrayLike[_S]', '_ArrayLike[_U]']: ... 192 | 193 | def put(self, ind: '_ArrayLike[int]', v: '_ArrayLike[_S]', mode: str='raise') -> None: ... 194 | 195 | def ravel(self, order: str='C') -> '_ArrayLike[_S]': ... 196 | 197 | def repeat(self, repeats: Union[int, Sequence[int]], 198 | axis: Optional[int]=None) -> '_ArrayLike[_S]': ... 199 | 200 | def reshape(self, newshape: ShapeType, 201 | order: str='C') -> '_ArrayLike[_S]': ... 202 | 203 | def resize(self, new_shape: ShapeType, refcheck: bool=True) -> None: ... 204 | 205 | def round(self, decimals: int=0, 206 | out: 'Optional[_ArrayLike[_U]]'=None) -> Union['_ArrayLike[_S]', '_ArrayLike[_U]']: ... 207 | 208 | def searchsorted(self, v: Union[_S, '_ArrayLike[_S]'], side: str='left', 209 | sorter: 'Optional[_ArrayLike[int]]'=None) -> '_ArrayLike[int]': ... 210 | 211 | def setfield(self, val: Any, dtype: DtypeType, offset: int=0) -> None: ... 212 | 213 | def setflags(self, write: Optional[bool]=None, align: Optional[bool]=None, 214 | uic: Optional[bool]=None) -> None: ... 215 | 216 | def sort(self, axis: int=-1, kind: str='quicksort', order: Optional[OrderType]=None) -> None: ... 217 | 218 | def squeeze(self, axis: Optional[AxesType]=None) -> '_ArrayLike[_S]': ... 219 | 220 | def std(self, axis: Optional[AxesType]=None, dtype: Optional[DtypeType]=None, 221 | out: 'Optional[_ArrayLike[_U]]'=None, ddof: int=0, keepdims: bool=False) -> '_ArrayLike[floating]': ... 222 | 223 | def sum(self, axis: Optional[AxesType]=None, dtype: Optional[DtypeType]=None, 224 | out: 'Optional[_ArrayLike[_U]]'=None, 225 | keepdims: bool=False) -> '_ArrayLike[Any]': ... 226 | 227 | def swapaxes(self, axis1: int, axis2: int) -> '_ArrayLike[_S]': ... 228 | 229 | def take(self, indices: Sequence[int], axis: Optional[int]=None, 230 | out: 'Optional[_ArrayLike[_U]]'=None, 231 | mode: str='raise') -> Union['_ArrayLike[_S]', '_ArrayLike[_U]']: ... 232 | 233 | def tobytes(self, order: str='C') -> bytes: ... 234 | 235 | def tofile(self, fid: object, sep: str='', # TODO fix fid definition (There's a bug in mypy io's namespace https://github.com/python/mypy/issues/1462) 236 | format: str='%s') -> None: ... 237 | 238 | def tolist(self) -> List[Any]: ... 239 | 240 | def tostring(self, order: str='C') -> bytes: ... 241 | 242 | def trace(self, offset: int=0, axis1: int=0, axis2: int=1, 243 | dtype: Optional[DtypeType]=None, out: 'Optional[_ArrayLike[_U]]'=None) -> '_ArrayLike[Any]': ... 244 | 245 | def transpose(self, axes: Optional[AxesType]) -> '_ArrayLike[_S]': ... 246 | 247 | def var(self, axis: Optional[AxesType]=None, dtype: Optional[DtypeType]=None, 248 | out: 'Optional[_ArrayLike[_U]]'=None, ddof: int=0, keepdims: bool=False) -> '_ArrayLike[Any]': ... 249 | 250 | def view(self, dtype: Optional[Union[DtypeType, Type['ndarray[Any]']]]=None, 251 | type: Optional[type]=None) -> '_ArrayLike[Any]': ... 252 | 253 | # 254 | # Magic methods 255 | # 256 | 257 | def __abs__(self) -> '_ArrayLike[_S]': ... 258 | 259 | def __add__(self, value: object) -> '_ArrayLike[Any]': ... 260 | 261 | def __and__(self, value: object) -> '_ArrayLike[int]': ... 262 | 263 | def __array__(self, dtype: Optional[DtypeType]=None) -> '_ArrayLike[Any]': ... 264 | 265 | def __array_prepare__(self, context: Optional[object]=None) -> '_ArrayLike[Any]': ... 266 | 267 | def __array_wrap__(self, context: Optional[object]=None) -> '_ArrayLike[Any]': ... 268 | 269 | def __bool__(self) -> bool: ... 270 | 271 | def __complex__(self) -> complex: ... 272 | 273 | def __contains__(self, key: object) -> bool: ... 274 | 275 | def __copy__(self) -> '_ArrayLike[_S]': ... 276 | 277 | def __deepcopy__(self) -> '_ArrayLike[_S]': ... 278 | 279 | def __delattr__(self, name: str) -> None: ... 280 | 281 | def __delitem__(self, key: str) -> None: ... 282 | 283 | def __dir__(self) -> List[str]: ... 284 | 285 | def __divmod__(self, value: object) -> Tuple['_ArrayLike[int]', '_ArrayLike[float]']: ... 286 | 287 | def __eq__(self, value: object) -> '_ArrayLike[bool]': ... # type: ignore 288 | 289 | def __float__(self) -> float: ... 290 | 291 | def __floordiv__(self, value: object) -> '_ArrayLike[int]': ... 292 | 293 | def __ge__(self, value: object) -> '_ArrayLike[bool]': ... 294 | 295 | def __getattribute__(self, name: str) -> Any: ... 296 | 297 | def __getitem__(self, key: Any) -> '_ArrayLike[_S]': ... 298 | 299 | def __gt__(self, value: object) -> '_ArrayLike[bool]': ... 300 | 301 | def __iadd__(self, value: object) -> None: ... 302 | 303 | def __iand__(self, value: object) -> None: ... 304 | 305 | def __ifloordiv__(self, value: object) -> None: ... 306 | 307 | def __ilshift__(self, value: object) -> None: ... 308 | 309 | def __imatmul__(self, value: '_ArrayLike[Any]') -> None: ... 310 | 311 | def __imod__(self, value: object) -> None: ... 312 | 313 | def __imul__(self, value: object) -> None: ... 314 | 315 | def __index__(self) -> int: ... 316 | 317 | def __int__(self) -> int: ... 318 | 319 | def __invert__(self) -> '_ArrayLike[_S]': ... 320 | 321 | def __ior__(self, value: object) -> None: ... 322 | 323 | def __ipow__(self, value: object) -> None: ... 324 | 325 | def __irshift__(self, value: object) -> None: ... 326 | 327 | def __isub__(self, value: object) -> None: ... 328 | 329 | def __iter__(self) -> Iterator['_ArrayLike[_S]']: ... 330 | 331 | def __itruediv__(sel, value: object) -> None: ... 332 | 333 | def __ixor__(self, value: object) -> None: ... 334 | 335 | def __le__(self, value: object) -> '_ArrayLike[bool]': ... 336 | 337 | def __len__(self) -> int: ... 338 | 339 | def __lshift__(self, value: object) -> '_ArrayLike[_S]': ... 340 | 341 | def __lt__(self, value: object) -> '_ArrayLike[bool]': ... 342 | 343 | def __matmul__(self, value: '_ArrayLike[Any]') -> '_ArrayLike[Any]': ... 344 | 345 | def __mod__(self, value: object) -> '_ArrayLike[_S]': ... 346 | 347 | def __mul__(self, value: object) -> '_ArrayLike[Any]': ... 348 | 349 | def __ne__(self, value: object) -> '_ArrayLike[bool]': ... # type: ignore 350 | 351 | def __neg__(self) -> '_ArrayLike[_S]': ... 352 | 353 | def __or__(self, value: object) -> '_ArrayLike[_S]': ... 354 | 355 | def __pos__(self) -> '_ArrayLike[_S]': ... 356 | 357 | def __pow__(self, value: object) -> '_ArrayLike[Any]': ... 358 | 359 | def __radd__(self, value: object) -> '_ArrayLike[Any]': ... 360 | 361 | def __rand__(self, value: object) -> '_ArrayLike[_S]': ... 362 | 363 | def __rdivmod__(self, value: object) -> Tuple['_ArrayLike[int]', '_ArrayLike[float]']: ... 364 | 365 | def __rfloordiv__(self, value: object) -> '_ArrayLike[Any]': ... 366 | 367 | def __rlshift__(self, value: object) -> '_ArrayLike[Any]': ... 368 | 369 | def __rmatmul__(self, value: object) -> '_ArrayLike[Any]': ... 370 | 371 | def __rmod__(self, value: object) -> '_ArrayLike[Any]': ... 372 | 373 | def __rmul__(self, value: object) -> '_ArrayLike[Any]': ... 374 | 375 | def __ror__(self, value: object) -> '_ArrayLike[_S]': ... 376 | 377 | def __rpow__(self, value: object) -> '_ArrayLike[Any]': ... 378 | 379 | def __rrshift__(self, value: object) -> '_ArrayLike[Any]': ... 380 | 381 | def __rshift__(self, value: object) -> '_ArrayLike[Any]': ... 382 | 383 | def __rsub__(self, value: object) -> '_ArrayLike[Any]': ... 384 | 385 | def __rtruediv__(self, value: object) -> '_ArrayLike[Any]': ... 386 | 387 | def __rxor__(self, value: object) -> '_ArrayLike[_S]': ... 388 | 389 | def __setattr__(self, name: str, value: Any) -> None: ... 390 | 391 | def __setitem__(self, key: Any, value: Any) -> None: ... 392 | 393 | def __str__(self) -> str: ... 394 | 395 | def __sub__(self, value: object) -> '_ArrayLike[Any]': ... 396 | 397 | def __truediv__(sel, value: object) -> '_ArrayLike[Any]': ... 398 | 399 | def __xor__(self, value: object) -> '_ArrayLike[_S]': ... 400 | 401 | # 402 | # numpy's scalar hierarchy (http://docs.scipy.org/doc/numpy/reference/arrays.scalars.html#scalars) 403 | # 404 | 405 | class generic(_ArrayLike[_S], Generic[_S]): ... 406 | class bool_(generic[bool]): ... 407 | bool8 = bool_ 408 | class object_(generic[Any]): ... 409 | class number(generic[_S], Generic[_S]): ... 410 | class integer(number[int]): ... 411 | class signedinteger(integer): ... 412 | class byte(signedinteger): ... 413 | class short(signedinteger): ... 414 | class intc(signedinteger): ... 415 | class int_(signedinteger): ... 416 | class longlong(signedinteger): ... 417 | class int8(signedinteger): ... 418 | class int16(signedinteger): ... 419 | class int32(signedinteger): ... 420 | class int64(signedinteger): ... 421 | class unsignedinteger(integer): ... 422 | class ubyte(unsignedinteger): ... 423 | class ushort(unsignedinteger): ... 424 | class uintc(unsignedinteger): ... 425 | class uint(unsignedinteger): ... 426 | class ulonglong(unsignedinteger): ... 427 | class uint8(signedinteger): ... 428 | class uint16(signedinteger): ... 429 | class uint32(signedinteger): ... 430 | class uint64(signedinteger): ... 431 | class inexact(number[float]): ... 432 | class floating(inexact): ... 433 | class half(floating): ... 434 | class single(floating): ... 435 | class float_(floating): ... 436 | class longfloat_(floating): ... 437 | class float16(floating): ... 438 | class float32(floating): ... 439 | class float64(floating): ... 440 | class float128(floating): ... 441 | class complexfloating(inexact): ... 442 | class csingle(complexfloating): ... 443 | class complex_(complexfloating): ... 444 | class clongfloat(complexfloating): ... 445 | class complex64(complexfloating): ... 446 | class complex128(complexfloating): ... 447 | class complex256(complexfloating): ... 448 | class flexible(generic[_S], Generic[_S]): ... 449 | class character(flexible[str]): ... 450 | class str_(character): ... 451 | class unicode_(character): ... 452 | class void(flexible[None]): ... 453 | 454 | class ndarray(_ArrayLike[_S], Generic[_S]): 455 | """numpy.ndarray""" 456 | ctypes = None # type: Optional[Any] # TODO Implement ctypes type hint 457 | 458 | # TODO Need to find a way to restrict buffer type 459 | def __init__(self, shape: Tuple[int, ...], dtype: Optional[DtypeType]=None, 460 | buffer: Optional[Any]=None, offset: Optional[int]=None, 461 | strides: Optional[Tuple[int, ...]]=None, order: Optional[str]=None) -> None: ... 462 | 463 | # 464 | # Array creation routines 465 | # 466 | 467 | def array(object: Any, dtype: Optional[Any]=None, copy: bool=True, 468 | order: Optional[str]=None, subok: bool=False, 469 | ndmin: int=0) -> ndarray[Any]: ... 470 | def asarray(a: Any, dtype: Optional[DtypeType]=None, order: Optional[str]=None) -> ndarray[Any]: ... 471 | def asanyarray(a: Any, dtype: Optional[DtypeType]=None, order: Optional[str]=None) -> ndarray[Any]: ... # TODO figure out a way to restrict the return type 472 | def asmatrix(data: Any, dtype: Optional[DtypeType]=None) -> Any: ... # TODO define matrix 473 | def ascontiguousarray(a: Any, dtype: Optional[DtypeType]=None) -> ndarray[Any]: ... 474 | def copy(a: Any, order: Optional[str]=None) -> ndarray[Any]: ... 475 | def empty(shape: ShapeType, dtype: DtypeType=float, order: str='C') -> ndarray[Any]: ... 476 | def empty_like(a: Any, dtype: Optional[Any]=None, order: str='K', subok: bool=True) -> ndarray[Any]: ... 477 | def eye(N: int, M: Optional[int]=None, k: int=0, dtype: DtypeType=float) -> ndarray[Any]: ... 478 | def frombuffer(buffer: Any, dtype: DtypeType=float, count: int=-1, # TODO figure out a way to restrict buffer 479 | offset: int=0) -> ndarray[Any]: ... 480 | def fromfile(file: object, dtype: DtypeType=float, count: int=-1, sep: str='') -> ndarray[Any]: ... # TODO fix file definition (There's a bug in mypy io's namespace https://github.com/python/mypy/issues/1462) 481 | def full(shape: ShapeType, fill_value: Any, dtype: Optional[DtypeType]=None, 482 | order: str='C') -> ndarray[Any]: ... 483 | def full_like(a: Any, fill_value: Any, dtype: Optional[DtypeType]=None, order: str='C', 484 | subok: bool=True) -> ndarray[Any]: ... 485 | def fromfunction(function: Callable[..., _S], shape: ShapeType, dtype: DtypeType=float) -> ndarray[_S]: ... 486 | def fromiter(iterable: Iterator[Any], dytpe: DtypeType, count: int=-1) -> ndarray[Any]: ... 487 | def fromstring(string: str, dtype: DtypeType=float, count: int=-1, sep: str='') -> ndarray[Any]: ... 488 | def identity(n: int, dtype: Optional[DtypeType]=None) -> ndarray[Any]: ... 489 | def loadtxt(fname: Any, dtype: DtypeType=float, comments: Union[str, Sequence[str]]='#', 490 | delimiter: Optional[str]=None, converters: Optional[Dict[int, Callable[[Any], float]]]=None, 491 | skiprows: int=0, usecols: Optional[Sequence[int]]=None, 492 | unpack: bool=False, ndmin: int=0) -> ndarray[float]: ... 493 | def ones(shape: ShapeType, dtype: Optional[DtypeType]=..., order: str='C') -> ndarray[Any]: ... 494 | def ones_like(a: Any, dtype: Optional[Any]=None, order: str='K', subok: bool=True) -> ndarray[Any]: ... 495 | def zeros(shape: ShapeType, dtype: DtypeType=float, order: str='C') -> ndarray[Any]: ... 496 | def zeros_like(a: Any, dtype: Optional[Any]=None, order: str='K', subok: bool=True) -> ndarray[Any]: ... 497 | def squeeze(a: _ArrayLike[_S], axis: Optional[AxesType]=None) -> _ArrayLike[_S]: ... 498 | def multiply(a: _ArrayLike[_S], b: _ArrayLike[_S], **kwargs: Any) -> _ArrayLike[_S]: ... 499 | 500 | # Specific values 501 | inf: float 502 | --------------------------------------------------------------------------------