├── nne ├── analyze │ ├── __init__.py │ ├── tflite.py │ └── onnx.py ├── convert │ ├── __init__.py │ ├── common.py │ ├── torchscript.py │ ├── torch.py │ ├── trt.py │ ├── onnx.py │ └── tflite.py ├── quant │ ├── __init__.py │ └── onnx.py ├── __init__.py └── benchmark.py ├── .gitignore ├── docs └── logo.png ├── .github ├── FUNDING.yml └── workflows │ └── pythonapp.yml ├── examples ├── edgetpu_example.py ├── torch_example.py ├── torchscript_example.py ├── onnx_example.py ├── tensorrt_example.py ├── onnx_quantize.py └── tflite_example.py ├── test ├── test_analyze.py ├── test_torch.py ├── test_script.py ├── test_tflite.py └── test_onnx.py ├── benchmark └── torch_bench.py ├── setup.py ├── bin └── nne ├── README.md └── LICENSE /nne/analyze/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /nne/convert/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /nne/quant/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | */__pycache__/* 2 | *.pyc 3 | *.onnx -------------------------------------------------------------------------------- /docs/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kuroko1t/nne/HEAD/docs/logo.png -------------------------------------------------------------------------------- /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | # These are supported funding model platforms 2 | 3 | github: kuroko1t 4 | -------------------------------------------------------------------------------- /examples/edgetpu_example.py: -------------------------------------------------------------------------------- 1 | import torchvision 2 | import torch 3 | import numpy as np 4 | import nne 5 | 6 | input_shape = (10, 3, 112, 112) 7 | model = torchvision.models.mobilenet_v2(pretrained=True) 8 | 9 | tflite_file = 'mobilenet.tflite' 10 | 11 | nne.cv2tflite(model , input_shape, tflite_file, edgetpu=True) 12 | -------------------------------------------------------------------------------- /examples/torch_example.py: -------------------------------------------------------------------------------- 1 | import nne 2 | import torchvision 3 | import torch 4 | import numpy as np 5 | 6 | input_shape = (1, 3, 224, 224) 7 | model = torchvision.models.resnet34(pretrained=True).cuda() 8 | 9 | input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32) 10 | output_data = nne.infer_torch(model, input_data) 11 | -------------------------------------------------------------------------------- /examples/torchscript_example.py: -------------------------------------------------------------------------------- 1 | import nne 2 | import torchvision 3 | import torch 4 | import numpy as np 5 | 6 | input_shape = (1, 3, 224, 224) 7 | model = torchvision.models.resnet50(pretrained=True).cuda() 8 | script_file = 'resnet_script.zip' 9 | input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32) 10 | nne.cv2torchscript(model, input_shape, script_file) 11 | model_script = nne.load_torchscript(script_file) 12 | output = nne.infer_torchscript(model_script, input_data) 13 | print(output) 14 | -------------------------------------------------------------------------------- /examples/onnx_example.py: -------------------------------------------------------------------------------- 1 | import nne 2 | import torchvision 3 | import torch 4 | import numpy as np 5 | 6 | input_shape = (1, 3, 64, 64) 7 | onnx_file = 'resnet.onnx' 8 | model = torchvision.models.resnet34(pretrained=True).cuda() 9 | 10 | # convert pytorch model to onnx model 11 | nne.cv2onnx(model, input_shape, onnx_file) 12 | 13 | input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32) 14 | 15 | # load onnx model 16 | 17 | onnx_model = nne.load_onnx(onnx_file) 18 | 19 | # inference 20 | output_data = nne.infer_onnx(onnx_model, input_data) 21 | 22 | print(output_data) 23 | -------------------------------------------------------------------------------- /examples/tensorrt_example.py: -------------------------------------------------------------------------------- 1 | import nne 2 | import torchvision 3 | import torch 4 | import numpy as np 5 | 6 | input_shape = (1, 3, 224, 224) 7 | trt_file = 'alexnet_trt.pth' 8 | model = torchvision.models.alexnet(pretrained=True).cuda() 9 | 10 | # convert pytorch model to TensorRT model 11 | nne.cv2trt(model, input_shape, trt_file) 12 | 13 | input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32) 14 | 15 | # load TensorRT model 16 | trt_model = nne.load_trt(trt_file) 17 | 18 | # inference 19 | output_data = nne.infer_trt(trt_model, input_data) 20 | 21 | print(output_data) 22 | -------------------------------------------------------------------------------- /examples/onnx_quantize.py: -------------------------------------------------------------------------------- 1 | from nne.quant.onnx import quant_oplist, quant_summary, quantize 2 | import torchvision 3 | import nne 4 | 5 | input_shape = (1, 3, 64, 64) 6 | onnx_file = 'resnet.onnx' 7 | model = torchvision.models.resnet34(pretrained=True) 8 | 9 | # convert pytorch model to onnx model 10 | nne.cv2onnx(model, input_shape, onnx_file) 11 | 12 | # onnx model to quantized model 13 | quantize("resnet.onnx") 14 | 15 | # return support quantized operation list 16 | quantie_op = quant_oplist() 17 | 18 | # return summary information about quantized model 19 | summary = quant_summary("resnet.quant.onnx") 20 | print(summary) 21 | -------------------------------------------------------------------------------- /examples/tflite_example.py: -------------------------------------------------------------------------------- 1 | import torchvision 2 | import torch 3 | import numpy as np 4 | import nne 5 | 6 | input_shape = (10, 3, 224, 224) 7 | model = torchvision.models.mobilenet_v2(pretrained=True).cuda() 8 | 9 | tflite_file = 'mobilenet.tflite' 10 | 11 | # convert pytorch model to tensorflow lite 12 | nne.cv2tflite(model, input_shape, tflite_file) 13 | 14 | input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32) 15 | 16 | # load tflite model 17 | tflite_model = nne.load_tflite(tflite_file) 18 | 19 | # inference 20 | output_data = nne.infer_tflite(tflite_model, input_data) 21 | 22 | print(output_data) 23 | -------------------------------------------------------------------------------- /test/test_analyze.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import nne 3 | import torchvision 4 | import torch 5 | import numpy as np 6 | from nne.quant.onnx import quant_oplist, quant_summary, quantize 7 | 8 | class AnalyzeTests(unittest.TestCase): 9 | def __init__(self, *args, **kwargs): 10 | super(AnalyzeTests, self).__init__(*args, **kwargs) 11 | 12 | def test_onnx(self): 13 | input_shape = (1, 3, 64, 64) 14 | onnx_file = 'resnet.onnx' 15 | model = torchvision.models.resnet34(pretrained=True) 16 | nne.cv2onnx(model, input_shape, onnx_file) 17 | nne.analyze(onnx_file) 18 | nne.analyze(onnx_file, "resnet.json") 19 | -------------------------------------------------------------------------------- /benchmark/torch_bench.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import nne 4 | import os 5 | import torchvision 6 | 7 | input_shape = (1, 3, 224, 224) 8 | model = torchvision.models.resnet34(pretrained=True) 9 | torch.save(model, "resnet.pt") 10 | 11 | bm = nne.Benchmark(name='torch') 12 | input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32) 13 | output_data = nne.infer_torch(model, input_data, bm=bm) 14 | 15 | ## onnx 16 | onnx_file = "resnet.onnx" 17 | bm = nne.Benchmark(name='onnx') 18 | nne.cv2onnx(model, input_shape, onnx_file) 19 | onnx_model = nne.load_onnx(onnx_file) 20 | nne.infer_onnx(onnx_model, input_data, bm=bm) 21 | 22 | ## tflite 23 | tflite_file = "resnet.tflite" 24 | bm = nne.Benchmark(name='tflite') 25 | nne.cv2tflite(model, input_shape, tflite_file) 26 | tflite_model = nne.load_tflite(tflite_file) 27 | nne.infer_tflite(tflite_model, input_data, bm=bm) 28 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | import platform 3 | import os 4 | 5 | def check_tensorrt(): 6 | try: 7 | import tensorrt 8 | return True 9 | except: 10 | return False 11 | 12 | def check_jetson(): 13 | if platform.machine() == "aarch64": 14 | return True 15 | else: 16 | return False 17 | 18 | def get_requires(): 19 | requires = [] 20 | if not check_jetson(): 21 | requires = ["tensorflow==2.11.0", "tensorflow_addons"] 22 | requires += ["torch", "tensorflow==2.11.0", "tensorflow_probability", "onnx", "onnx_tf @ git+https://github.com/onnx/onnx-tensorflow", 23 | "matplotlib", "onnx-simplifier"] 24 | if check_tensorrt(): 25 | requires += ["pycuda"] 26 | return requires 27 | 28 | setup( 29 | name="nne", 30 | scripts=["bin/nne"], 31 | packages=find_packages(), 32 | install_requires=get_requires(), 33 | version="0.1" 34 | ) 35 | -------------------------------------------------------------------------------- /test/test_torch.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import nne 3 | import torchvision 4 | import torch 5 | import numpy as np 6 | 7 | class TorchTests(unittest.TestCase): 8 | def __init__(self, *args, **kwargs): 9 | super(TorchTests, self).__init__(*args, **kwargs) 10 | 11 | def test_torch(self): 12 | input_shape = (1, 3, 64, 64) 13 | script_file = 'resnet_script.zip' 14 | model = torchvision.models.resnet34(pretrained=True) 15 | nne.cv2torchscript(model, input_shape, script_file) 16 | 17 | input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32) 18 | model_script = nne.load_torchscript(script_file) 19 | out_script = nne.infer_torchscript(model_script, input_data) 20 | model.eval() 21 | out_pytorch = model(torch.from_numpy(input_data)).detach().cpu().numpy() 22 | np.testing.assert_allclose(out_script, out_pytorch, rtol=1e-03, atol=1e-05) 23 | 24 | if __name__ == "__main__": 25 | unittest.main() 26 | -------------------------------------------------------------------------------- /test/test_script.py: -------------------------------------------------------------------------------- 1 | #import unittest 2 | #import nne 3 | #import torchvision 4 | #import torch 5 | #import numpy as np 6 | #import subprocess 7 | # 8 | #class ScriptTests(unittest.TestCase): 9 | # def __init__(self, *args, **kwargs): 10 | # super(ScriptTests, self).__init__(*args, **kwargs) 11 | # self.onnx_file = 'resnet.onnx' 12 | # input_shape = (1, 3, 64, 64) 13 | # model = torchvision.models.resnet34(pretrained=True) 14 | # nne.cv2onnx(model, input_shape, self.onnx_file) 15 | # 16 | # def test_analyze(self): 17 | # subprocess.check_output(["nne", self.onnx_file], stderr=subprocess.STDOUT) 18 | # subprocess.check_output(["nne", self.onnx_file, "-a", "resnet.json"], stderr=subprocess.STDOUT) 19 | # 20 | # def test_convert(self): 21 | # subprocess.check_output(["nne", self.onnx_file, "-s", "resnet_smip.onnx"], stderr=subprocess.STDOUT) 22 | # subprocess.check_output(["nne", self.onnx_file, "-t", "resnet.tflite"], stderr=subprocess.STDOUT) 23 | -------------------------------------------------------------------------------- /bin/nne: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import nne 4 | import argparse 5 | 6 | parser = argparse.ArgumentParser(description='Neural Network Graph Analyzer') 7 | parser.add_argument("model_path", type=str, 8 | help="model path for analyzing(onnx or tflite)") 9 | parser.add_argument("-a", "--analyze_path", type=str, 10 | help="Specify the path to output the Node information of the model in json format.") 11 | parser.add_argument("-s", "--simplyfy_path", type=str, 12 | help="onnx model to simplyfier") 13 | parser.add_argument("-t", "--tflite_path", type=str, 14 | help="onnx model to tflite") 15 | args = parser.parse_args() 16 | if args.simplyfy_path: 17 | nne.cv2onnxsimplify(args.model_path, args.simplyfy_path) 18 | elif args.tflite_path: 19 | nne.onnx2tflite(args.model_path, args.tflite_path) 20 | elif args.analyze_path: 21 | nne.analyze(args.model_path, args.analyze_path) 22 | else: 23 | nne.analyze(args.model_path, None) 24 | -------------------------------------------------------------------------------- /nne/__init__.py: -------------------------------------------------------------------------------- 1 | from .analyze import onnx as onnx_analyze 2 | import json 3 | from .analyze import tflite as tflite_analyze 4 | from .convert.torch import * 5 | from .benchmark import * 6 | from .convert.torchscript import * 7 | from .convert.onnx import * 8 | from .convert.tflite import * 9 | import os 10 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 11 | 12 | if check_tensorrt(): 13 | from .trt import * 14 | 15 | 16 | def analyze(model_path, output_path=None): 17 | ext = os.path.splitext(model_path)[1] 18 | if ext == ".onnx": 19 | model_info = onnx_analyze.analyze_graph(model_path, output_path) 20 | return model_info 21 | elif ext == ".tflite": 22 | model_info = tflite_analyze.analyze_graph(model_path, output_path) 23 | if output_path == None: 24 | output_path = model_path.replace(".tflite", "_tflite.json") 25 | with open(output_path, 'w') as f: 26 | json.dump(model_info, f, indent=2, cls=tflite_analyze.NumpyEncoder) 27 | print(f"Write Dump Result -> {output_path}") 28 | return model_info 29 | else: 30 | raise Exception(f"no support {ext} file") 31 | -------------------------------------------------------------------------------- /nne/analyze/tflite.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import json 3 | import numpy as np 4 | 5 | 6 | def analyze_graph(model_path, output_path): 7 | interpreter = tf.lite.Interpreter(model_path=model_path) 8 | interpreter.allocate_tensors() 9 | graphinfo = interpreter.get_tensor_details() 10 | for i, op in enumerate(graphinfo): 11 | graphinfo[i]["shape"] = op["shape"].tolist() 12 | graphinfo[i]["shape_signature"] = op["shape_signature"].tolist() 13 | graphinfo[i]["dtype"] = [op["dtype"].__name__] 14 | graphinfo[i]["quantization_parameters"]["scales"] = op["quantization_parameters"]["scales"].tolist() 15 | graphinfo[i]["quantization_parameters"]["zero_points"] = op["quantization_parameters"]["zero_points"].tolist() 16 | return graphinfo 17 | 18 | 19 | class NumpyEncoder(json.JSONEncoder): 20 | def default(self, obj): 21 | if isinstance(obj, np.integer): 22 | return int(obj) 23 | elif isinstance(obj, np.floating): 24 | return float(obj) 25 | elif isinstance(obj, np.ndarray): 26 | return obj.tolist() 27 | else: 28 | return super(NumpyEncoder, self).default(obj) 29 | -------------------------------------------------------------------------------- /test/test_tflite.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import nne 3 | import torchvision 4 | import torch 5 | import numpy as np 6 | 7 | class TFliteTests(unittest.TestCase): 8 | def __init__(self, *args, **kwargs): 9 | super(TFliteTests, self).__init__(*args, **kwargs) 10 | 11 | def test_tflite(self): 12 | input_shape = (1, 3, 64, 64) 13 | tflite_file = 'mobilenet.tflite' 14 | model = torchvision.models.mobilenet_v2(pretrained=True) 15 | nne.cv2tflite(model, input_shape, tflite_file) 16 | 17 | input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32) 18 | model_tflite = nne.load_tflite(tflite_file) 19 | out_tflite = nne.infer_tflite(model_tflite, input_data) 20 | model.eval() 21 | out_pytorch = model(torch.from_numpy(input_data)).detach().cpu().numpy() 22 | np.testing.assert_allclose(out_tflite, out_pytorch, rtol=1e-03, atol=1e-05) 23 | 24 | def test_tflite_quant(self): 25 | input_shape = (1, 3, 64, 64) 26 | tflite_file = 'mobilenet.tflite' 27 | model = torchvision.models.mobilenet_v2(pretrained=True) 28 | nne.cv2tflite(model, input_shape, tflite_file, quantization=True) 29 | 30 | if __name__ == "__main__": 31 | unittest.main() 32 | -------------------------------------------------------------------------------- /.github/workflows/pythonapp.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a single version of Python 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: Python application 5 | 6 | on: 7 | push: 8 | branches: [ master ] 9 | pull_request: 10 | branches: [ master ] 11 | 12 | jobs: 13 | build: 14 | 15 | runs-on: ubuntu-latest 16 | 17 | steps: 18 | - uses: actions/checkout@v2 19 | - name: Set up Python 3.7 20 | uses: actions/setup-python@v1 21 | with: 22 | python-version: 3.7 23 | - name: Install dependencies 24 | run: | 25 | python -m pip install --upgrade pip 26 | python -m pip install numpy torchvision 27 | python -m pip install -e . 28 | - name: Lint with flake8 29 | run: | 30 | pip install flake8 31 | # stop the build if there are Python syntax errors or undefined names 32 | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics 33 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide 34 | flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics 35 | - name: Test with pytest 36 | run: | 37 | python -m pip install pytest 38 | pytest . -p no:warnings 39 | -------------------------------------------------------------------------------- /nne/convert/common.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 kurosawa. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================= 15 | 16 | import platform 17 | import os 18 | import onnxsim 19 | 20 | 21 | def check_jetson(): 22 | if platform.machine() == "aarch64": 23 | return True 24 | else: 25 | return False 26 | 27 | def check_tensorrt(): 28 | try: 29 | import tensorrt 30 | return True 31 | except: 32 | return False 33 | 34 | def check_model_is_cuda(model): 35 | return next(model.parameters()).is_cuda 36 | 37 | def onnx_simplify(model, input_shapes): 38 | model_opt, check_ok = onnxsim.simplify( 39 | model, check_n=3, perform_optimization=False, skip_fuse_bn=False, input_shapes=None) 40 | return model_opt, check_ok 41 | -------------------------------------------------------------------------------- /nne/convert/torchscript.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 kurosawa. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================= 15 | 16 | import torch 17 | from .common import * 18 | 19 | def cv2torchscript(model, input_shape, script_path): 20 | model.eval() 21 | if check_model_is_cuda(model): 22 | dummy_input = torch.randn(input_shape, device="cuda") 23 | else: 24 | dummy_input = torch.randn(input_shape, device="cpu") 25 | traced = torch.jit.trace(model, dummy_input) 26 | traced.save(script_path) 27 | 28 | def load_torchscript(script_path): 29 | model = torch.jit.load(script_path) 30 | return model 31 | 32 | def infer_torchscript(model, input_data, bm=None): 33 | input_data = torch.from_numpy(input_data) 34 | if check_model_is_cuda(model): 35 | input_data = input_data.cuda() 36 | if bm: 37 | output = bm.measure(model, name="torchscript")(input_data) 38 | else: 39 | output = model(input_data) 40 | return output.detach().cpu().numpy() 41 | -------------------------------------------------------------------------------- /test/test_onnx.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import nne 3 | import torchvision 4 | import torch 5 | import numpy as np 6 | from nne.quant.onnx import quant_oplist, quant_summary, quantize 7 | 8 | class OnnxTests(unittest.TestCase): 9 | def __init__(self, *args, **kwargs): 10 | super(OnnxTests, self).__init__(*args, **kwargs) 11 | 12 | def test_onnx(self): 13 | input_shape = (1, 3, 64, 64) 14 | onnx_file = 'resnet.onnx' 15 | model = torchvision.models.resnet34(pretrained=True) 16 | nne.cv2onnx(model, input_shape, onnx_file) 17 | 18 | input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32) 19 | model_onnx = nne.load_onnx(onnx_file) 20 | out_onnx = nne.infer_onnx(model_onnx, input_data)[0] 21 | model.eval() 22 | out_pytorch = model(torch.from_numpy(input_data)).detach().cpu().numpy() 23 | np.testing.assert_allclose(out_onnx, out_pytorch, rtol=1e-03, atol=1e-05) 24 | 25 | def test_onnx_analyzer(self): 26 | input_shape = (1, 3, 64, 64) 27 | onnx_file = 'resnet.onnx' 28 | model = torchvision.models.resnet34(pretrained=True) 29 | nne.cv2onnx(model, input_shape, onnx_file) 30 | nne.analyze(onnx_file, "") 31 | nne.analyze(onnx_file, "resnet.json") 32 | 33 | def test_onnx_quant(self): 34 | input_shape = (1, 3, 64, 64) 35 | onnx_file = 'resnet.onnx' 36 | model = torchvision.models.resnet34(pretrained=True) 37 | nne.cv2onnx(model, input_shape, onnx_file) 38 | quantize("resnet.onnx") 39 | quantie_op = quant_oplist() 40 | summary = quant_summary("resnet.quant.onnx") 41 | 42 | if __name__ == "__main__": 43 | unittest.main() 44 | -------------------------------------------------------------------------------- /nne/convert/torch.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 kurosawa. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================= 15 | 16 | import torch 17 | from .common import * 18 | 19 | def infer_torch(model, input_data, bm=None): 20 | """ 21 | model : loaded model 22 | input_data: numpy array 23 | """ 24 | model.eval() 25 | if type(input_data) == tuple: 26 | if check_model_is_cuda(model): 27 | input_data = [torch.from_numpy(data).cuda() for data in input_data] 28 | else: 29 | input_data = [torch.from_numpy(data) for data in input_data] 30 | else: 31 | input_data = torch.from_numpy(input_data) 32 | if check_model_is_cuda(model): 33 | input_data = input_data.cuda() 34 | if bm: 35 | if type(input_data) == list: 36 | output = bm.measure(model, name="torch")(*input_data) 37 | else: 38 | output = bm.measure(model, name="torch")(input_data) 39 | else: 40 | if type(input_data) == list: 41 | output = model(*input_data) 42 | else: 43 | output = model(input_data) 44 | return output.detach().cpu().numpy() 45 | -------------------------------------------------------------------------------- /nne/quant/onnx.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 kurosawa. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================= 15 | 16 | from onnxruntime.quantization import quantize_dynamic, QuantType 17 | from onnxruntime.quantization.registry import QLinearOpsRegistry, QDQRegistry, IntegerOpsRegistry 18 | import onnx 19 | import collections 20 | 21 | 22 | def quantize(modelpath): 23 | model_quant = modelpath.replace(".onnx", ".quant.onnx") 24 | quantized_model = quantize_dynamic(modelpath, model_quant) 25 | 26 | 27 | def quant_oplist(): 28 | qoplist = {} 29 | qoplist.update(QLinearOpsRegistry) 30 | qoplist.update(QDQRegistry) 31 | qoplist.update(IntegerOpsRegistry) 32 | quantized_op = {} 33 | for v in qoplist.values(): 34 | quantized_op.update({v.__name__: v}) 35 | # quantized_opname.append("DynamicQuantizeLinear") 36 | # print(qoplist) 37 | return quantized_op 38 | 39 | 40 | def quant_summary(quantmodel): 41 | summary = {} 42 | model = onnx.load(quantmodel) 43 | quant_op = [] 44 | summary.update({"opset_version": model.opset_import[-1].version}) 45 | for node in model.graph.node: 46 | if node.op_type in quant_oplist().keys(): 47 | quant_op.append(node.op_type) 48 | quant_op_counter = dict(collections.Counter(quant_op)) 49 | summary.update({"quant_op": quant_op_counter}) 50 | return summary 51 | -------------------------------------------------------------------------------- /nne/benchmark.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 kurosawa. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================= 15 | 16 | import time 17 | import matplotlib.pyplot as plt 18 | 19 | class Benchmark: 20 | """ 21 | This class is for measuring inference time 22 | """ 23 | def __init__(self, counter=11, name="sample"): 24 | self.ave = [] 25 | self.counter = counter 26 | self.saves = [] 27 | self.name = name 28 | 29 | def measure(self, func, name): 30 | def inner(*args, **kwargs): 31 | durations = [] 32 | for i in range(self.counter): 33 | start = time.time() 34 | func(*args, **kwargs) 35 | end = time.time() 36 | if i != 0: 37 | durations.append((end - start) * 1000) 38 | ave = sum(durations) / self.counter 39 | self.ave.append(ave) 40 | min_value = min(durations) 41 | max_value = max(durations) 42 | print(f"{name},average[ms],{round(ave, 4)},min[ms],{round(min_value, 4)},max[ms],{round(max_value, 4)}") 43 | return func(*args, **kwargs) 44 | return inner 45 | 46 | 47 | class Plot: 48 | """ 49 | Take the Benchmark class as an argument and plot the inference time. 50 | The x-axis assumes batch size. 51 | """ 52 | def __init__(self, benchmarks:list): 53 | self.benchmarks = benchmarks 54 | 55 | def plot(self, x, xlabel, title, savefile): 56 | for bench in self.benchmarks: 57 | plt.plot(x, bench.ave, "-o", label=bench.name) 58 | plt.title(title) 59 | plt.legend() 60 | plt.xlabel("batch size") 61 | plt.ylabel("inference time[ms]") 62 | plt.savefig(savefile) 63 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | 3 | convert pytorch model for Edge Device 4 | 5 | # nne 6 | contents 7 | 8 | - [Install](#install) 9 | - [Example](#Example) 10 | - [onnx](#onnx) 11 | - [tflite](#tflite) 12 | - [tflite(edgetpu)](#tflite-edgetpu) 13 | - [TensorRT](#tensorrt) 14 | - [Script](#Script) 15 | - [Support Format](#Support-Format) 16 | - [License](#License) 17 | 18 | ## Install 19 | 20 | ```bash 21 | python -m pip install -e . 22 | ``` 23 | 24 | * edgetpu 25 | 26 | If you want to compile pytorch model for edgetpu, [install edgetpu_compiler](https://coral.ai/docs/edgetpu/compiler/) 27 | 28 | ## Example 29 | 30 | example compile pytorch model for edge device. See [example](https://github.com/kuroko1t/nne/tree/master/examples) for details 31 | 32 | ### onnx 33 | 34 | comvert to onnx model 35 | 36 | ```python3 37 | import nne 38 | import torchvision 39 | import torch 40 | import numpy as np 41 | 42 | input_shape = (1, 3, 64, 64) 43 | onnx_file = 'resnet.onnx' 44 | model = torchvision.models.resnet34(pretrained=True).cuda() 45 | 46 | nne.cv2onnx(model, input_shape, onnx_file) 47 | ``` 48 | 49 | ### tflite 50 | 51 | comvert to tflite model 52 | 53 | ```python3 54 | import torchvision 55 | import torch 56 | import numpy as np 57 | import nne 58 | 59 | input_shape = (10, 3, 224, 224) 60 | model = torchvision.models.mobilenet_v2(pretrained=True).cuda() 61 | 62 | tflite_file = 'mobilenet.tflite' 63 | 64 | nne.cv2tflite(model, input_shape, tflite_file) 65 | ``` 66 | 67 | ### tflite(edgetpu) 68 | 69 | comvert to tflite model(edge tpu) 70 | 71 | ```python3 72 | import torchvision 73 | import torch 74 | import numpy as np 75 | import nne 76 | 77 | input_shape = (10, 3, 112, 112) 78 | model = torchvision.models.mobilenet_v2(pretrained=True) 79 | 80 | tflite_file = 'mobilenet.tflite' 81 | 82 | nne.cv2tflite(model, input_shape, tflite_file, edgetpu=True) 83 | ``` 84 | 85 | ### TensorRT 86 | 87 | convert to TensorRT model 88 | 89 | ```python3 90 | import nne 91 | import torchvision 92 | import torch 93 | import numpy as np 94 | 95 | input_shape = (1, 3, 224, 224) 96 | trt_file = 'alexnet_trt.pth' 97 | model = torchvision.models.alexnet(pretrained=True).cuda() 98 | nne.cv2trt(model, input_shape, trt_file) 99 | ``` 100 | 101 | ## Script 102 | 103 | * show summary model info 104 | * dump detailed model information(node name, attrs) to json file. 105 | * convert onnx model to tflite, simplifier 106 | 107 | ```bash 108 | $nne -h 109 | usage: nne [-h] [-a ANALYZE_PATH] [-s SIMPLYFY_PATH] [-t TFLITE_PATH] model_path 110 | 111 | Neural Network Graph Analyzer 112 | 113 | positional arguments: 114 | model_path model path for analyzing 115 | 116 | optional arguments: 117 | -h, --help show this help message and exit 118 | -a ANALYZE_PATH, --analyze_path ANALYZE_PATH 119 | Specify the path to output the Node information of the model in json format. 120 | -s SIMPLYFY_PATH, --simplyfy_path SIMPLYFY_PATH 121 | onnx model to simplyfier 122 | -t TFLITE_PATH, --tflite_path TFLITE_PATH 123 | onnx model to tflite 124 | ``` 125 | 126 | ## Support Format 127 | 128 | |format | support | 129 | |---|---| 130 | | tflite | :white_check_mark: | 131 | | edge tpu | trial | 132 | | onnx| :white_check_mark: | 133 | | tensorRT| :white_check_mark: | 134 | 135 | ## License 136 | Apache 2.0 137 | -------------------------------------------------------------------------------- /nne/convert/trt.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 kurosawa. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================= 15 | 16 | import torch 17 | import tensorrt as trt 18 | import pycuda.driver as cuda 19 | import pycuda.autoinit 20 | from .common import * 21 | from .onnx import * 22 | import numpy as np 23 | 24 | TRT_LOGGER = trt.Logger(trt.Logger.WARNING) 25 | 26 | def cv2trt(model, input_shape, trt_file, fp16_mode=False): 27 | """ 28 | convert torch model to tflite model using onnx 29 | """ 30 | model.eval() 31 | onnx_file = os.path.splitext(trt_file)[0] + ".onnx" 32 | cv2onnx(model, input_shape, onnx_file, simplify=True) 33 | EXPLICIT_BATCH = 1 34 | builder = trt.Builder(TRT_LOGGER) 35 | network = builder.create_network(EXPLICIT_BATCH) 36 | parser = trt.OnnxParser(network, TRT_LOGGER) 37 | with open(onnx_file, "rb") as onnx_model: 38 | parser.parse(onnx_model.read()) 39 | if parser.num_errors > 0: 40 | error = parser.get_error(0) 41 | raise Exception(error) 42 | max_workspace_size = 1 << 28 43 | max_batch_size = 32 44 | builder.max_batch_size = max_batch_size 45 | builder.max_workspace_size = max_workspace_size 46 | builder.fp16_mode = fp16_mode 47 | engine = builder.build_cuda_engine(network) 48 | with open(trt_file, "wb") as f: 49 | f.write(engine.serialize()) 50 | os.remove(onnx_file) 51 | 52 | 53 | def load_trt(trt_file): 54 | runtime = trt.Runtime(TRT_LOGGER) 55 | with open(trt_file, "rb") as f: 56 | engine = runtime.deserialize_cuda_engine(f.read()) 57 | return engine 58 | 59 | 60 | def infer_trt(engine, input_data, bm=None): 61 | #outputs = engine.run(input_data) 62 | h_input = cuda.pagelocked_empty(trt.volume(engine.get_binding_shape(0)), dtype=np.float32) 63 | h_output = cuda.pagelocked_empty(trt.volume(engine.get_binding_shape(1)), dtype=np.float32) 64 | 65 | def execute(): 66 | # Allocate device memory for inputs and outputs. 67 | d_input = cuda.mem_alloc(h_input.nbytes) 68 | d_output = cuda.mem_alloc(h_output.nbytes) 69 | # Create a stream in which to copy inputs/outputs and run inference. 70 | stream = cuda.Stream() 71 | with engine.create_execution_context() as context: 72 | # Transfer input data to the GPU. 73 | cuda.memcpy_htod_async(d_input, h_input, stream) 74 | # Run inference. 75 | context.execute_async(bindings=[int(d_input), int(d_output)], stream_handle=stream.handle) 76 | # Transfer predictions back from the GPU. 77 | cuda.memcpy_dtoh_async(h_output, d_output, stream) 78 | # Synchronize the stream 79 | stream.synchronize() 80 | # Return the host output. 81 | return h_output 82 | 83 | if bm: 84 | output = bm.measure(execute, name="TensorRT")() 85 | else: 86 | output = execute() 87 | return output 88 | -------------------------------------------------------------------------------- /nne/convert/onnx.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 kurosawa. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================= 15 | import onnx 16 | import torch 17 | from .common import * 18 | from ..analyze.onnx import Analyze 19 | import sys 20 | import onnx 21 | try: 22 | import onnxruntime 23 | except: 24 | pass 25 | 26 | import tensorflow 27 | 28 | def cv2onnxsimplify(onnx_file, output_file): 29 | analyze = Analyze(onnx_file, None) 30 | onnx_model = load_onnx(onnx_file) 31 | model_opt, check_ok = onnx_simplify(onnx_file, analyze.input_shapes) 32 | if check_ok: 33 | onnx.save(model_opt, output_file) 34 | 35 | def cv2onnx(model, input_shape, onnx_file, simplify=False, verbose=False, 36 | input_names=["input"], output_names=["output"]): 37 | """ 38 | convert torch model to tflite model using onnx 39 | """ 40 | if type(input_shape[0]) == tuple: 41 | if check_model_is_cuda(model): 42 | dummy_input = tuple([torch.randn(ishape, device="cuda") for ishape in input_shape]) 43 | else: 44 | dummy_input = tuple([torch.randn(ishape, device="cpu") for ishape in input_shape]) 45 | elif type(input_shape) == tuple: 46 | if check_model_is_cuda(model): 47 | dummy_input = torch.randn(input_shape, device="cuda") 48 | else: 49 | dummy_input = torch.randn(input_shape, device="cpu") 50 | else: 51 | raise Exception("input_shape must be tuple") 52 | 53 | try: 54 | torch.onnx.export(model, dummy_input, onnx_file, 55 | input_names=input_names , output_names=output_names, verbose=verbose) 56 | except RuntimeError as e: 57 | opset_version = 11 58 | torch.onnx.export(model, dummy_input, onnx_file, 59 | opset_version=opset_version, 60 | input_names=input_names , output_names=output_names) 61 | 62 | if simplify: 63 | onnx_model = load_onnx(onnx_file) 64 | model_opt, check_ok = onnx_simplify(onnx_model, input_shape) 65 | if check_ok: 66 | onnx.save(model_opt, onnx_file) 67 | 68 | def load_onnx(onnx_file): 69 | sess = onnxruntime.InferenceSession(onnx_file) 70 | if "TensorrtExecutionProvider" in sess.get_providers(): 71 | print("onnxmodel with TensorrtExecutionProvider") 72 | sess.set_providers(["TensorrtExecutionProvider"]) 73 | elif "CUDAExecutionProvider" in sess.get_providers(): 74 | print("onnxmodel with CUDAExecutionProvider") 75 | sess.set_providers(["CUDAExecutionProvider"]) 76 | elif "CPUExecutionProvider" in sess.get_providers(): 77 | print("onnxmodel with CPUExecutionProvider") 78 | sess.set_providers(["CPUExecutionProvider"]) 79 | return sess 80 | 81 | 82 | def infer_onnx(sess, input_data, bm=None): 83 | if type(input_data) == tuple: 84 | for i, data in enumerate(input_data): 85 | if i == 0: 86 | ort_inputs = {sess.get_inputs()[i].name: data} 87 | else: 88 | ort_inputs[sess.get_inputs()[i].name] = data 89 | else: 90 | ort_inputs = {sess.get_inputs()[0].name: input_data} 91 | if bm: 92 | ort_outs = bm.measure(sess.run, name="onnx")(None, ort_inputs) 93 | else: 94 | ort_outs = sess.run(None, ort_inputs) 95 | return ort_outs 96 | -------------------------------------------------------------------------------- /nne/convert/tflite.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 kurosawa. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================= 15 | 16 | import onnx 17 | from onnx_tf.backend import prepare 18 | import torch 19 | import tensorflow as tf 20 | import os 21 | import shutil 22 | import sys 23 | import subprocess 24 | from .common import * 25 | from .onnx import cv2onnx 26 | import numpy as np 27 | 28 | def onnx2tflite(onnx_file, tflite_path): 29 | print(onnx_file, tflite_path) 30 | onnx_model = onnx.load(onnx_file) 31 | tf_rep = prepare(onnx_model) 32 | tmp_pb_file = "tmp.pb" 33 | tf_rep.export_graph(tmp_pb_file) 34 | converter = tf.lite.TFLiteConverter.from_saved_model(tmp_pb_file) 35 | tflite_model = converter.convert() 36 | shutil.rmtree(tmp_pb_file) 37 | with open(tflite_path, "wb") as f: 38 | f.write(tflite_model) 39 | 40 | def cv2tflite(model, input_shape, tflite_path, edgetpu=False, quantization=False): 41 | """ 42 | convert torch model to tflite model using onnx 43 | """ 44 | onnx_input_flag = False 45 | if type(model) == str: 46 | ext = os.path.splitext(model)[1] 47 | if ext == ".onnx": 48 | onnx_input_flag = True 49 | tmp_pb_file = "tmp.pb" 50 | if not onnx_input_flag: 51 | onnx_file = "tmp.onnx" 52 | cv2onnx(model, input_shape, onnx_file) 53 | else: 54 | onnx_file = model 55 | onnx_model = onnx.load(onnx_file) 56 | onnx_input_names = [input.name for input in onnx_model.graph.input] 57 | onnx_output_names = [output.name for output in onnx_model.graph.output] 58 | tf_rep = prepare(onnx_model) 59 | tf_rep.export_graph(tmp_pb_file) 60 | 61 | converter = tf.lite.TFLiteConverter.from_saved_model(tmp_pb_file) 62 | 63 | if quantization: 64 | converter.optimizations = [tf.lite.Optimize.DEFAULT] 65 | 66 | if edgetpu: 67 | if type(input_shape[0]) == tuple: 68 | if check_model_is_cuda(model): 69 | dummy_input = tuple([np.randn(ishape) for ishape in input_shape]) 70 | else: 71 | dummy_input = tuple([np.randn(ishape) for ishape in input_shape]) 72 | elif type(input_shape) == tuple: 73 | if check_model_is_cuda(model): 74 | dummy_input = np.randn(input_shape) 75 | else: 76 | dummy_input = np.randn(input_shape) 77 | else: 78 | raise Exception("input_shape must be tuple") 79 | train = tf.convert_to_tensor(input_data) 80 | my_ds = tf.data.Dataset.from_tensor_slices((train)).batch(10) 81 | def representative_dataset_gen(): 82 | for input_value in my_ds.take(10): 83 | yield [input_value] 84 | converter.representative_dataset = representative_dataset_gen 85 | converter.allow_custom_ops = True 86 | converter.experimental_new_converter = True 87 | converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] 88 | converter.inference_input_type = tf.int8 89 | converter.inference_output_type = tf.int8 90 | 91 | # convert tensorflow to tflite model 92 | tflite_model = converter.convert() 93 | 94 | with open(tflite_path, "wb") as f: 95 | f.write(tflite_model) 96 | if not onnx_input_flag: 97 | os.remove(onnx_file) 98 | shutil.rmtree(tmp_pb_file) 99 | 100 | if edgetpu: 101 | subprocess.check_call(f"edgetpu_compiler {tflite_path}", shell=True) 102 | 103 | 104 | def load_tflite(tflitepath): 105 | interpreter = tf.lite.Interpreter(model_path=tflitepath) 106 | # allocate memory 107 | interpreter.allocate_tensors() 108 | return interpreter 109 | 110 | 111 | def infer_tflite(interpreter, input_data, bm=None): 112 | # get model input and output propaty 113 | input_details = interpreter.get_input_details() 114 | output_details = interpreter.get_output_details() 115 | ## get input shape 116 | #input_shape = input_details[0]["shape"] 117 | def execute(): 118 | # set tensor pointer to index 119 | if type(input_data) == tuple: 120 | for i, data in enumerate(input_data): 121 | interpreter.set_tensor(input_details[i]["index"], data) 122 | else: 123 | interpreter.set_tensor(input_details[0]["index"], input_data) 124 | # execute infer 125 | interpreter.invoke() 126 | # get result from index of output_details 127 | output_data = interpreter.get_tensor(output_details[0]["index"]) 128 | return output_data 129 | if bm: 130 | output_data = bm.measure(execute, name="tflite")() 131 | else: 132 | output_data = execute() 133 | return output_data 134 | -------------------------------------------------------------------------------- /nne/analyze/onnx.py: -------------------------------------------------------------------------------- 1 | import re 2 | import collections 3 | import onnx 4 | import json 5 | from onnx import defs 6 | 7 | 8 | class Analyze: 9 | def __init__(self, model_path, output_path): 10 | self.model_path = model_path 11 | model = onnx.load(model_path) 12 | graph_def = model.graph 13 | self.input_shapes = self.get_input_shape(graph_def) 14 | self.output_shapes = self.get_output_shape(graph_def) 15 | self.opset_info = model.opset_import[0].version 16 | self.nodes = self.nodes(graph_def) 17 | self.model_info = self.modelinfo() 18 | if output_path: 19 | self.dump(self.model_info, output_path) 20 | 21 | def modelinfo(self): 22 | model_info = {} 23 | model_info["OpsetVersion"] = self.opset_info 24 | model_info["InputShape"] = self.input_shapes 25 | model_info["OutputShape"] = self.output_shapes 26 | node_dict_list = [model_info] + self.nodes2dict() 27 | return node_dict_list 28 | 29 | def nodes(self, graph_def): 30 | return [OnnxNode(node) for node in graph_def.node] 31 | 32 | def get_input_shape(self, graph_def): 33 | input_shapes = [] 34 | if graph_def.initializer: 35 | initialized = {init.name for init in graph_def.initializer} 36 | else: 37 | initialized = set() 38 | for value_info in graph_def.input: 39 | if value_info.name in initialized: 40 | continue 41 | shape = list( 42 | d.dim_value if ( 43 | d.dim_value > 0 and d.dim_param == "") else None 44 | for d in value_info.type.tensor_type.shape.dim) 45 | input_shapes.append(shape) 46 | return input_shapes 47 | 48 | def get_output_shape(self, graph_def): 49 | output_shapes = [] 50 | for value_info in graph_def.output: 51 | output_shape = [ 52 | dim.dim_value for dim in value_info.type.tensor_type.shape.dim] 53 | output_shapes.append(output_shape) 54 | return output_shapes 55 | 56 | def unique_nodes(self): 57 | # [re.sub("_\d+", '', node.name) for node in self.nodes] 58 | strip_node_names = [node.op_type for node in self.nodes] 59 | counter_nodes = collections.Counter(strip_node_names) 60 | return dict(counter_nodes) 61 | 62 | def nodes2dict(self): 63 | node_dict_list = [] 64 | for node in self.nodes: 65 | node_dict = collections.OrderedDict() 66 | node_dict["name"] = node.name 67 | node_dict["op_type"] = node.op_type 68 | if node.attrs: 69 | if "value" in node.attrs: 70 | node_attrs = node.attrs.pop("value") 71 | else: 72 | node_attrs = node.attrs 73 | node_dict["attrs"] = node.attrs 74 | node_dict_list.append(node_dict) 75 | return node_dict_list 76 | 77 | def dump(self, node_dict_list, output_path): 78 | with open(output_path, "w") as f: 79 | json.dump(node_dict_list, f, indent=2) 80 | 81 | def summary(self): 82 | print() 83 | print("#### SUMMARY ONNX MODEL ####") 84 | print("opset:", self.opset_info) 85 | print("INPUT:", self.input_shapes) 86 | print("OUTPUT:", self.output_shapes) 87 | unique_nodes_count = self.unique_nodes() 88 | print(f"--Node List-- num({len(self.nodes)})") 89 | print(unique_nodes_count) 90 | 91 | 92 | class OnnxNode(object): 93 | def __init__(self, node): 94 | self.name = str(node.name) 95 | self.op_type = str(node.op_type) 96 | self.domain = str(node.domain) 97 | self.attrs = dict([(attr.name, 98 | translate_onnx( 99 | attr.name, convert_onnx_attribute_proto(attr))) 100 | for attr in node.attribute]) 101 | self.inputs = list(node.input) 102 | self.outputs = list(node.output) 103 | self.node_proto = node 104 | 105 | 106 | def analyze_graph(model_path, output_path): 107 | analyzer = Analyze(model_path, output_path) 108 | analyzer.summary() 109 | return analyzer.model_info 110 | 111 | 112 | def translate_onnx(key, val): 113 | return onnx_attr_translator.get(key, lambda x: x)(val) 114 | 115 | 116 | onnx_attr_translator = { 117 | "axis": lambda x: int(x), 118 | "axes": lambda x: [int(a) for a in x], 119 | "dtype": lambda x: x, 120 | "keepdims": lambda x: bool(x), 121 | "to": lambda x: x, 122 | } 123 | 124 | 125 | def convert_onnx_attribute_proto(attr_proto): 126 | """ 127 | Convert an ONNX AttributeProto into an appropriate Python object 128 | for the type. 129 | NB: Tensor attribute gets returned as the straight proto. 130 | """ 131 | if attr_proto.HasField('f'): 132 | return attr_proto.f 133 | elif attr_proto.HasField('i'): 134 | return attr_proto.i 135 | elif attr_proto.HasField('s'): 136 | return str(attr_proto.s, 'utf-8') 137 | elif attr_proto.HasField('t'): 138 | return attr_proto.t # this is a proto! 139 | elif attr_proto.HasField('g'): 140 | return attr_proto.g 141 | elif attr_proto.floats: 142 | return list(attr_proto.floats) 143 | elif attr_proto.ints: 144 | return list(attr_proto.ints) 145 | elif attr_proto.strings: 146 | str_list = list(attr_proto.strings) 147 | str_list = list(map(lambda x: str(x, 'utf-8'), str_list)) 148 | return str_list 149 | elif attr_proto.HasField('sparse_tensor'): 150 | return attr_proto.sparse_tensor 151 | else: 152 | raise ValueError("Unsupported ONNX attribute: {}".format(attr_proto)) 153 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | Copyright yuya kurosawa 2020 5 | kurosawa.yk@gmail.com 6 | 7 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 8 | 9 | 1. Definitions. 10 | 11 | "License" shall mean the terms and conditions for use, reproduction, 12 | and distribution as defined by Sections 1 through 9 of this document. 13 | 14 | "Licensor" shall mean the copyright owner or entity authorized by 15 | the copyright owner that is granting the License. 16 | 17 | "Legal Entity" shall mean the union of the acting entity and all 18 | other entities that control, are controlled by, or are under common 19 | control with that entity. For the purposes of this definition, 20 | "control" means (i) the power, direct or indirect, to cause the 21 | direction or management of such entity, whether by contract or 22 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 23 | outstanding shares, or (iii) beneficial ownership of such entity. 24 | 25 | "You" (or "Your") shall mean an individual or Legal Entity 26 | exercising permissions granted by this License. 27 | 28 | "Source" form shall mean the preferred form for making modifications, 29 | including but not limited to software source code, documentation 30 | source, and configuration files. 31 | 32 | "Object" form shall mean any form resulting from mechanical 33 | transformation or translation of a Source form, including but 34 | not limited to compiled object code, generated documentation, 35 | and conversions to other media types. 36 | 37 | "Work" shall mean the work of authorship, whether in Source or 38 | Object form, made available under the License, as indicated by a 39 | copyright notice that is included in or attached to the work 40 | (an example is provided in the Appendix below). 41 | 42 | "Derivative Works" shall mean any work, whether in Source or Object 43 | form, that is based on (or derived from) the Work and for which the 44 | editorial revisions, annotations, elaborations, or other modifications 45 | represent, as a whole, an original work of authorship. For the purposes 46 | of this License, Derivative Works shall not include works that remain 47 | separable from, or merely link (or bind by name) to the interfaces of, 48 | the Work and Derivative Works thereof. 49 | 50 | "Contribution" shall mean any work of authorship, including 51 | the original version of the Work and any modifications or additions 52 | to that Work or Derivative Works thereof, that is intentionally 53 | submitted to Licensor for inclusion in the Work by the copyright owner 54 | or by an individual or Legal Entity authorized to submit on behalf of 55 | the copyright owner. For the purposes of this definition, "submitted" 56 | means any form of electronic, verbal, or written communication sent 57 | to the Licensor or its representatives, including but not limited to 58 | communication on electronic mailing lists, source code control systems, 59 | and issue tracking systems that are managed by, or on behalf of, the 60 | Licensor for the purpose of discussing and improving the Work, but 61 | excluding communication that is conspicuously marked or otherwise 62 | designated in writing by the copyright owner as "Not a Contribution." 63 | 64 | "Contributor" shall mean Licensor and any individual or Legal Entity 65 | on behalf of whom a Contribution has been received by Licensor and 66 | subsequently incorporated within the Work. 67 | 68 | 2. Grant of Copyright License. Subject to the terms and conditions of 69 | this License, each Contributor hereby grants to You a perpetual, 70 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 71 | copyright license to reproduce, prepare Derivative Works of, 72 | publicly display, publicly perform, sublicense, and distribute the 73 | Work and such Derivative Works in Source or Object form. 74 | 75 | 3. Grant of Patent License. Subject to the terms and conditions of 76 | this License, each Contributor hereby grants to You a perpetual, 77 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 78 | (except as stated in this section) patent license to make, have made, 79 | use, offer to sell, sell, import, and otherwise transfer the Work, 80 | where such license applies only to those patent claims licensable 81 | by such Contributor that are necessarily infringed by their 82 | Contribution(s) alone or by combination of their Contribution(s) 83 | with the Work to which such Contribution(s) was submitted. If You 84 | institute patent litigation against any entity (including a 85 | cross-claim or counterclaim in a lawsuit) alleging that the Work 86 | or a Contribution incorporated within the Work constitutes direct 87 | or contributory patent infringement, then any patent licenses 88 | granted to You under this License for that Work shall terminate 89 | as of the date such litigation is filed. 90 | 91 | 4. Redistribution. You may reproduce and distribute copies of the 92 | Work or Derivative Works thereof in any medium, with or without 93 | modifications, and in Source or Object form, provided that You 94 | meet the following conditions: 95 | 96 | (a) You must give any other recipients of the Work or 97 | Derivative Works a copy of this License; and 98 | 99 | (b) You must cause any modified files to carry prominent notices 100 | stating that You changed the files; and 101 | 102 | (c) You must retain, in the Source form of any Derivative Works 103 | that You distribute, all copyright, patent, trademark, and 104 | attribution notices from the Source form of the Work, 105 | excluding those notices that do not pertain to any part of 106 | the Derivative Works; and 107 | 108 | (d) If the Work includes a "NOTICE" text file as part of its 109 | distribution, then any Derivative Works that You distribute must 110 | include a readable copy of the attribution notices contained 111 | within such NOTICE file, excluding those notices that do not 112 | pertain to any part of the Derivative Works, in at least one 113 | of the following places: within a NOTICE text file distributed 114 | as part of the Derivative Works; within the Source form or 115 | documentation, if provided along with the Derivative Works; or, 116 | within a display generated by the Derivative Works, if and 117 | wherever such third-party notices normally appear. The contents 118 | of the NOTICE file are for informational purposes only and 119 | do not modify the License. You may add Your own attribution 120 | notices within Derivative Works that You distribute, alongside 121 | or as an addendum to the NOTICE text from the Work, provided 122 | that such additional attribution notices cannot be construed 123 | as modifying the License. 124 | 125 | You may add Your own copyright statement to Your modifications and 126 | may provide additional or different license terms and conditions 127 | for use, reproduction, or distribution of Your modifications, or 128 | for any such Derivative Works as a whole, provided Your use, 129 | reproduction, and distribution of the Work otherwise complies with 130 | the conditions stated in this License. 131 | 132 | 5. Submission of Contributions. Unless You explicitly state otherwise, 133 | any Contribution intentionally submitted for inclusion in the Work 134 | by You to the Licensor shall be under the terms and conditions of 135 | this License, without any additional terms or conditions. 136 | Notwithstanding the above, nothing herein shall supersede or modify 137 | the terms of any separate license agreement you may have executed 138 | with Licensor regarding such Contributions. 139 | 140 | 6. Trademarks. This License does not grant permission to use the trade 141 | names, trademarks, service marks, or product names of the Licensor, 142 | except as required for reasonable and customary use in describing the 143 | origin of the Work and reproducing the content of the NOTICE file. 144 | 145 | 7. Disclaimer of Warranty. Unless required by applicable law or 146 | agreed to in writing, Licensor provides the Work (and each 147 | Contributor provides its Contributions) on an "AS IS" BASIS, 148 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 149 | implied, including, without limitation, any warranties or conditions 150 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 151 | PARTICULAR PURPOSE. You are solely responsible for determining the 152 | appropriateness of using or redistributing the Work and assume any 153 | risks associated with Your exercise of permissions under this License. 154 | 155 | 8. Limitation of Liability. In no event and under no legal theory, 156 | whether in tort (including negligence), contract, or otherwise, 157 | unless required by applicable law (such as deliberate and grossly 158 | negligent acts) or agreed to in writing, shall any Contributor be 159 | liable to You for damages, including any direct, indirect, special, 160 | incidental, or consequential damages of any character arising as a 161 | result of this License or out of the use or inability to use the 162 | Work (including but not limited to damages for loss of goodwill, 163 | work stoppage, computer failure or malfunction, or any and all 164 | other commercial damages or losses), even if such Contributor 165 | has been advised of the possibility of such damages. 166 | 167 | 9. Accepting Warranty or Additional Liability. While redistributing 168 | the Work or Derivative Works thereof, You may choose to offer, 169 | and charge a fee for, acceptance of support, warranty, indemnity, 170 | or other liability obligations and/or rights consistent with this 171 | License. However, in accepting such obligations, You may act only 172 | on Your own behalf and on Your sole responsibility, not on behalf 173 | of any other Contributor, and only if You agree to indemnify, 174 | defend, and hold each Contributor harmless for any liability 175 | incurred by, or claims asserted against, such Contributor by reason 176 | of your accepting any such warranty or additional liability. 177 | 178 | END OF TERMS AND CONDITIONS 179 | 180 | APPENDIX: How to apply the Apache License to your work. 181 | 182 | To apply the Apache License to your work, attach the following 183 | boilerplate notice, with the fields enclosed by brackets "[]" 184 | replaced with your own identifying information. (Don't include 185 | the brackets!) The text should be enclosed in the appropriate 186 | comment syntax for the file format. We also recommend that a 187 | file or class name and description of purpose be included on the 188 | same "printed page" as the copyright notice for easier 189 | identification within third-party archives. 190 | 191 | Copyright [yyyy] [name of copyright owner] 192 | 193 | Licensed under the Apache License, Version 2.0 (the "License"); 194 | you may not use this file except in compliance with the License. 195 | You may obtain a copy of the License at 196 | 197 | http://www.apache.org/licenses/LICENSE-2.0 198 | 199 | Unless required by applicable law or agreed to in writing, software 200 | distributed under the License is distributed on an "AS IS" BASIS, 201 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 202 | See the License for the specific language governing permissions and 203 | limitations under the License. 204 | --------------------------------------------------------------------------------