├── nne
├── analyze
│ ├── __init__.py
│ ├── tflite.py
│ └── onnx.py
├── convert
│ ├── __init__.py
│ ├── common.py
│ ├── torchscript.py
│ ├── torch.py
│ ├── trt.py
│ ├── onnx.py
│ └── tflite.py
├── quant
│ ├── __init__.py
│ └── onnx.py
├── __init__.py
└── benchmark.py
├── .gitignore
├── docs
└── logo.png
├── .github
├── FUNDING.yml
└── workflows
│ └── pythonapp.yml
├── examples
├── edgetpu_example.py
├── torch_example.py
├── torchscript_example.py
├── onnx_example.py
├── tensorrt_example.py
├── onnx_quantize.py
└── tflite_example.py
├── test
├── test_analyze.py
├── test_torch.py
├── test_script.py
├── test_tflite.py
└── test_onnx.py
├── benchmark
└── torch_bench.py
├── setup.py
├── bin
└── nne
├── README.md
└── LICENSE
/nne/analyze/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/nne/convert/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/nne/quant/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | */__pycache__/*
2 | *.pyc
3 | *.onnx
--------------------------------------------------------------------------------
/docs/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kuroko1t/nne/HEAD/docs/logo.png
--------------------------------------------------------------------------------
/.github/FUNDING.yml:
--------------------------------------------------------------------------------
1 | # These are supported funding model platforms
2 |
3 | github: kuroko1t
4 |
--------------------------------------------------------------------------------
/examples/edgetpu_example.py:
--------------------------------------------------------------------------------
1 | import torchvision
2 | import torch
3 | import numpy as np
4 | import nne
5 |
6 | input_shape = (10, 3, 112, 112)
7 | model = torchvision.models.mobilenet_v2(pretrained=True)
8 |
9 | tflite_file = 'mobilenet.tflite'
10 |
11 | nne.cv2tflite(model , input_shape, tflite_file, edgetpu=True)
12 |
--------------------------------------------------------------------------------
/examples/torch_example.py:
--------------------------------------------------------------------------------
1 | import nne
2 | import torchvision
3 | import torch
4 | import numpy as np
5 |
6 | input_shape = (1, 3, 224, 224)
7 | model = torchvision.models.resnet34(pretrained=True).cuda()
8 |
9 | input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32)
10 | output_data = nne.infer_torch(model, input_data)
11 |
--------------------------------------------------------------------------------
/examples/torchscript_example.py:
--------------------------------------------------------------------------------
1 | import nne
2 | import torchvision
3 | import torch
4 | import numpy as np
5 |
6 | input_shape = (1, 3, 224, 224)
7 | model = torchvision.models.resnet50(pretrained=True).cuda()
8 | script_file = 'resnet_script.zip'
9 | input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32)
10 | nne.cv2torchscript(model, input_shape, script_file)
11 | model_script = nne.load_torchscript(script_file)
12 | output = nne.infer_torchscript(model_script, input_data)
13 | print(output)
14 |
--------------------------------------------------------------------------------
/examples/onnx_example.py:
--------------------------------------------------------------------------------
1 | import nne
2 | import torchvision
3 | import torch
4 | import numpy as np
5 |
6 | input_shape = (1, 3, 64, 64)
7 | onnx_file = 'resnet.onnx'
8 | model = torchvision.models.resnet34(pretrained=True).cuda()
9 |
10 | # convert pytorch model to onnx model
11 | nne.cv2onnx(model, input_shape, onnx_file)
12 |
13 | input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32)
14 |
15 | # load onnx model
16 |
17 | onnx_model = nne.load_onnx(onnx_file)
18 |
19 | # inference
20 | output_data = nne.infer_onnx(onnx_model, input_data)
21 |
22 | print(output_data)
23 |
--------------------------------------------------------------------------------
/examples/tensorrt_example.py:
--------------------------------------------------------------------------------
1 | import nne
2 | import torchvision
3 | import torch
4 | import numpy as np
5 |
6 | input_shape = (1, 3, 224, 224)
7 | trt_file = 'alexnet_trt.pth'
8 | model = torchvision.models.alexnet(pretrained=True).cuda()
9 |
10 | # convert pytorch model to TensorRT model
11 | nne.cv2trt(model, input_shape, trt_file)
12 |
13 | input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32)
14 |
15 | # load TensorRT model
16 | trt_model = nne.load_trt(trt_file)
17 |
18 | # inference
19 | output_data = nne.infer_trt(trt_model, input_data)
20 |
21 | print(output_data)
22 |
--------------------------------------------------------------------------------
/examples/onnx_quantize.py:
--------------------------------------------------------------------------------
1 | from nne.quant.onnx import quant_oplist, quant_summary, quantize
2 | import torchvision
3 | import nne
4 |
5 | input_shape = (1, 3, 64, 64)
6 | onnx_file = 'resnet.onnx'
7 | model = torchvision.models.resnet34(pretrained=True)
8 |
9 | # convert pytorch model to onnx model
10 | nne.cv2onnx(model, input_shape, onnx_file)
11 |
12 | # onnx model to quantized model
13 | quantize("resnet.onnx")
14 |
15 | # return support quantized operation list
16 | quantie_op = quant_oplist()
17 |
18 | # return summary information about quantized model
19 | summary = quant_summary("resnet.quant.onnx")
20 | print(summary)
21 |
--------------------------------------------------------------------------------
/examples/tflite_example.py:
--------------------------------------------------------------------------------
1 | import torchvision
2 | import torch
3 | import numpy as np
4 | import nne
5 |
6 | input_shape = (10, 3, 224, 224)
7 | model = torchvision.models.mobilenet_v2(pretrained=True).cuda()
8 |
9 | tflite_file = 'mobilenet.tflite'
10 |
11 | # convert pytorch model to tensorflow lite
12 | nne.cv2tflite(model, input_shape, tflite_file)
13 |
14 | input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32)
15 |
16 | # load tflite model
17 | tflite_model = nne.load_tflite(tflite_file)
18 |
19 | # inference
20 | output_data = nne.infer_tflite(tflite_model, input_data)
21 |
22 | print(output_data)
23 |
--------------------------------------------------------------------------------
/test/test_analyze.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import nne
3 | import torchvision
4 | import torch
5 | import numpy as np
6 | from nne.quant.onnx import quant_oplist, quant_summary, quantize
7 |
8 | class AnalyzeTests(unittest.TestCase):
9 | def __init__(self, *args, **kwargs):
10 | super(AnalyzeTests, self).__init__(*args, **kwargs)
11 |
12 | def test_onnx(self):
13 | input_shape = (1, 3, 64, 64)
14 | onnx_file = 'resnet.onnx'
15 | model = torchvision.models.resnet34(pretrained=True)
16 | nne.cv2onnx(model, input_shape, onnx_file)
17 | nne.analyze(onnx_file)
18 | nne.analyze(onnx_file, "resnet.json")
19 |
--------------------------------------------------------------------------------
/benchmark/torch_bench.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import nne
4 | import os
5 | import torchvision
6 |
7 | input_shape = (1, 3, 224, 224)
8 | model = torchvision.models.resnet34(pretrained=True)
9 | torch.save(model, "resnet.pt")
10 |
11 | bm = nne.Benchmark(name='torch')
12 | input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32)
13 | output_data = nne.infer_torch(model, input_data, bm=bm)
14 |
15 | ## onnx
16 | onnx_file = "resnet.onnx"
17 | bm = nne.Benchmark(name='onnx')
18 | nne.cv2onnx(model, input_shape, onnx_file)
19 | onnx_model = nne.load_onnx(onnx_file)
20 | nne.infer_onnx(onnx_model, input_data, bm=bm)
21 |
22 | ## tflite
23 | tflite_file = "resnet.tflite"
24 | bm = nne.Benchmark(name='tflite')
25 | nne.cv2tflite(model, input_shape, tflite_file)
26 | tflite_model = nne.load_tflite(tflite_file)
27 | nne.infer_tflite(tflite_model, input_data, bm=bm)
28 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 | import platform
3 | import os
4 |
5 | def check_tensorrt():
6 | try:
7 | import tensorrt
8 | return True
9 | except:
10 | return False
11 |
12 | def check_jetson():
13 | if platform.machine() == "aarch64":
14 | return True
15 | else:
16 | return False
17 |
18 | def get_requires():
19 | requires = []
20 | if not check_jetson():
21 | requires = ["tensorflow==2.11.0", "tensorflow_addons"]
22 | requires += ["torch", "tensorflow==2.11.0", "tensorflow_probability", "onnx", "onnx_tf @ git+https://github.com/onnx/onnx-tensorflow",
23 | "matplotlib", "onnx-simplifier"]
24 | if check_tensorrt():
25 | requires += ["pycuda"]
26 | return requires
27 |
28 | setup(
29 | name="nne",
30 | scripts=["bin/nne"],
31 | packages=find_packages(),
32 | install_requires=get_requires(),
33 | version="0.1"
34 | )
35 |
--------------------------------------------------------------------------------
/test/test_torch.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import nne
3 | import torchvision
4 | import torch
5 | import numpy as np
6 |
7 | class TorchTests(unittest.TestCase):
8 | def __init__(self, *args, **kwargs):
9 | super(TorchTests, self).__init__(*args, **kwargs)
10 |
11 | def test_torch(self):
12 | input_shape = (1, 3, 64, 64)
13 | script_file = 'resnet_script.zip'
14 | model = torchvision.models.resnet34(pretrained=True)
15 | nne.cv2torchscript(model, input_shape, script_file)
16 |
17 | input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32)
18 | model_script = nne.load_torchscript(script_file)
19 | out_script = nne.infer_torchscript(model_script, input_data)
20 | model.eval()
21 | out_pytorch = model(torch.from_numpy(input_data)).detach().cpu().numpy()
22 | np.testing.assert_allclose(out_script, out_pytorch, rtol=1e-03, atol=1e-05)
23 |
24 | if __name__ == "__main__":
25 | unittest.main()
26 |
--------------------------------------------------------------------------------
/test/test_script.py:
--------------------------------------------------------------------------------
1 | #import unittest
2 | #import nne
3 | #import torchvision
4 | #import torch
5 | #import numpy as np
6 | #import subprocess
7 | #
8 | #class ScriptTests(unittest.TestCase):
9 | # def __init__(self, *args, **kwargs):
10 | # super(ScriptTests, self).__init__(*args, **kwargs)
11 | # self.onnx_file = 'resnet.onnx'
12 | # input_shape = (1, 3, 64, 64)
13 | # model = torchvision.models.resnet34(pretrained=True)
14 | # nne.cv2onnx(model, input_shape, self.onnx_file)
15 | #
16 | # def test_analyze(self):
17 | # subprocess.check_output(["nne", self.onnx_file], stderr=subprocess.STDOUT)
18 | # subprocess.check_output(["nne", self.onnx_file, "-a", "resnet.json"], stderr=subprocess.STDOUT)
19 | #
20 | # def test_convert(self):
21 | # subprocess.check_output(["nne", self.onnx_file, "-s", "resnet_smip.onnx"], stderr=subprocess.STDOUT)
22 | # subprocess.check_output(["nne", self.onnx_file, "-t", "resnet.tflite"], stderr=subprocess.STDOUT)
23 |
--------------------------------------------------------------------------------
/bin/nne:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | import nne
4 | import argparse
5 |
6 | parser = argparse.ArgumentParser(description='Neural Network Graph Analyzer')
7 | parser.add_argument("model_path", type=str,
8 | help="model path for analyzing(onnx or tflite)")
9 | parser.add_argument("-a", "--analyze_path", type=str,
10 | help="Specify the path to output the Node information of the model in json format.")
11 | parser.add_argument("-s", "--simplyfy_path", type=str,
12 | help="onnx model to simplyfier")
13 | parser.add_argument("-t", "--tflite_path", type=str,
14 | help="onnx model to tflite")
15 | args = parser.parse_args()
16 | if args.simplyfy_path:
17 | nne.cv2onnxsimplify(args.model_path, args.simplyfy_path)
18 | elif args.tflite_path:
19 | nne.onnx2tflite(args.model_path, args.tflite_path)
20 | elif args.analyze_path:
21 | nne.analyze(args.model_path, args.analyze_path)
22 | else:
23 | nne.analyze(args.model_path, None)
24 |
--------------------------------------------------------------------------------
/nne/__init__.py:
--------------------------------------------------------------------------------
1 | from .analyze import onnx as onnx_analyze
2 | import json
3 | from .analyze import tflite as tflite_analyze
4 | from .convert.torch import *
5 | from .benchmark import *
6 | from .convert.torchscript import *
7 | from .convert.onnx import *
8 | from .convert.tflite import *
9 | import os
10 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
11 |
12 | if check_tensorrt():
13 | from .trt import *
14 |
15 |
16 | def analyze(model_path, output_path=None):
17 | ext = os.path.splitext(model_path)[1]
18 | if ext == ".onnx":
19 | model_info = onnx_analyze.analyze_graph(model_path, output_path)
20 | return model_info
21 | elif ext == ".tflite":
22 | model_info = tflite_analyze.analyze_graph(model_path, output_path)
23 | if output_path == None:
24 | output_path = model_path.replace(".tflite", "_tflite.json")
25 | with open(output_path, 'w') as f:
26 | json.dump(model_info, f, indent=2, cls=tflite_analyze.NumpyEncoder)
27 | print(f"Write Dump Result -> {output_path}")
28 | return model_info
29 | else:
30 | raise Exception(f"no support {ext} file")
31 |
--------------------------------------------------------------------------------
/nne/analyze/tflite.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import json
3 | import numpy as np
4 |
5 |
6 | def analyze_graph(model_path, output_path):
7 | interpreter = tf.lite.Interpreter(model_path=model_path)
8 | interpreter.allocate_tensors()
9 | graphinfo = interpreter.get_tensor_details()
10 | for i, op in enumerate(graphinfo):
11 | graphinfo[i]["shape"] = op["shape"].tolist()
12 | graphinfo[i]["shape_signature"] = op["shape_signature"].tolist()
13 | graphinfo[i]["dtype"] = [op["dtype"].__name__]
14 | graphinfo[i]["quantization_parameters"]["scales"] = op["quantization_parameters"]["scales"].tolist()
15 | graphinfo[i]["quantization_parameters"]["zero_points"] = op["quantization_parameters"]["zero_points"].tolist()
16 | return graphinfo
17 |
18 |
19 | class NumpyEncoder(json.JSONEncoder):
20 | def default(self, obj):
21 | if isinstance(obj, np.integer):
22 | return int(obj)
23 | elif isinstance(obj, np.floating):
24 | return float(obj)
25 | elif isinstance(obj, np.ndarray):
26 | return obj.tolist()
27 | else:
28 | return super(NumpyEncoder, self).default(obj)
29 |
--------------------------------------------------------------------------------
/test/test_tflite.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import nne
3 | import torchvision
4 | import torch
5 | import numpy as np
6 |
7 | class TFliteTests(unittest.TestCase):
8 | def __init__(self, *args, **kwargs):
9 | super(TFliteTests, self).__init__(*args, **kwargs)
10 |
11 | def test_tflite(self):
12 | input_shape = (1, 3, 64, 64)
13 | tflite_file = 'mobilenet.tflite'
14 | model = torchvision.models.mobilenet_v2(pretrained=True)
15 | nne.cv2tflite(model, input_shape, tflite_file)
16 |
17 | input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32)
18 | model_tflite = nne.load_tflite(tflite_file)
19 | out_tflite = nne.infer_tflite(model_tflite, input_data)
20 | model.eval()
21 | out_pytorch = model(torch.from_numpy(input_data)).detach().cpu().numpy()
22 | np.testing.assert_allclose(out_tflite, out_pytorch, rtol=1e-03, atol=1e-05)
23 |
24 | def test_tflite_quant(self):
25 | input_shape = (1, 3, 64, 64)
26 | tflite_file = 'mobilenet.tflite'
27 | model = torchvision.models.mobilenet_v2(pretrained=True)
28 | nne.cv2tflite(model, input_shape, tflite_file, quantization=True)
29 |
30 | if __name__ == "__main__":
31 | unittest.main()
32 |
--------------------------------------------------------------------------------
/.github/workflows/pythonapp.yml:
--------------------------------------------------------------------------------
1 | # This workflow will install Python dependencies, run tests and lint with a single version of Python
2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
3 |
4 | name: Python application
5 |
6 | on:
7 | push:
8 | branches: [ master ]
9 | pull_request:
10 | branches: [ master ]
11 |
12 | jobs:
13 | build:
14 |
15 | runs-on: ubuntu-latest
16 |
17 | steps:
18 | - uses: actions/checkout@v2
19 | - name: Set up Python 3.7
20 | uses: actions/setup-python@v1
21 | with:
22 | python-version: 3.7
23 | - name: Install dependencies
24 | run: |
25 | python -m pip install --upgrade pip
26 | python -m pip install numpy torchvision
27 | python -m pip install -e .
28 | - name: Lint with flake8
29 | run: |
30 | pip install flake8
31 | # stop the build if there are Python syntax errors or undefined names
32 | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
33 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
34 | flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
35 | - name: Test with pytest
36 | run: |
37 | python -m pip install pytest
38 | pytest . -p no:warnings
39 |
--------------------------------------------------------------------------------
/nne/convert/common.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 kurosawa. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # =============================================================================
15 |
16 | import platform
17 | import os
18 | import onnxsim
19 |
20 |
21 | def check_jetson():
22 | if platform.machine() == "aarch64":
23 | return True
24 | else:
25 | return False
26 |
27 | def check_tensorrt():
28 | try:
29 | import tensorrt
30 | return True
31 | except:
32 | return False
33 |
34 | def check_model_is_cuda(model):
35 | return next(model.parameters()).is_cuda
36 |
37 | def onnx_simplify(model, input_shapes):
38 | model_opt, check_ok = onnxsim.simplify(
39 | model, check_n=3, perform_optimization=False, skip_fuse_bn=False, input_shapes=None)
40 | return model_opt, check_ok
41 |
--------------------------------------------------------------------------------
/nne/convert/torchscript.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 kurosawa. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # =============================================================================
15 |
16 | import torch
17 | from .common import *
18 |
19 | def cv2torchscript(model, input_shape, script_path):
20 | model.eval()
21 | if check_model_is_cuda(model):
22 | dummy_input = torch.randn(input_shape, device="cuda")
23 | else:
24 | dummy_input = torch.randn(input_shape, device="cpu")
25 | traced = torch.jit.trace(model, dummy_input)
26 | traced.save(script_path)
27 |
28 | def load_torchscript(script_path):
29 | model = torch.jit.load(script_path)
30 | return model
31 |
32 | def infer_torchscript(model, input_data, bm=None):
33 | input_data = torch.from_numpy(input_data)
34 | if check_model_is_cuda(model):
35 | input_data = input_data.cuda()
36 | if bm:
37 | output = bm.measure(model, name="torchscript")(input_data)
38 | else:
39 | output = model(input_data)
40 | return output.detach().cpu().numpy()
41 |
--------------------------------------------------------------------------------
/test/test_onnx.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import nne
3 | import torchvision
4 | import torch
5 | import numpy as np
6 | from nne.quant.onnx import quant_oplist, quant_summary, quantize
7 |
8 | class OnnxTests(unittest.TestCase):
9 | def __init__(self, *args, **kwargs):
10 | super(OnnxTests, self).__init__(*args, **kwargs)
11 |
12 | def test_onnx(self):
13 | input_shape = (1, 3, 64, 64)
14 | onnx_file = 'resnet.onnx'
15 | model = torchvision.models.resnet34(pretrained=True)
16 | nne.cv2onnx(model, input_shape, onnx_file)
17 |
18 | input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32)
19 | model_onnx = nne.load_onnx(onnx_file)
20 | out_onnx = nne.infer_onnx(model_onnx, input_data)[0]
21 | model.eval()
22 | out_pytorch = model(torch.from_numpy(input_data)).detach().cpu().numpy()
23 | np.testing.assert_allclose(out_onnx, out_pytorch, rtol=1e-03, atol=1e-05)
24 |
25 | def test_onnx_analyzer(self):
26 | input_shape = (1, 3, 64, 64)
27 | onnx_file = 'resnet.onnx'
28 | model = torchvision.models.resnet34(pretrained=True)
29 | nne.cv2onnx(model, input_shape, onnx_file)
30 | nne.analyze(onnx_file, "")
31 | nne.analyze(onnx_file, "resnet.json")
32 |
33 | def test_onnx_quant(self):
34 | input_shape = (1, 3, 64, 64)
35 | onnx_file = 'resnet.onnx'
36 | model = torchvision.models.resnet34(pretrained=True)
37 | nne.cv2onnx(model, input_shape, onnx_file)
38 | quantize("resnet.onnx")
39 | quantie_op = quant_oplist()
40 | summary = quant_summary("resnet.quant.onnx")
41 |
42 | if __name__ == "__main__":
43 | unittest.main()
44 |
--------------------------------------------------------------------------------
/nne/convert/torch.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 kurosawa. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # =============================================================================
15 |
16 | import torch
17 | from .common import *
18 |
19 | def infer_torch(model, input_data, bm=None):
20 | """
21 | model : loaded model
22 | input_data: numpy array
23 | """
24 | model.eval()
25 | if type(input_data) == tuple:
26 | if check_model_is_cuda(model):
27 | input_data = [torch.from_numpy(data).cuda() for data in input_data]
28 | else:
29 | input_data = [torch.from_numpy(data) for data in input_data]
30 | else:
31 | input_data = torch.from_numpy(input_data)
32 | if check_model_is_cuda(model):
33 | input_data = input_data.cuda()
34 | if bm:
35 | if type(input_data) == list:
36 | output = bm.measure(model, name="torch")(*input_data)
37 | else:
38 | output = bm.measure(model, name="torch")(input_data)
39 | else:
40 | if type(input_data) == list:
41 | output = model(*input_data)
42 | else:
43 | output = model(input_data)
44 | return output.detach().cpu().numpy()
45 |
--------------------------------------------------------------------------------
/nne/quant/onnx.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 kurosawa. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # =============================================================================
15 |
16 | from onnxruntime.quantization import quantize_dynamic, QuantType
17 | from onnxruntime.quantization.registry import QLinearOpsRegistry, QDQRegistry, IntegerOpsRegistry
18 | import onnx
19 | import collections
20 |
21 |
22 | def quantize(modelpath):
23 | model_quant = modelpath.replace(".onnx", ".quant.onnx")
24 | quantized_model = quantize_dynamic(modelpath, model_quant)
25 |
26 |
27 | def quant_oplist():
28 | qoplist = {}
29 | qoplist.update(QLinearOpsRegistry)
30 | qoplist.update(QDQRegistry)
31 | qoplist.update(IntegerOpsRegistry)
32 | quantized_op = {}
33 | for v in qoplist.values():
34 | quantized_op.update({v.__name__: v})
35 | # quantized_opname.append("DynamicQuantizeLinear")
36 | # print(qoplist)
37 | return quantized_op
38 |
39 |
40 | def quant_summary(quantmodel):
41 | summary = {}
42 | model = onnx.load(quantmodel)
43 | quant_op = []
44 | summary.update({"opset_version": model.opset_import[-1].version})
45 | for node in model.graph.node:
46 | if node.op_type in quant_oplist().keys():
47 | quant_op.append(node.op_type)
48 | quant_op_counter = dict(collections.Counter(quant_op))
49 | summary.update({"quant_op": quant_op_counter})
50 | return summary
51 |
--------------------------------------------------------------------------------
/nne/benchmark.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 kurosawa. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # =============================================================================
15 |
16 | import time
17 | import matplotlib.pyplot as plt
18 |
19 | class Benchmark:
20 | """
21 | This class is for measuring inference time
22 | """
23 | def __init__(self, counter=11, name="sample"):
24 | self.ave = []
25 | self.counter = counter
26 | self.saves = []
27 | self.name = name
28 |
29 | def measure(self, func, name):
30 | def inner(*args, **kwargs):
31 | durations = []
32 | for i in range(self.counter):
33 | start = time.time()
34 | func(*args, **kwargs)
35 | end = time.time()
36 | if i != 0:
37 | durations.append((end - start) * 1000)
38 | ave = sum(durations) / self.counter
39 | self.ave.append(ave)
40 | min_value = min(durations)
41 | max_value = max(durations)
42 | print(f"{name},average[ms],{round(ave, 4)},min[ms],{round(min_value, 4)},max[ms],{round(max_value, 4)}")
43 | return func(*args, **kwargs)
44 | return inner
45 |
46 |
47 | class Plot:
48 | """
49 | Take the Benchmark class as an argument and plot the inference time.
50 | The x-axis assumes batch size.
51 | """
52 | def __init__(self, benchmarks:list):
53 | self.benchmarks = benchmarks
54 |
55 | def plot(self, x, xlabel, title, savefile):
56 | for bench in self.benchmarks:
57 | plt.plot(x, bench.ave, "-o", label=bench.name)
58 | plt.title(title)
59 | plt.legend()
60 | plt.xlabel("batch size")
61 | plt.ylabel("inference time[ms]")
62 | plt.savefig(savefile)
63 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |

2 |
3 | convert pytorch model for Edge Device
4 |
5 | # nne
6 | contents
7 |
8 | - [Install](#install)
9 | - [Example](#Example)
10 | - [onnx](#onnx)
11 | - [tflite](#tflite)
12 | - [tflite(edgetpu)](#tflite-edgetpu)
13 | - [TensorRT](#tensorrt)
14 | - [Script](#Script)
15 | - [Support Format](#Support-Format)
16 | - [License](#License)
17 |
18 | ## Install
19 |
20 | ```bash
21 | python -m pip install -e .
22 | ```
23 |
24 | * edgetpu
25 |
26 | If you want to compile pytorch model for edgetpu, [install edgetpu_compiler](https://coral.ai/docs/edgetpu/compiler/)
27 |
28 | ## Example
29 |
30 | example compile pytorch model for edge device. See [example](https://github.com/kuroko1t/nne/tree/master/examples) for details
31 |
32 | ### onnx
33 |
34 | comvert to onnx model
35 |
36 | ```python3
37 | import nne
38 | import torchvision
39 | import torch
40 | import numpy as np
41 |
42 | input_shape = (1, 3, 64, 64)
43 | onnx_file = 'resnet.onnx'
44 | model = torchvision.models.resnet34(pretrained=True).cuda()
45 |
46 | nne.cv2onnx(model, input_shape, onnx_file)
47 | ```
48 |
49 | ### tflite
50 |
51 | comvert to tflite model
52 |
53 | ```python3
54 | import torchvision
55 | import torch
56 | import numpy as np
57 | import nne
58 |
59 | input_shape = (10, 3, 224, 224)
60 | model = torchvision.models.mobilenet_v2(pretrained=True).cuda()
61 |
62 | tflite_file = 'mobilenet.tflite'
63 |
64 | nne.cv2tflite(model, input_shape, tflite_file)
65 | ```
66 |
67 | ### tflite(edgetpu)
68 |
69 | comvert to tflite model(edge tpu)
70 |
71 | ```python3
72 | import torchvision
73 | import torch
74 | import numpy as np
75 | import nne
76 |
77 | input_shape = (10, 3, 112, 112)
78 | model = torchvision.models.mobilenet_v2(pretrained=True)
79 |
80 | tflite_file = 'mobilenet.tflite'
81 |
82 | nne.cv2tflite(model, input_shape, tflite_file, edgetpu=True)
83 | ```
84 |
85 | ### TensorRT
86 |
87 | convert to TensorRT model
88 |
89 | ```python3
90 | import nne
91 | import torchvision
92 | import torch
93 | import numpy as np
94 |
95 | input_shape = (1, 3, 224, 224)
96 | trt_file = 'alexnet_trt.pth'
97 | model = torchvision.models.alexnet(pretrained=True).cuda()
98 | nne.cv2trt(model, input_shape, trt_file)
99 | ```
100 |
101 | ## Script
102 |
103 | * show summary model info
104 | * dump detailed model information(node name, attrs) to json file.
105 | * convert onnx model to tflite, simplifier
106 |
107 | ```bash
108 | $nne -h
109 | usage: nne [-h] [-a ANALYZE_PATH] [-s SIMPLYFY_PATH] [-t TFLITE_PATH] model_path
110 |
111 | Neural Network Graph Analyzer
112 |
113 | positional arguments:
114 | model_path model path for analyzing
115 |
116 | optional arguments:
117 | -h, --help show this help message and exit
118 | -a ANALYZE_PATH, --analyze_path ANALYZE_PATH
119 | Specify the path to output the Node information of the model in json format.
120 | -s SIMPLYFY_PATH, --simplyfy_path SIMPLYFY_PATH
121 | onnx model to simplyfier
122 | -t TFLITE_PATH, --tflite_path TFLITE_PATH
123 | onnx model to tflite
124 | ```
125 |
126 | ## Support Format
127 |
128 | |format | support |
129 | |---|---|
130 | | tflite | :white_check_mark: |
131 | | edge tpu | trial |
132 | | onnx| :white_check_mark: |
133 | | tensorRT| :white_check_mark: |
134 |
135 | ## License
136 | Apache 2.0
137 |
--------------------------------------------------------------------------------
/nne/convert/trt.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 kurosawa. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # =============================================================================
15 |
16 | import torch
17 | import tensorrt as trt
18 | import pycuda.driver as cuda
19 | import pycuda.autoinit
20 | from .common import *
21 | from .onnx import *
22 | import numpy as np
23 |
24 | TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
25 |
26 | def cv2trt(model, input_shape, trt_file, fp16_mode=False):
27 | """
28 | convert torch model to tflite model using onnx
29 | """
30 | model.eval()
31 | onnx_file = os.path.splitext(trt_file)[0] + ".onnx"
32 | cv2onnx(model, input_shape, onnx_file, simplify=True)
33 | EXPLICIT_BATCH = 1
34 | builder = trt.Builder(TRT_LOGGER)
35 | network = builder.create_network(EXPLICIT_BATCH)
36 | parser = trt.OnnxParser(network, TRT_LOGGER)
37 | with open(onnx_file, "rb") as onnx_model:
38 | parser.parse(onnx_model.read())
39 | if parser.num_errors > 0:
40 | error = parser.get_error(0)
41 | raise Exception(error)
42 | max_workspace_size = 1 << 28
43 | max_batch_size = 32
44 | builder.max_batch_size = max_batch_size
45 | builder.max_workspace_size = max_workspace_size
46 | builder.fp16_mode = fp16_mode
47 | engine = builder.build_cuda_engine(network)
48 | with open(trt_file, "wb") as f:
49 | f.write(engine.serialize())
50 | os.remove(onnx_file)
51 |
52 |
53 | def load_trt(trt_file):
54 | runtime = trt.Runtime(TRT_LOGGER)
55 | with open(trt_file, "rb") as f:
56 | engine = runtime.deserialize_cuda_engine(f.read())
57 | return engine
58 |
59 |
60 | def infer_trt(engine, input_data, bm=None):
61 | #outputs = engine.run(input_data)
62 | h_input = cuda.pagelocked_empty(trt.volume(engine.get_binding_shape(0)), dtype=np.float32)
63 | h_output = cuda.pagelocked_empty(trt.volume(engine.get_binding_shape(1)), dtype=np.float32)
64 |
65 | def execute():
66 | # Allocate device memory for inputs and outputs.
67 | d_input = cuda.mem_alloc(h_input.nbytes)
68 | d_output = cuda.mem_alloc(h_output.nbytes)
69 | # Create a stream in which to copy inputs/outputs and run inference.
70 | stream = cuda.Stream()
71 | with engine.create_execution_context() as context:
72 | # Transfer input data to the GPU.
73 | cuda.memcpy_htod_async(d_input, h_input, stream)
74 | # Run inference.
75 | context.execute_async(bindings=[int(d_input), int(d_output)], stream_handle=stream.handle)
76 | # Transfer predictions back from the GPU.
77 | cuda.memcpy_dtoh_async(h_output, d_output, stream)
78 | # Synchronize the stream
79 | stream.synchronize()
80 | # Return the host output.
81 | return h_output
82 |
83 | if bm:
84 | output = bm.measure(execute, name="TensorRT")()
85 | else:
86 | output = execute()
87 | return output
88 |
--------------------------------------------------------------------------------
/nne/convert/onnx.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 kurosawa. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # =============================================================================
15 | import onnx
16 | import torch
17 | from .common import *
18 | from ..analyze.onnx import Analyze
19 | import sys
20 | import onnx
21 | try:
22 | import onnxruntime
23 | except:
24 | pass
25 |
26 | import tensorflow
27 |
28 | def cv2onnxsimplify(onnx_file, output_file):
29 | analyze = Analyze(onnx_file, None)
30 | onnx_model = load_onnx(onnx_file)
31 | model_opt, check_ok = onnx_simplify(onnx_file, analyze.input_shapes)
32 | if check_ok:
33 | onnx.save(model_opt, output_file)
34 |
35 | def cv2onnx(model, input_shape, onnx_file, simplify=False, verbose=False,
36 | input_names=["input"], output_names=["output"]):
37 | """
38 | convert torch model to tflite model using onnx
39 | """
40 | if type(input_shape[0]) == tuple:
41 | if check_model_is_cuda(model):
42 | dummy_input = tuple([torch.randn(ishape, device="cuda") for ishape in input_shape])
43 | else:
44 | dummy_input = tuple([torch.randn(ishape, device="cpu") for ishape in input_shape])
45 | elif type(input_shape) == tuple:
46 | if check_model_is_cuda(model):
47 | dummy_input = torch.randn(input_shape, device="cuda")
48 | else:
49 | dummy_input = torch.randn(input_shape, device="cpu")
50 | else:
51 | raise Exception("input_shape must be tuple")
52 |
53 | try:
54 | torch.onnx.export(model, dummy_input, onnx_file,
55 | input_names=input_names , output_names=output_names, verbose=verbose)
56 | except RuntimeError as e:
57 | opset_version = 11
58 | torch.onnx.export(model, dummy_input, onnx_file,
59 | opset_version=opset_version,
60 | input_names=input_names , output_names=output_names)
61 |
62 | if simplify:
63 | onnx_model = load_onnx(onnx_file)
64 | model_opt, check_ok = onnx_simplify(onnx_model, input_shape)
65 | if check_ok:
66 | onnx.save(model_opt, onnx_file)
67 |
68 | def load_onnx(onnx_file):
69 | sess = onnxruntime.InferenceSession(onnx_file)
70 | if "TensorrtExecutionProvider" in sess.get_providers():
71 | print("onnxmodel with TensorrtExecutionProvider")
72 | sess.set_providers(["TensorrtExecutionProvider"])
73 | elif "CUDAExecutionProvider" in sess.get_providers():
74 | print("onnxmodel with CUDAExecutionProvider")
75 | sess.set_providers(["CUDAExecutionProvider"])
76 | elif "CPUExecutionProvider" in sess.get_providers():
77 | print("onnxmodel with CPUExecutionProvider")
78 | sess.set_providers(["CPUExecutionProvider"])
79 | return sess
80 |
81 |
82 | def infer_onnx(sess, input_data, bm=None):
83 | if type(input_data) == tuple:
84 | for i, data in enumerate(input_data):
85 | if i == 0:
86 | ort_inputs = {sess.get_inputs()[i].name: data}
87 | else:
88 | ort_inputs[sess.get_inputs()[i].name] = data
89 | else:
90 | ort_inputs = {sess.get_inputs()[0].name: input_data}
91 | if bm:
92 | ort_outs = bm.measure(sess.run, name="onnx")(None, ort_inputs)
93 | else:
94 | ort_outs = sess.run(None, ort_inputs)
95 | return ort_outs
96 |
--------------------------------------------------------------------------------
/nne/convert/tflite.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 kurosawa. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # =============================================================================
15 |
16 | import onnx
17 | from onnx_tf.backend import prepare
18 | import torch
19 | import tensorflow as tf
20 | import os
21 | import shutil
22 | import sys
23 | import subprocess
24 | from .common import *
25 | from .onnx import cv2onnx
26 | import numpy as np
27 |
28 | def onnx2tflite(onnx_file, tflite_path):
29 | print(onnx_file, tflite_path)
30 | onnx_model = onnx.load(onnx_file)
31 | tf_rep = prepare(onnx_model)
32 | tmp_pb_file = "tmp.pb"
33 | tf_rep.export_graph(tmp_pb_file)
34 | converter = tf.lite.TFLiteConverter.from_saved_model(tmp_pb_file)
35 | tflite_model = converter.convert()
36 | shutil.rmtree(tmp_pb_file)
37 | with open(tflite_path, "wb") as f:
38 | f.write(tflite_model)
39 |
40 | def cv2tflite(model, input_shape, tflite_path, edgetpu=False, quantization=False):
41 | """
42 | convert torch model to tflite model using onnx
43 | """
44 | onnx_input_flag = False
45 | if type(model) == str:
46 | ext = os.path.splitext(model)[1]
47 | if ext == ".onnx":
48 | onnx_input_flag = True
49 | tmp_pb_file = "tmp.pb"
50 | if not onnx_input_flag:
51 | onnx_file = "tmp.onnx"
52 | cv2onnx(model, input_shape, onnx_file)
53 | else:
54 | onnx_file = model
55 | onnx_model = onnx.load(onnx_file)
56 | onnx_input_names = [input.name for input in onnx_model.graph.input]
57 | onnx_output_names = [output.name for output in onnx_model.graph.output]
58 | tf_rep = prepare(onnx_model)
59 | tf_rep.export_graph(tmp_pb_file)
60 |
61 | converter = tf.lite.TFLiteConverter.from_saved_model(tmp_pb_file)
62 |
63 | if quantization:
64 | converter.optimizations = [tf.lite.Optimize.DEFAULT]
65 |
66 | if edgetpu:
67 | if type(input_shape[0]) == tuple:
68 | if check_model_is_cuda(model):
69 | dummy_input = tuple([np.randn(ishape) for ishape in input_shape])
70 | else:
71 | dummy_input = tuple([np.randn(ishape) for ishape in input_shape])
72 | elif type(input_shape) == tuple:
73 | if check_model_is_cuda(model):
74 | dummy_input = np.randn(input_shape)
75 | else:
76 | dummy_input = np.randn(input_shape)
77 | else:
78 | raise Exception("input_shape must be tuple")
79 | train = tf.convert_to_tensor(input_data)
80 | my_ds = tf.data.Dataset.from_tensor_slices((train)).batch(10)
81 | def representative_dataset_gen():
82 | for input_value in my_ds.take(10):
83 | yield [input_value]
84 | converter.representative_dataset = representative_dataset_gen
85 | converter.allow_custom_ops = True
86 | converter.experimental_new_converter = True
87 | converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
88 | converter.inference_input_type = tf.int8
89 | converter.inference_output_type = tf.int8
90 |
91 | # convert tensorflow to tflite model
92 | tflite_model = converter.convert()
93 |
94 | with open(tflite_path, "wb") as f:
95 | f.write(tflite_model)
96 | if not onnx_input_flag:
97 | os.remove(onnx_file)
98 | shutil.rmtree(tmp_pb_file)
99 |
100 | if edgetpu:
101 | subprocess.check_call(f"edgetpu_compiler {tflite_path}", shell=True)
102 |
103 |
104 | def load_tflite(tflitepath):
105 | interpreter = tf.lite.Interpreter(model_path=tflitepath)
106 | # allocate memory
107 | interpreter.allocate_tensors()
108 | return interpreter
109 |
110 |
111 | def infer_tflite(interpreter, input_data, bm=None):
112 | # get model input and output propaty
113 | input_details = interpreter.get_input_details()
114 | output_details = interpreter.get_output_details()
115 | ## get input shape
116 | #input_shape = input_details[0]["shape"]
117 | def execute():
118 | # set tensor pointer to index
119 | if type(input_data) == tuple:
120 | for i, data in enumerate(input_data):
121 | interpreter.set_tensor(input_details[i]["index"], data)
122 | else:
123 | interpreter.set_tensor(input_details[0]["index"], input_data)
124 | # execute infer
125 | interpreter.invoke()
126 | # get result from index of output_details
127 | output_data = interpreter.get_tensor(output_details[0]["index"])
128 | return output_data
129 | if bm:
130 | output_data = bm.measure(execute, name="tflite")()
131 | else:
132 | output_data = execute()
133 | return output_data
134 |
--------------------------------------------------------------------------------
/nne/analyze/onnx.py:
--------------------------------------------------------------------------------
1 | import re
2 | import collections
3 | import onnx
4 | import json
5 | from onnx import defs
6 |
7 |
8 | class Analyze:
9 | def __init__(self, model_path, output_path):
10 | self.model_path = model_path
11 | model = onnx.load(model_path)
12 | graph_def = model.graph
13 | self.input_shapes = self.get_input_shape(graph_def)
14 | self.output_shapes = self.get_output_shape(graph_def)
15 | self.opset_info = model.opset_import[0].version
16 | self.nodes = self.nodes(graph_def)
17 | self.model_info = self.modelinfo()
18 | if output_path:
19 | self.dump(self.model_info, output_path)
20 |
21 | def modelinfo(self):
22 | model_info = {}
23 | model_info["OpsetVersion"] = self.opset_info
24 | model_info["InputShape"] = self.input_shapes
25 | model_info["OutputShape"] = self.output_shapes
26 | node_dict_list = [model_info] + self.nodes2dict()
27 | return node_dict_list
28 |
29 | def nodes(self, graph_def):
30 | return [OnnxNode(node) for node in graph_def.node]
31 |
32 | def get_input_shape(self, graph_def):
33 | input_shapes = []
34 | if graph_def.initializer:
35 | initialized = {init.name for init in graph_def.initializer}
36 | else:
37 | initialized = set()
38 | for value_info in graph_def.input:
39 | if value_info.name in initialized:
40 | continue
41 | shape = list(
42 | d.dim_value if (
43 | d.dim_value > 0 and d.dim_param == "") else None
44 | for d in value_info.type.tensor_type.shape.dim)
45 | input_shapes.append(shape)
46 | return input_shapes
47 |
48 | def get_output_shape(self, graph_def):
49 | output_shapes = []
50 | for value_info in graph_def.output:
51 | output_shape = [
52 | dim.dim_value for dim in value_info.type.tensor_type.shape.dim]
53 | output_shapes.append(output_shape)
54 | return output_shapes
55 |
56 | def unique_nodes(self):
57 | # [re.sub("_\d+", '', node.name) for node in self.nodes]
58 | strip_node_names = [node.op_type for node in self.nodes]
59 | counter_nodes = collections.Counter(strip_node_names)
60 | return dict(counter_nodes)
61 |
62 | def nodes2dict(self):
63 | node_dict_list = []
64 | for node in self.nodes:
65 | node_dict = collections.OrderedDict()
66 | node_dict["name"] = node.name
67 | node_dict["op_type"] = node.op_type
68 | if node.attrs:
69 | if "value" in node.attrs:
70 | node_attrs = node.attrs.pop("value")
71 | else:
72 | node_attrs = node.attrs
73 | node_dict["attrs"] = node.attrs
74 | node_dict_list.append(node_dict)
75 | return node_dict_list
76 |
77 | def dump(self, node_dict_list, output_path):
78 | with open(output_path, "w") as f:
79 | json.dump(node_dict_list, f, indent=2)
80 |
81 | def summary(self):
82 | print()
83 | print("#### SUMMARY ONNX MODEL ####")
84 | print("opset:", self.opset_info)
85 | print("INPUT:", self.input_shapes)
86 | print("OUTPUT:", self.output_shapes)
87 | unique_nodes_count = self.unique_nodes()
88 | print(f"--Node List-- num({len(self.nodes)})")
89 | print(unique_nodes_count)
90 |
91 |
92 | class OnnxNode(object):
93 | def __init__(self, node):
94 | self.name = str(node.name)
95 | self.op_type = str(node.op_type)
96 | self.domain = str(node.domain)
97 | self.attrs = dict([(attr.name,
98 | translate_onnx(
99 | attr.name, convert_onnx_attribute_proto(attr)))
100 | for attr in node.attribute])
101 | self.inputs = list(node.input)
102 | self.outputs = list(node.output)
103 | self.node_proto = node
104 |
105 |
106 | def analyze_graph(model_path, output_path):
107 | analyzer = Analyze(model_path, output_path)
108 | analyzer.summary()
109 | return analyzer.model_info
110 |
111 |
112 | def translate_onnx(key, val):
113 | return onnx_attr_translator.get(key, lambda x: x)(val)
114 |
115 |
116 | onnx_attr_translator = {
117 | "axis": lambda x: int(x),
118 | "axes": lambda x: [int(a) for a in x],
119 | "dtype": lambda x: x,
120 | "keepdims": lambda x: bool(x),
121 | "to": lambda x: x,
122 | }
123 |
124 |
125 | def convert_onnx_attribute_proto(attr_proto):
126 | """
127 | Convert an ONNX AttributeProto into an appropriate Python object
128 | for the type.
129 | NB: Tensor attribute gets returned as the straight proto.
130 | """
131 | if attr_proto.HasField('f'):
132 | return attr_proto.f
133 | elif attr_proto.HasField('i'):
134 | return attr_proto.i
135 | elif attr_proto.HasField('s'):
136 | return str(attr_proto.s, 'utf-8')
137 | elif attr_proto.HasField('t'):
138 | return attr_proto.t # this is a proto!
139 | elif attr_proto.HasField('g'):
140 | return attr_proto.g
141 | elif attr_proto.floats:
142 | return list(attr_proto.floats)
143 | elif attr_proto.ints:
144 | return list(attr_proto.ints)
145 | elif attr_proto.strings:
146 | str_list = list(attr_proto.strings)
147 | str_list = list(map(lambda x: str(x, 'utf-8'), str_list))
148 | return str_list
149 | elif attr_proto.HasField('sparse_tensor'):
150 | return attr_proto.sparse_tensor
151 | else:
152 | raise ValueError("Unsupported ONNX attribute: {}".format(attr_proto))
153 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 | Copyright yuya kurosawa 2020
5 | kurosawa.yk@gmail.com
6 |
7 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
8 |
9 | 1. Definitions.
10 |
11 | "License" shall mean the terms and conditions for use, reproduction,
12 | and distribution as defined by Sections 1 through 9 of this document.
13 |
14 | "Licensor" shall mean the copyright owner or entity authorized by
15 | the copyright owner that is granting the License.
16 |
17 | "Legal Entity" shall mean the union of the acting entity and all
18 | other entities that control, are controlled by, or are under common
19 | control with that entity. For the purposes of this definition,
20 | "control" means (i) the power, direct or indirect, to cause the
21 | direction or management of such entity, whether by contract or
22 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
23 | outstanding shares, or (iii) beneficial ownership of such entity.
24 |
25 | "You" (or "Your") shall mean an individual or Legal Entity
26 | exercising permissions granted by this License.
27 |
28 | "Source" form shall mean the preferred form for making modifications,
29 | including but not limited to software source code, documentation
30 | source, and configuration files.
31 |
32 | "Object" form shall mean any form resulting from mechanical
33 | transformation or translation of a Source form, including but
34 | not limited to compiled object code, generated documentation,
35 | and conversions to other media types.
36 |
37 | "Work" shall mean the work of authorship, whether in Source or
38 | Object form, made available under the License, as indicated by a
39 | copyright notice that is included in or attached to the work
40 | (an example is provided in the Appendix below).
41 |
42 | "Derivative Works" shall mean any work, whether in Source or Object
43 | form, that is based on (or derived from) the Work and for which the
44 | editorial revisions, annotations, elaborations, or other modifications
45 | represent, as a whole, an original work of authorship. For the purposes
46 | of this License, Derivative Works shall not include works that remain
47 | separable from, or merely link (or bind by name) to the interfaces of,
48 | the Work and Derivative Works thereof.
49 |
50 | "Contribution" shall mean any work of authorship, including
51 | the original version of the Work and any modifications or additions
52 | to that Work or Derivative Works thereof, that is intentionally
53 | submitted to Licensor for inclusion in the Work by the copyright owner
54 | or by an individual or Legal Entity authorized to submit on behalf of
55 | the copyright owner. For the purposes of this definition, "submitted"
56 | means any form of electronic, verbal, or written communication sent
57 | to the Licensor or its representatives, including but not limited to
58 | communication on electronic mailing lists, source code control systems,
59 | and issue tracking systems that are managed by, or on behalf of, the
60 | Licensor for the purpose of discussing and improving the Work, but
61 | excluding communication that is conspicuously marked or otherwise
62 | designated in writing by the copyright owner as "Not a Contribution."
63 |
64 | "Contributor" shall mean Licensor and any individual or Legal Entity
65 | on behalf of whom a Contribution has been received by Licensor and
66 | subsequently incorporated within the Work.
67 |
68 | 2. Grant of Copyright License. Subject to the terms and conditions of
69 | this License, each Contributor hereby grants to You a perpetual,
70 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
71 | copyright license to reproduce, prepare Derivative Works of,
72 | publicly display, publicly perform, sublicense, and distribute the
73 | Work and such Derivative Works in Source or Object form.
74 |
75 | 3. Grant of Patent License. Subject to the terms and conditions of
76 | this License, each Contributor hereby grants to You a perpetual,
77 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
78 | (except as stated in this section) patent license to make, have made,
79 | use, offer to sell, sell, import, and otherwise transfer the Work,
80 | where such license applies only to those patent claims licensable
81 | by such Contributor that are necessarily infringed by their
82 | Contribution(s) alone or by combination of their Contribution(s)
83 | with the Work to which such Contribution(s) was submitted. If You
84 | institute patent litigation against any entity (including a
85 | cross-claim or counterclaim in a lawsuit) alleging that the Work
86 | or a Contribution incorporated within the Work constitutes direct
87 | or contributory patent infringement, then any patent licenses
88 | granted to You under this License for that Work shall terminate
89 | as of the date such litigation is filed.
90 |
91 | 4. Redistribution. You may reproduce and distribute copies of the
92 | Work or Derivative Works thereof in any medium, with or without
93 | modifications, and in Source or Object form, provided that You
94 | meet the following conditions:
95 |
96 | (a) You must give any other recipients of the Work or
97 | Derivative Works a copy of this License; and
98 |
99 | (b) You must cause any modified files to carry prominent notices
100 | stating that You changed the files; and
101 |
102 | (c) You must retain, in the Source form of any Derivative Works
103 | that You distribute, all copyright, patent, trademark, and
104 | attribution notices from the Source form of the Work,
105 | excluding those notices that do not pertain to any part of
106 | the Derivative Works; and
107 |
108 | (d) If the Work includes a "NOTICE" text file as part of its
109 | distribution, then any Derivative Works that You distribute must
110 | include a readable copy of the attribution notices contained
111 | within such NOTICE file, excluding those notices that do not
112 | pertain to any part of the Derivative Works, in at least one
113 | of the following places: within a NOTICE text file distributed
114 | as part of the Derivative Works; within the Source form or
115 | documentation, if provided along with the Derivative Works; or,
116 | within a display generated by the Derivative Works, if and
117 | wherever such third-party notices normally appear. The contents
118 | of the NOTICE file are for informational purposes only and
119 | do not modify the License. You may add Your own attribution
120 | notices within Derivative Works that You distribute, alongside
121 | or as an addendum to the NOTICE text from the Work, provided
122 | that such additional attribution notices cannot be construed
123 | as modifying the License.
124 |
125 | You may add Your own copyright statement to Your modifications and
126 | may provide additional or different license terms and conditions
127 | for use, reproduction, or distribution of Your modifications, or
128 | for any such Derivative Works as a whole, provided Your use,
129 | reproduction, and distribution of the Work otherwise complies with
130 | the conditions stated in this License.
131 |
132 | 5. Submission of Contributions. Unless You explicitly state otherwise,
133 | any Contribution intentionally submitted for inclusion in the Work
134 | by You to the Licensor shall be under the terms and conditions of
135 | this License, without any additional terms or conditions.
136 | Notwithstanding the above, nothing herein shall supersede or modify
137 | the terms of any separate license agreement you may have executed
138 | with Licensor regarding such Contributions.
139 |
140 | 6. Trademarks. This License does not grant permission to use the trade
141 | names, trademarks, service marks, or product names of the Licensor,
142 | except as required for reasonable and customary use in describing the
143 | origin of the Work and reproducing the content of the NOTICE file.
144 |
145 | 7. Disclaimer of Warranty. Unless required by applicable law or
146 | agreed to in writing, Licensor provides the Work (and each
147 | Contributor provides its Contributions) on an "AS IS" BASIS,
148 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
149 | implied, including, without limitation, any warranties or conditions
150 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
151 | PARTICULAR PURPOSE. You are solely responsible for determining the
152 | appropriateness of using or redistributing the Work and assume any
153 | risks associated with Your exercise of permissions under this License.
154 |
155 | 8. Limitation of Liability. In no event and under no legal theory,
156 | whether in tort (including negligence), contract, or otherwise,
157 | unless required by applicable law (such as deliberate and grossly
158 | negligent acts) or agreed to in writing, shall any Contributor be
159 | liable to You for damages, including any direct, indirect, special,
160 | incidental, or consequential damages of any character arising as a
161 | result of this License or out of the use or inability to use the
162 | Work (including but not limited to damages for loss of goodwill,
163 | work stoppage, computer failure or malfunction, or any and all
164 | other commercial damages or losses), even if such Contributor
165 | has been advised of the possibility of such damages.
166 |
167 | 9. Accepting Warranty or Additional Liability. While redistributing
168 | the Work or Derivative Works thereof, You may choose to offer,
169 | and charge a fee for, acceptance of support, warranty, indemnity,
170 | or other liability obligations and/or rights consistent with this
171 | License. However, in accepting such obligations, You may act only
172 | on Your own behalf and on Your sole responsibility, not on behalf
173 | of any other Contributor, and only if You agree to indemnify,
174 | defend, and hold each Contributor harmless for any liability
175 | incurred by, or claims asserted against, such Contributor by reason
176 | of your accepting any such warranty or additional liability.
177 |
178 | END OF TERMS AND CONDITIONS
179 |
180 | APPENDIX: How to apply the Apache License to your work.
181 |
182 | To apply the Apache License to your work, attach the following
183 | boilerplate notice, with the fields enclosed by brackets "[]"
184 | replaced with your own identifying information. (Don't include
185 | the brackets!) The text should be enclosed in the appropriate
186 | comment syntax for the file format. We also recommend that a
187 | file or class name and description of purpose be included on the
188 | same "printed page" as the copyright notice for easier
189 | identification within third-party archives.
190 |
191 | Copyright [yyyy] [name of copyright owner]
192 |
193 | Licensed under the Apache License, Version 2.0 (the "License");
194 | you may not use this file except in compliance with the License.
195 | You may obtain a copy of the License at
196 |
197 | http://www.apache.org/licenses/LICENSE-2.0
198 |
199 | Unless required by applicable law or agreed to in writing, software
200 | distributed under the License is distributed on an "AS IS" BASIS,
201 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
202 | See the License for the specific language governing permissions and
203 | limitations under the License.
204 |
--------------------------------------------------------------------------------