├── VERSION_NUMBER ├── MANIFEST.in ├── jenkins ├── test.sh ├── build.sh └── setup.sh ├── travis ├── after_failure.sh ├── after_success.sh ├── script.sh ├── install.sh ├── setup.sh └── before_install.sh ├── .gitignore ├── tests ├── __init__.py ├── helper_test.py ├── test_utils.py ├── onnx_backend_test.py ├── ssa_test.py ├── optimize_onnx_test.py ├── ONNXOpCoverage.md ├── conversion_test.py └── caffe2_ref_test.py ├── onnx_caffe2 ├── bin │ ├── __init__.py │ └── conversion.py ├── __init__.py ├── error.py ├── workspace.py ├── backend_rep.py ├── helper.py ├── frontend.py └── backend.py ├── setup.cfg ├── .gitmodules ├── install-develop.sh ├── test.sh ├── install.sh ├── LICENSE ├── .travis.yml ├── README.md ├── setup.py └── examples └── pytorch_to_caffe2.py /VERSION_NUMBER: -------------------------------------------------------------------------------- 1 | 1.0.0 2 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include LICENSE 2 | include VERSION_NUMBER 3 | -------------------------------------------------------------------------------- /jenkins/test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | time "$TOP_DIR/test.sh" 4 | -------------------------------------------------------------------------------- /travis/after_failure.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | source "$(dirname $(readlink -e "${BASH_SOURCE[0]}"))/setup.sh" 4 | -------------------------------------------------------------------------------- /travis/after_success.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | source "$(dirname $(readlink -e "${BASH_SOURCE[0]}"))/setup.sh" 4 | -------------------------------------------------------------------------------- /jenkins/build.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | time CMAKE_ARGS='-DUSE_ATEN=ON -DUSE_OPENMP=ON' "$TOP_DIR/install-develop.sh" 4 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | .cache/ 3 | .coverage 4 | .eggs/ 5 | build/ 6 | onnx_caffe2.egg-info/ 7 | onnx_caffe2/version.py 8 | .pytest_cache/ -------------------------------------------------------------------------------- /travis/script.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | source "$(dirname $(readlink -e "${BASH_SOURCE[0]}"))/setup.sh" 4 | 5 | time "$top_dir/test.sh" 6 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | -------------------------------------------------------------------------------- /onnx_caffe2/bin/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | -------------------------------------------------------------------------------- /travis/install.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | source "$(dirname $(readlink -e "${BASH_SOURCE[0]}"))/setup.sh" 4 | 5 | time CMAKE_ARGS='-DUSE_ATEN=ON -DUSE_OPENMP=ON' "$top_dir/install-develop.sh" 6 | -------------------------------------------------------------------------------- /travis/setup.sh: -------------------------------------------------------------------------------- 1 | set -ex 2 | 3 | top_dir=$(dirname $(dirname $(readlink -e "${BASH_SOURCE[0]}"))) 4 | 5 | # setup ccache 6 | export PATH="/usr/lib/ccache:$PATH" 7 | ccache --max-size 1G 8 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | description-file = README.md 3 | license-file = LICENSE 4 | 5 | [aliases] 6 | test=pytest 7 | 8 | [tool:pytest] 9 | addopts = --cov=onnx_caffe2 --cov-report term-missing 10 | testpaths = tests 11 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "third_party/onnx"] 2 | path = third_party/onnx 3 | url = https://github.com/onnx/onnx.git 4 | [submodule "third_party/caffe2"] 5 | path = third_party/caffe2 6 | url = https://github.com/caffe2/caffe2.git 7 | -------------------------------------------------------------------------------- /onnx_caffe2/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | from .version import version as __version__ 7 | -------------------------------------------------------------------------------- /onnx_caffe2/error.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | class BaseException(Exception): pass 6 | class Unsupported(BaseException): pass 7 | -------------------------------------------------------------------------------- /tests/helper_test.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | import unittest 7 | 8 | from onnx_caffe2.helper import dummy_name 9 | 10 | from tests.test_utils import TestCase 11 | 12 | 13 | class TestCaffe2Basic(TestCase): 14 | def test_dummy_name(self): 15 | dummy_name([]) 16 | names_1 = [dummy_name() for _ in range(3)] 17 | dummy_name([]) 18 | names_2 = [dummy_name() for _ in range(3)] 19 | self.assertEqual(names_1, names_2) 20 | 21 | dummy_name(names_1) 22 | names_3 = [dummy_name() for _ in range(3)] 23 | self.assertFalse(set(names_1) & set(names_3)) 24 | 25 | 26 | if __name__ == '__main__': 27 | unittest.main() 28 | -------------------------------------------------------------------------------- /install-develop.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -ex 4 | 5 | # realpath might not be available on MacOS 6 | script_path=$(python -c "import os; import sys; print(os.path.realpath(sys.argv[1]))" "${BASH_SOURCE[0]}") 7 | top_dir=$(dirname "$script_path") 8 | REPOS_DIR="$top_dir/third_party" 9 | BUILD_DIR="$top_dir/build" 10 | mkdir -p "$BUILD_DIR" 11 | 12 | _pip_install() { 13 | if [[ -n "$CI" ]]; then 14 | ccache -z 15 | fi 16 | if [[ -n "$CI" ]]; then 17 | time pip install "$@" 18 | else 19 | pip install "$@" 20 | fi 21 | if [[ -n "$CI" ]]; then 22 | ccache -s 23 | fi 24 | } 25 | 26 | # Install caffe2 27 | _pip_install -b "$BUILD_DIR/caffe2" "file://$REPOS_DIR/caffe2#egg=caffe2" 28 | 29 | # Install onnx 30 | _pip_install -e "$REPOS_DIR/onnx" 31 | 32 | # Install onnx-caffe2 33 | _pip_install -e . 34 | -------------------------------------------------------------------------------- /test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -ex 4 | 5 | UNKNOWN=() 6 | 7 | # defaults 8 | PARALLEL=0 9 | 10 | while [[ $# -gt 0 ]] 11 | do 12 | arg="$1" 13 | case $arg in 14 | -p|--parallel) 15 | PARALLEL=1 16 | shift # past argument 17 | ;; 18 | *) # unknown option 19 | UNKNOWN+=("$1") # save it in an array for later 20 | shift # past argument 21 | ;; 22 | esac 23 | done 24 | set -- "${UNKNOWN[@]}" # leave UNKNOWN 25 | 26 | script_path=$(python -c "import os; import sys; print(os.path.realpath(sys.argv[1]))" "${BASH_SOURCE[0]}") 27 | top_dir=$(dirname "$script_path") 28 | TEST_DIR="$top_dir/tests" 29 | 30 | pip install pytest-cov tabulate 31 | 32 | if [[ $PARALLEL == 1 ]]; then 33 | pip install pytest-xdist 34 | pytest "$TEST_DIR" -n 2 35 | else 36 | pytest "$TEST_DIR" 37 | fi 38 | 39 | -------------------------------------------------------------------------------- /travis/before_install.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | source "$(dirname $(readlink -e "${BASH_SOURCE[0]}"))/setup.sh" 4 | 5 | # Install GCC 5 6 | sudo add-apt-repository -y ppa:ubuntu-toolchain-r/test 7 | sudo apt-get update 8 | sudo apt-get install -y --no-install-recommends g++-5 9 | sudo update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-5 60 \ 10 | --slave /usr/bin/g++ g++ /usr/bin/g++-5 11 | 12 | # Install protobuf 13 | pb_version="2.6.1" 14 | pb_dir="~/.cache/pb" 15 | mkdir -p "$pb_dir" 16 | wget -qO- "https://github.com/google/protobuf/releases/download/v$pb_version/protobuf-$pb_version.tar.gz" | tar -xz -C "$pb_dir" --strip-components 1 17 | ccache -z 18 | cd "$pb_dir" && ./configure && make && make check && sudo make install && sudo ldconfig 19 | ccache -s 20 | 21 | # Update all existing python packages 22 | pip list --outdated --format=freeze | grep -v '^\-e' | cut -d = -f 1 | xargs -n1 pip install -U 23 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | import unittest 7 | 8 | import numpy as np 9 | 10 | 11 | class TestCase(unittest.TestCase): 12 | def setUp(self): 13 | np.random.seed(seed=0) 14 | 15 | def assertSameOutputs(self, outputs1, outputs2, decimal=7): 16 | self.assertEqual(len(outputs1), len(outputs2)) 17 | for o1, o2 in zip(outputs1, outputs2): 18 | np.testing.assert_almost_equal(o1, o2, decimal=decimal) 19 | 20 | def add_test_case(name, test_func): 21 | if not name.startswith('test_'): 22 | raise ValueError('Test name must start with test_: {}'.format(name)) 23 | if hasattr(self, name): 24 | raise ValueError('Duplicated test name: {}'.format(name)) 25 | setattr(self, name, test_func) 26 | -------------------------------------------------------------------------------- /install.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -ex 4 | 5 | # realpath might not be available on MacOS 6 | script_path=$(python -c "import os; import sys; print(os.path.realpath(sys.argv[1]))" "${BASH_SOURCE[0]}") 7 | top_dir=$(dirname "$script_path") 8 | REPOS_DIR="$top_dir/third_party" 9 | BUILD_DIR="$top_dir/build" 10 | mkdir -p "$BUILD_DIR" 11 | 12 | _pip_install() { 13 | if [[ -n "$CI" ]]; then 14 | ccache -z 15 | fi 16 | if [[ -n "$CI" ]]; then 17 | time pip install "$@" 18 | else 19 | pip install "$@" 20 | fi 21 | if [[ -n "$CI" ]]; then 22 | ccache -s 23 | fi 24 | } 25 | 26 | # Install caffe2 27 | _pip_install -b "$BUILD_DIR/caffe2" "file://$REPOS_DIR/caffe2#egg=caffe2" 28 | python -c 'from caffe2.python import build; from pprint import pprint; pprint(build.build_options)' 29 | 30 | # Install onnx 31 | _pip_install -b "$BUILD_DIR/onnx" "file://$REPOS_DIR/onnx#egg=onnx" 32 | 33 | # Install onnx-caffe2 34 | _pip_install . 35 | -------------------------------------------------------------------------------- /jenkins/setup.sh: -------------------------------------------------------------------------------- 1 | set -ex 2 | 3 | export CI=true 4 | 5 | export TOP_DIR=$(dirname $(dirname $(readlink -e "${BASH_SOURCE[0]}"))) 6 | 7 | export OS="$(uname)" 8 | 9 | # setup ccache 10 | if [[ "$OS" == "Darwin" ]]; then 11 | export PATH="/usr/local/opt/ccache/libexec:$PATH" 12 | else 13 | if [[ -d "/usr/lib/ccache" ]]; then 14 | export PATH="/usr/lib/ccache:$PATH" 15 | elif hash ccache > /dev/null; then 16 | mkdir -p "$TOP_DIR/ccache" 17 | ln -sf "$(which ccache)" "$TOP_DIR/ccache/cc" 18 | ln -sf "$(which ccache)" "$TOP_DIR/ccache/c++" 19 | ln -sf "$(which ccache)" "$TOP_DIR/ccache/gcc" 20 | ln -sf "$(which ccache)" "$TOP_DIR/ccache/g++" 21 | ln -sf "$(which ccache)" "$TOP_DIR/ccache/x86_64-linux-gnu-gcc" 22 | export PATH="$TOP_DIR/ccache:$PATH" 23 | fi 24 | export LC_ALL=C.UTF-8 25 | export LANG=C.UTF-8 26 | fi 27 | 28 | # setup virtualenv 29 | virtualenv "$TOP_DIR/venv" 30 | source "$TOP_DIR/venv/bin/activate" 31 | pip install -U pip setuptools 32 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Open Neural Network Exchange 2 | 3 | Copyright (c) Facebook, Inc. and Microsoft Corporation. 4 | All rights reserved. 5 | 6 | MIT License 7 | 8 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the ""Software""), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 11 | 12 | THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 13 | -------------------------------------------------------------------------------- /tests/onnx_backend_test.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | import os 7 | 8 | import unittest 9 | import onnx.backend.test 10 | 11 | import onnx_caffe2.backend as c2 12 | 13 | # This is a pytest magic variable to load extra plugins 14 | pytest_plugins = 'onnx.backend.test.report', 15 | 16 | backend_test = onnx.backend.test.BackendTest(c2, __name__) 17 | 18 | backend_test.exclude(r'(test_ceil|test_floor' # Does not support Ceil and Floor. 19 | '|test_hardsigmoid|test_pow' # Does not support Hardsigmoid and Pow. 20 | '|test_mean|test_hardmax' # Does not support Mean and Hardmax. 21 | '|test_cast.*FLOAT16.*)') # Does not support Cast in Float16 case. 22 | 23 | # Skip vgg to speed up CI 24 | if 'CI' in os.environ: 25 | backend_test.exclude(r'(test_vgg19|test_vgg)') 26 | 27 | # import all test cases at global scope to make them visible to python.unittest 28 | globals().update(backend_test 29 | .enable_report() 30 | .test_cases) 31 | 32 | if __name__ == '__main__': 33 | unittest.main() 34 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | os: linux 2 | dist: trusty 3 | sudo: required 4 | language: python 5 | python: 6 | - "2.7" 7 | - "3.6" 8 | 9 | before_install: 10 | - ./travis/before_install.sh 11 | 12 | install: 13 | - ./travis/install.sh 14 | 15 | script: 16 | - ./travis/script.sh 17 | 18 | after_success: 19 | - ./travis/after_success.sh 20 | 21 | after_failure: 22 | - ./travis/after_failure.sh 23 | 24 | cache: 25 | - directories: 26 | - $HOME/.cache/pb 27 | - $HOME/.ccache 28 | 29 | # We don't want to trigger travis builds for auto PR because there are 30 | # too many of them and they are updated very frequently. travis can 31 | # barely catchup and so causing the build status of these PRs always 32 | # being shown as pending. 33 | 34 | # For each PR, travis triggers two builds, one is "branch updates" 35 | # which uses the commit in the PR branch, another one is "pull request 36 | # updates" which uses the merge commit that merges the PR branch into 37 | # master branch. The following branches filter will only be able to 38 | # disable builds of "branch updates", but not the "pull request 39 | # updates", because the latter has branch name "master". To achieve 40 | # that, we have turned off "Build pull request updates" option in the 41 | # travis website. 42 | 43 | branches: 44 | except: 45 | - /^auto-.*$/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | NOTICE: THIS REPO IS DEPRECATED! onnx-caffe2 has been merge into [Caffe2](https://github.com/caffe2/caffe2/tree/master/caffe2/python/onnx). 2 | ======= 3 | 4 | 5 | onnx-caffe2 6 | ======== 7 | | Travis | Jenkins | 8 | |--------|---------| 9 | | [![Build Status](https://travis-ci.org/onnx/onnx-caffe2.svg?branch=master)](https://travis-ci.org/onnx/onnx-caffe2) | [![Build Status](https://ci.pytorch.org/jenkins/buildStatus/icon?job=onnx-caffe2-master)](https://ci.pytorch.org/jenkins/job/onnx-caffe2-master/) | 10 | 11 | Caffe2 implementation of Open Neural Network Exchange (ONNX). 12 | 13 | Repository location may change. 14 | 15 | # Installation 16 | 17 | ``` 18 | pip install onnx-caffe2 19 | ``` 20 | 21 | # Usage 22 | 23 | * [ONNX to Caffe2](https://github.com/onnx/tutorials/blob/master/tutorials/OnnxCaffe2Import.ipynb) 24 | * [Caffe2 to ONNX](https://github.com/onnx/tutorials/blob/master/tutorials/Caffe2OnnxExport.ipynb) 25 | * [other end-to-end tutorials](https://github.com/onnx/tutorials) 26 | 27 | # Folder Structure 28 | 29 | - onnx_caffe2/: the main folder that all code lies under 30 | - frontend.py: translate from caffe2 model to onnx model 31 | - backend.py: execution engine that runs onnx on caffe2 32 | - tests/: test files 33 | 34 | # Testing 35 | 36 | onnx-caffe2 uses [pytest](https://docs.pytest.org) as test driver. In order to run tests, first you need to install pytest: 37 | 38 | 39 | ``` 40 | pip install pytest-cov 41 | ``` 42 | 43 | After installing pytest, do 44 | 45 | ``` 46 | pytest 47 | ``` 48 | 49 | to run tests. 50 | 51 | Testing coverage issues/status: https://github.com/onnx/onnx-caffe2/blob/master/tests/ONNXOpCoverage.md 52 | 53 | # Development 54 | 55 | During development it's convenient to install onnx-caffe2 in development mode: 56 | 57 | ``` 58 | git clone https://github.com/onnx/onnx-caffe2.git --recursive 59 | pip install -e onnx-caffe2/ 60 | ``` 61 | 62 | # License 63 | 64 | [MIT License](LICENSE) 65 | 66 | -------------------------------------------------------------------------------- /onnx_caffe2/workspace.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | import uuid 7 | 8 | from caffe2.python import workspace 9 | 10 | 11 | class Workspace(object): 12 | """ 13 | An object representing a Caffe2 workspace. It is a context manager, 14 | so you can say 'with workspace:' to use the represented workspace 15 | as your global workspace. It also supports every method supported 16 | by caffe2.python.workspace, but instead of running these operations 17 | in the global workspace, it runs them in the workspace represented 18 | by this object. When this object goes dead, the workspace (and all 19 | nets and blobs within it) are freed. 20 | 21 | Why do we need this class? Caffe2's workspace model is very "global state" 22 | oriented, in that there is always some ambient global workspace you are 23 | working in which holds on to all of your networks and blobs. This class 24 | makes it possible to work with workspaces more locally, and without 25 | forgetting to deallocate everything in the end. 26 | """ 27 | def __init__(self): 28 | # Caffe2 (apparently) doesn't provide any native method of generating 29 | # a fresh, unused workspace, so we have to fake it by generating 30 | # a unique ID and hoping it's not used already / will not be used 31 | # directly in the future. 32 | self.workspace_id = str(uuid.uuid4()) 33 | # A stack, so that the context manager is reentrant. 34 | self.workspace_stack = [] 35 | 36 | def __getattr__(self, attr): 37 | def f(*args, **kwargs): 38 | with self: 39 | return getattr(workspace, attr)(*args, **kwargs) 40 | return f 41 | 42 | def __enter__(self): 43 | self.workspace_stack.append(workspace.CurrentWorkspace()) 44 | workspace.SwitchWorkspace(self.workspace_id, create_if_missing=True) 45 | 46 | def __exit__(self, exc_type, exc_value, traceback): 47 | w = self.workspace_stack.pop() 48 | # Strictly speaking, create_if_missing here is unnecessary, since a user 49 | # is not supposed to be allowed to destruct a workspace while we're in 50 | # it. However, empirically, it has been observed that during abnormal 51 | # shutdown, Caffe2 deletes its default workspace fairly early in the 52 | # final calls to destructors. In this case, we may attempt to exit 53 | # to a default workspace which no longer exists. create_if_missing=True 54 | # will (harmlessly) recreate the workspace before we finally quit.) 55 | workspace.SwitchWorkspace(w, create_if_missing=True) 56 | 57 | def __del__(self): 58 | # NB: This is a 'self' call because we need to switch into the workspace 59 | # we want to reset before we actually reset it. A direct call to 60 | # workspace.ResetWorkspace() will reset the ambient workspace, which 61 | # is not want we want. 62 | self.ResetWorkspace() 63 | -------------------------------------------------------------------------------- /onnx_caffe2/backend_rep.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | from caffe2.python import core, workspace 7 | from caffe2.proto import caffe2_pb2 8 | from onnx.backend.base import BackendRep, namedtupledict 9 | 10 | class Caffe2Rep(BackendRep): 11 | def __init__(self, init_net, predict_net, workspace, uninitialized): 12 | super(Caffe2Rep, self).__init__() 13 | self.init_net = init_net 14 | self.predict_net = predict_net 15 | self.workspace = workspace 16 | # The list of uninitialized external_inputs in workspace, we need this to 17 | # pair the name with given sequence inputs. 18 | self.uninitialized = uninitialized 19 | self.nets_created = False 20 | self.ran_init_net = False 21 | 22 | @property 23 | def _name_scope(self): 24 | if self.predict_net.device_option.device_type == caffe2_pb2.CUDA: 25 | return 'gpu_{}'.format(self.predict_net.device_option.cuda_gpu_id) 26 | return '' 27 | 28 | def run(self, inputs, **kwargs): 29 | super(Caffe2Rep, self).run(inputs, **kwargs) 30 | with self.workspace: 31 | with core.DeviceScope(self.predict_net.device_option): 32 | if isinstance(inputs, dict): 33 | with core.NameScope(self._name_scope): 34 | for key, value in inputs.items(): 35 | workspace.FeedBlob(key, value) 36 | elif isinstance(inputs, list) or isinstance(inputs, tuple): 37 | if len(self.uninitialized) != len(inputs): 38 | raise RuntimeError('Expected {} values for uninitialized ' 39 | 'graph inputs ({}), but got {}.'.format( 40 | len(self.uninitialized), 41 | ', '.join(self.uninitialized), 42 | len(inputs))) 43 | for i, value in enumerate(inputs): 44 | # namescope already baked into protobuf 45 | workspace.FeedBlob(self.uninitialized[i], value) 46 | else: 47 | # single input 48 | workspace.FeedBlob(self.uninitialized[0], inputs) 49 | if not self.nets_created: 50 | workspace.CreateNet(self.init_net) 51 | workspace.CreateNet(self.predict_net) 52 | self.nets_created = True 53 | if not self.ran_init_net: 54 | workspace.RunNet(self.init_net.name) 55 | self.ran_init_net = True 56 | workspace.RunNet(self.predict_net.name) 57 | output_values = [workspace.FetchBlob(name) 58 | for name in self.predict_net.external_output] 59 | return namedtupledict('Outputs', 60 | self.predict_net.external_output)(*output_values) 61 | -------------------------------------------------------------------------------- /onnx_caffe2/bin/conversion.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import json 6 | 7 | from caffe2.proto import caffe2_pb2 8 | import click 9 | import numpy as np 10 | from onnx import checker, ModelProto 11 | 12 | from onnx_caffe2.backend import Caffe2Backend as c2 13 | import onnx_caffe2.frontend as c2_onnx 14 | 15 | 16 | @click.command( 17 | help='convert caffe2 net to onnx model', 18 | context_settings={ 19 | 'help_option_names': ['-h', '--help'] 20 | } 21 | ) 22 | @click.argument('caffe2_net', type=click.File('rb')) 23 | @click.option('--caffe2-net-name', 24 | type=str, 25 | help="Name of the caffe2 net") 26 | @click.option('--caffe2-init-net', 27 | type=click.File('rb'), 28 | help="Path of the caffe2 init net pb file") 29 | @click.option('--value-info', 30 | type=str, 31 | help='A json string providing the ' 32 | 'type and shape information of the inputs') 33 | @click.option('-o', '--output', required=True, 34 | type=click.File('wb'), 35 | help='Output path for the onnx model pb file') 36 | def caffe2_to_onnx(caffe2_net, 37 | caffe2_net_name, 38 | caffe2_init_net, 39 | value_info, 40 | output): 41 | c2_net_proto = caffe2_pb2.NetDef() 42 | c2_net_proto.ParseFromString(caffe2_net.read()) 43 | if not c2_net_proto.name and not caffe2_net_name: 44 | raise click.BadParameter( 45 | 'The input caffe2 net does not have name, ' 46 | '--caffe2-net-name must be provided') 47 | c2_net_proto.name = caffe2_net_name or c2_net_proto.name 48 | if caffe2_init_net: 49 | c2_init_net_proto = caffe2_pb2.NetDef() 50 | c2_init_net_proto.ParseFromString(caffe2_init_net.read()) 51 | c2_init_net_proto.name = '{}_init'.format(caffe2_net_name) 52 | else: 53 | c2_init_net_proto = None 54 | 55 | if value_info: 56 | value_info = json.loads(value_info) 57 | 58 | onnx_model = c2_onnx.caffe2_net_to_onnx_model( 59 | predict_net=c2_net_proto, 60 | init_net=c2_init_net_proto, 61 | value_info=value_info) 62 | 63 | output.write(onnx_model.SerializeToString()) 64 | 65 | 66 | @click.command( 67 | help='convert onnx model to caffe2 net', 68 | context_settings={ 69 | 'help_option_names': ['-h', '--help'] 70 | } 71 | ) 72 | @click.argument('onnx_model', type=click.File('rb')) 73 | @click.option('-o', '--output', required=True, 74 | type=click.File('wb'), 75 | help='Output path for the caffe2 net file') 76 | @click.option('--init-net-output', 77 | required=True, 78 | type=click.File('wb'), 79 | help='Output path for the caffe2 init net file') 80 | def onnx_to_caffe2(onnx_model, output, init_net_output): 81 | onnx_model_proto = ModelProto() 82 | onnx_model_proto.ParseFromString(onnx_model.read()) 83 | 84 | init_net, predict_net = c2.onnx_graph_to_caffe2_net(onnx_model_proto) 85 | init_net_output.write(init_net.SerializeToString()) 86 | output.write(predict_net.SerializeToString()) 87 | -------------------------------------------------------------------------------- /tests/ssa_test.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | import onnx 7 | import numpy as np 8 | from caffe2.proto import caffe2_pb2 9 | from caffe2.python import core 10 | from onnx import helper, TensorProto 11 | 12 | import onnx_caffe2.frontend as c2_onnx 13 | from onnx_caffe2.helper import c2_native_run_net 14 | from tests.test_utils import TestCase 15 | 16 | 17 | class TestFrontendSSAConversion(TestCase): 18 | def test_ssa(self): 19 | X = np.random.randn(4, 2).astype(np.float32) 20 | W = np.random.randn(3, 2).astype(np.float32) 21 | b = np.random.randn(3).astype(np.float32) 22 | s = np.random.randn(1).astype(np.float32) 23 | np_result = X.dot(W.transpose()) + b + s 24 | 25 | net = caffe2_pb2.NetDef() 26 | net.name = 'test-ssa' 27 | net.external_input[:] = ['W', 'X', 'b', 's'] 28 | net.op.extend([ 29 | core.CreateOperator( 30 | 'FC', 31 | ['X', 'W', 'b'], 32 | ['Y'] 33 | ), 34 | core.CreateOperator( 35 | 'Add', 36 | ['Y', 's'], 37 | ['Y'], 38 | broadcast=True, 39 | ) 40 | ]) 41 | net.external_output[:] = ['Y'] 42 | 43 | init_net = caffe2_pb2.NetDef() 44 | init_net.name = 'test-ssa-init' 45 | init_net.op.extend([ 46 | core.CreateOperator( 47 | 'GivenTensorFill', 48 | [], 49 | ['W'], 50 | values=W, 51 | shape=W.shape, 52 | ), 53 | core.CreateOperator( 54 | 'GivenTensorFill', 55 | [], 56 | ['b'], 57 | values=b, 58 | shape=b.shape, 59 | ), 60 | core.CreateOperator( 61 | 'GivenTensorFill', 62 | [], 63 | ['s'], 64 | values=s, 65 | shape=s.shape, 66 | ) 67 | ]) 68 | init_net.external_output[:] = ['W', 'b', 's'] 69 | 70 | _, orig_output = c2_native_run_net( 71 | predict_net=net, 72 | init_net=init_net, 73 | inputs=[X]) 74 | 75 | value_info = {'X': (TensorProto.FLOAT, X.shape)} 76 | c2_onnx.Caffe2Frontend._ssa_rewrite( 77 | net, 78 | init_net, 79 | value_info) 80 | 81 | self.assertEqual(net.external_input, ['W_0', 'X_0', 'b_0', 's_0']) 82 | self.assertEqual(net.op[0].input, ['X_0', 'W_0', 'b_0']) 83 | self.assertEqual(net.op[0].output, ['Y_1']) 84 | self.assertEqual(net.op[1].input, ['Y_1', 's_0']) 85 | self.assertEqual(net.op[1].output, ['Y_2']) 86 | self.assertEqual(net.external_output, ['Y_2']) 87 | 88 | self.assertEqual(init_net.external_input, []) 89 | self.assertEqual(init_net.op[0].input, []) 90 | self.assertEqual(init_net.op[0].output, ['W_0']) 91 | self.assertEqual(init_net.op[1].input, []) 92 | self.assertEqual(init_net.op[1].output, ['b_0']) 93 | self.assertEqual(init_net.op[2].input, []) 94 | self.assertEqual(init_net.op[2].output, ['s_0']) 95 | self.assertEqual(init_net.external_output, ['W_0', 'b_0', 's_0']) 96 | self.assertEqual(value_info, {'X_0': (TensorProto.FLOAT, X.shape)}) 97 | 98 | _, ssa_output = c2_native_run_net( 99 | predict_net=net, 100 | init_net=init_net, 101 | inputs=[X]) 102 | 103 | self.assertSameOutputs(ssa_output, orig_output) 104 | self.assertSameOutputs(ssa_output, [np_result]) 105 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | from collections import namedtuple 7 | import os 8 | import sys 9 | from setuptools import setup, find_packages, Command 10 | import distutils.command.build 11 | import setuptools.command.build_py 12 | import setuptools.command.develop 13 | import setuptools.command.install 14 | import subprocess 15 | from textwrap import dedent 16 | 17 | TOP_DIR = os.path.realpath(os.path.dirname(__file__)) 18 | SRC_DIR = os.path.join(TOP_DIR, 'onnx_caffe2') 19 | 20 | ################################################################################ 21 | # Version 22 | ################################################################################ 23 | 24 | try: 25 | git_version = subprocess.check_output(['git', 'rev-parse', 'HEAD'], cwd=TOP_DIR).decode('ascii').strip() 26 | except (OSError, subprocess.CalledProcessError): 27 | git_version = None 28 | 29 | with open(os.path.join(TOP_DIR, 'VERSION_NUMBER')) as version_file: 30 | VersionInfo = namedtuple('VersionInfo', ['version', 'git_version'])( 31 | version=version_file.read().strip(), 32 | git_version=git_version 33 | ) 34 | 35 | ################################################################################ 36 | # Customized commands 37 | ################################################################################ 38 | 39 | class create_version(Command): 40 | user_options = [] 41 | 42 | def initialize_options(self): 43 | pass 44 | 45 | def finalize_options(self): 46 | pass 47 | 48 | def run(self): 49 | with open(os.path.join(SRC_DIR, 'version.py'), 'w') as f: 50 | f.write(dedent(''' 51 | version = '{version}' 52 | git_version = '{git_version}' 53 | '''.format(**dict(VersionInfo._asdict())))) 54 | 55 | 56 | class build_py(setuptools.command.build_py.build_py): 57 | def run(self): 58 | self.run_command('create_version') 59 | setuptools.command.build_py.build_py.run(self) 60 | 61 | class build(distutils.command.build.build): 62 | def run(self): 63 | self.run_command('build_py') 64 | 65 | class develop(setuptools.command.develop.develop): 66 | def run(self): 67 | self.run_command('create_version') 68 | self.run_command('build') 69 | setuptools.command.develop.develop.run(self) 70 | 71 | cmdclass={ 72 | 'create_version': create_version, 73 | 'build_py': build_py, 74 | 'build': build, 75 | 'develop': develop, 76 | } 77 | 78 | ################################################################################ 79 | # Dependencies 80 | ################################################################################ 81 | 82 | install_requires = ['click'] 83 | if sys.version_info < (3, 4): 84 | install_requires.append('enum34') 85 | 86 | ################################################################################ 87 | # Final 88 | ################################################################################ 89 | 90 | setup( 91 | name="onnx-caffe2", 92 | version=VersionInfo.version, 93 | description="Caffe2 frontend and backend of Open Neural Network Exchange", 94 | install_requires=install_requires, 95 | setup_requires=['pytest-runner'], 96 | tests_require=['numpy', 'pytest-cov'], 97 | cmdclass=cmdclass, 98 | packages=find_packages(), 99 | author='bddppq', 100 | author_email='jbai@fb.com', 101 | url='https://github.com/onnx/onnx-caffe2', 102 | entry_points={ 103 | 'console_scripts': [ 104 | 'convert-caffe2-to-onnx = onnx_caffe2.bin.conversion:caffe2_to_onnx', 105 | 'convert-onnx-to-caffe2 = onnx_caffe2.bin.conversion:onnx_to_caffe2' 106 | ] 107 | }, 108 | ) 109 | -------------------------------------------------------------------------------- /examples/pytorch_to_caffe2.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | from caffe2.proto import caffe2_pb2 7 | from caffe2.python import core 8 | from torch.autograd import Variable 9 | from onnx_caffe2.backend import Caffe2Backend 10 | from onnx_caffe2.helper import c2_native_run_net, save_caffe2_net, load_caffe2_net, \ 11 | benchmark_caffe2_model, benchmark_pytorch_model 12 | 13 | import io 14 | import logging 15 | import numpy as np 16 | import torch 17 | import torch.nn as nn 18 | import torch.nn.functional as F 19 | import onnx 20 | 21 | 22 | log = logging.getLogger(__name__) 23 | logging.basicConfig(level=logging.INFO) 24 | 25 | 26 | class MNIST(nn.Module): 27 | 28 | def __init__(self): 29 | super(MNIST, self).__init__() 30 | self.conv1 = nn.Conv2d(1, 10, kernel_size=5) 31 | self.conv2 = nn.Conv2d(10, 20, kernel_size=5) 32 | self.conv2_drop = nn.Dropout2d() 33 | self.fc1 = nn.Linear(320, 50) 34 | self.fc2 = nn.Linear(50, 10) 35 | 36 | def forward(self, x): 37 | x = F.relu(F.max_pool2d(self.conv1(x), 2)) 38 | x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) 39 | x = x.view(-1, 320) 40 | x = F.relu(self.fc1(x)) 41 | x = F.dropout(x, training=self.training) 42 | x = self.fc2(x) 43 | return F.log_softmax(x) 44 | 45 | # Create a pytorch model. 46 | log.info("Create a PyTorch model.") 47 | pytorch_model = MNIST() 48 | pytorch_model.train(False) 49 | 50 | # Make the inputs in tuple format. 51 | inputs = (Variable(torch.randn(3, 1, 28, 28), requires_grad=True), ) 52 | 53 | # Export an ONNX model. 54 | log.info("Export an ONNX model from the PyTorch model.") 55 | f = io.BytesIO() 56 | torch.onnx.export(pytorch_model, inputs, f, verbose=True) 57 | onnx_model = onnx.ModelProto.FromString(f.getvalue()) 58 | 59 | # Check whether the onnx_model is valid or not. 60 | log.info("Check the ONNX model.") 61 | onnx.checker.check_model(onnx_model) 62 | 63 | # Convert the ONNX model to a Caffe2 model. 64 | log.info("Convert the model to a Caffe2 model.") 65 | init_net, predict_net = Caffe2Backend.onnx_graph_to_caffe2_net(onnx_model.graph, device="CPU") 66 | 67 | # Caffe2 model takes a numpy array list as input. 68 | caffe2_inputs = [var.data.numpy() for var in inputs] 69 | 70 | # Save and load the converted Caffe2 model in the protobuf files. 71 | log.info("Save the Caffe2 models as pb files.") 72 | init_file = "./mymodel_init.pb" 73 | predict_file = "./mymodel_predict.pb" 74 | save_caffe2_net(init_net, init_file, output_txt=False) 75 | save_caffe2_net(predict_net, predict_file, output_txt=True) 76 | log.info("Load the Caffe2 models back.") 77 | init_net = load_caffe2_net(init_file) 78 | predict_net = load_caffe2_net(predict_file) 79 | 80 | # Compute the results using the PyTorch model. 81 | log.info("Run the PyTorch model.") 82 | pytorch_results = pytorch_model(*inputs) 83 | 84 | # Compute the results using the Caffe2 model. 85 | log.info("Run the Caffe2 model.") 86 | _, caffe2_results = c2_native_run_net(init_net, predict_net, caffe2_inputs) 87 | 88 | # Check the decimal precision of the exported Caffe2. 89 | expected_decimal = 5 90 | for p, c in zip([pytorch_results], caffe2_results): 91 | np.testing.assert_almost_equal(p.data.cpu().numpy(), c, decimal=expected_decimal) 92 | log.info("The exported model achieves {}-decimal precision.".format(expected_decimal)) 93 | 94 | pytorch_time = benchmark_pytorch_model(pytorch_model, inputs) 95 | caffe2_time = benchmark_caffe2_model(init_net, predict_net) 96 | 97 | print("PyTorch model's execution time is {} milliseconds/ iteration, {} iterations per second.".format( 98 | pytorch_time, 1000 / pytorch_time)) 99 | print("Caffe2 model's execution time is {} milliseconds / iteration, {} iterations per second".format( 100 | caffe2_time, 1000 / caffe2_time)) 101 | -------------------------------------------------------------------------------- /tests/optimize_onnx_test.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | import os 7 | import tarfile 8 | import tempfile 9 | import unittest 10 | 11 | from collections import namedtuple 12 | from subprocess import Popen, PIPE 13 | from six.moves.urllib.request import urlretrieve 14 | import numpy as np 15 | 16 | import onnx 17 | from onnx import helper, ModelProto, TensorProto 18 | from onnx.backend.test.runner import Runner 19 | import onnx_caffe2.backend as c2 20 | 21 | from tests.test_utils import TestCase 22 | 23 | class TestRoundtrip(TestCase): 24 | def _roundtrip(self, model_name): 25 | model_dir = Runner(c2)._prepare_model_data( 26 | namedtuple('dummy', ['model_name'])(model_name)) 27 | 28 | pb_path = os.path.join(model_dir, 'model.pb') 29 | 30 | before_roundtrip = onnx.load(pb_path) 31 | 32 | with open(pb_path, 'rb') as pb: 33 | after_roundtrip = onnx.load_from_string(pb.read()) 34 | 35 | assert onnx.helper.printable_graph(before_roundtrip.graph) \ 36 | == onnx.helper.printable_graph(after_roundtrip.graph) 37 | 38 | with open(pb_path, 'rb') as pb: 39 | assert after_roundtrip.SerializeToString() == pb.read() 40 | 41 | # arbitrarily pick one relatively small model to sanity test with 42 | def test_squeezenet_v3(self): 43 | self._roundtrip('squeezenet-ir-version-3') 44 | 45 | # testing just to be sure that we no-op instead of breaking on an 46 | # older IR version. 47 | def test_squeezenet_v1(self): 48 | self._roundtrip('squeezenet-ir-version-1') 49 | 50 | class TestOptimize(TestCase): 51 | def _optimized(self, graph): 52 | orig_model = helper.make_model(graph, producer_name='onnx-to-caffe2-test') 53 | orig_model_str = orig_model.SerializeToString() 54 | optimized_model_str = c2.Caffe2Backend.optimize_onnx(orig_model_str) 55 | optimized_model = ModelProto() 56 | optimized_model.ParseFromString(optimized_model_str) 57 | return optimized_model 58 | 59 | def test_nop_transpose(self): 60 | trans = helper.make_node("Transpose", ["X"], ["Y"], perm=[0,1]) 61 | graph = helper.make_graph( 62 | [trans], 63 | "test", 64 | [helper.make_tensor_value_info("X", TensorProto.FLOAT, (2, 3))], 65 | [helper.make_tensor_value_info("Y", TensorProto.FLOAT, (3, 2))]) 66 | optimized_model = self._optimized(graph) 67 | 68 | for node in optimized_model.graph.node: 69 | assert node.op_type != "Transpose" 70 | 71 | def test_fuse_transpose(self): 72 | trans1 = helper.make_node("Transpose", ["X"], ["Y"], perm=[1,0,2]) 73 | trans2 = helper.make_node("Transpose", ["Y"], ["Z"], perm=[2,0,1]) 74 | trans3 = helper.make_node("Transpose", ["Z"], ["A"], perm=[2,0,1]) 75 | graph = helper.make_graph( 76 | [trans1, trans2, trans3], 77 | "test", 78 | [helper.make_tensor_value_info("X", TensorProto.FLOAT, (2, 3, 4))], 79 | [helper.make_tensor_value_info("A", TensorProto.FLOAT, (4, 3, 2))]) 80 | optimized_model = self._optimized(graph) 81 | 82 | assert len(list(optimized_model.graph.node)) == 1 83 | 84 | def test_fuse_transpose_into_gemm(self): 85 | trans1 = helper.make_node("Transpose", ["X"], ["A"], perm=[1,0]) 86 | trans2 = helper.make_node("Transpose", ["Y"], ["B"], perm=[1,0]) 87 | gemm = helper.make_node("Gemm", ["A", "B", "C"], ["Z"]) 88 | graph = helper.make_graph( 89 | [trans1, trans2, gemm], 90 | "test", 91 | [helper.make_tensor_value_info("X", TensorProto.FLOAT, (2, 3)), 92 | helper.make_tensor_value_info("Y", TensorProto.FLOAT, (5, 2)), 93 | helper.make_tensor_value_info("C", TensorProto.FLOAT, (3, 5))], 94 | [helper.make_tensor_value_info("Z", TensorProto.FLOAT, (3, 5))]) 95 | optimized_model = self._optimized(graph) 96 | 97 | assert len(list(optimized_model.graph.node)) == 1 98 | 99 | if __name__ == '__main__': 100 | unittest.main() 101 | -------------------------------------------------------------------------------- /tests/ONNXOpCoverage.md: -------------------------------------------------------------------------------- 1 | # Tracking why operators are not covered 2 | [ONNX backend test script](https://github.com/onnx/onnx-caffe2/blob/master/tests/onnx_backend_test.py) 3 | reports the coverage on the operators and attributes. But we have various of reasons for the missing test coverage on operators. 4 | This doc keeps tracking why operators are not covered by the testcases. 5 | 6 | - 💚 The ONNX operator can map to a Caffe2 operator. 7 | - 💛 The solution is not perfect/finished, for example, the operator can map to a combination of Caffe2 operators. 8 | - 💔 Hard to find a solution with existing Caffe2 operators. 9 | 10 | | Operator | Test Coverage | PyTorch | Caffe2 | 11 | |---|:--:|:---:|:---:| 12 | |Abs|Yes|OK|💚OK| 13 | |Add|Yes|OK|💚OK| 14 | |And||Support int tensor, but no bool tensor|💚OK| 15 | |ArgMax|||💔No op| 16 | |ArgMin|||💔No op| 17 | |AveragePool|Yes|OK|💚OK| 18 | |BatchNormalization|Yes|OK|💚OK| 19 | |Cast|||💔No op| 20 | |Ceil|||💔No op| 21 | |Clip|Yes|OK|💚OK| 22 | |Concat|Yes|OK|💚OK| 23 | |Constant|Yes|OK|💛Special handling| 24 | |Conv|Yes|OK|💚OK| 25 | |ConvTranspose|||💚OK| 26 | |DepthToSpace|||💛Should be BatchToSpace, no tests| 27 | |Div|Yes|OK|💚OK| 28 | |Dropout|Yes|OK|💚OK| 29 | |Elu|Yes|OK|💚OK| 30 | |Equal|Yes|OK|💚OK| 31 | |Exp|Yes|OK|💚OK| 32 | |Flatten|Yes|OK|💚OK| 33 | |Floor|||💔No op| 34 | |GRU|||💛Under development| 35 | |Gather|Yes|OK|💛C2 only support axis=0 or 1| 36 | |Gemm|Yes|OK|💛C2 use FC or MatMul + Add| 37 | |GlobalAveragePool|Yes|No direct mapping|💚OK| 38 | |GlobalLpPool|||💔No op| 39 | |GlobalMaxPool|||💚OK| 40 | |Greater|||💔Only support int tensor| 41 | |HardSigmoid|||💔No op| 42 | |Hardmax|||💔No op| 43 | |InstanceNormalization|||💚OK| 44 | |LRN|Yes|OK|💚OK| 45 | |LSTM|||💛Under development| 46 | |LeakyRelu|Yes|OK|💚OK| 47 | |Less|||💔Only support int tensor| 48 | |Log|Yes|OK|💚OK| 49 | |LogSoftmax||OK|💛No op, translated in onnx-caffe2| 50 | |LpNormalization|||💚Should be LpNorm, no tests| 51 | |LpPool|||💚Should be LpPool, no tests| 52 | |MatMul|Yes|OK|💚OK| 53 | |Max|Yes|OK|💚OK| 54 | |MaxPool|Yes|OK|💚OK| 55 | |MaxRoiPool|||💔No op| 56 | |Mean|||💔No op| 57 | |Min|Yes|OK|💚OK| 58 | |Mul|Yes|OK|💚OK| 59 | |Neg|Yes|OK|💚OK| 60 | |Not|||💚OK| 61 | |Or|||💚OK| 62 | |PRelu|Yes|OK|💚OK| 63 | |Pad|Yes|OK|💚OK| 64 | |Pow||OK|💛Under development, C2 only accepts exponent as argument, not an input| 65 | |RNN|||💛Under development| 66 | |RandomNormal|||💔No op| 67 | |RandomNormalLike|||💔No op| 68 | |RandomUniform|||💔No op| 69 | |RandomUniformLike|||💔No op| 70 | |Reciprocal|||💛Use Pow to implement| 71 | |ReduceL1|||💔No op| 72 | |ReduceL2|||💔No op| 73 | |ReduceLogSum|||💔No op| 74 | |ReduceLogSumExp|||💔No op| 75 | |ReduceMax|||💔No op| 76 | |ReduceMean|||💔No op| 77 | |ReduceMin|||💔No op| 78 | |ReduceProd|||💔No op| 79 | |ReduceSum|||💔No op| 80 | |ReduceSumSquare|||💔No op| 81 | |Relu|Yes|OK|💚OK| 82 | |Reshape|Yes|OK|💚OK| 83 | |Selu|Yes|OK|💚OK| 84 | |Sigmoid|Yes|OK|💚OK| 85 | |Slice|Yes|OK|💔ScatterAssign + Cast, very hacky implementaion, Slice in C2 only supports one dimension| 86 | |Softmax|Yes|OK|💔Axis and dim has different semantics| 87 | |Softplus|Yes|OK|💚OK| 88 | |Softsign|||💚OK, no tests| 89 | |SpaceToDepth|||💛Should be SpaceToBatch, no tests| 90 | |Split|Yes|OK|💚OK| 91 | |Sqrt|||💛Use Pow to implement| 92 | |Squeeze|||💚OK, no tests| 93 | |Sub||OK|💚OK| 94 | |Sum|Yes|OK|💚OK| 95 | |Tanh|Yes|OK|💚OK| 96 | |Tile|||💚OK, no tests| 97 | |Transpose|Yes|OK|💚OK| 98 | |Xor|||💚OK| 99 | |experimental ATen|||💚OK| 100 | |experimental Affine|||💔No op| 101 | |experimental ConstantFill|||💚OK| 102 | |experimental Crop|||💔No op| 103 | |experimental FC|||💚OK| 104 | |experimental GRUUnit|||💚OK, no tests| 105 | |experimental GivenTensorFill|||💚OK| 106 | |experimental Identity|||💚OK| 107 | |experimental ImageScaler|||💔No op| 108 | |experimental MeanVarianceNormalization|||💔No op| 109 | |experimental ParametricSoftplus|||💔No op| 110 | |experimental Scale|||💚OK| 111 | |experimental ScaledTanh|||💔No op| 112 | |experimental ThresholdedRelu|||💔No op| 113 | |experimental Upsample|||💔No bilinear| 114 | -------------------------------------------------------------------------------- /onnx_caffe2/helper.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | from caffe2.proto import caffe2_pb2 7 | from onnx import helper 8 | from onnx.backend.base import namedtupledict 9 | 10 | from onnx_caffe2.workspace import Workspace 11 | 12 | import io 13 | import logging 14 | import time 15 | 16 | 17 | log = logging.getLogger(__name__) 18 | 19 | 20 | class _DummyNameFactory(object): 21 | used_names = set() 22 | counter = 0 23 | 24 | @classmethod 25 | def dummy_name(cls, used_names=None): 26 | if used_names is not None: 27 | cls.used_names.clear() 28 | cls.used_names.update(used_names) 29 | cls.counter = 0 30 | return None 31 | else: 32 | while True: 33 | name = 'OC2_DUMMY_{}'.format(cls.counter) 34 | cls.counter += 1 35 | if name not in cls.used_names: 36 | cls.used_names.add(name) 37 | return name 38 | 39 | dummy_name = _DummyNameFactory.dummy_name 40 | 41 | 42 | def make_model(graph, **kwargs): 43 | kwargs.setdefault('producer_name', 'onnx-caffe2') 44 | return helper.make_model(graph=graph, **kwargs) 45 | 46 | 47 | def c2_native_run_op(op_def, inputs): 48 | ws = Workspace() 49 | if isinstance(inputs, dict): 50 | for key, value in inputs.items(): 51 | ws.FeedBlob(key, value, op_def.device_option) 52 | else: 53 | assert(len(op_def.input) == len(inputs)) 54 | for key, value in zip(op_def.input, inputs): 55 | ws.FeedBlob(key, value, op_def.device_option) 56 | 57 | ws.RunOperatorOnce(op_def) 58 | 59 | output_names = op_def.output 60 | output_values = [ws.FetchBlob(name) for name in output_names] 61 | return ws, namedtupledict('Outputs', output_names)(*output_values) 62 | 63 | 64 | def c2_native_run_net(init_net, predict_net, inputs): 65 | ws = Workspace() 66 | if init_net: 67 | ws.RunNetOnce(init_net) 68 | 69 | if isinstance(inputs, dict): 70 | for key, value in inputs.items(): 71 | ws.FeedBlob(key, value, predict_net.device_option) 72 | else: 73 | uninitialized = [input_name 74 | for input_name in predict_net.external_input 75 | if not ws.HasBlob(input_name)] 76 | if len(uninitialized) == len(inputs): 77 | for key, value in zip(uninitialized, inputs): 78 | ws.FeedBlob(key, value, predict_net.device_option) 79 | else: 80 | # If everything is initialized, 81 | # we just initialized the first len(inputs) external_input. 82 | assert(len(inputs) <= len(predict_net.external_input)) 83 | for i in range(len(inputs)): 84 | ws.FeedBlob(predict_net.external_input[i], inputs[i], 85 | predict_net.device_option) 86 | 87 | ws.RunNetOnce(predict_net) 88 | 89 | output_names = predict_net.external_output 90 | output_values = [ws.FetchBlob(name) for name in output_names] 91 | return ws, namedtupledict('Outputs', output_names)(*output_values) 92 | 93 | 94 | def load_caffe2_net(file): 95 | net = caffe2_pb2.NetDef() 96 | with open(file, "rb") as f: 97 | net.ParseFromString(f.read()) 98 | return net 99 | 100 | 101 | def save_caffe2_net(net, file, output_txt=False): 102 | with open(file, "wb") as f: 103 | f.write(net.SerializeToString()) 104 | if output_txt: 105 | with open(file + "txt", "w") as f: 106 | f.write(str(net)) 107 | 108 | 109 | def benchmark_caffe2_model(init_net, predict_net, warmup_iters=3, main_iters=10, layer_details=True): 110 | ''' 111 | Run the benchmark net on the target model. 112 | Return the execution time per iteration (millisecond). 113 | ''' 114 | ws = Workspace() 115 | if init_net: 116 | ws.RunNetOnce(init_net) 117 | ws.CreateNet(predict_net) 118 | results = ws.BenchmarkNet(predict_net.name, warmup_iters, main_iters, layer_details) 119 | del ws 120 | return results[0] 121 | 122 | 123 | def benchmark_pytorch_model(model, inputs, training=False, warmup_iters=3, 124 | main_iters=10, verbose=False): 125 | ''' 126 | Run the model several times, and measure the execution time. 127 | Return the execution time per iteration (millisecond). 128 | ''' 129 | for _i in range(warmup_iters): 130 | model(*inputs) 131 | total_pytorch_time = 0.0 132 | for _i in range(main_iters): 133 | ts = time.time() 134 | model(*inputs) 135 | te = time.time() 136 | total_pytorch_time += te - ts 137 | log.info("The PyTorch model execution time per iter is {} milliseconds, " 138 | "{} iters per second.".format(total_pytorch_time / main_iters * 1000, 139 | main_iters / total_pytorch_time)) 140 | return total_pytorch_time * 1000 / main_iters 141 | -------------------------------------------------------------------------------- /tests/conversion_test.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import json 6 | import tempfile 7 | import textwrap 8 | import traceback 9 | 10 | from caffe2.proto import caffe2_pb2 11 | from caffe2.python import brew, core 12 | from caffe2.python.model_helper import ModelHelper 13 | from click.testing import CliRunner 14 | import numpy as np 15 | from onnx import helper, ModelProto, TensorProto 16 | from onnx_caffe2.helper import make_model, c2_native_run_net 17 | 18 | from onnx_caffe2.bin.conversion import caffe2_to_onnx, onnx_to_caffe2 19 | from onnx_caffe2.helper import dummy_name 20 | import onnx_caffe2.backend as c2 21 | from tests.test_utils import TestCase 22 | 23 | 24 | class TestConversion(TestCase): 25 | def _run_command(self, cmd, *args, **kwargs): 26 | runner = CliRunner() 27 | result = runner.invoke(cmd, *args, **kwargs) 28 | self.assertEqual(result.exit_code, 0, textwrap.dedent(''' 29 | Command exited with non-zero exit code: 30 | output: {} 31 | exception: {} 32 | exc_info: {} 33 | '''.format(result.output, 34 | result.exception, 35 | traceback.format_exception(*result.exc_info)))) 36 | return result 37 | 38 | def test_caffe2_to_onnx(self): 39 | caffe2_net = tempfile.NamedTemporaryFile() 40 | caffe2_init_net = tempfile.NamedTemporaryFile() 41 | output = tempfile.NamedTemporaryFile() 42 | 43 | model = ModelHelper(name='caffe2-to-onnx-test') 44 | brew.relu(model, ["X"], "Y") 45 | caffe2_net.write(model.net.Proto().SerializeToString()) 46 | caffe2_net.flush() 47 | 48 | init_model = ModelHelper(name='caffe2-to-onnx-init-test') 49 | init_model.net.GivenTensorFill([], 'X', shape=[2, 2], 50 | values=np.zeros((2, 2)).flatten().astype(float)) 51 | caffe2_init_net.write(init_model.net.Proto().SerializeToString()) 52 | caffe2_init_net.flush() 53 | 54 | result = self._run_command( 55 | caffe2_to_onnx, [ 56 | caffe2_net.name, 57 | '--caffe2-init-net', caffe2_init_net.name, 58 | '--output', output.name, 59 | ], 60 | catch_exceptions=False, 61 | ) 62 | 63 | onnx_model = ModelProto() 64 | onnx_model.ParseFromString(output.read()) 65 | self.assertEqual(len(onnx_model.graph.node), 1) 66 | self.assertEqual(onnx_model.graph.node[0].op_type, 'Relu') 67 | self.assertEqual(len(onnx_model.graph.initializer), 1) 68 | self.assertEqual(onnx_model.graph.initializer[0].name, onnx_model.graph.input[0].name) 69 | 70 | def test_caffe2_to_onnx_value_info(self): 71 | caffe2_net = tempfile.NamedTemporaryFile() 72 | output = tempfile.NamedTemporaryFile() 73 | 74 | model = ModelHelper(name='caffe2-to-onnx-test') 75 | brew.relu(model, ["X"], "Y") 76 | caffe2_net.write(model.net.Proto().SerializeToString()) 77 | caffe2_net.flush() 78 | 79 | args = [caffe2_net.name, '--output', output.name] 80 | self.assertRaisesRegexp(Exception, 81 | 'value info', 82 | self._run_command, caffe2_to_onnx, args) 83 | 84 | args.extend([ 85 | '--value-info', 86 | json.dumps({ 87 | 'X': (TensorProto.FLOAT, (2, 2)), 88 | })]) 89 | result = self._run_command(caffe2_to_onnx, args) 90 | 91 | onnx_model = ModelProto() 92 | onnx_model.ParseFromString(output.read()) 93 | self.assertEqual(len(onnx_model.graph.node), 1) 94 | self.assertEqual(onnx_model.graph.node[0].op_type, 'Relu') 95 | self.assertEqual(len(onnx_model.graph.initializer), 0) 96 | 97 | def test_onnx_to_caffe2(self): 98 | onnx_model = tempfile.NamedTemporaryFile() 99 | output = tempfile.NamedTemporaryFile() 100 | init_net_output = tempfile.NamedTemporaryFile() 101 | 102 | node_def = helper.make_node( 103 | "Mul", ["X", "W"], ["Y"]) 104 | graph_def = helper.make_graph( 105 | [node_def], 106 | "test", 107 | [helper.make_tensor_value_info("X", TensorProto.FLOAT, (2, 3)), 108 | helper.make_tensor_value_info("W", TensorProto.FLOAT, (3, 2))], 109 | [helper.make_tensor_value_info("Y", TensorProto.FLOAT, (2, 2))], 110 | initializer=[helper.make_tensor("W", 111 | TensorProto.FLOAT, 112 | [3, 2], 113 | np.zeros((3, 2)).flatten().astype(float))]) 114 | model_def = make_model(graph_def, producer_name='onnx-to-caffe2-test') 115 | onnx_model.write(model_def.SerializeToString()) 116 | onnx_model.flush() 117 | 118 | result = self._run_command( 119 | onnx_to_caffe2, [ 120 | onnx_model.name, 121 | '--output', output.name, 122 | '--init-net-output', init_net_output.name, 123 | ]) 124 | 125 | caffe2_net = caffe2_pb2.NetDef() 126 | caffe2_net.ParseFromString(output.read()) 127 | self.assertEqual(len(caffe2_net.op), 1) 128 | self.assertEqual(caffe2_net.op[0].type, 'Mul') 129 | 130 | caffe2_init_net = caffe2_pb2.NetDef() 131 | caffe2_init_net.ParseFromString(init_net_output.read()) 132 | self.assertEqual(len(caffe2_init_net.op), 1) 133 | self.assertEqual(set(sum([list(init_op.output) 134 | for init_op in caffe2_init_net.op], [])), 135 | {'W'}) 136 | 137 | def test_convert_end2end(self): 138 | predict_net_f = tempfile.NamedTemporaryFile() 139 | init_net_f = tempfile.NamedTemporaryFile() 140 | onnx_model_f = tempfile.NamedTemporaryFile() 141 | 142 | x = 'X' 143 | w = 'W' 144 | b = 'b' 145 | y = 'Y' 146 | 147 | predict_net = caffe2_pb2.NetDef() 148 | predict_net.name = 'test-convert-end2end' 149 | predict_net.external_input[:] = [x, w, b] 150 | predict_net.external_output[:] = [y] 151 | predict_net.op.extend([ 152 | core.CreateOperator( 153 | 'FC', 154 | inputs=[x, w, b], 155 | outputs=[y], 156 | axis=2, 157 | ), 158 | ]) 159 | predict_net_f.write(predict_net.SerializeToString()) 160 | predict_net_f.flush() 161 | 162 | init_net = caffe2_pb2.NetDef() 163 | init_net.name = 'test-convert-end2end-init' 164 | init_net.external_output[:] = [w, b] 165 | x_val = np.random.randn(1, 3, 2).astype(np.float32) 166 | w_val = np.random.randn(4, 2).astype(np.float32) 167 | b_val = np.random.randn(4).astype(np.float32) 168 | init_net.op.extend([ 169 | core.CreateOperator( 170 | 'GivenTensorFill', 171 | [], 172 | [w], 173 | values=w_val, 174 | shape=w_val.shape, 175 | ), 176 | core.CreateOperator( 177 | 'GivenTensorFill', 178 | [], 179 | [b], 180 | values=b_val, 181 | shape=b_val.shape, 182 | ), 183 | ]) 184 | init_net_f.write(init_net.SerializeToString()) 185 | init_net_f.flush() 186 | 187 | y_val = np.matmul(x_val, w_val.transpose()) + b_val 188 | for _ in range(5): 189 | self._run_command( 190 | caffe2_to_onnx, [ 191 | predict_net_f.name, 192 | '--caffe2-init-net', init_net_f.name, 193 | '--output', onnx_model_f.name, 194 | '--value-info', 195 | json.dumps({ 196 | x: (TensorProto.FLOAT, (1, 3, 2)), 197 | }), 198 | ], 199 | catch_exceptions=False, 200 | ) 201 | 202 | onnx_model_f.seek(0) 203 | onnx_model = ModelProto() 204 | onnx_model.ParseFromString(onnx_model_f.read()) 205 | np.testing.assert_almost_equal( 206 | c2.run_model( 207 | onnx_model, {onnx_model.graph.input[0].name: x_val}), 208 | [y_val]) 209 | 210 | self._run_command( 211 | onnx_to_caffe2, [ 212 | onnx_model_f.name, 213 | '--output', predict_net_f.name, 214 | '--init-net-output', init_net_f.name, 215 | ]) 216 | predict_net_f.seek(0) 217 | predict_net = caffe2_pb2.NetDef() 218 | predict_net.ParseFromString(predict_net_f.read()) 219 | init_net_f.seek(0) 220 | init_net = caffe2_pb2.NetDef() 221 | init_net.ParseFromString(init_net_f.read()) 222 | x = predict_net.external_input[0] 223 | np.testing.assert_almost_equal(c2_native_run_net(init_net=init_net, 224 | predict_net=predict_net, 225 | inputs={x: x_val})[1], 226 | [y_val]) 227 | -------------------------------------------------------------------------------- /tests/caffe2_ref_test.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | import json 7 | import os 8 | import unittest 9 | 10 | from caffe2.python import core 11 | from caffe2.proto import caffe2_pb2 12 | 13 | import onnx 14 | from onnx.helper import make_node, make_graph, make_tensor, make_tensor_value_info 15 | from onnx_caffe2.helper import make_model, c2_native_run_net, c2_native_run_op 16 | 17 | from onnx import defs, mapping 18 | import onnx_caffe2.frontend as c2_onnx 19 | import onnx_caffe2.backend as c2 20 | 21 | import numpy as np 22 | from caffe2.python.models.download import downloadFromURLToFile, getURLFromName, deleteDirectory 23 | 24 | from onnx_caffe2.helper import dummy_name 25 | from tests.test_utils import TestCase 26 | 27 | 28 | class TestCaffe2Basic(TestCase): 29 | def test_dummy_name(self): 30 | n1 = dummy_name() 31 | n2 = dummy_name() 32 | assert n1 != n2, "Got same names in different calls: {}".format(n1) 33 | 34 | def test_relu_node_inplace(self): 35 | X = np.random.randn(3, 2).astype(np.float32) 36 | Y_ref = np.clip(X, 0, np.inf) 37 | 38 | node_def = make_node( 39 | "Relu", ["X"], ["Y"], consumed_inputs=[1]) 40 | output = c2.run_node( 41 | node_def, {"X": X}) 42 | np.testing.assert_almost_equal(output.X, Y_ref) 43 | 44 | node_def = make_node( 45 | "Relu", ["X"], ["Y"], consumed_inputs=[1]) 46 | graph_def = make_graph( 47 | [node_def], 48 | name="test", 49 | inputs=[make_tensor_value_info("X", onnx.TensorProto.FLOAT, [3, 2])], 50 | outputs=[make_tensor_value_info("X", onnx.TensorProto.FLOAT, [3, 2])]) 51 | c2_rep = c2.prepare(make_model(graph_def)) 52 | output = c2_rep.run({"X": X}) 53 | np.testing.assert_almost_equal(output.X, Y_ref) 54 | 55 | def test_relu_graph(self): 56 | X = np.random.randn(3, 2).astype(np.float32) 57 | Y_ref = np.clip(X, 0, np.inf) 58 | 59 | node_def = make_node( 60 | "Relu", ["X"], ["Y"]) 61 | output = c2.run_node( 62 | node_def, {"X": X}) 63 | np.testing.assert_almost_equal(output.Y, Y_ref) 64 | 65 | graph_def = make_graph( 66 | [node_def], 67 | name="test", 68 | inputs=[make_tensor_value_info("X", onnx.TensorProto.FLOAT, [3, 2])], 69 | outputs=[make_tensor_value_info("Y", onnx.TensorProto.FLOAT, [3, 2])]) 70 | c2_rep = c2.prepare(make_model(graph_def)) 71 | output = c2_rep.run(X) 72 | np.testing.assert_almost_equal(output.Y, Y_ref) 73 | 74 | def test_initializer(self): 75 | X = np.array([[1, 2], [3, 4]]).astype(np.float32) 76 | Y = np.array([[1, 2], [3, 4]]).astype(np.float32) 77 | weight = np.array([[1, 0], [0, 1]]) 78 | graph_def = make_graph( 79 | [make_node("Add", ["X", "Y"], ["Z0"]), 80 | make_node("Cast", ["Z0"], ["Z"], to="float"), 81 | make_node("Mul", ["Z", "weight"], ["W0"]), 82 | make_node("Tanh", ["W0"], ["W1"]), 83 | make_node("Sigmoid", ["W1"], ["W2"]), 84 | make_node("Scale", ["W2"], ["W3"], scale=-1.0)], 85 | name="test_initializer", 86 | inputs=[ 87 | make_tensor_value_info("X", onnx.TensorProto.FLOAT, (2, 2)), 88 | make_tensor_value_info("Y", onnx.TensorProto.FLOAT, (2, 2)), 89 | make_tensor_value_info("weight", onnx.TensorProto.FLOAT, (2, 2)), 90 | ], 91 | outputs=[ 92 | make_tensor_value_info("W3", onnx.TensorProto.FLOAT, (2, 2)) 93 | ], 94 | initializer=[make_tensor("weight", 95 | onnx.TensorProto.FLOAT, 96 | [2, 2], 97 | weight.flatten().astype(float))] 98 | ) 99 | 100 | def sigmoid(x): 101 | return 1 / (1 + np.exp(-x)) 102 | 103 | W_ref = -sigmoid(np.tanh((X + Y) * weight)) 104 | c2_rep = c2.prepare(make_model(graph_def)) 105 | output = c2_rep.run({"X": X, "Y": Y}) 106 | np.testing.assert_almost_equal(output["W3"], W_ref) 107 | 108 | def test_gemm(self): 109 | # simple 110 | A = np.random.randn(3, 2).astype(np.float32) 111 | B = np.random.randn(2, 4).astype(np.float32) 112 | C = np.random.randn(3, 4).astype(np.float32) 113 | node_def = make_node( 114 | 'Gemm', 115 | ['A', 'B', 'C'], 116 | ["Y"]) 117 | output = c2.run_node(node_def, [A, B, C]) 118 | np.testing.assert_almost_equal(output["Y"], np.dot(A, B) + C) 119 | 120 | # transA 121 | A = np.transpose(A) 122 | node_def = make_node( 123 | 'Gemm', 124 | ['A', 'B', 'C'], 125 | ["Y"], 126 | transA=True) 127 | output = c2.run_node(node_def, [A, B, C]) 128 | np.testing.assert_almost_equal( 129 | output["Y"], 130 | np.dot(np.transpose(A), B) + C) 131 | # revert A 132 | A = np.transpose(A) 133 | 134 | # transB 135 | B = np.transpose(B) 136 | node_def = make_node( 137 | 'Gemm', 138 | ['A', 'B', 'C'], 139 | ["Y"], 140 | transB=True) 141 | output = c2.run_node(node_def, [A, B, C]) 142 | np.testing.assert_almost_equal( 143 | output["Y"], 144 | np.dot(A, np.transpose(B)) + C) 145 | # revert A 146 | B = np.transpose(B) 147 | 148 | # scale 149 | alpha = np.random.random() 150 | beta = np.random.random() 151 | node_def = make_node( 152 | 'Gemm', 153 | ['A', 'B', 'C'], 154 | ["Y"], 155 | alpha=alpha, 156 | beta=beta) 157 | output = c2.run_node(node_def, [A, B, C]) 158 | np.testing.assert_almost_equal( 159 | output["Y"], 160 | alpha * np.dot(A, B) + beta * C) 161 | 162 | # broadcast 163 | C = np.random.randn(4).astype(np.float32) 164 | node_def = make_node( 165 | 'Gemm', 166 | ['A', 'B', 'C'], 167 | ["Y"], 168 | alpha=alpha, 169 | beta=beta, 170 | broadcast=1) 171 | output = c2.run_node(node_def, [A, B, C]) 172 | np.testing.assert_almost_equal( 173 | output["Y"], 174 | alpha * np.dot(A, B) + beta * C) 175 | 176 | def test_tensor_filling_ops(self): 177 | for dtype in [ 178 | onnx.TensorProto.FLOAT, 179 | onnx.TensorProto.DOUBLE, 180 | onnx.TensorProto.BOOL, 181 | onnx.TensorProto.INT8, 182 | onnx.TensorProto.INT16, 183 | onnx.TensorProto.INT32, 184 | onnx.TensorProto.INT64, 185 | onnx.TensorProto.UINT8, 186 | onnx.TensorProto.UINT16, 187 | onnx.TensorProto.UINT32, 188 | ]: 189 | shape = (1, 2, 3) 190 | vals = np.random.randn(*shape) 191 | if dtype != onnx.TensorProto.BOOL: 192 | vals *= 5 193 | vals = vals.astype( 194 | mapping.TENSOR_TYPE_TO_NP_TYPE[dtype]) 195 | tensor = make_tensor( 196 | name='test-tensor-{}'.format(dtype), 197 | data_type=dtype, 198 | dims=[1, 2, 3], 199 | vals=vals.flatten().tolist(), 200 | ) 201 | op = c2.Caffe2Backend._create_tensor_filling_op(tensor) 202 | self.assertEqual(len(op.input), 0) 203 | self.assertEqual(op.output, [tensor.name]) 204 | ws, output = c2_native_run_op(op, inputs=[]) 205 | self.assertEqual(len(output), 1) 206 | np.testing.assert_almost_equal(output[0], vals) 207 | np.testing.assert_almost_equal(ws.FetchBlob(op.output[0]), vals) 208 | 209 | def test_slice(self): 210 | X = np.random.randn(1, 2, 3).astype(np.float32) 211 | starts = np.array([0, 1, 0], dtype=np.int32) 212 | ends = np.array([-1, 2, 3], dtype=np.int32) 213 | 214 | predict_net = caffe2_pb2.NetDef() 215 | predict_net.name = 'test-slice-net' 216 | predict_net.external_input[:] = ['X'] 217 | predict_net.external_output[:] = ['Y'] 218 | predict_net.op.extend([ 219 | core.CreateOperator( 220 | 'Slice', 221 | inputs=['X'], 222 | outputs=['Y'], 223 | starts=starts, 224 | ends=ends, 225 | ), 226 | ]) 227 | ws, (Y,) = c2_native_run_net( 228 | init_net=None, 229 | predict_net=predict_net, 230 | inputs=[X]) 231 | 232 | onnx_model = c2_onnx.caffe2_net_to_onnx_model( 233 | predict_net=predict_net, 234 | value_info={ 235 | 'X': (onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[X.dtype], X.shape) 236 | }) 237 | Y, = c2.run_model(onnx_model, inputs=[X]) 238 | np.testing.assert_almost_equal(Y, X[:, 1:2, :]) 239 | 240 | 241 | class TestCaffe2End2End(TestCase): 242 | def _model_dir(self, model): 243 | caffe2_home = os.path.expanduser(os.getenv('ONNX_HOME', '~/.caffe2')) 244 | models_dir = os.getenv('ONNX_MODELS', os.path.join(caffe2_home, 'models')) 245 | return os.path.join(models_dir, model) 246 | 247 | def _test_net(self, 248 | net_name, 249 | input_blob_dims=(1, 3, 224, 224), 250 | decimal=7): 251 | np.random.seed(seed=0) 252 | model_dir = self._model_dir(net_name) 253 | if not os.path.exists(model_dir): 254 | self._download(net_name) 255 | c2_predict_pb = os.path.join(model_dir, 'predict_net.pb') 256 | c2_predict_net = caffe2_pb2.NetDef() 257 | with open(c2_predict_pb, 'rb') as f: 258 | c2_predict_net.ParseFromString(f.read()) 259 | c2_predict_net.name = net_name 260 | 261 | c2_init_pb = os.path.join(model_dir, 'init_net.pb') 262 | c2_init_net = caffe2_pb2.NetDef() 263 | with open(c2_init_pb, 'rb') as f: 264 | c2_init_net.ParseFromString(f.read()) 265 | c2_init_net.name = net_name + '_init' 266 | 267 | n, c, h, w = input_blob_dims 268 | data = np.random.randn(n, c, h, w).astype(np.float32) 269 | inputs = [data] 270 | _, c2_outputs = c2_native_run_net(c2_init_net, c2_predict_net, inputs) 271 | del _ 272 | 273 | model = c2_onnx.caffe2_net_to_onnx_model( 274 | predict_net=c2_predict_net, 275 | init_net=c2_init_net, 276 | value_info=json.load(open(os.path.join(model_dir, 'value_info.json')))) 277 | c2_ir = c2.prepare(model) 278 | onnx_outputs = c2_ir.run(inputs) 279 | self.assertSameOutputs(c2_outputs, onnx_outputs, decimal=decimal) 280 | 281 | def _download(self, model): 282 | model_dir = self._model_dir(model) 283 | assert not os.path.exists(model_dir) 284 | os.makedirs(model_dir) 285 | for f in ['predict_net.pb', 'init_net.pb', 'value_info.json']: 286 | url = getURLFromName(model, f) 287 | dest = os.path.join(model_dir, f) 288 | try: 289 | try: 290 | downloadFromURLToFile(url, dest, 291 | show_progress=False) 292 | except TypeError: 293 | # show_progress not supported prior to 294 | # Caffe2 78c014e752a374d905ecfb465d44fa16e02a28f1 295 | # (Sep 17, 2017) 296 | downloadFromURLToFile(url, dest) 297 | except Exception as e: 298 | print("Abort: {reason}".format(reason=e)) 299 | print("Cleaning up...") 300 | deleteDirectory(model_dir) 301 | exit(1) 302 | 303 | def test_alexnet(self): 304 | self._test_net('bvlc_alexnet', decimal=4) 305 | 306 | def test_resnet50(self): 307 | self._test_net('resnet50') 308 | 309 | @unittest.skipIf( 310 | os.environ.get('CI'), 311 | 'Taking too long to download!') 312 | def test_vgg16(self): 313 | self._test_net('vgg16') 314 | 315 | @unittest.skipIf( 316 | os.environ.get('CI'), 317 | 'Running vgg19 on Travis with Python 2 keeps getting OOM!') 318 | def test_vgg19(self): 319 | self._test_net('vgg19') 320 | 321 | def test_inception_v1(self): 322 | self._test_net('inception_v1', decimal=2) 323 | 324 | def test_inception_v2(self): 325 | self._test_net('inception_v2') 326 | 327 | @unittest.skip('Need to add support for ConstantFill operator') 328 | def test_squeezenet(self): 329 | self._test_net('squeezenet') 330 | 331 | def test_shufflenet(self): 332 | self._test_net('shufflenet') 333 | 334 | def test_densenet121(self): 335 | self._test_net('densenet121') 336 | 337 | 338 | if __name__ == '__main__': 339 | unittest.main() 340 | -------------------------------------------------------------------------------- /onnx_caffe2/frontend.py: -------------------------------------------------------------------------------- 1 | """Caffe2 Protobuf to ONNX converter 2 | 3 | To run this, you will need to have Caffe2 installed as well. 4 | """ 5 | 6 | from __future__ import absolute_import 7 | from __future__ import division 8 | from __future__ import print_function 9 | from __future__ import unicode_literals 10 | 11 | import itertools 12 | import collections 13 | import logging 14 | import re 15 | 16 | from caffe2.python import core as caffe2_core 17 | from enum import Enum 18 | from onnx import (defs, checker, helper, numpy_helper, mapping, 19 | ModelProto, GraphProto, NodeProto, AttributeProto, TensorProto, OperatorSetIdProto) 20 | from onnx.helper import make_tensor, make_tensor_value_info 21 | import numpy as np 22 | 23 | from onnx_caffe2.helper import make_model, c2_native_run_net, dummy_name 24 | from onnx_caffe2.error import Unsupported 25 | 26 | logging.basicConfig(level=logging.INFO) 27 | logger = logging.getLogger(__name__) 28 | 29 | 30 | class Caffe2Frontend(object): 31 | # This number controls the semantics of the operators we target. Whenever 32 | # ONNX makes a BC breaking change to semantics of operators, having this set 33 | # to an accurate number will prevent our models form exporting. However, 34 | # we should strive to keep this up-to-date as much as possible. 35 | _target_opset_version = 3 36 | 37 | _renamed_operators = { 38 | 'SpatialBN': 'BatchNormalization', 39 | 'Conv1D': 'Conv', 40 | 'Conv2D': 'Conv', 41 | 'Conv3D': 'Conv', 42 | 'ConvTranspose1D': 'ConvTranspose', 43 | 'ConvTranspose2D': 'ConvTranspose', 44 | 'ConvTranspose3D': 'ConvTranspose', 45 | 'MaxPool1D': 'MaxPool', 46 | 'MaxPool2D': 'MaxPool', 47 | 'MaxPool3D': 'MaxPool', 48 | 'AveragePool1D': 'AveragePool', 49 | 'AveragePool2D': 'AveragePool', 50 | 'AveragePool3D': 'AveragePool', 51 | } 52 | 53 | # caffe2 arguments that are completely removed in onnx 54 | _blacklist_caffe2_args = { 55 | 'order': {b'NCHW'}, 56 | 'cudnn_exhaustive_search': {0, 1}, 57 | 'use_cudnn': {0, 1}, 58 | } 59 | 60 | _global_renamed_args = { 61 | 'kernels': 'kernel_shape', 62 | } 63 | 64 | _per_op_renamed_args = { 65 | 'Squeeze': {'dims': 'axes'}, 66 | 'Transpose': {'axes': 'perm'}, 67 | } 68 | 69 | _special_operators = { 70 | 'Conv': '_create_conv_pool_op', 71 | 'ConvTranspose': '_create_conv_pool_op', 72 | 'ChannelShuffle': '_create_channel_shuffle', 73 | 'MaxPool': '_create_conv_pool_op', 74 | 'AveragePool': '_create_conv_pool_op', 75 | 'Concat': '_create_concat', 76 | 'FC': '_create_gemm', 77 | 'LRN': '_create_lrn', 78 | 'Slice': '_create_slice', 79 | 'Reshape': '_create_reshape', 80 | } 81 | 82 | @classmethod 83 | def _common_caffe2_arg_to_onnx_attr(cls, op_def, arg): 84 | # name 85 | op_type = op_def.type 86 | if op_type in cls._per_op_renamed_args: 87 | name = cls._per_op_renamed_args[op_type].get( 88 | arg.name, arg.name) 89 | else: 90 | name = cls._global_renamed_args.get(arg.name, arg.name) 91 | 92 | # value 93 | if arg.HasField('f'): 94 | value = arg.f 95 | elif arg.HasField('i'): 96 | value = arg.i 97 | elif arg.HasField('s'): 98 | value = arg.s 99 | elif arg.floats: 100 | value = arg.floats 101 | elif arg.ints: 102 | value = arg.ints 103 | elif arg.strings: 104 | value = arg.strings 105 | else: 106 | raise ValueError('Could not find data field in arg: {}'.format(arg)) 107 | 108 | if name in cls._blacklist_caffe2_args: 109 | assert value in cls._blacklist_caffe2_args[arg.name] 110 | return None 111 | 112 | return helper.make_attribute(name, value) 113 | 114 | @classmethod 115 | def caffe2_arg_to_onnx_attr(cls, op_def, arg): 116 | return cls._common_caffe2_arg_to_onnx_attr(op_def, arg) 117 | 118 | @classmethod 119 | def _common_caffe2_op_to_onnx_node(cls, op_def, shapes): 120 | node_def = NodeProto() 121 | node_def.name = op_def.name 122 | 123 | node_def.op_type = cls._renamed_operators.get(op_def.type, op_def.type) 124 | 125 | node_def.input.extend(op_def.input) 126 | node_def.output.extend(op_def.output) 127 | 128 | attrs = filter(None, [cls.caffe2_arg_to_onnx_attr(op_def, arg) 129 | for arg in op_def.arg]) 130 | node_def.attribute.extend(attrs) 131 | 132 | return node_def 133 | 134 | @classmethod 135 | def _create_concat(cls, op_def, shapes): 136 | node = cls._common_caffe2_op_to_onnx_node(op_def, shapes) 137 | if len(node.output) == 2: 138 | del node.output[1] 139 | explicit_axis = any(arg.name == 'axis' for arg in op_def.arg) 140 | if not explicit_axis: 141 | node.attribute.extend([helper.make_attribute('axis', 1)]) 142 | return node 143 | 144 | @classmethod 145 | def _create_reshape(cls, op_def, shapes): 146 | node = cls._common_caffe2_op_to_onnx_node(op_def, shapes) 147 | if len(node.output) == 2: 148 | del node.output[1] 149 | return node 150 | 151 | @classmethod 152 | def _create_conv_pool_op(cls, op_def, shapes): 153 | node = cls._common_caffe2_op_to_onnx_node(op_def, shapes) 154 | 155 | if node.op_type in ['MaxPool', 'AveragePool']: 156 | for i, attr in enumerate(node.attribute): 157 | if attr.name == 'global_pooling' and attr.i: 158 | node.op_type = 'Global{}'.format(node.op_type) 159 | del node.attribute[i] 160 | break 161 | 162 | attrs = {attr.name: attr for attr in node.attribute} 163 | def apply_trans(k, dim=2, ks=None): 164 | ks = ks or (k + 's') 165 | if dim == 2: 166 | k_h, k_w = k + '_h', k + '_w' 167 | else: 168 | k_t, k_l, k_b, k_r = k + '_t', k + '_l', k + '_b', k + '_r' 169 | 170 | vals = None 171 | if (dim == 2 and k_h in attrs and k_w in attrs): 172 | vals = [attrs[k_h].i, attrs[k_w].i] 173 | del attrs[k_h] 174 | del attrs[k_w] 175 | elif (dim == 4 and 176 | k_t in attrs and k_l in attrs and k_b in attrs and k_r in attrs): 177 | vals = [attrs[k_t].i, 178 | attrs[k_l].i, 179 | attrs[k_b].i, 180 | attrs[k_r].i] 181 | del attrs[k_t] 182 | del attrs[k_l] 183 | del attrs[k_b] 184 | del attrs[k_r] 185 | elif k in attrs: 186 | vals = [attrs[k].i] * dim 187 | del attrs[k] 188 | 189 | if vals and not node.op_type.startswith('Global'): 190 | attrs[ks] = helper.make_attribute(ks, vals) 191 | 192 | apply_trans('kernel', ks='kernel_shape') 193 | apply_trans('stride') 194 | apply_trans('dilation') 195 | apply_trans('adj') 196 | apply_trans('pad', 4) 197 | 198 | del node.attribute[:] 199 | node.attribute.extend(attrs.values()) 200 | return node 201 | 202 | @classmethod 203 | def _create_gemm(cls, op_def, shapes): 204 | x, w, b = op_def.input 205 | args = {arg.name: arg for arg in op_def.arg} 206 | y, = op_def.output 207 | x_shape = list(shapes[x]) 208 | 209 | nodes = [] 210 | if 'axis' in args: 211 | axis = args['axis'].i 212 | outer = np.prod(x_shape[:axis]).astype(int) 213 | inner = np.prod(x_shape[axis:]).astype(int) 214 | reshaped_x = dummy_name() 215 | nodes.append(helper.make_node( 216 | 'Reshape', 217 | inputs=[x], 218 | outputs=[reshaped_x], 219 | shape=[outer, inner], 220 | )) 221 | x = reshaped_x 222 | 223 | if 'axis_w' in args: 224 | axis_w = args['axis_w'].i 225 | w_shape = shapes[w] 226 | outer = np.prod(w_shape[:axis_w]).astype(int).item() 227 | inner = np.prod(w_shape[axis_w:]).astype(int).item() 228 | reshaped_w = dummy_name() 229 | nodes.append(helper.make_node( 230 | 'Reshape', 231 | inputs=[w], 232 | outputs=[reshaped_w], 233 | shape=[outer, inner], 234 | )) 235 | w = reshaped_w 236 | 237 | gemm_y_output = dummy_name() if 'axis' in args else y 238 | nodes.append(helper.make_node( 239 | 'Gemm', 240 | inputs=[x, w, b], 241 | outputs=[gemm_y_output], 242 | name=op_def.name, 243 | transB=1, 244 | broadcast=1, 245 | )) 246 | 247 | if 'axis' in args: 248 | axis = args['axis'].i 249 | nodes.append(helper.make_node( 250 | 'Reshape', 251 | inputs=[gemm_y_output], 252 | outputs=[y], 253 | shape=x_shape[:axis] + [-1], 254 | )) 255 | 256 | return nodes 257 | 258 | @classmethod 259 | def _create_lrn(cls, op_def, shapes): 260 | node = cls._common_caffe2_op_to_onnx_node(op_def, shapes) 261 | if len(node.output) == 2: 262 | del node.output[1] 263 | return node 264 | 265 | @classmethod 266 | def _create_slice(cls, op_def, shapes): 267 | if len(op_def.input) > 1: 268 | raise Unsupported( 269 | 'ONNX Slice operator does not support dynamic slice.') 270 | node = cls._common_caffe2_op_to_onnx_node(op_def, shapes) 271 | attrs = {attr.name: attr for attr in node.attribute} 272 | ndims = len(attrs['starts'].ints) 273 | 274 | node.attribute.extend([helper.make_attribute('axes', range(ndims))]) 275 | 276 | data, = node.input 277 | shape = shapes[data] 278 | 279 | ends = attrs['ends'].ints 280 | for i, end in enumerate(ends): 281 | if end >= 0: 282 | continue 283 | if end == -1: 284 | end = shape[i] 285 | else: 286 | end = end + 1 287 | ends[i] = end 288 | 289 | return node 290 | 291 | @classmethod 292 | def _create_channel_shuffle(cls, op_def, shapes): 293 | x, = op_def.input 294 | y, = op_def.output 295 | n, c, h, w = shapes[x] 296 | args = {arg.name: arg for arg in op_def.arg} 297 | g = args['group'].i 298 | assert c % g == 0 299 | 300 | nodes = [] 301 | 302 | tmp1 = dummy_name() 303 | nodes.append(helper.make_node( 304 | 'Reshape', 305 | inputs=[x], 306 | outputs=[tmp1], 307 | shape=[n, g, c // g, h, w], 308 | )) 309 | 310 | tmp2 = dummy_name() 311 | nodes.append(helper.make_node( 312 | 'Transpose', 313 | inputs=[tmp1], 314 | outputs=[tmp2], 315 | perm=[0, 2, 1, 3, 4], 316 | )) 317 | 318 | nodes.append(helper.make_node( 319 | 'Reshape', 320 | inputs=[tmp2], 321 | outputs=[y], 322 | shape=[n, c, h, w], 323 | )) 324 | return nodes 325 | 326 | @classmethod 327 | def caffe2_op_to_onnx_node(cls, op_def, shapes): 328 | if op_def.type in cls._special_operators: 329 | translator = getattr(cls, cls._special_operators[op_def.type]) 330 | else: 331 | translator = cls._common_caffe2_op_to_onnx_node 332 | nodes = translator(op_def, shapes) 333 | if not isinstance(nodes, collections.Iterable): 334 | nodes = [nodes] 335 | return nodes 336 | 337 | @staticmethod 338 | def _all_names_in_net(net): 339 | if net is None: 340 | return set() 341 | 342 | names = set() 343 | names.update(net.external_input) 344 | names.update(net.external_output) 345 | for op in net.op: 346 | names.update(op.input) 347 | names.update(op.output) 348 | return names 349 | 350 | @classmethod 351 | def caffe2_net_to_onnx_graph(cls, 352 | predict_net, 353 | init_net=None, 354 | value_info=None): 355 | if value_info is None: 356 | value_info = {} 357 | if not isinstance(value_info, dict): 358 | raise ValueError('Please pass value_info as a ' 359 | 'name -> (type, shape) dictionary') 360 | 361 | cls._ssa_rewrite(predict_net, init_net, value_info) 362 | 363 | if init_net: 364 | initializer = cls.caffe2_init_net_to_initializer(init_net) 365 | value_info.update({init.name: (init.data_type, init.dims) 366 | for init in initializer}) 367 | else: 368 | initializer = [] 369 | 370 | # Check whether we have got type shape info of all input 371 | missing = (set(list(predict_net.external_input)) - 372 | set(value_info.keys())) 373 | if missing: 374 | raise RuntimeError('Could not find value info of inputs: {}'.format( 375 | ', '.join(missing))) 376 | 377 | inputs = {} 378 | for name in predict_net.external_input: 379 | elem_type, shape = value_info[name] 380 | inputs[name] = np.random.randn(*shape).astype( 381 | mapping.TENSOR_TYPE_TO_NP_TYPE[elem_type]) 382 | 383 | ws, outputs = c2_native_run_net( 384 | init_net, 385 | predict_net, 386 | inputs) 387 | 388 | for name in predict_net.external_output: 389 | output = outputs[name] 390 | elem_type = mapping.NP_TYPE_TO_TENSOR_TYPE[output.dtype] 391 | shape = output.shape 392 | value_info[name] = (elem_type, shape) 393 | 394 | graph_def = GraphProto() 395 | graph_def.name = predict_net.name 396 | graph_def.initializer.extend(initializer) 397 | # This is a mapping from Caffe2 names to ONNX names 398 | graph_def.input.extend( 399 | make_tensor_value_info( 400 | name=name, 401 | elem_type=value_info[name][0], 402 | shape=value_info[name][1]) 403 | for name in predict_net.external_input) 404 | 405 | dummy_name(cls._all_names_in_net(predict_net) | 406 | cls._all_names_in_net(init_net)) 407 | 408 | for op in predict_net.op: 409 | shapes = {} 410 | for name in itertools.chain(op.input, op.output): 411 | blob = ws.FetchBlob(name) 412 | if hasattr(blob, 'shape'): 413 | shapes[name] = blob.shape 414 | graph_def.node.extend( 415 | cls.caffe2_op_to_onnx_node( 416 | op, shapes=shapes)) 417 | 418 | all_output = set(sum((list(node.output) for node in graph_def.node), 419 | [init.name for init in graph_def.initializer])) 420 | redundant_output = set(vi.name for vi in graph_def.output) - all_output 421 | if redundant_output: 422 | logger.warning( 423 | 'There are graph output not produced by any node or initializer: {}' 424 | '! Will drop them.'.format(', '.join(redundant_output))) 425 | graph_def.output.extend( 426 | make_tensor_value_info( 427 | name=name, 428 | elem_type=value_info[name][0], 429 | shape=value_info[name][1]) 430 | for name in predict_net.external_output 431 | if name in all_output) 432 | 433 | cls._annotate_consumed(graph_def) 434 | checker.check_graph(graph_def) 435 | return graph_def 436 | 437 | @classmethod 438 | def caffe2_init_net_to_initializer(cls, init_net): 439 | initializer = [] 440 | for op in init_net.op: 441 | assert not op.input 442 | try: 443 | data_type, field_name = { 444 | 'GivenTensorFill': (TensorProto.FLOAT, 'floats'), 445 | 'GivenTensorInt64Fill': (TensorProto.INT64, 'ints'), 446 | 'GivenTensorIntFill': (TensorProto.INT32, 'ints'), 447 | 'GivenTensorBoolFill': (TensorProto.BOOL, 'ints'), 448 | 'GivenTensorStringFill': (TensorProto.STRING, 'strings'), 449 | }[op.type] 450 | except KeyError: 451 | raise RuntimeError( 452 | "Can not translate init_net with operator '{}' " 453 | "to initializer".format(op.type) 454 | ) 455 | raw = (data_type != TensorProto.STRING) 456 | args = {a.name: a for a in op.arg} 457 | vals = getattr(args['values'], field_name) 458 | if raw: 459 | vals = np.asarray( 460 | vals, 461 | dtype=mapping.TENSOR_TYPE_TO_NP_TYPE[data_type]).tobytes() 462 | initializer.append(make_tensor( 463 | name=op.output[0], 464 | data_type=data_type, 465 | dims=args['shape'].ints, 466 | vals=vals, 467 | raw=raw, 468 | )) 469 | return initializer 470 | 471 | @classmethod 472 | def _annotate_consumed(cls, graph_def): 473 | for node in graph_def.node: 474 | schema = defs.get_schema(node.op_type) 475 | consumes = [] 476 | for i, _input_name in enumerate(node.input): 477 | consume_type, output_idx = schema.consumed(i) 478 | if consume_type == defs.OpSchema.UseType.CONSUME_ENFORCED: 479 | consumes.append(1) 480 | else: 481 | consumes.append(0) 482 | 483 | if any(consumes): 484 | node.attribute.extend([helper.make_attribute( 485 | 'consumed_inputs', 486 | consumes, 487 | )]) 488 | 489 | @classmethod 490 | def _ssa_rewrite(cls, net, init_net, value_info): 491 | def ssa_name(name, version): 492 | return '{}_{}'.format(name, version) 493 | 494 | if init_net: 495 | for op in init_net.op: 496 | assert re.match('GivenTensor.*Fill', op.type) 497 | assert len(op.output) == 1 498 | op.output[0] = ssa_name(op.output[0], 0) 499 | init_net.external_input[:] = [ssa_name(name, 0) 500 | for name in init_net.external_input] 501 | init_net.external_output[:] = [ssa_name(name, 0) 502 | for name in init_net.external_output] 503 | if value_info: 504 | ssa_value_info = {ssa_name(name, 0): value 505 | for name, value in value_info.items()} 506 | value_info.clear() 507 | value_info.update(ssa_value_info) 508 | net.external_input[:] = [ssa_name(name, 0) 509 | for name in net.external_input] 510 | ssa, blob_versions = caffe2_core.get_ssa(net) 511 | assert len(net.op) == len(ssa) 512 | for op, (versioned_inputs, versioned_outputs) in zip(net.op, ssa): 513 | op.input[:] = [ssa_name(name, version) 514 | for name, version in versioned_inputs] 515 | op.output[:] = [ssa_name(name, version) 516 | for name, version in versioned_outputs] 517 | net.external_output[:] = [ssa_name(name, blob_versions[name]) 518 | for name in net.external_output] 519 | 520 | @classmethod 521 | def caffe2_net_to_onnx_model(cls, *args, **kwargs): 522 | model = make_model(cls.caffe2_net_to_onnx_graph(*args, **kwargs)) 523 | opset_id = OperatorSetIdProto() 524 | opset_id.domain = '' # ONNX 525 | opset_id.version = cls._target_opset_version 526 | model.opset_import.extend([opset_id]) 527 | checker.check_model(model) 528 | return model 529 | 530 | 531 | caffe2_net_to_onnx_graph = Caffe2Frontend.caffe2_net_to_onnx_graph 532 | caffe2_net_to_onnx_model = Caffe2Frontend.caffe2_net_to_onnx_model 533 | caffe2_init_net_to_initializer = Caffe2Frontend.caffe2_init_net_to_initializer 534 | -------------------------------------------------------------------------------- /onnx_caffe2/backend.py: -------------------------------------------------------------------------------- 1 | """Backend for running ONNX on Caffe2 2 | 3 | To run this, you will need to have Caffe2 installed as well. 4 | """ 5 | from __future__ import absolute_import 6 | from __future__ import division 7 | from __future__ import print_function 8 | from __future__ import unicode_literals 9 | 10 | import os 11 | import collections 12 | from subprocess import Popen, PIPE 13 | 14 | import caffe2 15 | from caffe2.python import core, workspace, rnn_cell, gru_cell 16 | from caffe2.python.model_helper import ModelHelper 17 | from caffe2.proto import caffe2_pb2 18 | import caffe2.python.utils 19 | import numpy as np 20 | import onnx 21 | from onnx import checker, GraphProto, TensorProto, AttributeProto, ModelProto 22 | import onnx.numpy_helper 23 | import onnx.defs 24 | import onnx.optimizer 25 | from onnx.backend.base import Backend, Device, DeviceType, namedtupledict 26 | 27 | from onnx_caffe2.workspace import Workspace 28 | from onnx_caffe2.backend_rep import Caffe2Rep 29 | from onnx_caffe2.helper import dummy_name 30 | 31 | import warnings 32 | 33 | def force_unicode(s): 34 | try: 35 | return s.decode('utf-8') 36 | except AttributeError: 37 | return s 38 | 39 | def get_device_option(device): 40 | m = {DeviceType.CPU: caffe2_pb2.CPU, 41 | DeviceType.CUDA: caffe2_pb2.CUDA} 42 | return core.DeviceOption(m[device.type], device.device_id) 43 | 44 | 45 | class OnnxAttributes(dict): 46 | """ 47 | This is a more convenient way to work with ONNX/Caffe2 attributes 48 | that is not the protobuf representation. 49 | """ 50 | @staticmethod 51 | def from_onnx(args): 52 | d = OnnxAttributes() 53 | for arg in args: 54 | d[arg.name] = convertAttributeProto(arg) 55 | return d 56 | 57 | def caffe2(self, kmap=lambda k: k): 58 | for k, v in self.items(): 59 | if kmap(k) != '': 60 | yield caffe2.python.utils.MakeArgument(kmap(k), v) 61 | 62 | 63 | # TODO: Move this into ONNX main library 64 | def convertAttributeProto(onnx_arg): 65 | """ 66 | Convert an ONNX AttributeProto into an appropriate Python object 67 | for the type. 68 | 69 | NB: Tensor attribute gets returned as the straight proto. 70 | """ 71 | if onnx_arg.HasField('f'): 72 | return onnx_arg.f 73 | elif onnx_arg.HasField('i'): 74 | return onnx_arg.i 75 | elif onnx_arg.HasField('s'): 76 | return onnx_arg.s 77 | elif onnx_arg.HasField('t'): 78 | return onnx_arg.t # this is a proto! 79 | elif len(onnx_arg.floats): 80 | return list(onnx_arg.floats) 81 | elif len(onnx_arg.ints): 82 | return list(onnx_arg.ints) 83 | elif len(onnx_arg.strings): 84 | return list(onnx_arg.strings) 85 | else: 86 | raise ValueError("Unsupported ONNX attribute: {}".format(onnx_arg)) 87 | 88 | 89 | # TODO: Move this into ONNX main library 90 | class OnnxNode(object): 91 | """ 92 | Reimplementation of NodeProto from ONNX, but in a form 93 | more convenient to work with from Python. 94 | 95 | We may temporarily edit these nodes to get them into Caffe2 form, 96 | before actually translating into the Caffe2 protobuf, since this 97 | is easier than decomposing everything, and putting it back together 98 | when we're ready. 99 | """ 100 | def __init__(self, node): 101 | self.name = str(node.name) 102 | self.op_type = str(node.op_type) 103 | self.attrs = OnnxAttributes.from_onnx(node.attribute) 104 | self.consumed_inputs = self.attrs.pop("consumed_inputs", None) 105 | self.inputs = list(node.input) 106 | self.outputs = list(node.output) 107 | 108 | 109 | Caffe2Ops = collections.namedtuple('Caffe2Ops', ['ops', 'init_ops', 'interface_blobs']) 110 | 111 | 112 | class Caffe2Backend(Backend): 113 | 114 | # The greatest version of the ONNX operator set which we are aware of. 115 | # Models whose version is larger than this will cause us to emit a warning 116 | # that we are attempting to translate on a "best effort" basis. 117 | # 118 | # If you increase this, make SURE you cross-reference all BC-breaking 119 | # changes from one version to the next, and any that you did not 120 | # implement, mark as broken in _broken_operators 121 | _known_opset_version = 3 122 | 123 | # This dictionary will record operators which are KNOWN to be 124 | # broken, so we give a good error message rather than do something 125 | # bogus and then fail. 126 | _broken_operators = { 127 | # 'BrokenOp': version_it_was_broken_in 128 | } 129 | 130 | # Operators that are different between Caffe2 and 131 | # ONNX but only in their name. 132 | # In most cases, this should be empty - as the effort of ONNX is 133 | # to unify the operator definitions. 134 | _renamed_operators = { 135 | 'Caffe2ConvTranspose': 'ConvTranspose', 136 | 'GlobalMaxPool': 'MaxPool', 137 | 'GlobalAveragePool': 'AveragePool', 138 | 'Pad': 'PadImage', 139 | 'Neg': 'Negative', 140 | 'BatchNormalization': 'SpatialBN', 141 | 'InstanceNormalization': 'InstanceNorm', 142 | 'MatMul': 'BatchMatMul', 143 | 'Upsample': 'ResizeNearest', 144 | 'Identity': 'Copy', 145 | 'InstanceNormalization': 'InstanceNorm', 146 | 'Equal': 'EQ', 147 | 'Less': 'LT', 148 | 'Greater': 'GT', 149 | 'Unsqueeze': 'ExpandDims', 150 | } 151 | 152 | _global_renamed_attrs = {'kernel_shape': 'kernels'} 153 | _per_op_renamed_attrs = { 154 | 'Squeeze': {'axes': 'dims'}, 155 | 'Unsqueeze': {'axes': 'dims'}, 156 | 'Transpose': {'perm': 'axes'}, 157 | 'Upsample': {'mode': ''}, 158 | 'ConvTranspose': {'output_padding': 'adjs'}, 159 | 'Selu': {'gamma': 'scale'}, 160 | } 161 | 162 | # operators whose behavior is different beyond renaming 163 | # the value is an attribute of this class that is a 164 | # function from ToffeIR node_def to caffe2 op_def 165 | _special_operators = { 166 | 'Constant': '_create_constant', 167 | 'Conv': '_create_conv_pool_op_base', 168 | 'AveragePool': '_create_conv_pool_op_base', 169 | 'GlobalAveragePool': '_create_conv_pool_op_base', 170 | 'GlobalMaxPool': '_create_conv_pool_op_base', 171 | 'MaxPool': '_create_conv_pool_op_base', 172 | 'Reshape': '_create_reshape', 173 | 'Gather': '_create_gather', 174 | 'Gemm': '_create_gemm', 175 | 'Pad': '_create_pad', 176 | 'Concat': '_create_concat', 177 | 'LogSoftmax': '_create_logsoftmax', 178 | 'Slice': '_create_slice', 179 | 'LSTM': '_create_lstm', 180 | 'GRU': '_create_gru', 181 | 'RNN': '_create_rnn', 182 | 'Sqrt': '_create_sqrt', 183 | 'Reciprocal': '_create_reciprocal', 184 | } 185 | 186 | # NB: By default, you will use the LATEST definition of the operator, 187 | # so this interface MAY make BC-breaking changes. Specify an 188 | # opset_version if you don't want this to version. 189 | @classmethod 190 | def run_node(cls, node, inputs, device='CPU', opset_version=_known_opset_version): 191 | super(Caffe2Backend, cls).run_node(node, inputs, device) 192 | 193 | device_option = get_device_option(Device(device)) 194 | with Workspace(), core.DeviceScope(device_option): # temporary! 195 | if isinstance(inputs, dict): 196 | for key, value in inputs.items(): 197 | workspace.FeedBlob(key, value) 198 | else: 199 | assert len(node.input) == len(inputs), "{}: expected {} but got {}".format( 200 | node.op_type, len(node.input), len(inputs)) 201 | for key, value in zip(node.input, inputs): 202 | workspace.FeedBlob(key, value) 203 | 204 | cls._inplace_rewrite([node]) 205 | init_ops, ops, _ = cls._onnx_node_to_caffe2_op( 206 | None, None, node, opset_version or cls._known_opset_version) 207 | ops = init_ops + ops 208 | for op in ops: 209 | op.device_option.CopyFrom(device_option) 210 | workspace.RunOperatorsOnce(ops) 211 | output_values = [workspace.FetchBlob(name) for name in node.output] 212 | return namedtupledict('Outputs', node.output)(*output_values) 213 | 214 | @classmethod 215 | def _create_tensor_filling_op(cls, onnx_tensor, name=None): 216 | """ 217 | Given an Onnx TensorProto, translate it into a Caffe2 operator 218 | which produces the given tensor filling op. 219 | """ 220 | assert name or onnx_tensor.name 221 | name = name or onnx_tensor.name 222 | 223 | c2_op = caffe2_pb2.OperatorDef() 224 | 225 | c2_values = c2_op.arg.add() 226 | c2_values.name = "values" 227 | 228 | def tensor2list(onnx_tensor): 229 | # Use the onnx.numpy_helper because the data may be raw 230 | return onnx.numpy_helper.to_array(onnx_tensor).flatten().tolist() 231 | 232 | if onnx_tensor.data_type in [TensorProto.FLOAT]: 233 | c2_op.type = 'GivenTensorFill' 234 | c2_values.floats.extend(tensor2list(onnx_tensor)) 235 | elif onnx_tensor.data_type in [TensorProto.DOUBLE]: 236 | c2_op.type = 'GivenTensorDoubleFill' 237 | c2_values.floats.extend(tensor2list(onnx_tensor)) 238 | elif onnx_tensor.data_type in [TensorProto.INT64, 239 | TensorProto.UINT32]: 240 | c2_op.type = 'GivenTensorInt64Fill' 241 | c2_values.ints.extend(tensor2list(onnx_tensor)) 242 | elif onnx_tensor.data_type in [TensorProto.UINT8, 243 | TensorProto.INT8, 244 | TensorProto.UINT16, 245 | TensorProto.INT16, 246 | TensorProto.INT32]: 247 | c2_op.type = 'GivenTensorIntFill' 248 | c2_values.ints.extend(tensor2list(onnx_tensor)) 249 | elif onnx_tensor.data_type == TensorProto.BOOL: 250 | c2_op.type = 'GivenTensorBoolFill' 251 | c2_values.ints.extend(tensor2list(onnx_tensor)) 252 | elif onnx_tensor.data_type == TensorProto.STRING: 253 | c2_op.type = 'GivenTensorStringFill' 254 | c2_values.strings.extend(onnx_tensor.string_data) 255 | else: 256 | raise RuntimeError( 257 | "unrecognized tensor type {}".format(onnx_tensor.data_type)) 258 | 259 | c2_shape = c2_op.arg.add() 260 | c2_shape.name = "shape" 261 | c2_shape.ints.extend(onnx_tensor.dims) 262 | 263 | c2_op.output.append(name) 264 | 265 | return c2_op 266 | 267 | @classmethod 268 | def _create_constant(cls, init_model, pred_model, n, opset_version): 269 | assert len(n.outputs) == 1 270 | return cls._create_tensor_filling_op(n.attrs["value"], n.outputs[0]) 271 | 272 | @classmethod 273 | def _create_gather(cls, init_model, pred_model, n, opset_version): 274 | (A, B) = n.inputs 275 | (Y, ) = n.outputs 276 | axis = n.attrs.get('axis', 0) 277 | 278 | if axis == 0: 279 | return core.CreateOperator("Gather", [A, B], [Y]) 280 | elif axis == 1: 281 | return core.CreateOperator("BatchGather", [A, B], [Y]) 282 | raise ValueError( 283 | 'Caffe2 only supports Gather with axis being 0 or 1,' + 284 | 'whereas axis is ' + str(axis)) 285 | 286 | @classmethod 287 | def _create_logsoftmax(cls, init_model, pred_model, n, opset_version): 288 | # NB: this implementation is not backward stable. 289 | (A,) = n.inputs 290 | (Y,) = n.outputs 291 | axis = n.attrs.get('axis', 1) 292 | ops = [] 293 | softmax_A = dummy_name() 294 | ops.append(core.CreateOperator('Softmax', [A], [softmax_A], axis=axis)) 295 | ops.append(core.CreateOperator('Log', [softmax_A], [Y])) 296 | return ops 297 | 298 | @classmethod 299 | def _create_gemm(cls, init_model, pred_model, n, opset_version): 300 | (A, B, C) = n.inputs 301 | (Y,) = n.outputs 302 | alpha = n.attrs.get('alpha', 1.) 303 | beta = n.attrs.get('beta', 1.) 304 | 305 | ops = [] 306 | if alpha != 1: 307 | scaled_A = dummy_name() 308 | ops.append(core.CreateOperator('Scale', [A], [scaled_A], scale=alpha)) 309 | A = scaled_A 310 | if beta != 1: 311 | scaled_C = dummy_name() 312 | ops.append(core.CreateOperator('Scale', [C], [scaled_C], scale=beta)) 313 | C = scaled_C 314 | 315 | trans_a = n.attrs.get('transA', 0) 316 | trans_b = n.attrs.get('transB', 0) 317 | broadcast = n.attrs.get('broadcast', 0) 318 | if not trans_a and trans_b and broadcast: 319 | ops.append(core.CreateOperator('FC', 320 | [A, B, C], 321 | [Y])) 322 | else: 323 | AB = dummy_name() 324 | ops.append(core.CreateOperator('MatMul', 325 | [A, B], 326 | [AB], 327 | trans_a=trans_a, 328 | trans_b=trans_b)) 329 | ops.append(core.CreateOperator('Add', 330 | [AB, C], 331 | [Y], 332 | broadcast=broadcast)) 333 | 334 | return ops 335 | 336 | @classmethod 337 | def _rnn_shape_inference(cls, init_model, pred_model, n, input_blob, W): 338 | # ad-hoc, informally-specified, bug-ridden, slow 339 | # implementation of shape inference 340 | 341 | # if the weight matrices are directly provided as 342 | # initializers, their dimensions should be available in the 343 | # init net model. 344 | for x in init_model.graph.input: 345 | if x.name == W: 346 | return x.type.tensor_type.shape.dim[1].dim_value 347 | 348 | # otherwise, assume that the input_blob is either a direct 349 | # graph input, or another rnn op of the same type. This 350 | # matches the pattern produced by exporting from pytorch 351 | # (where the weight matrices are unusable for this purpose due 352 | # to reshaping operations that lose shape information). 353 | for x in pred_model.graph.input: 354 | if x.name == input_blob: 355 | return x.type.tensor_type.shape.dim[2].dim_value 356 | 357 | curr = n 358 | while True: 359 | for x in pred_model.graph.input: 360 | if x.name == curr.inputs[0] and curr.op_type == 'Gather': 361 | return x.type.tensor_type.shape.dim[1].dim_value 362 | prev = [x for x in map(OnnxNode, pred_model.graph.node) if x.outputs[0] == curr.inputs[0]] 363 | if len(prev) != 1: 364 | return 365 | prev = prev[0] 366 | if prev.op_type == n.op_type: 367 | return prev.attrs['hidden_size'] 368 | curr = prev 369 | 370 | @classmethod 371 | def _create_rnn(cls, init_model, pred_model, n, opset_version): 372 | assert init_model is not None, "cannot convert RNNs without access to the full model" 373 | assert pred_model is not None, "cannot convert RNNs without access to the full model" 374 | 375 | attrs = dict(n.attrs) # make a copy, which is safe to mutate 376 | hidden_size = attrs.pop('hidden_size') 377 | activation = force_unicode(attrs.pop('activations', ('tanh',))[0]) 378 | direction = force_unicode(attrs.pop('direction', 'forward')) 379 | assert not attrs, "unsupported RNN attributes: " + str(attrs.keys()) 380 | assert direction in ['forward', 'bidirectional'], "unsupported backwards RNN" 381 | 382 | input_blob, W, R, B, sequence_lens, initial_h = n.inputs 383 | 384 | if sequence_lens == "": 385 | sequence_lens = None 386 | 387 | input_size = cls._rnn_shape_inference(init_model, pred_model, n, input_blob, W) 388 | if input_size is None: 389 | raise RuntimeError("best-effort shape inference for RNN input failed") 390 | 391 | init_net = core.Net("init-net") 392 | pred_mh = ModelHelper() 393 | 394 | def make_rnn(direction_offset): 395 | name = dummy_name() 396 | 397 | # input and recurrence biases are squashed together in 398 | # onnx but not in caffe2 399 | 400 | bias_offset = 2 * direction_offset * hidden_size 401 | init_net.Slice(B, name + "/i2h_b", 402 | starts=[bias_offset + 0 * hidden_size], 403 | ends =[bias_offset + 1 * hidden_size]) 404 | init_net.Slice(B, name + "/gates_t_b", 405 | starts=[bias_offset + 1 * hidden_size], 406 | ends =[bias_offset + 2 * hidden_size]) 407 | 408 | weight_offset = direction_offset * hidden_size 409 | init_net.Slice(W, name + '/i2h_w', 410 | starts=[weight_offset + 0 * hidden_size, 0], 411 | ends =[weight_offset + 1 * hidden_size,-1]) 412 | init_net.Slice(R, name + '/gates_t_w', 413 | starts=[weight_offset + 0 * hidden_size, 0], 414 | ends =[weight_offset + 1 * hidden_size,-1]) 415 | 416 | initial_h_sliced = name + '/initial_h' 417 | init_net.Slice(initial_h, initial_h_sliced, 418 | starts=[direction_offset + 0, 0, 0], 419 | ends =[direction_offset + 1,-1,-1]) 420 | 421 | if direction_offset == 1: 422 | input = pred_mh.net.ReversePackedSegs( 423 | [input_blob, sequence_lens], name + "/input-reversed") 424 | else: 425 | input = input_blob 426 | 427 | hidden_t_all, hidden_t_last = rnn_cell.BasicRNN( 428 | pred_mh, 429 | input, 430 | sequence_lens, 431 | [initial_h_sliced], 432 | input_size, 433 | hidden_size, 434 | name, 435 | drop_states=True, 436 | forward_only=True, 437 | activation=activation 438 | ) 439 | 440 | if direction_offset == 1: 441 | hidden_t_all = pred_mh.net.ReversePackedSegs( 442 | [hidden_t_all, sequence_lens], name + "/output-reversed") 443 | 444 | return hidden_t_all, hidden_t_last 445 | 446 | if direction == 'forward': 447 | hidden_t_all, hidden_t_last = make_rnn(0) 448 | pred_mh.net = pred_mh.net.Clone( 449 | "dummy-clone-net", 450 | blob_remap={ hidden_t_all: n.outputs[0], hidden_t_last: n.outputs[1] } 451 | ) 452 | elif direction == 'bidirectional': 453 | hidden_t_all_f, hidden_t_last_f = make_rnn(0) 454 | hidden_t_all_b, hidden_t_last_b = make_rnn(1) 455 | pred_mh.net.Concat([hidden_t_all_f, hidden_t_all_b], 456 | [n.outputs[0], dummy_name()], axis=2) 457 | pred_mh.net.Concat([hidden_t_last_f, hidden_t_last_b], 458 | [n.outputs[1], dummy_name()], axis=2) 459 | 460 | return Caffe2Ops(list(pred_mh.Proto().op), 461 | list(init_net.Proto().op), 462 | list(pred_mh.Proto().external_input)) 463 | 464 | @classmethod 465 | def _create_lstm(cls, init_model, pred_model, n, opset_version): 466 | assert init_model is not None, "cannot convert LSTMs without access to the full model" 467 | assert pred_model is not None, "cannot convert LSTMs without access to the full model" 468 | 469 | attrs = dict(n.attrs) # make a copy, which is safe to mutate 470 | hidden_size = attrs.pop('hidden_size') 471 | direction = force_unicode(attrs.pop('direction', 'forward')) 472 | assert not attrs, "unsupported LSTM attributes: " + str(attrs.keys()) 473 | assert direction in ['forward', 'bidirectional'], "unsupported backwards LSTM" 474 | 475 | input_blob, W, R, B, sequence_lens, initial_h, initial_c = n.inputs 476 | 477 | if sequence_lens == "": 478 | sequence_lens = None 479 | 480 | input_size = cls._rnn_shape_inference(init_model, pred_model, n, input_blob, W) 481 | if input_size is None: 482 | raise RuntimeError("best-effort shape inference for LSTM input failed") 483 | 484 | init_net = core.Net("init-net") 485 | pred_mh = ModelHelper() 486 | 487 | def make_lstm(direction_offset): 488 | name = dummy_name() 489 | 490 | # input and recurrence biases are squashed together in 491 | # onnx but not in caffe2 492 | 493 | bias_offset = 8 * direction_offset * hidden_size 494 | Bi = init_net.Slice(B, name + "_bias_i2h", 495 | starts=[bias_offset + 0 * hidden_size], 496 | ends =[bias_offset + 4 * hidden_size]) 497 | Br = init_net.Slice(B, name + "_bias_gates", 498 | starts=[bias_offset + 4 * hidden_size], 499 | ends =[bias_offset + 8 * hidden_size]) 500 | 501 | weight_offset = 4 * direction_offset * hidden_size 502 | W_ = init_net.Slice(W, name + '/i2h_w_pre', 503 | starts=[weight_offset + 0 * hidden_size, 0], 504 | ends =[weight_offset + 4 * hidden_size,-1]) 505 | R_ = init_net.Slice(R, name + '/gates_t_w_pre', 506 | starts=[weight_offset + 0 * hidden_size, 0], 507 | ends =[weight_offset + 4 * hidden_size,-1]) 508 | 509 | # caffe2 has a different order from onnx. We need to rearrange 510 | # i o f c -> i f o c 511 | reforms = ((W_, 'i2h_w', [(0, -1)]), 512 | (R_, 'gates_t_w', [(0, -1)]), 513 | (Bi, 'i2h_b' , []), 514 | (Br, 'gates_t_b', [])) 515 | for name_from, name_to, extra_dims in reforms: 516 | xi, xo, xf, xc = [name_from + suffix for suffix in ("_i", "_o", "_f", "_c")] 517 | for i, x in enumerate([xi, xo, xf, xc]): 518 | dim0 = i * hidden_size, (i+1) * hidden_size 519 | starts, ends = zip(dim0, *extra_dims) 520 | init_net.Slice(name_from, x, starts=starts, ends=ends) 521 | init_net.Concat([xi, xf, xo, xc], ['%s/%s' % (name, name_to), dummy_name()], axis=0) 522 | 523 | initial_h_sliced = name + '/initial_h' 524 | init_net.Slice(initial_h, initial_h_sliced, 525 | starts=[direction_offset + 0, 0, 0], 526 | ends =[direction_offset + 1,-1,-1]) 527 | initial_c_sliced = name + '/initial_c' 528 | init_net.Slice(initial_c, initial_c_sliced, 529 | starts=[direction_offset + 0, 0, 0], 530 | ends =[direction_offset + 1,-1,-1]) 531 | 532 | if direction_offset == 1: 533 | input = pred_mh.net.ReversePackedSegs( 534 | [input_blob, sequence_lens], name + "/input-reversed") 535 | else: 536 | input = input_blob 537 | 538 | hidden_t_all, hidden_t_last, _, _, params = rnn_cell.LSTM( 539 | pred_mh, 540 | input, 541 | sequence_lens, 542 | [initial_h_sliced, initial_c_sliced], 543 | input_size, 544 | hidden_size, 545 | name, 546 | drop_states=True, 547 | forward_only=True, 548 | return_params=True 549 | ) 550 | 551 | if direction_offset == 1: 552 | hidden_t_all = pred_mh.net.ReversePackedSegs( 553 | [hidden_t_all, sequence_lens], name + "/output-reversed") 554 | 555 | return hidden_t_all, hidden_t_last 556 | 557 | if direction == 'forward': 558 | hidden_t_all, hidden_t_last = make_lstm(0) 559 | pred_mh.net = pred_mh.net.Clone( 560 | "dummy-clone-net", 561 | blob_remap={ hidden_t_all: n.outputs[0], hidden_t_last: n.outputs[1] } 562 | ) 563 | elif direction == 'bidirectional': 564 | hidden_t_all_f, hidden_t_last_f = make_lstm(0) 565 | hidden_t_all_b, hidden_t_last_b = make_lstm(1) 566 | pred_mh.net.Concat([hidden_t_all_f, hidden_t_all_b], 567 | [n.outputs[0], dummy_name()], axis=2) 568 | pred_mh.net.Concat([hidden_t_last_f, hidden_t_last_b], 569 | [n.outputs[1], dummy_name()], axis=2) 570 | 571 | return Caffe2Ops(list(pred_mh.Proto().op), 572 | list(init_net.Proto().op), 573 | list(pred_mh.Proto().external_input)) 574 | 575 | @classmethod 576 | def _create_gru(cls, init_model, pred_model, n, opset_version): 577 | assert init_model is not None, "cannot convert GRUs without access to the full model" 578 | assert pred_model is not None, "cannot convert GRUs without access to the full model" 579 | 580 | attrs = dict(n.attrs) # make a copy, which is safe to mutate 581 | hidden_size = attrs.pop('hidden_size') 582 | linear_before_reset = attrs.pop('linear_before_reset', 0) 583 | direction = force_unicode(attrs.pop('direction', 'forward')) 584 | assert not attrs, "unsupported GRU attributes: " + str(attrs.keys()) 585 | assert direction in ['forward', 'bidirectional'], "unsupported backwards GRU" 586 | 587 | input_blob, W, R, B, sequence_lens, initial_h = n.inputs 588 | 589 | if sequence_lens == "": 590 | sequence_lens = None 591 | 592 | input_size = cls._rnn_shape_inference(init_model, pred_model, n, input_blob, W) 593 | if input_size is None: 594 | raise RuntimeError("best-effort shape inference for GRU input failed") 595 | 596 | init_net = core.Net("init-net") 597 | pred_mh = ModelHelper() 598 | 599 | def make_gru(direction_offset): 600 | name = dummy_name() 601 | 602 | # input and recurrence biases are squashed together in 603 | # onnx but not in caffe2 604 | 605 | bias_offset = 6 * direction_offset * hidden_size 606 | Bi = init_net.Slice(B, name + "_bias_i2h", 607 | starts=[bias_offset + 0 * hidden_size], 608 | ends =[bias_offset + 3 * hidden_size]) 609 | Br = init_net.Slice(B, name + "_bias_gates", 610 | starts=[bias_offset + 3 * hidden_size], 611 | ends =[bias_offset + 6 * hidden_size]) 612 | 613 | weight_offset = 3 * direction_offset * hidden_size 614 | W_ = init_net.Slice(W, name + '/i2h_w_pre', 615 | starts=[weight_offset + 0 * hidden_size, 0], 616 | ends =[weight_offset + 3 * hidden_size,-1]) 617 | R_ = init_net.Slice(R, name + '/gates_t_w_pre', 618 | starts=[weight_offset + 0 * hidden_size, 0], 619 | ends =[weight_offset + 3 * hidden_size,-1]) 620 | 621 | # caffe2 has a different order from onnx. We need to rearrange 622 | # z r h -> r z h 623 | reforms = ((W_, 'i2h_w', True, [(0,-1)]), 624 | (R_, 'gate_t_w', False, [(0,-1)]), 625 | (Bi, 'i2h_b', True, []), 626 | (Br, 'gate_t_b', False, [])) 627 | for name_from, name_to, do_concat, extra_dims in reforms: 628 | xz, xr, xh = ['%s/%s_%s' % (name, prefix, name_to) for prefix in ('update', 'reset', 'output')] 629 | for i, x in enumerate([xz, xr, xh]): 630 | dim0 = i * hidden_size, (i+1) * hidden_size 631 | starts, ends = zip(dim0, *extra_dims) 632 | init_net.Slice(name_from, x, starts=starts, ends=ends) 633 | if do_concat: 634 | init_net.Concat([xr, xz, xh], ['%s/%s' % (name, name_to), dummy_name()], axis=0) 635 | 636 | initial_h_sliced = name + '/initial_h' 637 | init_net.Slice(initial_h, initial_h_sliced, 638 | starts=[direction_offset + 0, 0, 0], 639 | ends =[direction_offset + 1,-1,-1]) 640 | 641 | if direction_offset == 1: 642 | input = pred_mh.net.ReversePackedSegs( 643 | [input_blob, sequence_lens], name + "/input-reversed") 644 | else: 645 | input = input_blob 646 | 647 | hidden_t_all, hidden_t_last = gru_cell.GRU( 648 | pred_mh, 649 | input, 650 | sequence_lens, 651 | [initial_h_sliced], 652 | input_size, 653 | hidden_size, 654 | name, 655 | drop_states=True, 656 | forward_only=True, 657 | linear_before_reset=linear_before_reset 658 | ) 659 | 660 | if direction_offset == 1: 661 | hidden_t_all = pred_mh.net.ReversePackedSegs( 662 | [hidden_t_all, sequence_lens], name + "/output-reversed") 663 | 664 | return hidden_t_all, hidden_t_last 665 | 666 | if direction == 'forward': 667 | hidden_t_all, hidden_t_last = make_gru(0) 668 | pred_mh.net = pred_mh.net.Clone( 669 | "dummy-clone-net", 670 | blob_remap={ hidden_t_all: n.outputs[0], hidden_t_last: n.outputs[1] } 671 | ) 672 | elif direction == 'bidirectional': 673 | hidden_t_all_f, hidden_t_last_f = make_gru(0) 674 | hidden_t_all_b, hidden_t_last_b = make_gru(1) 675 | pred_mh.net.Concat([hidden_t_all_f, hidden_t_all_b], 676 | [n.outputs[0], dummy_name()], axis=2) 677 | pred_mh.net.Concat([hidden_t_last_f, hidden_t_last_b], 678 | [n.outputs[1], dummy_name()], axis=2) 679 | 680 | return Caffe2Ops(list(pred_mh.Proto().op), 681 | list(init_net.Proto().op), 682 | list(pred_mh.Proto().external_input)) 683 | 684 | @classmethod 685 | def _create_pad(cls, init_model, pred_model, n, opset_version): 686 | if opset_version < 2: 687 | pads = n.attrs['paddings'] 688 | else: 689 | pads = n.attrs['pads'] 690 | if not (len(pads) == 8 and 691 | # first two dim is for batch and channel 692 | set(pads[:2] + pads[4:6]) == {0}): 693 | raise ValueError('Caffe2 only supports padding 2D Tensor, whereas padding is ' + str(pads)) 694 | # Guard the invalid (negative) pads attribute. 695 | if min(pads) < 0: 696 | raise ValueError('ONNX does not support negative pads in Pad, but get {}.'.format(pads)) 697 | pads[:] = pads[2:4] + pads[6:8] 698 | return cls._common_onnx_node_to_caffe2_op(init_model, pred_model, n, opset_version) 699 | 700 | @classmethod 701 | def _create_concat(cls, init_model, pred_model, n, opset_version): 702 | # TODO: Caffe2 Concat has an extra output. It should be only 703 | # used when doing training, so we should change Caffe2 to allow 704 | # 1 output. 705 | op = cls._common_onnx_node_to_caffe2_op(init_model, pred_model, n, opset_version) 706 | assert len(op.output) == 1 707 | op.output.append(dummy_name()) 708 | return op 709 | 710 | @classmethod 711 | def _create_slice(cls, init_model, pred_model, n, opset_version): 712 | op = cls._common_onnx_node_to_caffe2_op(init_model, pred_model, n, opset_version) 713 | args = {arg.name: arg for arg in op.arg} 714 | starts_vals = np.array( 715 | args.pop('starts').ints, dtype=np.int64).tolist() 716 | ends_vals = np.array( 717 | [i - 1 if i < 0 else i for i in args.pop('ends').ints], 718 | dtype=np.int64).tolist() 719 | if 'axes' in args: 720 | axes_vals = np.array( 721 | args.pop('axes').ints, dtype=np.int32).tolist() 722 | else: 723 | ndims = len(starts_vals) 724 | axes_vals = np.array(range(ndims), dtype=np.int32).tolist() 725 | 726 | data, = op.input 727 | ops = [] 728 | 729 | shape_tensor = dummy_name() 730 | ops.append(core.CreateOperator( 731 | 'Shape', 732 | [data], 733 | [shape_tensor] 734 | )) 735 | 736 | axes_tensor = dummy_name() 737 | ops.extend([ 738 | core.CreateOperator( 739 | 'GivenTensorIntFill', 740 | [], 741 | [axes_tensor], 742 | shape=[len(axes_vals)], 743 | values=axes_vals, 744 | ), 745 | ]) 746 | 747 | starts_vals_tensor = dummy_name() 748 | starts_tensor = dummy_name() 749 | casted_starts_tensor = dummy_name() 750 | ops.extend([ 751 | core.CreateOperator( 752 | 'GivenTensorInt64Fill', 753 | [], 754 | [starts_vals_tensor], 755 | shape=[len(starts_vals)], 756 | values=starts_vals, 757 | ), 758 | core.CreateOperator( 759 | 'ConstantFill', 760 | [shape_tensor], 761 | [starts_tensor], 762 | dtype=caffe2_pb2.TensorProto.INT64, 763 | value=0, 764 | ), 765 | core.CreateOperator( 766 | 'ScatterAssign', 767 | [starts_tensor, axes_tensor, starts_vals_tensor], 768 | [starts_tensor], 769 | ), 770 | # Slice only accepts starts as int 771 | core.CreateOperator( 772 | 'Cast', 773 | [starts_tensor], 774 | [casted_starts_tensor], 775 | to=caffe2_pb2.TensorProto.INT32, 776 | ), 777 | ]) 778 | 779 | ends_vals_tensor = dummy_name() 780 | ends_tensor = dummy_name() 781 | casted_ends_tensor = dummy_name() 782 | ops.extend([ 783 | core.CreateOperator( 784 | 'GivenTensorInt64Fill', 785 | [], 786 | [ends_vals_tensor], 787 | shape=[len(ends_vals)], 788 | values=ends_vals, 789 | ), 790 | core.CreateOperator( 791 | 'ConstantFill', 792 | [shape_tensor], 793 | [ends_tensor], 794 | dtype=caffe2_pb2.TensorProto.INT64, 795 | value=-1, 796 | ), 797 | core.CreateOperator( 798 | 'ScatterAssign', 799 | [ends_tensor, axes_tensor, ends_vals_tensor], 800 | [ends_tensor], 801 | ), 802 | # Slice only accepts ends as int 803 | core.CreateOperator( 804 | 'Cast', 805 | [ends_tensor], 806 | [casted_ends_tensor], 807 | to=caffe2_pb2.TensorProto.INT32, 808 | ), 809 | ]) 810 | 811 | op.input[:] = [data, casted_starts_tensor, casted_ends_tensor] 812 | del op.arg[:] 813 | op.arg.extend(args.values()) 814 | ops.append(op) 815 | 816 | return ops 817 | 818 | # Note [Caffe2 ConvPoolOpBase] 819 | # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 820 | # To understand what is going on here, we have to talk a little bit about 821 | # Caffe2's internals. 822 | # 823 | # First, it's important to know that all of Caffe2's pooling and convolution 824 | # operators inherit from "ConvPoolOpBase", which is an abstract class that 825 | # defines all of the attributes (kernels, dilations, strides, etc) which one 826 | # sees on these operators. Unfortunately, Caffe2's documentation generator 827 | # doesn't know how to handle cases like this, so for example, if you look at 828 | # the docs for MaxPool at 829 | # you won't see any of the attributes. You have to go source diving to 830 | # find the information; in particular, you want to look at: 831 | # https://github.com/caffe2/caffe2/blob/master/caffe2/operators/conv_pool_op_base.h 832 | # This class handles *global* pooling as well. 833 | # 834 | # Second, it's important to know what Caffe2 expects for padding, which can 835 | # be somewhat difficult to understand from the code because Caffe2 handles 836 | # both singular/pluralized spellings of padding, and there is also legacy 837 | # padding business. The short version of the story is that, for NON-legacy 838 | # padding (which is what we want to output), padding is expected to be 839 | # *twice* the size of kernels. So if you have a 2D convolution, Caffe2 840 | # will accept two values in 'kernels', but FOUR values in 'pads'; 841 | # furthermore, this is *mandatory.* 842 | # 843 | # Finally, ConvPoolOpBase is not the only class of it's kind; there is 844 | # also ConvTransposeUnpoolBase, which backs ConvTranspose. So don't 845 | # be tricked by the fact that Conv and ConvTranspose have similar 846 | # parameters; they exercise different codepaths and need to be handled 847 | # differently. 848 | 849 | @classmethod 850 | def _create_conv_pool_op_base(cls, init_model, pred_model, n, opset_version): 851 | if n.op_type.startswith('Global'): 852 | n.attrs['global_pooling'] = 1 853 | 854 | try: 855 | kernels = n.attrs['kernel_shape'] 856 | pads = n.attrs['pads'] 857 | except KeyError: 858 | pass 859 | else: 860 | if len(kernels) == len(pads): 861 | # Caffe2 requires pads to be twice the size of kernels. 862 | n.attrs['pads'] = pads * 2 863 | 864 | return cls._common_onnx_node_to_caffe2_op(init_model, pred_model, n, opset_version) 865 | 866 | @classmethod 867 | def _create_reshape(cls, init_model, pred_model, n, opset_version): 868 | c2_op = cls._common_onnx_node_to_caffe2_op(init_model, pred_model, n, opset_version) 869 | # Caffe2 has an extra output 870 | c2_op.output.append(dummy_name()) 871 | return c2_op 872 | 873 | @classmethod 874 | def _create_sqrt(cls, init_model, pred_model, n, opset_version): 875 | (X,) = n.inputs 876 | (Y,) = n.outputs 877 | return core.CreateOperator( 878 | 'Pow', 879 | [X], 880 | [Y], 881 | exponent=0.5, 882 | ) 883 | 884 | @classmethod 885 | def _create_reciprocal(cls, init_model, pred_model, n, opset_version): 886 | (X,) = n.inputs 887 | (Y,) = n.outputs 888 | return core.CreateOperator( 889 | 'Pow', 890 | [X], 891 | [Y], 892 | exponent=-1.0, 893 | ) 894 | 895 | @classmethod 896 | def _direct_initialize_parameters(cls, initializer, ws, device_option): 897 | for tp in initializer: 898 | ws.FeedBlob(tp.name, onnx.numpy_helper.to_array(tp), device_option) 899 | 900 | @classmethod 901 | def _direct_initialize_inputs(cls, inputs, initialized, ws, device_option): 902 | for value_info in inputs: 903 | if value_info.name in initialized: 904 | continue 905 | shape = list(d.dim_value for d in value_info.type.tensor_type.shape.dim) 906 | ws.FeedBlob(value_info.name, np.ones(shape), device_option) 907 | 908 | @staticmethod 909 | def optimize_onnx(input, init=False, predict=False): 910 | passes = ['fuse_consecutive_transposes', 911 | 'eliminate_nop_transpose', 912 | 'fuse_transpose_into_gemm'] 913 | if init: 914 | passes.append('split_init') 915 | if predict: 916 | passes.append('split_predict') 917 | out = onnx.optimizer.optimize(input, passes) 918 | return out 919 | 920 | @classmethod 921 | def prepare(cls, model, device='CPU', **kwargs): 922 | ''' 923 | For Onnx Caffe2Backend, we require that init_graph don't initialize the actual input of the predict_graph, 924 | 925 | for example, if "img" is the input blob for the predict_net, we require that in init_graph and in 926 | initializer of the predict_graph, "img" is not initalized. We don't have a check for this, since 927 | there is no way we can know which blob is the input of the predict_graph. 928 | ''' 929 | super(Caffe2Backend, cls).prepare(model, device, **kwargs) 930 | 931 | 932 | opset_version = None 933 | for imp in model.opset_import: 934 | if not imp.HasField("domain") or imp.domain == "": 935 | opset_version = imp.version 936 | if imp.version > cls._known_opset_version: 937 | warnings.warn("This version of onnx-caffe2 targets ONNX operator set version {}, but the model we are trying to import uses version {}. We will try to import it anyway, but if the model uses operators which had BC-breaking changes in the intervening versions, import will fail.".format(cls._known_opset_version, imp.version)) 938 | else: 939 | warnings.warn("Unrecognized operator set {}".format(imp.domain)) 940 | if opset_version is None: 941 | if model.ir_version >= 0x00000003: 942 | raise RuntimeError("Model with IR version >= 3 did not specify ONNX operator set version (onnx-caffe2 requires it)") 943 | else: 944 | opset_version = 1 945 | 946 | ws = Workspace() 947 | device_option = get_device_option(Device(device)) 948 | 949 | # Directly load initializer data into blobs in workspace 950 | cls._direct_initialize_parameters( 951 | model.graph.initializer, 952 | ws, 953 | device_option, 954 | ) 955 | 956 | initialized = {init.name for init in model.graph.initializer} 957 | 958 | cls._direct_initialize_inputs( 959 | model.graph.input, 960 | initialized, 961 | ws, 962 | device_option, 963 | ) 964 | 965 | uninitialized = [value_info.name for value_info in model.graph.input if value_info.name not in initialized] 966 | 967 | init_net, predict_net = cls._onnx_model_to_caffe2_net(model, device, opset_version, False) 968 | 969 | retval = Caffe2Rep(init_net, predict_net, ws, uninitialized) 970 | return retval 971 | 972 | @classmethod 973 | # TODO: This method needs a refactor for clarity 974 | def _onnx_node_to_caffe2_op(cls, init_model, pred_model, node_def, opset_version): 975 | if node_def.op_type in cls._special_operators: 976 | translator = getattr(cls, cls._special_operators[node_def.op_type]) 977 | else: 978 | translator = cls._common_onnx_node_to_caffe2_op 979 | ops = translator(init_model, pred_model, OnnxNode(node_def), opset_version) 980 | if isinstance(ops, Caffe2Ops): 981 | return ops 982 | if not isinstance(ops, collections.Iterable): 983 | ops = [ops] 984 | return Caffe2Ops(ops, [], []) 985 | 986 | @classmethod 987 | def _common_onnx_node_to_caffe2_op(cls, init_model, pred_model, onnx_node, opset_version): 988 | """ 989 | This translator performs the basic translation of ONNX nodes into 990 | Caffe2 operators. Besides doing a straightforward marshalling from 991 | one format to another, it also does these extra things: 992 | 993 | - Renames operators based on '_renamed_operators' 994 | - Renames attributes based on '_global_renamed_attrs' and 995 | '_per_op_renamed_attrs' 996 | 997 | If you're writing a custom translator, consider calling this first, 998 | and then fixing things up further. 999 | """ 1000 | c2_op = caffe2_pb2.OperatorDef() 1001 | 1002 | c2_op.input.extend(onnx_node.inputs) 1003 | c2_op.output.extend(onnx_node.outputs) 1004 | c2_op.name = onnx_node.name 1005 | 1006 | onnx_op_type = onnx_node.op_type 1007 | broken_version = cls._broken_operators.get(onnx_op_type, float('Inf')) 1008 | if broken_version <= opset_version: 1009 | raise ValueError( 1010 | "Don't know how to translate op {} in ONNX operator set v{} (I only support prior to v{})".format(onnx_op_type, opset_version, broken_version)) 1011 | c2_op.type = cls._renamed_operators.get(onnx_op_type, onnx_op_type) 1012 | if not core.IsOperator(c2_op.type): 1013 | raise ValueError( 1014 | "Don't know how to translate op {}".format(onnx_op_type)) 1015 | 1016 | def kmap(k): 1017 | if (onnx_op_type in cls._per_op_renamed_attrs and 1018 | k in cls._per_op_renamed_attrs[onnx_op_type]): 1019 | return cls._per_op_renamed_attrs[onnx_op_type][k] 1020 | if k in cls._global_renamed_attrs: 1021 | return cls._global_renamed_attrs[k] 1022 | return k 1023 | c2_op.arg.extend(onnx_node.attrs.caffe2(kmap=kmap)) 1024 | 1025 | return c2_op 1026 | 1027 | 1028 | @classmethod 1029 | def _inplace_rewrite(cls, graph_or_nodes): 1030 | ''' 1031 | currently we use this to translate ONNX-style 1032 | consumed_input annotations to Caffe2-style in place 1033 | updates (use same input and output names). 1034 | ''' 1035 | is_graph = isinstance(graph_or_nodes, GraphProto) 1036 | if is_graph: 1037 | nodes = graph_or_nodes.node 1038 | else: 1039 | nodes = graph_or_nodes 1040 | 1041 | renamed = {} 1042 | 1043 | for node in nodes: 1044 | node.input[:] = [renamed.get(input_name, input_name) 1045 | for input_name in node.input] 1046 | consumed_inputs = OnnxNode(node).consumed_inputs or [] 1047 | output_idxes = set(range(len(node.output))) 1048 | schema = onnx.defs.get_schema(node.op_type) 1049 | for i, consumed in enumerate(consumed_inputs): 1050 | if not consumed: 1051 | continue 1052 | _, output_idx = schema.consumed(i) 1053 | # consumed outputs are not always present 1054 | # for instance batch norm in test mode 1055 | # does not return the consumed inputs 1056 | if output_idx < len(node.output): 1057 | output_idxes.remove(output_idx) 1058 | old_val = node.output[output_idx] 1059 | new_val = node.input[i] 1060 | node.output[output_idx] = new_val 1061 | renamed[old_val] = new_val 1062 | for idx in output_idxes: 1063 | name = node.output[idx] 1064 | node.output[idx] = renamed.get(name, name) 1065 | if is_graph: 1066 | for output in graph_or_nodes.output: 1067 | output.name = renamed.get(output.name, output.name) 1068 | 1069 | @staticmethod 1070 | def _all_names_in_graph(graph): 1071 | if graph is None: 1072 | return set() 1073 | 1074 | names = set() 1075 | names.update(value_info.name for value_info in graph.input) 1076 | names.update(value_info.name for value_info in graph.output) 1077 | for node in graph.node: 1078 | names.update(node.input) 1079 | names.update(node.output) 1080 | return names 1081 | 1082 | @classmethod 1083 | def _onnx_model_to_caffe2_net(cls, onnx_model, device, opset_version, include_initializers): 1084 | device_option = get_device_option(Device(device)) 1085 | 1086 | init_model = ModelProto() 1087 | init_model.ParseFromString(cls.optimize_onnx(onnx_model.SerializeToString(), init=True)) 1088 | cls._inplace_rewrite(init_model.graph) 1089 | 1090 | pred_model = ModelProto() 1091 | pred_model.ParseFromString(cls.optimize_onnx(onnx_model.SerializeToString(), predict=True)) 1092 | cls._inplace_rewrite(pred_model.graph) 1093 | 1094 | init_net = caffe2_pb2.NetDef() 1095 | pred_net = caffe2_pb2.NetDef() 1096 | 1097 | init_net.name = onnx_model.graph.name + '_init' 1098 | pred_net.name = onnx_model.graph.name + '_predict' 1099 | 1100 | if include_initializers: 1101 | init_net.op.extend(cls._create_tensor_filling_op(tp) for tp in onnx_model.graph.initializer) 1102 | 1103 | dummy_name(cls._all_names_in_graph(init_model.graph) | cls._all_names_in_graph(pred_model.graph)) 1104 | 1105 | for net, model in ( (init_net, init_model), (pred_net, pred_model) ): 1106 | net.device_option.CopyFrom(device_option) 1107 | for node in model.graph.node: 1108 | c2ops = cls._onnx_node_to_caffe2_op( 1109 | init_model, pred_model, node, opset_version) 1110 | (init_net if include_initializers else net).op.extend(c2ops.init_ops) 1111 | net.op.extend(c2ops.ops) 1112 | net.external_input.extend(c2ops.interface_blobs) 1113 | net.external_output.extend( 1114 | value_info.name for value_info in model.graph.output) 1115 | net.external_input.extend( 1116 | value_info.name for value_info in model.graph.input) 1117 | 1118 | return init_net, pred_net 1119 | 1120 | # wrapper for backwards compatability 1121 | @classmethod 1122 | def onnx_graph_to_caffe2_net(cls, model, device="CPU", opset_version=_known_opset_version): 1123 | return cls._onnx_model_to_caffe2_net(model, device=device, opset_version=opset_version, include_initializers=True) 1124 | 1125 | @classmethod 1126 | def supports_device(cls, device_str): 1127 | device = Device(device_str) 1128 | if device.type == DeviceType.CPU: 1129 | return True 1130 | elif device.type == DeviceType.CUDA: 1131 | return workspace.has_gpu_support 1132 | return False 1133 | 1134 | 1135 | prepare = Caffe2Backend.prepare 1136 | 1137 | run_node = Caffe2Backend.run_node 1138 | 1139 | run_model = Caffe2Backend.run_model 1140 | 1141 | supports_device = Caffe2Backend.supports_device # noqa 1142 | --------------------------------------------------------------------------------