├── onnx2tflite
├── onnx2tflite.egg-info
│ ├── dependency_links.txt
│ ├── top_level.txt
│ ├── requires.txt
│ ├── SOURCES.txt
│ └── PKG-INFO
├── onnx2tflite
│ ├── __init__.py
│ ├── utils
│ │ ├── __init__.py
│ │ ├── op_registry.py
│ │ ├── dimension_utils.py
│ │ ├── definitions.py
│ │ └── graph_tools.py
│ ├── layers
│ │ ├── __init__.py
│ │ ├── activations_layers.py
│ │ ├── mathematics_layers.py
│ │ ├── deformation_layers.py
│ │ └── common_layers.py
│ ├── components
│ │ ├── __init__.py
│ │ ├── dataloader.py
│ │ ├── output_check.py
│ │ ├── builder.py
│ │ ├── onnx_loader.py
│ │ └── builder1.py
│ ├── __main__.py
│ └── converter.py
├── build
│ └── lib
│ │ └── onnx2tflite
│ │ ├── __init__.py
│ │ ├── __main__.py
│ │ └── converter.py
├── dist
│ └── onnx2tflite-2.0-py3.8.egg
├── onnx2tflite.py
├── setup.py
└── test
│ ├── test_squeeze.py
│ ├── test_reshape_transpose.py
│ ├── test_concat.py
│ └── test_torchvison.py
├── figs
└── framework.png
├── pretrain
├── lolv1.onnx
├── lolv1.tflite
├── zrr_best_slim.pkl
├── lolv1_best_slim.pkl
├── uieb_best_slim.pkl
├── lolv2_real_best_slim.pkl
└── mai25_isp_challenge_best_slim.pkl
├── datasets
├── lle
│ ├── gt
│ │ ├── 1.png
│ │ └── 22.png
│ └── input
│ │ ├── 1.png
│ │ └── 22.png
├── isp
│ ├── gt
│ │ ├── 10.png
│ │ └── 15.png
│ └── Input
│ │ ├── 10.png
│ │ └── 15.png
└── uie
│ ├── gt
│ ├── 3_img_.png
│ └── 8_img_.png
│ └── input
│ ├── 3_img_.png
│ └── 8_img_.png
├── onnx_to_tf.py
├── tf_to_TFLite.py
├── config
├── isp.yaml
└── lle.yaml
├── data
├── lledata.py
├── ispdata.py
└── __init__.py
├── model
├── __init__.py
├── lle.py
├── isp.py
├── utils.py
└── utils_IWO.py
├── logger.py
├── option.py
├── complexity.py
├── torch_to_onnx.py
├── test_TFLite_RGB.py
├── test_TFLite_ISP.py
├── README.md
├── loss.py
├── main.py
└── LICENSE
/onnx2tflite/onnx2tflite.egg-info/dependency_links.txt:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/onnx2tflite/onnx2tflite.egg-info/top_level.txt:
--------------------------------------------------------------------------------
1 | onnx2tflite
2 |
--------------------------------------------------------------------------------
/figs/framework.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AVC2-UESTC/MobileIE/HEAD/figs/framework.png
--------------------------------------------------------------------------------
/pretrain/lolv1.onnx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AVC2-UESTC/MobileIE/HEAD/pretrain/lolv1.onnx
--------------------------------------------------------------------------------
/datasets/lle/gt/1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AVC2-UESTC/MobileIE/HEAD/datasets/lle/gt/1.png
--------------------------------------------------------------------------------
/pretrain/lolv1.tflite:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AVC2-UESTC/MobileIE/HEAD/pretrain/lolv1.tflite
--------------------------------------------------------------------------------
/datasets/isp/gt/10.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AVC2-UESTC/MobileIE/HEAD/datasets/isp/gt/10.png
--------------------------------------------------------------------------------
/datasets/isp/gt/15.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AVC2-UESTC/MobileIE/HEAD/datasets/isp/gt/15.png
--------------------------------------------------------------------------------
/datasets/lle/gt/22.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AVC2-UESTC/MobileIE/HEAD/datasets/lle/gt/22.png
--------------------------------------------------------------------------------
/datasets/lle/input/1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AVC2-UESTC/MobileIE/HEAD/datasets/lle/input/1.png
--------------------------------------------------------------------------------
/onnx2tflite/onnx2tflite/__init__.py:
--------------------------------------------------------------------------------
1 | __VERSION__ = "2.0"
2 |
3 | from .converter import onnx_converter
--------------------------------------------------------------------------------
/datasets/isp/Input/10.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AVC2-UESTC/MobileIE/HEAD/datasets/isp/Input/10.png
--------------------------------------------------------------------------------
/datasets/isp/Input/15.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AVC2-UESTC/MobileIE/HEAD/datasets/isp/Input/15.png
--------------------------------------------------------------------------------
/datasets/lle/input/22.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AVC2-UESTC/MobileIE/HEAD/datasets/lle/input/22.png
--------------------------------------------------------------------------------
/datasets/uie/gt/3_img_.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AVC2-UESTC/MobileIE/HEAD/datasets/uie/gt/3_img_.png
--------------------------------------------------------------------------------
/datasets/uie/gt/8_img_.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AVC2-UESTC/MobileIE/HEAD/datasets/uie/gt/8_img_.png
--------------------------------------------------------------------------------
/pretrain/zrr_best_slim.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AVC2-UESTC/MobileIE/HEAD/pretrain/zrr_best_slim.pkl
--------------------------------------------------------------------------------
/datasets/uie/input/3_img_.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AVC2-UESTC/MobileIE/HEAD/datasets/uie/input/3_img_.png
--------------------------------------------------------------------------------
/datasets/uie/input/8_img_.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AVC2-UESTC/MobileIE/HEAD/datasets/uie/input/8_img_.png
--------------------------------------------------------------------------------
/onnx2tflite/build/lib/onnx2tflite/__init__.py:
--------------------------------------------------------------------------------
1 | __VERSION__ = "2.0"
2 |
3 | from .converter import onnx_converter
--------------------------------------------------------------------------------
/pretrain/lolv1_best_slim.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AVC2-UESTC/MobileIE/HEAD/pretrain/lolv1_best_slim.pkl
--------------------------------------------------------------------------------
/pretrain/uieb_best_slim.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AVC2-UESTC/MobileIE/HEAD/pretrain/uieb_best_slim.pkl
--------------------------------------------------------------------------------
/pretrain/lolv2_real_best_slim.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AVC2-UESTC/MobileIE/HEAD/pretrain/lolv2_real_best_slim.pkl
--------------------------------------------------------------------------------
/onnx2tflite/dist/onnx2tflite-2.0-py3.8.egg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AVC2-UESTC/MobileIE/HEAD/onnx2tflite/dist/onnx2tflite-2.0-py3.8.egg
--------------------------------------------------------------------------------
/onnx2tflite/onnx2tflite/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .dimension_utils import *
2 | from .op_registry import OPERATOR
3 | from .definitions import *
--------------------------------------------------------------------------------
/pretrain/mai25_isp_challenge_best_slim.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AVC2-UESTC/MobileIE/HEAD/pretrain/mai25_isp_challenge_best_slim.pkl
--------------------------------------------------------------------------------
/onnx2tflite/onnx2tflite.egg-info/requires.txt:
--------------------------------------------------------------------------------
1 | onnx
2 | onnxruntime
3 | onnx-simplifier
4 | numpy<=1.24
5 | tensorflow<2.13,>=2.5
6 | opencv-python
7 |
--------------------------------------------------------------------------------
/onnx2tflite/onnx2tflite/layers/__init__.py:
--------------------------------------------------------------------------------
1 | from .conv_layers import *
2 | from .common_layers import *
3 | from .activations_layers import *
4 | from .mathematics_layers import *
5 | from .deformation_layers import *
--------------------------------------------------------------------------------
/onnx2tflite/onnx2tflite/components/__init__.py:
--------------------------------------------------------------------------------
1 | from .output_check import get_elements_error
2 | from .onnx_loader import load_onnx_modelproto
3 | from .builder import keras_builder, tflite_builder
4 |
5 | __all__ = ['load_onnx_modelproto', 'keras_builder', 'tflite_builder', 'get_elements_error']
--------------------------------------------------------------------------------
/onnx2tflite/onnx2tflite.egg-info/SOURCES.txt:
--------------------------------------------------------------------------------
1 | setup.py
2 | onnx2tflite/__init__.py
3 | onnx2tflite/__main__.py
4 | onnx2tflite/converter.py
5 | onnx2tflite.egg-info/PKG-INFO
6 | onnx2tflite.egg-info/SOURCES.txt
7 | onnx2tflite.egg-info/dependency_links.txt
8 | onnx2tflite.egg-info/requires.txt
9 | onnx2tflite.egg-info/top_level.txt
10 | test/test_concat.py
11 | test/test_reshape_transpose.py
12 | test/test_squeeze.py
13 | test/test_torchvison.py
--------------------------------------------------------------------------------
/onnx_to_tf.py:
--------------------------------------------------------------------------------
1 | import onnx
2 | from onnx_tf.backend import prepare
3 | import os
4 | os.environ['CUDA_VISIBLE_DEVICES'] = '0'
5 |
6 | onnx_model_path = './LLE.onnx'
7 | onnx_model = onnx.load(onnx_model_path)
8 |
9 | onnx.checker.check_model(onnx_model)
10 | print("ONNX to TensorFlow")
11 |
12 | try:
13 | tf_rep = prepare(onnx_model)
14 | tf_model_path = 'lle_tf'
15 | tf_rep.export_graph(tf_model_path)
16 | print(f"Success, and save to {tf_model_path}")
17 | except Exception as e:
18 | print(f"ERROR: {e}")
19 |
20 |
--------------------------------------------------------------------------------
/tf_to_TFLite.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import os
3 |
4 | os.environ['CUDA_VISIBLE_DEVICES'] = '0'
5 |
6 | saved_model_dir = "./lle_tf"
7 | converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
8 | converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS]
9 |
10 | tflite_model = converter.convert()
11 |
12 | tflite_model_path = "LLE.tflite"
13 | with open(tflite_model_path, 'wb') as f:
14 | f.write(tflite_model)
15 |
16 | print(f"TFLite save to {tflite_model_path}")
17 |
--------------------------------------------------------------------------------
/onnx2tflite/onnx2tflite.py:
--------------------------------------------------------------------------------
1 | import os
2 | os.environ['CUDA_VISIBLE_DEVICES'] = '2'
3 | from onnx2tflite.converter import onnx_converter
4 | onnx_path = "/data2/yanhailong/IR-Based/ICCV2025/MobileIE/LLE.onnx"
5 |
6 | onnx_converter(
7 | onnx_model_path = onnx_path,
8 | need_simplify = True,
9 | output_path = "//data2/yanhailong/IR-Based/ICCV2025/MobileIE/",
10 | target_formats = ['tflite'], # or ['keras'], ['keras', 'tflite']
11 | weight_quant = False,
12 | int8_model = False,
13 | int8_mean = None,
14 | int8_std = None,
15 | image_root = None
16 | )
--------------------------------------------------------------------------------
/onnx2tflite/setup.py:
--------------------------------------------------------------------------------
1 | import os
2 | from setuptools import setup, find_packages
3 | abs_path = os.path.dirname(os.path.abspath(__file__))
4 |
5 | setup(
6 | name="onnx2tflite",
7 | version="2.0",
8 | author="MPolaris",
9 | description="onnx to keras/tensorflow lite",
10 | long_description=open(os.path.join(abs_path, "readme.md")).read(),
11 | long_description_content_type='text/markdown',
12 | packages=find_packages(include=['onnx2tflite']),
13 | license="Apache-2.0",
14 | platforms=["Windows", "linux"],
15 | install_requires=open(os.path.join(abs_path, "requirements.txt")).read().splitlines()
16 | )
--------------------------------------------------------------------------------
/config/isp.yaml:
--------------------------------------------------------------------------------
1 | exp_name: isp
2 |
3 | train:
4 | warmup: false
5 | warmup_epoch: 10
6 | lr_warmup: 1e-6
7 | train_inp: ./isp/train/huawei_raw
8 | train_gt: ./isp/train/canon
9 | valid_inp: ./isp/test/huawei_raw
10 | valid_gt: ./isp/test/canon
11 | batch_size: 6
12 | epoch: 1000
13 | lr: 1e-3
14 | num_workers: 20
15 | save_every: 20
16 | save_slim: true
17 |
18 | test:
19 | test_inp: ./isp/test/huawei_raw
20 | test_gt: ./isp/test/canon
21 | num_workers: 0
22 | save: false
23 |
24 | demo:
25 | demo_inp: ./traindata/isp/test/huawei_raw
26 | num_workers: 0
27 |
28 | model:
29 | type: original # [original, re-parameterized]
30 | pretrained: false
31 | need_slim: false #true
32 | rep_scale: 4
33 | channels: 12
34 |
--------------------------------------------------------------------------------
/config/lle.yaml:
--------------------------------------------------------------------------------
1 | exp_name: lle
2 |
3 | train:
4 | warmup: True
5 | warmup_epoch: 10
6 | lr_warmup: 1e-6
7 | train_inp: ./lowlight/LOLdataset/our485/low
8 | train_gt: ./lowlight/LOLdataset/our485/high
9 | valid_inp: ./lowlight/LOLdataset/eval15/low
10 | valid_gt: ./lowlight/LOLdataset/eval15/high
11 | batch_size: 4
12 | epoch: 2000
13 | lr: 1e-3
14 | num_workers: 0
15 | save_every: 20
16 | save_slim: true
17 |
18 | test:
19 | test_inp: ./lowlight/LOLdataset/eval15/low
20 | test_gt: ./lowlight/LOLdataset/eval15/high
21 | num_workers: 0
22 | save: false
23 |
24 | demo:
25 | demo_inp: ./lowlight/LOLdataset/eval15/low
26 | num_workers: 0
27 |
28 | model:
29 | type: original # [original, re-parameterized]:
30 | pretrained: false
31 | need_slim: false # true
32 | rep_scale: 4
33 | channels: 12
34 |
--------------------------------------------------------------------------------
/data/lledata.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import os
4 | from PIL import Image
5 |
6 |
7 | class LLEData(torch.utils.data.Dataset):
8 | def __init__(self, opt, inp_path, gt_path=None):
9 | super(LLEData, self).__init__()
10 | self.img_li = [path for path in os.listdir(inp_path)]
11 | self.inp_path = inp_path
12 | self.gt_path = gt_path
13 | self.opt = opt
14 |
15 | def __getitem__(self, index):
16 | inp = Image.open(os.path.join(self.inp_path, self.img_li[index]))
17 | inp = np.array(inp).transpose([2, 0, 1])
18 | inp = inp.astype(np.float32) / 255
19 |
20 | inp = torch.Tensor(np.array(inp))
21 | inp = inp.to(self.opt.device)
22 |
23 | if self.gt_path: # gt_path -> train/test not demo
24 | gt = Image.open(os.path.join(self.gt_path, self.img_li[index]))
25 | gt = np.array(gt).transpose([2, 0, 1])
26 | gt = gt.astype(np.float32) / 255
27 |
28 | gt = torch.Tensor(np.array(gt))
29 | gt = gt.to(self.opt.device)
30 |
31 | return inp, gt, self.img_li[index].split('.')[0]
32 | return inp, self.img_li[index].split('.')[0]
33 |
34 | def __len__(self):
35 | return len(self.img_li)
36 |
--------------------------------------------------------------------------------
/onnx2tflite/onnx2tflite/utils/op_registry.py:
--------------------------------------------------------------------------------
1 | class Registry(object):
2 | def __init__(self, name) -> None:
3 | self._name = name
4 | self._operator_dict = dict()
5 |
6 | def __len__(self):
7 | return len(self._operator_dict)
8 |
9 | @property
10 | def name(self):
11 | return self._name
12 |
13 | @property
14 | def operator_dict(self):
15 | return self._operator_dict
16 |
17 | def get(self, key):
18 | return self._operator_dict.get(key, None)
19 |
20 | def _register_operator(self, op_class, op_name=None):
21 | if (not isinstance(op_name, str)) or op_name is None:
22 | op_name = op_class.__name__
23 |
24 | if self._operator_dict.get(op_name, None):
25 | raise KeyError(f'{op_name} is already registered in {self._name}')
26 |
27 | self._operator_dict[op_name] = op_class
28 |
29 | def register_operator(self, name=None, op_class=None):
30 | if op_class is not None:
31 | self._register_operator(op_class, name)
32 | return op_class
33 |
34 | def _register(cls):
35 | self._register_operator(cls, name)
36 | return cls
37 |
38 | return _register
39 |
40 | OPERATOR = Registry("TensorflowOP")
--------------------------------------------------------------------------------
/onnx2tflite/test/test_squeeze.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pytest
3 |
4 | import torch
5 | import torch.nn as nn
6 | from onnx2tflite import onnx_converter
7 |
8 | MODEL_ROOT = "./unit_test"
9 | os.makedirs(MODEL_ROOT, exist_ok=True)
10 |
11 | @pytest.mark.filterwarnings('ignore::UserWarning')
12 | @pytest.mark.filterwarnings('ignore::DeprecationWarning')
13 | def test_squeeze():
14 | class Squeeze(nn.Module):
15 | def __init__(self, *args, **kwargs) -> None:
16 | super().__init__(*args, **kwargs)
17 |
18 | def forward(self, x):
19 | x = torch.unsqueeze(x, dim=1)
20 | # x = torch.tile(x, dims=(2,1,1))
21 | x = torch.squeeze(x, dim=1)
22 |
23 | return x
24 |
25 | model = Squeeze()
26 | x = torch.randn(1,1,1,2)
27 |
28 | onnx_model_path = os.path.join(MODEL_ROOT, "test_squeeze.onnx")
29 | torch.onnx.export(model, x, onnx_model_path, opset_version=11)
30 |
31 | res = onnx_converter(
32 | onnx_model_path = onnx_model_path,
33 | need_simplify = True,
34 | output_path = MODEL_ROOT,
35 | target_formats = ['tflite'],
36 | native_groupconv=False,
37 | fp16_model=False,
38 | int8_model=False,
39 | )
40 |
41 | assert res['tflite_error'] < 1e-3
--------------------------------------------------------------------------------
/model/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from importlib import import_module
3 | from .lle import MobileIELLENet, MobileIELLENetS
4 | from .isp import MobileIEISPNet, MobileIEISPNetS
5 |
6 | __all__ = {
7 | 'MobileIELLENet',
8 | 'MobileIELLENetS',
9 | 'MobileIEISPNet',
10 | 'MobileIEISPNetS',
11 | 'import_model'
12 | }
13 |
14 | def import_model(opt):
15 | model_name = 'MobileIE'+opt.model_task.upper()
16 | kwargs = {'channels': opt.config['model']['channels']}
17 |
18 | if opt.config['model']['type'] == 're-parameterized':
19 | model_name += 'NetS'
20 | elif opt.config['model']['type'] == 'original':
21 | model_name += 'Net'
22 | kwargs['rep_scale'] = opt.config['model']['rep_scale']
23 | else:
24 | raise ValueError('unknown model type, please choose from [original, re-parameterized]')
25 |
26 | model = getattr(import_module('model'), model_name)(**kwargs)
27 | model = model.to(opt.device)
28 |
29 | if opt.config['model']['pretrained']:
30 | #model.load_state_dict(torch.load(opt.config['model']['pretrained']))
31 | model.load_state_dict(torch.load(opt.config['model']['pretrained']), strict=False)
32 |
33 | if opt.config['model']['type'] == 'original' and opt.config['model']['need_slim'] is True:
34 | model = model.slim().to(opt.device)
35 | return model
36 |
--------------------------------------------------------------------------------
/onnx2tflite/test/test_reshape_transpose.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pytest
3 |
4 | import torch
5 | import torch.nn as nn
6 | from onnx2tflite import onnx_converter
7 |
8 | MODEL_ROOT = "./unit_test"
9 | os.makedirs(MODEL_ROOT, exist_ok=True)
10 |
11 | @pytest.mark.filterwarnings('ignore::UserWarning')
12 | @pytest.mark.filterwarnings('ignore::DeprecationWarning')
13 | def test_reshape_trans():
14 | class test1(nn.Module):
15 | def __init__(self, *args, **kwargs) -> None:
16 | super().__init__(*args, **kwargs)
17 | self.conv1 = nn.Conv2d(3, 3, 3, 2, 1)
18 | self.conv2 = nn.Conv2d(3, 3, 3, 2, 1)
19 |
20 | def forward(self, x):
21 | x = torch.reshape(x, (1, 3, 32, 16))
22 | # x = torch.transpose(x, (0, 1, 3, 2))
23 | x = torch.transpose(x, 3, 2)
24 | x = self.conv1(x)
25 | x = self.conv2(x)
26 | return x
27 |
28 | model = test1()
29 | x = torch.randn(1, 3*32*16)
30 |
31 | onnx_model_path = os.path.join(MODEL_ROOT, "test_reshape_trans.onnx")
32 | torch.onnx.export(model, x, onnx_model_path, opset_version=11)
33 |
34 | res = onnx_converter(
35 | onnx_model_path = onnx_model_path,
36 | need_simplify = True,
37 | output_path = MODEL_ROOT,
38 | target_formats = ['tflite'],
39 | native_groupconv=False,
40 | fp16_model=False,
41 | int8_model = False,
42 | )
43 |
44 | assert res['tflite_error'] < 1e-3
--------------------------------------------------------------------------------
/data/ispdata.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import os
4 | from PIL import Image
5 |
6 |
7 | class ISPData(torch.utils.data.Dataset):
8 | def __init__(self, opt, raw_path, rgb_path=None):
9 | super(ISPData, self).__init__()
10 | self.img_li = [path for path in os.listdir(raw_path)]
11 | self.raw_path = raw_path
12 | self.rgb_path = rgb_path
13 | self.opt = opt
14 |
15 | def __getitem__(self, index):
16 | raw = Image.open(os.path.join(self.raw_path, self.img_li[index]))
17 | raw = np.array(raw)
18 | raw = self.bayer2rggb(raw)
19 | raw = raw.astype(np.float32) / 4095
20 |
21 | raw = torch.Tensor(np.array(raw))
22 | raw = raw.to(self.opt.device)
23 |
24 | if self.rgb_path: # gt_path -> train/test not demo
25 | rgb = Image.open(os.path.join(self.rgb_path, self.img_li[index]))
26 | rgb = np.array(rgb).transpose([2, 0, 1])
27 | rgb = rgb.astype(np.float32) / 255
28 |
29 | rgb = torch.Tensor(np.array(rgb))
30 | rgb = rgb.to(self.opt.device)
31 |
32 | return raw, rgb, self.img_li[index].split('.')[0]
33 |
34 | return raw, self.img_li[index].split('.')[0]
35 |
36 | def __len__(self):
37 | return len(self.img_li)
38 |
39 | def bayer2rggb(self, img_bayer):
40 | h, w = img_bayer.shape
41 | img_bayer = img_bayer.reshape(h // 2, 2, w // 2, 2)
42 | img_bayer = img_bayer.transpose([1, 3, 0, 2]).reshape([-1, h // 2, w // 2])
43 | return img_bayer
44 |
--------------------------------------------------------------------------------
/logger.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 |
4 | class Logger:
5 |
6 | def __init__(
7 | self,
8 | opt,
9 | logging_level=logging.INFO,
10 | file_level=logging.INFO,
11 | stream_level=logging.INFO
12 | ):
13 | self.opt = opt
14 | self.log_path = opt.log_path
15 | self.logging_level = logging_level
16 |
17 | self.file_level = file_level
18 | self.stream_level = stream_level
19 |
20 | self.logger = logging.getLogger('logger.log')
21 | self.logger.setLevel(self.logging_level)
22 |
23 | self.configure()
24 |
25 | def configure(self):
26 | log_format = logging.Formatter('%(asctime)s %(levelname)s: %(message)s')
27 |
28 | stream_handler = logging.StreamHandler()
29 | stream_handler.setLevel(self.stream_level)
30 | stream_handler.setFormatter(log_format)
31 |
32 | file_handler = logging.FileHandler(self.log_path)
33 | file_handler.setLevel(self.file_level)
34 | file_handler.setFormatter(log_format)
35 |
36 | self.logger.addHandler(file_handler)
37 | self.logger.addHandler(stream_handler)
38 |
39 | def debug(self, message):
40 | self.logger.debug(message)
41 |
42 | def info(self, message):
43 | self.logger.info(message)
44 |
45 | def warn(self, message):
46 | self.logger.warning(message)
47 |
48 | def error(self, message):
49 | self.logger.error(message)
50 |
51 | def critical(self, message):
52 | self.logger.critical(message)
53 |
--------------------------------------------------------------------------------
/onnx2tflite/test/test_concat.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pytest
3 |
4 | import torch
5 | import torch.nn as nn
6 | from onnx2tflite import onnx_converter
7 |
8 | MODEL_ROOT = "./unit_test"
9 | os.makedirs(MODEL_ROOT, exist_ok=True)
10 |
11 | @pytest.mark.filterwarnings('ignore::UserWarning')
12 | @pytest.mark.filterwarnings('ignore::DeprecationWarning')
13 | def test_concat():
14 | class Concat(nn.Module):
15 | def __init__(self, *args, **kwargs) -> None:
16 | super().__init__(*args, **kwargs)
17 | self.conv1 = nn.Conv2d(3, 3, 3, 2, 1)
18 | # self.conv2 = nn.Conv2d(3, 3, 3, 2, 1)
19 | self._const = torch.randn(1,2,16,8)
20 |
21 | def forward(self, x1, x2, x3):
22 | x1 = torch.reshape(x1, (1, 3, 16, 8))
23 | # x = torch.transpose(x, (0, 1, 3, 2))
24 | x2 = torch.transpose(x2, 3, 2)
25 | x3 = self.conv1(x3)
26 | x = torch.concat([x1,x2,x3,self._const], dim=1)
27 | return x
28 |
29 | model = Concat()
30 | x1 = torch.randn(1,3*16*8)
31 | x2 = torch.randn(1,3,8,16)
32 | x3 = torch.randn(1,3,32,16)
33 |
34 | onnx_model_path = os.path.join(MODEL_ROOT, "test_concat.onnx")
35 | torch.onnx.export(model, (x1,x2,x3), onnx_model_path, opset_version=11)
36 |
37 | res = onnx_converter(
38 | onnx_model_path = onnx_model_path,
39 | need_simplify = True,
40 | output_path = MODEL_ROOT,
41 | target_formats = ['tflite'],
42 | native_groupconv=False,
43 | fp16_model=False,
44 | int8_model=False,
45 | )
46 |
47 | assert res['tflite_error'] < 1e-3
--------------------------------------------------------------------------------
/onnx2tflite/onnx2tflite/utils/dimension_utils.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | '''
3 | shape and axis transform utils func.
4 | '''
5 | def channel_to_last_dimension(axis):
6 | '''
7 | make channel first to channel last
8 | '''
9 | if axis == 0:
10 | axis = 0
11 | elif axis == 1:
12 | axis = -1
13 | else:
14 | axis -= 1
15 | return axis
16 |
17 | def shape_NCD_to_NDC_format(shape):
18 | '''
19 | make shape format from channel first to channel last
20 | '''
21 | if len(shape) <= 2:
22 | return tuple(shape)
23 | new_shape = [shape[0], *shape[2:], shape[1]]
24 | return tuple(new_shape)
25 |
26 | def shape_NDC_to_NCD_format(shape):
27 | '''
28 | make shape format from channel last to channel first
29 | '''
30 | if len(shape) <= 2:
31 | return tuple(shape)
32 | new_shape = [shape[0], shape[-1], *shape[1:-1]]
33 | return tuple(new_shape)
34 |
35 | def tensor_NCD_to_NDC_format(tensor):
36 | '''
37 | make tensor format from channel first to channel last
38 | '''
39 | if(len(tensor.shape) > 2):
40 | shape = [i for i in range(len(tensor.shape))]
41 | shape = shape_NCD_to_NDC_format(shape)
42 | tensor = tf.transpose(tensor, perm=shape)
43 | return tensor
44 |
45 | def tensor_NDC_to_NCD_format(tensor):
46 | '''
47 | make tensor format from channel last to channel first
48 | '''
49 | if(len(tensor.shape) > 2):
50 | shape = [i for i in range(len(tensor.shape))]
51 | shape = shape_NDC_to_NCD_format(shape)
52 | tensor = tf.transpose(tensor, perm=shape)
53 | return tensor
54 |
55 | def intfloat_to_list(x:int or float, lens:int):
56 | if isinstance(x, (int, float)):
57 | return [x]*lens
58 | else:
59 | return x
--------------------------------------------------------------------------------
/onnx2tflite/onnx2tflite/utils/definitions.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from abc import ABC
3 | from enum import Enum, unique
4 |
5 | @unique
6 | class Layout(Enum):
7 | Default = 0
8 | Channel_First = 1 << 0# for onnx format
9 | Channel_Last = 1 << 1 # for tensorflow format
10 | Channel_None = 1 << 2 # no channel
11 |
12 | class Node_Layout:
13 | def __init__(self, name:str, pre:list=[], nxt:list=[]) -> None:
14 | self.name = name
15 | self.pre = pre
16 | self.nxt = nxt
17 | self.layout = Layout.Default
18 |
19 | class BaseOP(ABC):
20 | def __init__(self, tensor_graph, const_weights, node_attributes, node_inputs, node_outputs, layout_dict) -> None:
21 | pass
22 |
23 | onnx2tf_type = {
24 | 1: tf.float32, # ONNX_FLOAT
25 | 2: tf.uint8, # ONNX_UINT8
26 | 3: tf.int8, # ONNX_INT8
27 | 4: tf.uint16, # ONNX_UINT16
28 | 5: tf.int16, # ONNX_INT16
29 | 6: tf.int32, # ONNX_INT32
30 | 7: tf.int64, # ONNX_INT64
31 | 8: tf.string, # ONNX_STRING
32 | 9: tf.bool, # ONNX_BOOL
33 | 10: tf.float16, # ONNX_FLOAT16
34 | 11: tf.float64, # ONNX_DOUBLE
35 | 12: tf.uint32, # ONNX_UINT32
36 | 13: tf.uint64, # ONNX_UINT64
37 | 14: tf.complex64, # ONNX_COMPLEX64
38 | 15: tf.complex128 # ONNX_COMPLEX128
39 | }
40 |
41 | np2tf_type = {
42 | "int32": tf.int32,
43 | "int64": tf.int64,
44 | "float32": tf.float32,
45 | "float64": tf.float64,
46 | "bool": tf.bool,
47 | "uint8": tf.uint8,
48 | "int8": tf.int8,
49 | "int16": tf.int16,
50 | "uint16": tf.uint16,
51 | "uint32": tf.uint32,
52 | "uint64": tf.uint64,
53 | "complex64": tf.complex64,
54 | "complex128": tf.complex128
55 | }
56 |
57 | FORCE_CHANNEL_LAST_OP = ["Conv", "ConvTranspose", "DepthToSpace", "Pad", "AveragePool", "MaxPool", "Upsample", "Resize", "Gemm"]
58 | FORCE_CHANNEL_FIRST_OP = ["Reshape", "Transpose", "ScatterND", "MatMul"]
59 |
60 |
--------------------------------------------------------------------------------
/onnx2tflite/test/test_torchvison.py:
--------------------------------------------------------------------------------
1 | '''
2 | unit test for torchvision models
3 | '''
4 | import os
5 | import pytest
6 |
7 | import torch
8 | import torchvision
9 | from onnx2tflite import onnx_converter
10 |
11 | MODEL_ROOT = "./unit_test"
12 | os.makedirs(MODEL_ROOT, exist_ok=True)
13 |
14 | @pytest.mark.filterwarnings('ignore::UserWarning')
15 | @pytest.mark.filterwarnings('ignore::DeprecationWarning')
16 | def test_resnet():
17 | model = torchvision.models.resnet18(False)
18 | onnx_model_path = os.path.join(MODEL_ROOT, "resnet18.onnx")
19 | torch.onnx.export(model, torch.randn(1, 3, 224, 224), onnx_model_path, opset_version=13)
20 | error = onnx_converter(onnx_model_path, need_simplify = True, output_path = MODEL_ROOT, target_formats = ['tflite'])['tflite_error']
21 | assert error < 1e-3
22 |
23 | @pytest.mark.filterwarnings('ignore::UserWarning')
24 | @pytest.mark.filterwarnings('ignore::DeprecationWarning')
25 | def test_mobilenet():
26 | model = torchvision.models.mobilenet_v2(False)
27 | onnx_model_path = os.path.join(MODEL_ROOT, "mobilenet_v2.onnx")
28 | torch.onnx.export(model, torch.randn(1, 3, 224, 224), onnx_model_path, opset_version=13)
29 | error = onnx_converter(onnx_model_path, need_simplify = True, output_path = MODEL_ROOT, target_formats = ['tflite'])['tflite_error']
30 | assert error < 1e-3
31 |
32 | @pytest.mark.filterwarnings('ignore::UserWarning')
33 | @pytest.mark.filterwarnings('ignore::DeprecationWarning')
34 | def test_deeplabv3():
35 | model = torchvision.models.segmentation.deeplabv3_resnet50(False)
36 | onnx_model_path = os.path.join(MODEL_ROOT, "deeplabv3_resnet50.onnx")
37 | torch.onnx.export(model, torch.randn(1, 3, 512, 1024), onnx_model_path, opset_version=13)
38 | error = onnx_converter(onnx_model_path, need_simplify = True, output_path = MODEL_ROOT, target_formats = ['tflite'])['tflite_error']
39 | assert error < 1e-3
40 |
41 | @pytest.mark.filterwarnings('ignore::UserWarning')
42 | @pytest.mark.filterwarnings('ignore::DeprecationWarning')
43 | def test_vit():
44 | model = torchvision.models.vit_b_16(False)
45 | onnx_model_path = os.path.join(MODEL_ROOT, "vit_b_16.onnx")
46 | torch.onnx.export(model, torch.randn(1, 3, 224, 224), onnx_model_path, opset_version=13)
47 | error = onnx_converter(onnx_model_path, need_simplify = True, output_path = MODEL_ROOT, target_formats = ['tflite'])['tflite_error']
48 | assert error < 1e-3
--------------------------------------------------------------------------------
/option.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import yaml
4 | from datetime import datetime
5 |
6 |
7 | def get_option():
8 | parser = argparse.ArgumentParser()
9 | parser.add_argument(
10 | '-task',
11 | default='train',
12 | type=str,
13 | choices=['train', 'test', 'demo'],
14 | help='choose the task for running the model'
15 | )
16 | parser.add_argument(
17 | '-model_task',
18 | default='isp',
19 | type=str,
20 | choices=['isp', 'lle', 'sr'],
21 | help='the model of the task'
22 | )
23 | parser.add_argument(
24 | '-device',
25 | default='cuda',
26 | type=str,
27 | help='choose the device to run the model'
28 | )
29 | opt = parser.parse_args()
30 | opt = opt_format(opt)
31 | return opt
32 |
33 |
34 | def load_yaml(path):
35 | with open(path, 'r') as f:
36 | model_config = yaml.load(f, Loader=yaml.FullLoader)
37 | return model_config
38 |
39 |
40 | def save_yaml(path, file_dict):
41 | with open(path, 'w') as f:
42 | f.write(yaml.dump(file_dict, allow_unicode=True))
43 |
44 |
45 | def opt_format(opt):
46 | opt.root = os.getcwd()
47 | opt.config = r'{}/config/{}.yaml'.format(opt.root, opt.model_task)
48 | opt.config = load_yaml(opt.config)
49 |
50 | proper_time = str(datetime.now()).split('.')[0].replace(':', '-')
51 |
52 | opt.config['exp_name'] = '{}_{}'.format(opt.task, opt.config['exp_name'])
53 |
54 | opt.experiments = r'{}/experiments/{}'.format(opt.root, '{} {}'.format(proper_time, opt.config['exp_name']))
55 | if not os.path.exists(opt.experiments):
56 | os.mkdir(opt.experiments)
57 |
58 | config_path = r'{}/config.yaml'.format(opt.experiments)
59 | save_yaml(config_path, opt.config)
60 |
61 | if opt.task == 'demo' or (opt.task == 'test' and opt.config['test']['save'] != False):
62 | opt.save_image = True
63 | opt.save_image_dir = r'{}/{}'.format(opt.experiments, 'images')
64 | if not os.path.exists(opt.save_image_dir):
65 | os.mkdir(opt.save_image_dir)
66 |
67 | opt.log_path = r'{}/logger.log'.format(opt.experiments)
68 |
69 | if opt.task == 'train':
70 | opt.save_model = True
71 | opt.save_model_dir = r'{}/{}'.format(opt.experiments, 'models')
72 | if not os.path.exists(opt.save_model_dir):
73 | os.mkdir(opt.save_model_dir)
74 |
75 | return opt
76 |
--------------------------------------------------------------------------------
/complexity.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import time
3 | from thop import profile, clever_format
4 | from model import lle
5 | import os
6 | os.environ['CUDA_VISIBLE_DEVICES'] = '0'
7 |
8 | width = 640
9 | height = 480
10 |
11 |
12 | def compute_FLOPs_and_model_size(model, width, height):
13 | input = torch.randn(1, 3, width, height).cuda()
14 | macs, params = profile(model, inputs=(input,), verbose=False)
15 | return macs, params
16 |
17 | @torch.no_grad()
18 | def compute_fps_and_inference_time(model, shape, epoch=100, warmup=10, device=None):
19 | total_time = 0.0
20 |
21 | if not device:
22 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
23 | model = model.to(device)
24 |
25 | model.eval() # Switch to evaluation mode
26 |
27 | # Warm-up iterations
28 | for _ in range(warmup):
29 | data = torch.randn(shape).to(device)
30 | model(data)
31 |
32 | # Actual timing iterations
33 | for _ in range(epoch):
34 | data = torch.randn(shape).to(device)
35 |
36 | start = time.time()
37 | outputs = model(data)
38 | torch.cuda.synchronize() # Ensure CUDA has finished all tasks
39 | end = time.time()
40 |
41 | total_time += (end - start)
42 |
43 | avg_inference_time = total_time / epoch
44 | fps = epoch / total_time
45 |
46 | return fps, avg_inference_time
47 |
48 | def test_model_flops(width, height):
49 | model = lle.MobileIES(channels=12)
50 | model.cuda()
51 |
52 | FLOPs, params = compute_FLOPs_and_model_size(model, width, height)
53 |
54 | model_size = params * 4.0 / 1024 / 1024
55 | flops, params = clever_format([FLOPs, params], "%.3f")
56 |
57 | print('Number of parameters: {}'.format(params))
58 | print('Size of model: {:.2f} MB'.format(model_size))
59 | print('Computational complexity: {} FLOPs'.format(flops))
60 |
61 | def test_fps_and_inference_time(width, height):
62 | model = lle.MobileIES(channels=12)
63 | model.cuda()
64 |
65 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
66 | fps, avg_inference_time = compute_fps_and_inference_time(model, (1, 3, width, height), device=device)
67 | print('device: {} - fps: {:.3f}, average inference time per frame: {:.6f} seconds'.format(device.type, fps, avg_inference_time))
68 |
69 | if __name__ == '__main__':
70 | test_model_flops(width, height)
71 | test_fps_and_inference_time(width, height)
72 |
73 |
--------------------------------------------------------------------------------
/onnx2tflite/onnx2tflite/__main__.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from .converter import onnx_converter
3 |
4 | def parse_opt():
5 | parser = argparse.ArgumentParser()
6 | parser.add_argument('--weights', type=str, required=True, help='onnx model path')
7 | parser.add_argument('--outpath', type=str, default=None, help='tflite model save path')
8 | parser.add_argument('--input-node-names', nargs="+", default=None, help='which inputs is you want, support middle layers, None will using onnx orignal inputs')
9 | parser.add_argument('--output-node-names', nargs="+", default=None, help='which outputs is you want, support middle layers, None will using onnx orignal outputs')
10 | parser.add_argument('--nosimplify', default=False, action='store_true', help='do not simplify model')
11 | parser.add_argument("--native-groupconv", default=False, action='store_true', help='using native method for groupconv, only support for tflite version >= 2.9')
12 | parser.add_argument('--weigthquant', default=False, action='store_true', help='weight only int8 quant')
13 | parser.add_argument('--fp16', default=False, action='store_true', help='fp16 quant, include input output')
14 | parser.add_argument('--int8', default=False, action='store_true', help='int8 quant, include input output')
15 | parser.add_argument('--imgroot', type=str, default=None, help='when int8=True, imgroot should give for calculating running_mean and running_norm')
16 | parser.add_argument('--int8mean', type=float, nargs='+', default=[123.675, 116.28, 103.53], help='int8 image preprocesses mean, float or list')
17 | parser.add_argument('--int8std', type=float, nargs='+', default=[58.395, 57.12, 57.375], help='int8 image preprocesses std, float or list')
18 | parser.add_argument('--formats', nargs='+', default=['keras', 'tflite'], help='available formats are (h5, tflite)')
19 | opt = parser.parse_args()
20 | return opt
21 |
22 | def run():
23 | opt = parse_opt()
24 | onnx_converter(
25 | onnx_model_path = opt.weights,
26 | need_simplify = not opt.nosimplify,
27 | input_node_names = opt.input_node_names,
28 | output_node_names = opt.output_node_names,
29 | output_path = opt.outpath,
30 | target_formats = opt.formats,
31 | native_groupconv = opt.native_groupconv,
32 | weight_quant=opt.weigthquant,
33 | fp16_model=opt.fp16,
34 | int8_model=opt.int8,
35 | int8_mean=opt.int8mean,
36 | int8_std=opt.int8std,
37 | image_root=opt.imgroot
38 | )
39 |
40 | if __name__ == "__main__":
41 | run()
--------------------------------------------------------------------------------
/onnx2tflite/build/lib/onnx2tflite/__main__.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from .converter import onnx_converter
3 |
4 | def parse_opt():
5 | parser = argparse.ArgumentParser()
6 | parser.add_argument('--weights', type=str, required=True, help='onnx model path')
7 | parser.add_argument('--outpath', type=str, default=None, help='tflite model save path')
8 | parser.add_argument('--input-node-names', nargs="+", default=None, help='which inputs is you want, support middle layers, None will using onnx orignal inputs')
9 | parser.add_argument('--output-node-names', nargs="+", default=None, help='which outputs is you want, support middle layers, None will using onnx orignal outputs')
10 | parser.add_argument('--nosimplify', default=False, action='store_true', help='do not simplify model')
11 | parser.add_argument("--native-groupconv", default=False, action='store_true', help='using native method for groupconv, only support for tflite version >= 2.9')
12 | parser.add_argument('--weigthquant', default=False, action='store_true', help='weight only int8 quant')
13 | parser.add_argument('--fp16', default=False, action='store_true', help='fp16 quant, include input output')
14 | parser.add_argument('--int8', default=False, action='store_true', help='int8 quant, include input output')
15 | parser.add_argument('--imgroot', type=str, default=None, help='when int8=True, imgroot should give for calculating running_mean and running_norm')
16 | parser.add_argument('--int8mean', type=float, nargs='+', default=[123.675, 116.28, 103.53], help='int8 image preprocesses mean, float or list')
17 | parser.add_argument('--int8std', type=float, nargs='+', default=[58.395, 57.12, 57.375], help='int8 image preprocesses std, float or list')
18 | parser.add_argument('--formats', nargs='+', default=['keras', 'tflite'], help='available formats are (h5, tflite)')
19 | opt = parser.parse_args()
20 | return opt
21 |
22 | def run():
23 | opt = parse_opt()
24 | onnx_converter(
25 | onnx_model_path = opt.weights,
26 | need_simplify = not opt.nosimplify,
27 | input_node_names = opt.input_node_names,
28 | output_node_names = opt.output_node_names,
29 | output_path = opt.outpath,
30 | target_formats = opt.formats,
31 | native_groupconv = opt.native_groupconv,
32 | weight_quant=opt.weigthquant,
33 | fp16_model=opt.fp16,
34 | int8_model=opt.int8,
35 | int8_mean=opt.int8mean,
36 | int8_std=opt.int8std,
37 | image_root=opt.imgroot
38 | )
39 |
40 | if __name__ == "__main__":
41 | run()
--------------------------------------------------------------------------------
/torch_to_onnx.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | class MobileIENetS(nn.Module):
5 | def __init__(self, channels):
6 | super(MobileIENetS, self).__init__()
7 | self.head = FST(
8 | nn.Sequential(
9 | nn.Conv2d(3, channels, 5, 1, 2),
10 | nn.PReLU(channels),
11 | nn.Conv2d(channels, channels, 3, 1, 1)
12 | ),
13 | channels
14 | )
15 | self.body = FST(
16 | nn.Conv2d(channels, channels, 3, 1, 1),
17 | channels
18 | )
19 | self.att = nn.Sequential(
20 | nn.AdaptiveAvgPool2d(1),
21 | nn.Conv2d(channels, channels, 1),
22 | nn.Sigmoid()
23 | )
24 | self.att1 = nn.Sequential(
25 | nn.Conv2d(1, channels, 1, 1),
26 | nn.Sigmoid()
27 | )
28 | self.tail = nn.Conv2d(channels, 3, 3, 1, 1)
29 |
30 | def forward(self, x):
31 | x0 = self.head(x)
32 | x1 = self.body(x0)
33 | x2 = self.att(x1)
34 | max_out, _ = torch.max(x2 * x1, dim=1, keepdim=True)
35 | x3 = self.att1(max_out)
36 | x4 = torch.mul(x3, x2) * x1
37 | return self.tail(x4)
38 |
39 | class FST(nn.Module):
40 | def __init__(self, block1, channels):
41 | super(FST, self).__init__()
42 | self.block1 = block1
43 | self.weight1 = nn.Parameter(torch.randn(1))
44 | self.weight2 = nn.Parameter(torch.randn(1))
45 | self.bias = nn.Parameter(torch.randn((1, channels, 1, 1)))
46 |
47 | def forward(self, x):
48 | x1 = self.block1(x)
49 | weighted_block1 = self.weight1 * x1
50 | weighted_block2 = self.weight2 * x1
51 | return weighted_block1 * weighted_block2 + self.bias
52 |
53 | def export_onnx(pretrained_model_path):
54 | model = MobileIENetS(12)
55 |
56 | checkpoint = torch.load(pretrained_model_path)
57 | model.load_state_dict(checkpoint)
58 | model.eval()
59 |
60 | dummy_input = torch.randn(1, 3, 400, 600)
61 |
62 | torch.onnx.export(
63 | model,
64 | dummy_input,
65 | "LLE.onnx",
66 | opset_version=12,
67 | export_params=True,
68 | do_constant_folding=True,
69 | input_names=['input'],
70 | output_names=['output'],
71 | dynamic_axes=None
72 | )
73 | print("ONNX Success.")
74 |
75 | if __name__ == "__main__":
76 | pretrained_model_path = r'./pretrain/lolv1_best_slim.pkl'
77 | export_onnx(pretrained_model_path)
78 |
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
1 | from torch.utils import data
2 | from importlib import import_module
3 | import torchvision.transforms as transforms
4 |
5 | from .ispdata import ISPData
6 | from .lledata import LLEData
7 |
8 | __all__ = {
9 | 'ISPData',
10 | 'LLEData',
11 | 'import_loader'
12 | }
13 |
14 |
15 | def import_loader(opt):
16 | dataset_name = opt.model_task.upper()+'Data'
17 | dataset = getattr(import_module('data'), dataset_name)
18 |
19 | if opt.task == 'train':
20 | train_inp_path = opt.config['train']['train_inp']
21 | train_gt_path = opt.config['train']['train_gt']
22 | valid_inp_path = opt.config['train']['valid_inp']
23 | valid_gt_path = opt.config['train']['valid_gt']
24 |
25 |
26 | train_data = dataset(opt, train_inp_path, train_gt_path)
27 | #train_data = dataset(opt, train_inp_path, train_gt_path, transform=train_transform)
28 |
29 | if opt.model_task == 'sr':
30 | valid_data = dataset(opt, valid_inp_path, valid_gt_path, 'valid')
31 | else:
32 | valid_data = dataset(opt, valid_inp_path, valid_gt_path)
33 |
34 | train_loader = data.DataLoader(
35 | train_data,
36 | batch_size=opt.config['train']['batch_size'],
37 | shuffle=True,
38 | num_workers=opt.config['train']['num_workers'],
39 | drop_last=True,
40 | )
41 | valid_loader = data.DataLoader(
42 | valid_data,
43 | batch_size=1,
44 | shuffle=False,
45 | num_workers=opt.config['train']['num_workers'],
46 | drop_last=False,
47 | )
48 | return train_loader, valid_loader
49 |
50 | elif opt.task == 'test':
51 | inp_test_path = opt.config['test']['test_inp']
52 | gt_test_path = opt.config['test']['test_gt']
53 |
54 | test_data = dataset(opt, inp_test_path, gt_test_path)
55 | test_loader = data.DataLoader(
56 | test_data,
57 | batch_size=1,
58 | shuffle=False,
59 | num_workers=opt.config['test']['num_workers'],
60 | drop_last=False,
61 | )
62 | return test_loader
63 |
64 | elif opt.task == 'demo':
65 | inp_demo_path = opt.config['demo']['demo_inp']
66 | demo_data = dataset(opt, inp_demo_path)
67 | demo_loader = data.DataLoader(
68 | demo_data,
69 | batch_size=1,
70 | shuffle=False,
71 | num_workers=opt.config['demo']['num_workers'],
72 | drop_last=False,
73 | )
74 | return demo_loader
75 |
76 | else:
77 | raise ValueError('unknown task, please choose from [train, test, demo]')
78 |
--------------------------------------------------------------------------------
/test_TFLite_RGB.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from tensorflow import lite
3 | import os
4 | from PIL import Image
5 |
6 | def load_model(model_path):
7 | interpreter = lite.Interpreter(model_path=model_path)
8 | interpreter.allocate_tensors()
9 | return interpreter
10 |
11 | def print_model_structure(interpreter):
12 | print("Model Structure:")
13 | input_details = interpreter.get_input_details()
14 | output_details = interpreter.get_output_details()
15 |
16 | print("\nInput details:")
17 | for input_tensor in input_details:
18 | print(f"Name: {input_tensor['name']}, Shape: {input_tensor['shape']}, Type: {input_tensor['dtype']}")
19 |
20 | print("\nOutput details:")
21 | for output_tensor in output_details:
22 | print(f"Name: {output_tensor['name']}, Shape: {output_tensor['shape']}, Type: {output_tensor['dtype']}")
23 |
24 | def preprocess_image(image_path):
25 | #target_size = (1024, 1024)
26 | img = Image.open(image_path).convert("RGB")
27 | #img = img.resize(target_size, Image.BICUBIC)
28 | img_array = np.array(img).astype(np.float32) / 255.0
29 | img_array = np.expand_dims(img_array, axis=0)
30 | print(img_array.shape)
31 | return img_array
32 |
33 | def inference_and_save_results(interpreter, input_image_folder, output_image_folder):
34 | input_details = interpreter.get_input_details()
35 | output_details = interpreter.get_output_details()
36 |
37 | if not os.path.exists(output_image_folder):
38 | os.makedirs(output_image_folder)
39 |
40 | for image_name in os.listdir(input_image_folder):
41 | image_path = os.path.join(input_image_folder, image_name)
42 | if not image_path.lower().endswith(('png', 'jpg', 'jpeg')):
43 | continue
44 |
45 | img_array = preprocess_image(image_path)
46 |
47 | interpreter.set_tensor(input_details[0]['index'], img_array)
48 | interpreter.invoke()
49 | output_img = interpreter.get_tensor(output_details[0]['index'])
50 |
51 | output_img = np.clip(output_img, 0., 1.)
52 | output_img = np.squeeze(output_img)
53 | output_img = (output_img * 255).astype(np.uint8)
54 |
55 | pil_img = Image.fromarray(output_img)
56 | output_image_path = os.path.join(output_image_folder, image_name)
57 | pil_img.save(output_image_path)
58 |
59 | def main():
60 | model_path = './LLE.tflite'
61 | input_image_folder = './lowlight/LOLdataset/eval15/low'
62 | output_image_folder = './experiments/results'
63 |
64 | interpreter = load_model(model_path)
65 | print_model_structure(interpreter)
66 | inference_and_save_results(interpreter, input_image_folder, output_image_folder)
67 |
68 | if __name__ == "__main__":
69 | main()
70 |
--------------------------------------------------------------------------------
/onnx2tflite/onnx2tflite/components/dataloader.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import logging
4 | import numpy as np
5 |
6 | LOG = logging.getLogger("Quantization DataLoder :")
7 |
8 | class RandomLoader(object):
9 | def __init__(self, target_size):
10 | self.target_size = target_size
11 | LOG.warning(f"Generate quantization data from random, it's will lead to accuracy problem!")
12 |
13 | def __iter__(self):
14 | self.index = 0
15 | return self
16 |
17 | def __next__(self):
18 | if self.index > 5:
19 | raise StopIteration()
20 | self.index += 1
21 | return [np.random.randn(*self.target_size).astype(np.float32)]
22 |
23 | class ImageLoader(object):
24 | '''
25 | generate data for quantization from image datas.
26 | img_quan_data = (img - mean)/std, it's important for accuracy of model.
27 | '''
28 | VALID_FORMAT = ['.jpg', '.png', '.jpeg']
29 |
30 | def __init__(self, img_root, target_size, mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]) -> None:
31 | assert os.path.exists(img_root), F"{img_root} is not exists, please check!"
32 | self.fns = os.listdir(img_root)
33 | self.fns = list(filter(lambda fn: os.path.splitext(fn)[-1].lower() in self.VALID_FORMAT, self.fns))
34 | self.nums = len(self.fns)
35 | assert self.nums > 0, f"No images detected in {img_root}."
36 | if self.nums > 100:
37 | LOG.warning(f"{self.nums} images detected, the number of recommended images is less than 100.")
38 | else:
39 | LOG.info(f"{self.nums} images detected.")
40 | self.fns = [os.path.join(img_root, fn) for fn in self.fns]
41 |
42 | self.batch, self.size = target_size[0], target_size[1:-1]
43 | if isinstance(mean, list):
44 | mean = np.array(mean, dtype=np.float32)
45 | if isinstance(std, list):
46 | std = np.array(std, dtype=np.float32)
47 | self.mean, self.std = mean, std
48 |
49 | def __iter__(self):
50 | self.index = 0
51 | return self
52 |
53 | def __next__(self):
54 | if self.index >= self.nums:
55 | raise StopIteration()
56 |
57 | _input = cv2.imread(self.fns[self.index])
58 | _input = cv2.resize(_input, self.size)[:, :, ::-1]#BGR->RGB
59 | _input = _input.astype(np.float32)
60 |
61 | if self.mean is not None:
62 | _input = (_input - self.mean)
63 | if self.std is not None:
64 | _input = _input/self.std
65 |
66 | _input = np.expand_dims(_input, axis=0)
67 | if self.batch > 1:
68 | _input = np.repeat(_input, self.batch, axis=0).astype(np.float32)
69 |
70 | self.index += 1
71 | return [_input]
72 |
--------------------------------------------------------------------------------
/onnx2tflite/onnx2tflite/utils/graph_tools.py:
--------------------------------------------------------------------------------
1 | from onnx import numpy_helper
2 | import tensorflow as tf
3 | from tensorflow import keras
4 | from .definitions import *
5 |
6 | # copy from https://github.com/gmalivenko/onnx2keras
7 | def decode_node_attribute(node)->dict:
8 | """
9 | Parse ONNX attributes to Python dictionary
10 | :param args: ONNX attributes object
11 | :return: Python dictionary
12 | """
13 | def onnx_attribute_to_dict(onnx_attr):
14 | """
15 | Parse ONNX attribute
16 | :param onnx_attr: ONNX attribute
17 | :return: Python data type
18 | """
19 | if onnx_attr.HasField('t'):
20 | return numpy_helper.to_array(getattr(onnx_attr, 't'))
21 |
22 | for attr_type in ['f', 'i']:
23 | if onnx_attr.HasField(attr_type):
24 | return getattr(onnx_attr, attr_type)
25 |
26 | # s need to be decode, bytes to string
27 | if onnx_attr.HasField('s'):
28 | return getattr(onnx_attr, 's').decode()
29 |
30 | for attr_type in ['floats', 'ints', 'strings']:
31 | if getattr(onnx_attr, attr_type):
32 | return list(getattr(onnx_attr, attr_type))
33 | return {arg.name: onnx_attribute_to_dict(arg) for arg in node.attribute}
34 |
35 | def build_tf_inputs(model_graph, layout_dict:dict):
36 | inputs_name = []
37 | for inp in model_graph.input:
38 | input_shape = [x.dim_value for x in inp.type.tensor_type.shape.dim]
39 | if input_shape == []:
40 | continue
41 | inputs_name.append(inp.name)
42 | layout_dict[inp.name] = Layout.Default
43 | if len(input_shape) < 3:
44 | layout_dict[inp.name] = Layout.Channel_None
45 |
46 | _inputs_name = inputs_name.copy()
47 | for node in model_graph.node:
48 | op_name, node_inputs = node.op_type, node.input
49 | # output_layout = Layout.Default
50 | for ninp in node_inputs:
51 | if ninp in _inputs_name and op_name in FORCE_CHANNEL_LAST_OP and layout_dict[ninp] == Layout.Default:
52 | layout_dict[ninp] = Layout.Channel_Last
53 | _inputs_name.remove(ninp)
54 | if ninp in _inputs_name and op_name in FORCE_CHANNEL_FIRST_OP and layout_dict[ninp] == Layout.Default:
55 | layout_dict[ninp] = Layout.Channel_First
56 | _inputs_name.remove(ninp)
57 | # output_layout = output_layout | node_dict[ninp]
58 |
59 | if len(_inputs_name) == 0:
60 | break
61 |
62 | input_nodes = {}
63 | for inp in model_graph.input:
64 | input_shape = [x.dim_value for x in inp.type.tensor_type.shape.dim]
65 | if input_shape == []:
66 | continue
67 | batch_size = 1 if input_shape[0] <= 0 else input_shape[0]
68 | input_shape = input_shape[1:]
69 | if layout_dict[inp.name] == Layout.Channel_Last:
70 | input_shape = input_shape[1:] + input_shape[0:1]
71 |
72 | input_nodes[inp.name] = keras.Input(shape=input_shape, batch_size=batch_size, dtype=onnx2tf_type.get(inp.type.tensor_type.elem_type))
73 |
74 | return input_nodes
75 |
--------------------------------------------------------------------------------
/test_TFLite_ISP.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from tensorflow import lite
3 | import os
4 | import cv2
5 | from PIL import Image
6 |
7 | def load_model(model_path):
8 | interpreter = lite.Interpreter(model_path=model_path)
9 | interpreter.allocate_tensors()
10 | return interpreter
11 |
12 | def print_model_structure(interpreter):
13 | print("Model Structure:")
14 | input_details = interpreter.get_input_details()
15 | output_details = interpreter.get_output_details()
16 |
17 | print("\nInput details:")
18 | for input_tensor in input_details:
19 | print(f"Name: {input_tensor['name']}, Shape: {input_tensor['shape']}, Type: {input_tensor['dtype']}")
20 |
21 | print("\nOutput details:")
22 | for output_tensor in output_details:
23 | print(f"Name: {output_tensor['name']}, Shape: {output_tensor['shape']}, Type: {output_tensor['dtype']}")
24 |
25 | def bayer2rggb(img_bayer):
26 | h, w = img_bayer.shape
27 | img_bayer = img_bayer.reshape(h // 2, 2, w // 2, 2)
28 | img_bayer = img_bayer.transpose([1, 3, 0, 2]).reshape([4, h // 2, w // 2]) # [4, h//2, w//2]
29 | return img_bayer
30 |
31 | def inference_and_save_results(interpreter, input_image_folder, output_image_folder):
32 | input_details = interpreter.get_input_details()
33 | output_details = interpreter.get_output_details()
34 |
35 | if not os.path.exists(output_image_folder):
36 | os.makedirs(output_image_folder)
37 |
38 | for image_name in os.listdir(input_image_folder):
39 | image_path = os.path.join(input_image_folder, image_name)
40 | if not image_path.lower().endswith(('png', 'jpg', 'jpeg')):
41 | continue
42 |
43 | img = Image.open(image_path)
44 | img_array = np.array(img)
45 | img_array = bayer2rggb(img_array) # Convert Bayer pattern to RGGB
46 | img_array = img_array.astype(np.float32) / 4095.0 # Normalize to [0, 1]
47 |
48 | # Convert to shape [1, 128, 128, 4] for the model input
49 | img_array = np.transpose(img_array, (1, 2, 0))
50 | img_array = np.expand_dims(img_array, axis=0)
51 |
52 | interpreter.set_tensor(input_details[0]['index'], img_array)
53 | interpreter.invoke()
54 | output_img = interpreter.get_tensor(output_details[0]['index'])
55 | output_img = np.clip(output_img, 0., 1.)
56 | output_img = np.squeeze(output_img)
57 | print(output_img.shape)
58 |
59 | #output_img = output_img.transpose(1, 2, 0)
60 | output_img = (output_img * 255).astype(np.uint8)
61 | pil_img = Image.fromarray(output_img)
62 |
63 | # Save the output image using PIL (Image.save)
64 | output_image_path = os.path.join(output_image_folder, image_name)
65 | pil_img.save(output_image_path) # This automatically saves as RGB
66 |
67 | def main():
68 | model_path = './ISP.tflite' # TFLite model path
69 | input_image_folder = './ISP/Input' # Input the Image folder path
70 | output_image_folder = './Output' # Output the Image folder path
71 |
72 | interpreter = load_model(model_path)
73 | print_model_structure(interpreter)
74 | inference_and_save_results(interpreter, input_image_folder, output_image_folder)
75 |
76 | if __name__ == "__main__":
77 | main()
78 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
[ICCV 2025] MobileIE: An Extremely Lightweight and Effective ConvNet for Real-Time Image Enhancement on Mobile Devices
2 |
3 |
4 |
5 | Hailong Yan
1
6 | Ao Li
1
7 | Xiangtao Zhang
1
8 | Zhe Liu
1
9 | Zenglin Shi
2
10 | Ce Zhu
1
11 | Le Zhang
1,†
12 |
13 |
1 UESTC
2 Hefei University of Technology
14 |
† Corresponding author.
15 |
16 |
19 |
20 |
21 |
22 | Abstract: Recent advancements in deep neural networks have driven significant progress in image enhancement (IE). However, deploying deep learning models on resource-constrained platforms, such as mobile devices, remains challenging due to high computation and memory demands. To address these challenges and facilitate real-time IE on mobile, we introduce an extremely lightweight Convolutional Neural Network (CNN) framework with around 4K parameters. Our approach integrates reparameterization with an Incremental Weight Optimization strategy to ensure efficiency. Additionally, we enhance performance with a Feature Self-Transform module and a Hierarchical Dual-Path Attention mechanism, optimized with a Local Variance-Weighted loss. With this efficient framework, we are the first to achieve real-time IE inference at up to 1,100 frames per second (FPS) while delivering competitive image quality, achieving the best trade-off between speed and performance across multiple IE tasks.
23 |
24 |
25 |
26 |
27 |
28 |
29 | ---
30 |
31 |
32 | ### Preparation
33 |
34 | 1. Replace the dataset path in the config file.
35 | 2. If you want to train the model, set the type in config to "original" and need_slims to "false".
36 | 3. If you want to test the pretrain model, set the type in config to "re-parameterized", need_slims to "true", and load the re-parameterized pre-trained model. You can also run inference with TFLite model by executing "test_TFLite_RGB.py/test_TFLite_ISP.py".
37 | 4. You can use the TFLite model and import it into AI Benchmark (https://ai-benchmark.com/) to obtain the inference speed on mobile devices.
38 | 5. If you want to perform UIE task, replace the dataset path in config/lle.yaml with your underwater image dataset.
39 |
40 | ### Train
41 |
42 | ```bash
43 | python main.py -task train -model_task lle/isp -device cuda
44 | ```
45 |
46 | ### Test
47 |
48 | ```bash
49 | python main.py -task test -model_task lle/isp -device cuda
50 | ```
51 |
52 | ### Demo
53 |
54 | ```bash
55 | python main.py -task demo -model_task lle/isp -device cuda
56 | ```
57 |
58 | ### Contact
59 | If you have any questions, please contact me by e-mail (yanhailong@std.uestc.edu.cn; yhl00825@163.com).
60 |
61 | ### Citation
62 |
63 | If you find the code helpful in your research or work, please cite the following paper:
64 |
65 | ```
66 | @InProceedings{yan2025mobileie,
67 | author = {Yan, Hailong and Li, Ao and Zhang, Xiangtao and Liu, Zhe and Shi, Zenglin and Zhu, Ce and Zhang, Le},
68 | title = {MobileIE: An Extremely Lightweight and Effective ConvNet for Real-Time Image Enhancement on Mobile Devices},
69 | booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision},
70 | month = {October},
71 | year = {2025},
72 | }
73 | ```
74 |
--------------------------------------------------------------------------------
/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from option import get_option
4 |
5 | class CharbonnierLoss(nn.Module):
6 | def __init__(self, eps=1e-6):
7 | super(CharbonnierLoss, self).__init__()
8 | self.eps2 = eps ** 2
9 |
10 | def forward(self, inp, target):
11 | return ((nn.functional.mse_loss(inp, target, reduction='none') + self.eps2) ** .5).mean()
12 | #####################################################################################################
13 | class OutlierAwareLoss(nn.Module):
14 | def __init__(self,):
15 | super(OutlierAwareLoss, self).__init__()
16 |
17 | def forward(self, out, lab):
18 | delta = out - lab
19 | var = delta.std((2, 3), keepdims=True) / (2 ** .5)
20 | avg = delta.mean((2, 3), True)
21 | weight = torch.tanh((delta - avg).abs() / (var + 1e-6)).detach()
22 | loss = (delta.abs() * weight).mean()
23 | return loss
24 |
25 | #####################################################################################################
26 | class LossWarmup(nn.Module):
27 | def __init__(self):
28 | super(LossWarmup, self).__init__()
29 | self.loss_cb = CharbonnierLoss(1e-8)
30 | self.loss_cs = nn.CosineSimilarity()
31 |
32 | def forward(self, inp, gt, warmup1, warmup2):
33 | loss = self.loss_cb(warmup2, inp) + \
34 | (self.loss_cb(warmup1, gt) + (1 - self.loss_cs(warmup1.clip(0, 1), gt)).mean())
35 |
36 | return loss
37 |
38 |
39 | class LossLLE(nn.Module):
40 | def __init__(self):
41 | super(LossLLE, self).__init__()
42 | self.loss_cs = nn.CosineSimilarity()
43 | self.loss_oa = OutlierAwareLoss()
44 | self.psnr = PSNRLoss()
45 |
46 | def forward(self, out, gt):
47 | loss = (self.loss_oa(out, gt) + (1 - self.loss_cs(out.clip(0, 1), gt)).mean()) + 2 * self.psnr(out, gt)
48 | return loss
49 |
50 | class LossISP(nn.Module):
51 | def __init__(self):
52 | super(LossISP, self).__init__()
53 | self.loss_cs = nn.CosineSimilarity()
54 | self.loss_oa = OutlierAwareLoss()
55 | self.psnr = PSNRLoss()
56 |
57 | def forward(self, out, gt):
58 | loss = (self.loss_oa(out, gt) + (1 - self.loss_cs(out.clip(0, 1), gt)).mean()) + 2 * self.psnr(out, gt)
59 | return loss
60 |
61 | def import_loss(training_task):
62 | if training_task == 'isp':
63 | return LossISP()
64 | elif training_task == 'lle':
65 | return LossLLE()
66 | elif training_task == 'warmup':
67 | return LossWarmup()
68 | else:
69 | raise ValueError('unknown training task, please choose from [isp, lle, warmup].')
70 |
71 | class PSNRLoss(nn.Module):
72 |
73 | def __init__(self, loss_weight=1.0, reduction='mean', toY=False):
74 | super(PSNRLoss, self).__init__()
75 | assert reduction == 'mean'
76 | self.loss_weight = loss_weight
77 | self.toY = toY
78 | self.coef = torch.tensor([65.481, 128.553, 24.966]).reshape(1, 3, 1, 1)
79 | self.first = True
80 |
81 | def forward(self, pred, target):
82 | assert len(pred.size()) == 4
83 | if self.toY:
84 | if self.first:
85 | self.coef = self.coef.to(pred.device)
86 | self.first = False
87 |
88 | pred = (pred * self.coef).sum(dim=1).unsqueeze(dim=1) + 16.
89 | target = (target * self.coef).sum(dim=1).unsqueeze(dim=1) + 16.
90 |
91 | pred, target = pred / 255., target / 255.
92 | pass
93 | assert len(pred.size()) == 4
94 | imdff=pred-target
95 | rmse=((imdff**2).mean(dim=(1,2,3))+1e-8).sqrt()
96 | loss=20*torch.log10(1/rmse).mean()
97 | loss=(50.0-loss)/100.0
98 | return loss
99 |
--------------------------------------------------------------------------------
/onnx2tflite/onnx2tflite/components/output_check.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import tensorflow as tf
4 | import onnxruntime as ort
5 | from onnx2tflite.utils.definitions import Layout
6 | from onnx2tflite.utils.dimension_utils import tensor_NDC_to_NCD_format
7 |
8 | def tflite_run(model_path:str) -> np.ndarray:
9 | '''
10 | tflite runtime
11 | '''
12 | tflite_runtime = tf.lite.Interpreter(model_path, num_threads=4)
13 | tflite_runtime.allocate_tensors()
14 | input_details, output_details = tflite_runtime.get_input_details(), tflite_runtime.get_output_details()
15 | for i in range(len(input_details)):
16 | tflite_runtime.set_tensor(input_details[i]['index'], np.ones(input_details[i]['shape'], dtype=np.float32))
17 | tflite_runtime.invoke()
18 |
19 | # only compare one output is ok.
20 | tflite_output = tflite_runtime.get_tensor(output_details[0]['index'])
21 | return tflite_output
22 |
23 | def keras_run(model_path:str) -> np.ndarray:
24 | '''
25 | keras runtime
26 | '''
27 | keras_runtime = tf.keras.models.load_model(model_path)
28 | _input = []
29 | for inp in keras_runtime.inputs:
30 | _input.append(np.ones(list(inp.shape), dtype=np.float32))
31 |
32 | keras_output = keras_runtime.predict(_input)
33 | # only compare one output is ok.
34 | if isinstance(keras_output, list):
35 | keras_output = keras_output[0]
36 | return keras_output
37 |
38 |
39 | def get_elements_error(onnx_proto, keras_model_path:str, tflite_model_path:str, input_layout:dict, output_layout:dict) -> dict:
40 | '''
41 | use ones input arr to check model.
42 | more carefully check is up to youself custom code.
43 | '''
44 | result = {}
45 | # test onnx
46 | onnx_runtime = ort.InferenceSession(onnx_proto.SerializeToString())
47 | onnx_inputs = {}
48 | for inp in onnx_runtime.get_inputs():
49 | shape = inp.shape
50 | if isinstance(shape[0], str) or shape[0] < 1:
51 | shape[0] = 1
52 | onnx_inputs[inp.name] = np.ones(shape, dtype=np.float32)
53 | if len(shape) > 2:
54 | _transpose_index = [i for i in range(len(shape))]
55 | _transpose_index = _transpose_index[0:1] + _transpose_index[2:] + _transpose_index[1:2]
56 | onnx_outputs = onnx_runtime.run([], onnx_inputs)
57 |
58 | channel_last = False
59 | for oup in onnx_proto.graph.output:
60 | channel_last = output_layout[oup.name] == Layout.Channel_Last
61 | break
62 |
63 | if keras_model_path is not None:
64 | # test keras model
65 | keras_output = keras_run(keras_model_path)
66 | if channel_last:
67 | keras_output = tensor_NDC_to_NCD_format(keras_output)
68 | # get max error
69 | keras_max_error = 1000
70 | for onnx_output in onnx_outputs:
71 | if onnx_output.shape != keras_output.shape:
72 | continue
73 | diff = np.abs(onnx_output - keras_output)
74 | max_diff = np.max(diff)
75 | keras_max_error = min(keras_max_error, max_diff)
76 | result['keras'] = keras_max_error
77 |
78 | if tflite_model_path is not None:
79 | # test tflite
80 | tflite_output = tflite_run(tflite_model_path)
81 | if channel_last:
82 | tflite_output = tensor_NDC_to_NCD_format(tflite_output)
83 | # get max error
84 | tflite_max_error = 1000
85 | for onnx_output in onnx_outputs:
86 | if onnx_output.shape != tflite_output.shape:
87 | continue
88 | diff = np.abs(onnx_output - tflite_output)
89 | max_diff = np.max(diff)
90 | tflite_max_error = min(tflite_max_error, max_diff)
91 | result['tflite'] = tflite_max_error
92 |
93 | return result
--------------------------------------------------------------------------------
/model/lle.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 | from .utils import (
4 | MBRConv5,
5 | MBRConv3,
6 | MBRConv1,
7 | DropBlock,
8 | FST,
9 | FSTS,
10 | )
11 |
12 | class MobileIELLENet(nn.Module):
13 | def __init__(self, channels, rep_scale=4):
14 | super(MobileIELLENet, self).__init__()
15 | self.channels = channels
16 | self.head = FST(
17 | nn.Sequential(
18 | MBRConv5(3, channels, rep_scale=rep_scale),
19 | nn.PReLU(channels),
20 | MBRConv3(channels, channels, rep_scale=rep_scale)
21 | ),
22 | channels
23 | )
24 | self.body = FST(
25 | MBRConv3(channels, channels, rep_scale=rep_scale),
26 | channels
27 | )
28 | self.att = nn.Sequential(
29 | nn.AdaptiveAvgPool2d(1),
30 | MBRConv1(channels, channels, rep_scale=rep_scale),
31 | nn.Sigmoid()
32 | )
33 | self.att1= nn.Sequential(
34 | MBRConv1(1, channels, rep_scale=rep_scale),
35 | nn.Sigmoid()
36 | )
37 | self.tail = MBRConv3(channels, 3, rep_scale=rep_scale)
38 | self.tail_warm = MBRConv3(channels, 3, rep_scale=rep_scale)
39 | self.drop = DropBlock(3)
40 |
41 | def forward(self, x):
42 | x0 = self.head(x)
43 | x1 = self.body(x0)
44 | x2 = self.att(x1)
45 | max_out, _ = torch.max(x2 * x1 , dim=1, keepdim=True)
46 | x3 = self.att1(max_out)
47 | x4 = torch.mul(x2, x3) * x1
48 | return self.tail(x4)
49 |
50 | def forward_warm(self, x):
51 | x = self.drop(x)
52 | x = self.head(x)
53 | x = self.body(x)
54 | return self.tail(x), self.tail_warm(x)
55 |
56 | def slim(self):
57 | net_slim = MobileIELLENetS(self.channels)
58 | weight_slim = net_slim.state_dict()
59 | for name, mod in self.named_modules():
60 | if isinstance(mod, MBRConv3) or isinstance(mod, MBRConv5) or isinstance(mod, MBRConv1):
61 | if '%s.weight' % name in weight_slim:
62 | w, b = mod.slim()
63 | weight_slim['%s.weight' % name] = w
64 | weight_slim['%s.bias' % name] = b
65 | elif isinstance(mod, FST):
66 | weight_slim['%s.bias' % name] = mod.bias
67 | weight_slim['%s.weight1' % name] = mod.weight1
68 | weight_slim['%s.weight2' % name] = mod.weight2
69 | elif isinstance(mod, nn.PReLU):
70 | weight_slim['%s.weight' % name] = mod.weight
71 | net_slim.load_state_dict(weight_slim)
72 | return net_slim
73 |
74 | class MobileIELLENetS(nn.Module):
75 | def __init__(self, channels):
76 | super(MobileIELLENetS, self).__init__()
77 | self.head = FSTS(
78 | nn.Sequential(
79 | nn.Conv2d(3, channels, 5, 1, 2),
80 | nn.PReLU(channels),
81 | nn.Conv2d(channels, channels, 3, 1, 1)
82 | ),
83 | channels
84 | )
85 | self.body = FSTS(
86 | nn.Conv2d(channels, channels, 3, 1, 1),
87 | channels
88 | )
89 | self.att = nn.Sequential(
90 | nn.AdaptiveAvgPool2d(1),
91 | nn.Conv2d(channels, channels, 1),
92 | nn.Sigmoid()
93 | )
94 | self.att1 = nn.Sequential(
95 | nn.Conv2d(1, channels, 1, 1),
96 | nn.Sigmoid()
97 | )
98 | self.tail = nn.Conv2d(channels, 3, 3, 1, 1)
99 |
100 | def forward(self, x):
101 | x0 = self.head(x)
102 | x1 = self.body(x0)
103 | x2 = self.att(x1)
104 | max_out, _ = torch.max(x2 * x1, dim=1, keepdim=True)
105 | x3 = self.att1(max_out)
106 | x4 = torch.mul(x3, x2) * x1
107 | return self.tail(x4)
108 |
109 |
110 |
--------------------------------------------------------------------------------
/model/isp.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 | from .utils import (
4 | MBRConv5,
5 | MBRConv3,
6 | MBRConv1,
7 | DropBlock,
8 | FST,
9 | FSTS,
10 | )
11 |
12 | class MobileIEISPNet(nn.Module):
13 | def __init__(self, channels, rep_scale=4):
14 | super(MobileIEISPNet, self).__init__()
15 | self.channels = channels
16 | self.head = FST(
17 | nn.Sequential(
18 | MBRConv5(4, channels, rep_scale=rep_scale),
19 | nn.PReLU(channels),
20 | MBRConv3(channels, channels, rep_scale=rep_scale)
21 | ),
22 | channels
23 | )
24 | self.body = FST(
25 | MBRConv3(channels, channels, rep_scale=rep_scale),
26 | channels
27 | )
28 | self.att = nn.Sequential(
29 | nn.AdaptiveAvgPool2d(1),
30 | MBRConv1(channels, channels, rep_scale=rep_scale),
31 | nn.Sigmoid()
32 | )
33 | self.att1= nn.Sequential(
34 | MBRConv1(1, channels, rep_scale=rep_scale),
35 | nn.Sigmoid()
36 | )
37 | self.tail = nn.Sequential(nn.PixelShuffle(2), MBRConv3(3, 3, rep_scale=rep_scale))
38 | self.tail_warm = MBRConv3(channels, 4, rep_scale=rep_scale)
39 | self.drop = DropBlock(3)
40 |
41 | def forward(self, x):
42 | x0 = self.head(x)
43 | x1 = self.body(x0)
44 | x2 = self.att(x1)
45 | max_out, _ = torch.max(x2 * x1, dim=1, keepdim=True)
46 | x3 = self.att1(max_out)
47 | x4 = torch.mul(x3, x2) * x1
48 | return self.tail(x4)
49 |
50 | def forward_warm(self, x):
51 | x = self.drop(x)
52 | x = self.head(x)
53 | x = self.body(x)
54 | return self.tail(x), self.tail_warm(x)
55 |
56 | def slim(self):
57 | net_slim = MobileIEISPNetS(self.channels)
58 | weight_slim = net_slim.state_dict()
59 | for name, mod in self.named_modules():
60 | if isinstance(mod, MBRConv3) or isinstance(mod, MBRConv5) or isinstance(mod, MBRConv1):
61 | if '%s.weight' % name in weight_slim:
62 | w, b = mod.slim()
63 | weight_slim['%s.weight' % name] = w
64 | weight_slim['%s.bias' % name] = b
65 | elif isinstance(mod, FST):
66 | weight_slim['%s.bias' % name] = mod.bias
67 | weight_slim['%s.weight1' % name] = mod.weight1
68 | weight_slim['%s.weight2' % name] = mod.weight2
69 | elif isinstance(mod, nn.PReLU):
70 | weight_slim['%s.weight' % name] = mod.weight
71 | net_slim.load_state_dict(weight_slim)
72 | return net_slim
73 |
74 | class MobileIEISPNetS(nn.Module):
75 | def __init__(self, channels):
76 | super(MobileIEISPNetS, self).__init__()
77 | self.head = FSTS(
78 | nn.Sequential(
79 | nn.Conv2d(4, channels, 5, 1, 2),
80 | nn.PReLU(channels),
81 | nn.Conv2d(channels, channels, 3, 1, 1)
82 | ),
83 | channels
84 | )
85 | self.body = FSTS(
86 | nn.Conv2d(channels, channels, 3, 1, 1),
87 | channels
88 | )
89 | self.att = nn.Sequential(
90 | nn.AdaptiveAvgPool2d(1),
91 | nn.Conv2d(channels, channels, 1),
92 | nn.Sigmoid()
93 | )
94 | self.att1 = nn.Sequential(
95 | nn.Conv2d(1, channels, 1, 1),
96 | nn.Sigmoid()
97 | )
98 | self.tail = nn.Sequential(nn.PixelShuffle(2), nn.Conv2d(3, 3, 3, 1, 1))
99 |
100 | def forward(self, x):
101 | x0 = self.head(x)
102 | x1 = self.body(x0)
103 | x2 = self.att(x1)
104 | max_out, _ = torch.max(x2 * x1, dim=1, keepdim=True)
105 | x3 = self.att1(max_out)
106 | x4 = torch.mul(x3, x2) * x1
107 | return self.tail(x4)
108 |
--------------------------------------------------------------------------------
/onnx2tflite/onnx2tflite/components/builder.py:
--------------------------------------------------------------------------------
1 | import os
2 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
3 |
4 | import tensorflow as tf
5 | from tensorflow import keras
6 | from onnx import numpy_helper
7 | from .dataloader import RandomLoader, ImageLoader
8 |
9 | from onnx2tflite.utils import OPERATOR
10 | from onnx2tflite.layers import conv_layers
11 | from onnx2tflite.utils.definitions import *
12 | from onnx2tflite.utils.graph_tools import build_tf_inputs, decode_node_attribute
13 |
14 | def keras_builder(onnx_model, native_groupconv:bool=False):
15 |
16 | conv_layers.USE_NATIVE_GROUP_CONV = native_groupconv
17 |
18 | model_graph = onnx_model.graph
19 | layout_dict, tf_tensor = {}, {}
20 |
21 | '''
22 | init onnx model's build-in tensors
23 | '''
24 | onnx_weights = dict()
25 | for initializer in model_graph.initializer:
26 | onnx_weights[initializer.name] = numpy_helper.to_array(initializer)
27 |
28 | '''
29 | build input nodes
30 | '''
31 | input_nodes = build_tf_inputs(model_graph, layout_dict)
32 | tf_tensor.update(input_nodes)
33 |
34 | '''
35 | build model inline node by iterate onnx nodes.
36 | '''
37 | for node in model_graph.node:
38 | op_name, node_inputs, node_outputs = node.op_type, node.input, node.output
39 | op_attr = decode_node_attribute(node)
40 |
41 | tf_operator = OPERATOR.get(op_name)
42 | if tf_operator is None:
43 | raise KeyError(f"{op_name} not implemented yet")
44 |
45 | _inputs = None
46 | if len(node_inputs) > 0:
47 | _inputs = tf_tensor[node_inputs[0]] if node_inputs[0] in tf_tensor else onnx_weights[node_inputs[0]]
48 |
49 | # init layout
50 | for index in range(len(node_outputs)):
51 | layout_dict[node_outputs[index]] = layout_dict.get(node_inputs[0], Layout.Default)
52 |
53 | res = tf_operator(tf_tensor, onnx_weights, node_inputs, op_attr, node_outputs, layout_dict)(_inputs)
54 | if isinstance(res, list):
55 | for index in range(len(node_outputs)):
56 | tf_tensor[node_outputs[index]] = res[index]
57 | else:
58 | tf_tensor[node_outputs[0]] = res
59 |
60 | '''
61 | build keras model
62 | '''
63 | input_nodes = [tf_tensor[x.name] for x in model_graph.input]
64 | outputs_nodes = [tf_tensor[x.name] for x in model_graph.output]
65 | keras_model = keras.Model(inputs=input_nodes, outputs=outputs_nodes)
66 | keras_model.trainable = False
67 | # keras_model.summary()
68 | # print(layout_dict)
69 | input_layout, output_layout = {}, {}
70 | for inp in model_graph.input:
71 | input_layout[inp.name] = layout_dict[inp.name]
72 | for oup in model_graph.output:
73 | output_layout[oup.name] = layout_dict[oup.name]
74 | return keras_model, input_layout, output_layout
75 |
76 | def tflite_builder(keras_model, weight_quant:bool=False, fp16_model=False, int8_model:bool=False, image_root:str=None,
77 | int8_mean:list or float = [123.675, 116.28, 103.53], int8_std:list or float = [58.395, 57.12, 57.375]):
78 | converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
79 | converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS]
80 | if weight_quant or int8_model or fp16_model:
81 | converter.experimental_new_converter = True
82 | converter.optimizations = [tf.lite.Optimize.DEFAULT]
83 |
84 | if fp16_model:
85 | converter.target_spec.supported_types = [tf.float16]
86 | converter.inference_input_type = tf.float32
87 | converter.inference_output_type = tf.float32
88 | elif int8_model:
89 | assert len(keras_model.inputs) == 1, f"help want, only support single input model."
90 | shape = list(keras_model.inputs[0].shape)
91 | dataset = RandomLoader(shape) if image_root is None else ImageLoader(image_root, shape, int8_mean, int8_std)
92 | converter.representative_dataset = lambda: dataset
93 | converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8, tf.lite.OpsSet.SELECT_TF_OPS]
94 | converter.target_spec.supported_types = []
95 | converter.inference_input_type = tf.uint8
96 | converter.inference_output_type = tf.uint8
97 | converter.experimental_new_converter = True
98 |
99 | tflite_model = converter.convert()
100 | return tflite_model
--------------------------------------------------------------------------------
/onnx2tflite/onnx2tflite/components/onnx_loader.py:
--------------------------------------------------------------------------------
1 | import os
2 | import onnx
3 | import logging
4 | from onnxsim import simplify # ONNX 模型简化工具
5 |
6 | LOG = logging.getLogger("onnx_loader running:")
7 | LOG.setLevel(logging.INFO)
8 |
9 | def clean_model_input(model_proto):
10 | """
11 | 清理 ONNX 模型的输入,删除 ONNX 计算图中冗余的输入节点。
12 |
13 | 逻辑:
14 | - 遍历 ONNX 计算图中的 `graph.input`
15 | - 如果某个 `input` 也出现在 `initializer` 中,则说明它是一个冗余输入(即它的值已经在 `initializer` 中存储)
16 | - 从 `graph.input` 中移除这些冗余输入
17 |
18 | 参数:
19 | - model_proto (onnx.ModelProto): 需要清理的 ONNX 模型
20 | """
21 | inputs = model_proto.graph.input # 获取 ONNX 计算图中的输入
22 | name_to_input = {} # 创建输入名称到输入对象的映射
23 | for input in inputs:
24 | name_to_input[input.name] = input
25 |
26 | names = []
27 | for initializer in model_proto.graph.initializer: # 遍历所有初始化参数
28 | if initializer.name in name_to_input: # 如果初始化参数的名字在输入列表中
29 | inputs.remove(name_to_input[initializer.name]) # 删除该输入
30 | names.append(initializer.name)
31 |
32 | if len(names) > 0:
33 | LOG.warning(f"[{len(names)}] redundant input nodes are removed.\n \
34 | nodes name : {','.join(names)}")
35 |
36 | def get_onnx_submodel(onnx_model_path:str, input_node_names:list=None, output_node_names:list=None):
37 | """
38 | 截取 ONNX 子模型,即从原始 ONNX 模型中提取以 `input_node_names` 为输入、
39 | `output_node_names` 为输出的子图。
40 |
41 | 逻辑:
42 | - 载入 ONNX 模型
43 | - 确定输入节点和输出节点(如果未指定,则默认使用整个模型的输入/输出)
44 | - 使用 `onnx.utils.extract_model` 提取子模型并保存
45 | - 载入提取后的子模型并返回
46 |
47 | 参数:
48 | - onnx_model_path (str): ONNX 模型文件路径
49 | - input_node_names (list, optional): 指定子模型的输入节点名称
50 | - output_node_names (list, optional): 指定子模型的输出节点名称
51 |
52 | 返回:
53 | - model_proto (onnx.ModelProto): 提取的子模型
54 | """
55 | model_proto = onnx.load(onnx_model_path) # 载入 ONNX 模型
56 | # 如果未指定输入节点,则默认使用 ONNX 模型的全部输入
57 | if input_node_names is None:
58 | input_node_names = []
59 | for inp in model_proto.graph.input:
60 | input_node_names.append(inp.name)
61 |
62 | # 如果未指定输出节点,则默认使用 ONNX 模型的全部输出
63 | if output_node_names is None:
64 | output_node_names = []
65 | for oup in model_proto.graph.output:
66 | output_node_names.append(oup.name)
67 | del model_proto # 释放原始模型的内存
68 |
69 | # 生成新模型的文件路径
70 | new_model_path = os.path.splitext(onnx_model_path)[0] + "_sub.onnx"
71 | # 提取子模型并保存
72 | onnx.utils.extract_model(onnx_model_path, new_model_path, input_node_names, output_node_names)
73 | # 载入提取后的子模型
74 | model_proto = onnx.load(new_model_path)
75 | return model_proto
76 |
77 | def get_proto(onnx_model_path:str, input_node_names:list=None, output_node_names:list=None):
78 | if input_node_names is None and output_node_names is None:
79 | return onnx.load(onnx_model_path)
80 | else:
81 | return get_onnx_submodel(onnx_model_path, input_node_names, output_node_names)
82 |
83 | def load_onnx_modelproto(onnx_model_path:str, input_node_names:list=None, output_node_names:list=None, need_simplify:bool=True):
84 | """
85 | 载入 ONNX 模型,并根据需要进行简化和清理。
86 |
87 | 逻辑:
88 | - 检查 ONNX 模型文件是否存在
89 | - 载入完整的 ONNX 模型或子模型
90 | - 检测是否存在动态输入
91 | - 如果 `need_simplify=True`,则尝试使用 `onnx-simplifier` 进行模型优化
92 | - 移除 ONNX 计算图中冗余的输入
93 |
94 | 参数:
95 | - onnx_model_path (str): ONNX 模型文件路径
96 | - input_node_names (list, optional): 需要提取的输入节点
97 | - output_node_names (list, optional): 需要提取的输出节点
98 | - need_simplify (bool, optional): 是否对 ONNX 进行简化,默认启用
99 |
100 | 返回:
101 | - model_proto (onnx.ModelProto): 处理后的 ONNX 模型
102 | """
103 | # 1. 检查 ONNX 文件是否存在
104 | if not os.path.exists(onnx_model_path):
105 | LOG.error(f"{onnx_model_path} is not exists.")
106 | raise FileExistsError(f"{onnx_model_path} is not exists.")
107 | # 2. 载入 ONNX 模型或子模型
108 | model_proto = get_proto(onnx_model_path, input_node_names, output_node_names)
109 | # 3. 检查是否存在动态输入(即输入形状中有未指定的维度)
110 | dynamic_input = False
111 | for inp in model_proto.graph.input:
112 | for x in inp.type.tensor_type.shape.dim:
113 | if x.dim_value <= 0: # 发现动态输入
114 | dynamic_input = True
115 | break
116 | # 4. 进行 ONNX 模型简化(如果启用)
117 | if need_simplify:
118 | success = False
119 | try:
120 | # 使用 `onnxsim` 进行模型优化,允许动态输入
121 | model_proto, success = simplify(model_proto, check_n=1, dynamic_input_shape=dynamic_input)
122 | except:
123 | success = False
124 | # 如果简化失败,记录警告信息
125 | if not success:
126 | LOG.warning(f"onnxsim is failed, maybe make convert fails.")
127 |
128 | model_proto = onnx.load(onnx_model_path)
129 |
130 | # 5. 清理 ONNX 模型的冗余输入
131 | clean_model_input(model_proto)
132 | # 在返回 ONNX 之前,检查是否仍然存在动态输入
133 |
134 | ##################################################################################
135 | for inp in model_proto.graph.input:
136 | for x in inp.type.tensor_type.shape.dim:
137 | if x.dim_value <= 0: # 仍然是动态输入
138 | LOG.warning(f"ONNX 仍然包含动态输入: {inp.name},维度: {[dim.dim_value for dim in inp.type.tensor_type.shape.dim]}")
139 | break # 只打印一次警告即可
140 | ##################################################################################
141 |
142 | return model_proto
--------------------------------------------------------------------------------
/onnx2tflite/build/lib/onnx2tflite/converter.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 | from .components import load_onnx_modelproto, keras_builder, tflite_builder, get_elements_error
4 |
5 | logging.basicConfig(level=logging.INFO)
6 | LOG = logging.getLogger("converter running:")
7 |
8 | def onnx_converter(onnx_model_path:str, output_path:str=None,
9 | input_node_names:list=None, output_node_names:list=None,
10 | need_simplify:bool=True, target_formats:list = ['keras', 'tflite'],
11 | native_groupconv:bool=False,
12 | weight_quant:bool=False, fp16_model:bool=False, int8_model:bool=False, image_root:str=None,
13 | int8_mean:list or float = [123.675, 116.28, 103.53], int8_std:list or float = [58.395, 57.12, 57.375])->float:
14 | """
15 | Converts an ONNX model to various target formats with optional optimizations.
16 |
17 | Parameters:
18 | onnx_model_path (str): Path to the input ONNX model file.
19 | output_path (str, optional): Path to save the converted model(s). If None, the converted model(s) will be saved in the same directory as the input model.
20 | input_node_names (list, optional): List of input node names. If None, the default input nodes of the ONNX model are used.
21 | output_node_names (list, optional): List of output node names. If None, the default output nodes of the ONNX model are used.
22 | need_simplify (bool, optional): If True, the ONNX model will be simplified before conversion. Default is True.
23 | target_formats (list, optional): List of target formats to convert the ONNX model to. Default is ['keras', 'tflite'].
24 | native_groupconv (bool, optional): If True, retains native group convolution operations during conversion. Default is False.
25 | weight_quant (bool, optional): If True, applies weight quantization to the converted model. Default is False.
26 | fp16_model (bool, optional): If True, converts the model to use FP16 precision. Default is False.
27 | int8_model (bool, optional): If True, converts the model to use INT8 precision. Default is False.
28 | image_root (str, optional): Path to the root directory of images for calibration if INT8 quantization is enabled. Default is None.
29 | int8_mean (list or float, optional): Mean values for INT8 quantization. Default is [123.675, 116.28, 103.53].
30 | int8_std (list or float, optional): Standard deviation values for INT8 quantization. Default is [58.395, 57.12, 57.375].
31 |
32 | Returns:
33 | float: Error value.
34 |
35 | Note:
36 | - The function supports multiple target formats for conversion and allows for various optimizations such as simplification, quantization, and precision reduction.
37 | - When INT8 quantization is enabled, 'image_root', 'int8_mean', and 'int8_std' parameters are used for calibration.
38 | """
39 | if not isinstance(target_formats, list) and 'keras' not in target_formats and 'tflite' not in target_formats:
40 | raise KeyError("'keras' or 'tflite' should in list")
41 |
42 | model_proto = load_onnx_modelproto(onnx_model_path, input_node_names, output_node_names, need_simplify)
43 |
44 | keras_model, input_layout, output_layout = keras_builder(model_proto, native_groupconv)
45 |
46 | if 'tflite' in target_formats:
47 | tflite_model = tflite_builder(keras_model, weight_quant, fp16_model, int8_model, image_root, int8_mean, int8_std)
48 |
49 | onnx_path, model_name = os.path.split(onnx_model_path)
50 | if output_path is None:
51 | output_path = onnx_path
52 | output_path = os.path.join(output_path, model_name.split('.')[0])
53 |
54 | if fp16_model:
55 | output_path = output_path + "_fp16"
56 | elif int8_model:
57 | output_path = output_path + "_int8"
58 |
59 | keras_model_path = None
60 | if 'keras' in target_formats:
61 | keras_model_path = output_path + ".h5"
62 | keras_model.save(keras_model_path)
63 | LOG.info(f"keras model saved in {keras_model_path}")
64 |
65 | tflite_model_path = None
66 | if 'tflite' in target_formats:
67 | tflite_model_path = output_path + ".tflite"
68 | with open(tflite_model_path, "wb") as fp:
69 | fp.write(tflite_model)
70 |
71 | convert_result = {"keras":keras_model_path, "tflite":tflite_model_path, "keras_error":0, "tflite_error":0}
72 | # ignore quantization model
73 | if int8_model:
74 | return convert_result
75 |
76 | error_dict = {}
77 | try:
78 | error_dict = get_elements_error(model_proto, keras_model_path, tflite_model_path, input_layout, output_layout)
79 | keras_error, tflite_error = error_dict.get("keras", None), error_dict.get("tflite", None)
80 | if keras_error:
81 | if keras_error > 1e-2:
82 | LOG.error("h5 model elements' max error has reached {:^.4E}, but convert is done, please check {} carefully!".format(keras_error, keras_model_path))
83 | elif keras_error > 1e-4:
84 | LOG.warning("h5 model elements' max error is {:^.4E}, pass, h5 saved in {}".format(keras_error, keras_model_path))
85 | else:
86 | LOG.info("h5 model elements' max error is {:^.4E}, pass, h5 saved in {}".format(keras_error, keras_model_path))
87 | if tflite_error:
88 | if tflite_error > 1e-2:
89 | LOG.error("tflite model elements' max error has reached {:^.4E}, but convert is done, please check {} carefully!".format(tflite_error, tflite_model_path))
90 | elif tflite_error > 1e-4:
91 | LOG.warning("tflite model elements' max error is {:^.4E}, pass, tflite saved in {}".format(tflite_error, tflite_model_path))
92 | else:
93 | LOG.info("tflite model elements' max error is {:^.4E}, pass, tflite saved in {}".format(tflite_error, tflite_model_path))
94 | except:
95 | LOG.warning("convert is successed, but model running is failed, please check carefully!")
96 |
97 | convert_result["keras_error"] = error_dict.get("keras", None)
98 | convert_result["tflite_error"] = error_dict.get("tflite", None)
99 | return convert_result
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import cv2
4 | from tqdm import tqdm
5 | from logger import Logger
6 | from option import get_option
7 | from data import import_loader
8 | from loss import import_loss
9 | from model import import_model
10 | import multiprocessing as mp
11 | import os
12 | os.environ['CUDA_VISIBLE_DEVICES'] = '5'
13 |
14 | def count_parameters(model):
15 | return sum(p.numel() for p in model.parameters() if p.requires_grad)
16 |
17 | def train(opt, logger):
18 | logger.info('task: {}, model task: {}'.format(opt.task, opt.model_task))
19 |
20 | train_loader, valid_loader = import_loader(opt)
21 | lr = float(opt.config['train']['lr'])
22 | lr_warmup = float(opt.config['train']['lr_warmup'])
23 |
24 | loss_warmup = import_loss('warmup')
25 | loss_training = import_loss(opt.model_task)
26 | net = import_model(opt)
27 | # logger.info(net)
28 | num_params = count_parameters(net)
29 | print("Total number of parameters: ", num_params)
30 |
31 | net.train()
32 | # Phase Warming-up
33 | if opt.config['train']['warmup']:
34 | logger.info('start warming-up')
35 |
36 | optim_warm = torch.optim.Adam(net.parameters(), lr_warmup, weight_decay=0)
37 | epochs = opt.config['train']['warmup_epoch']
38 | for epo in range(epochs):
39 | loss_li = []
40 | for img_inp, img_gt, _ in tqdm(train_loader, ncols=80):
41 | optim_warm.zero_grad()
42 | warmup_out1, warmup_out2 = net.forward_warm(img_inp)
43 | loss = loss_warmup(img_inp, img_gt, warmup_out1, warmup_out2)
44 | loss.backward()
45 | optim_warm.step()
46 | loss_li.append(loss.item())
47 |
48 | logger.info('epoch: {}, train_loss: {}'.format(epo+1, sum(loss_li)/len(loss_li)))
49 | torch.save(net.state_dict(), r'{}/model_pre.pkl'.format(opt.save_model_dir))
50 | logger.info('warming-up phase done')
51 |
52 | # Phase Training
53 | best_psnr = 0
54 | epochs = int(opt.config['train']['epoch'])
55 | optim = torch.optim.Adam(net.parameters(), lr, weight_decay=0)
56 | lr_sch = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optim, 50, 2, 1e-7)
57 |
58 | logger.info('start training')
59 | for epo in range(epochs):
60 | loss_li = []
61 | test_psnr = []
62 | net.train()
63 | for img_inp, img_gt, _ in tqdm(train_loader, ncols=80):
64 | out = net(img_inp)
65 | loss = loss_training(out, img_gt)
66 | optim.zero_grad()
67 | loss.backward()
68 | optim.step()
69 | loss_li.append(loss.item())
70 | lr_sch.step()
71 |
72 | # Validation
73 | net.eval()
74 | for img_inp, img_gt, _ in tqdm(valid_loader, ncols=80):
75 | with torch.no_grad():
76 | out = net(img_inp)
77 | mse = ((out - img_gt)**2).mean((2, 3))
78 | psnr = (1 / mse).log10().mean() * 10
79 | test_psnr.append(psnr.item())
80 | mean_psnr = sum(test_psnr)/len(test_psnr)
81 |
82 | if (epo+1) % int(opt.config['train']['save_every']) == 0:
83 | torch.save(net.state_dict(), r'{}/model_{}.pkl'.format(opt.save_model_dir, epo+1))
84 |
85 | logger.info('epoch: {}, training loss: {}, validation psnr: {}'.format(
86 | epo+1, sum(loss_li) / len(loss_li), sum(test_psnr) / len(test_psnr)
87 | ))
88 |
89 | if mean_psnr > best_psnr:
90 | best_psnr = mean_psnr
91 | torch.save(net.state_dict(), r'{}/model_best.pkl'.format(opt.save_model_dir))
92 | if opt.config['train']['save_slim']:
93 | net_slim = net.slim().to(opt.device)
94 | torch.save(net_slim.state_dict(), r'{}/model_best_slim.pkl'.format(opt.save_model_dir))
95 | logger.info('best model saved and re-parameterized in epoch {}'.format(epo+1))
96 | else:
97 | logger.info('best model saved in epoch in epoch {}'.format(epo+1))
98 |
99 | logger.info('training done')
100 |
101 |
102 | def test(opt, logger):
103 | test_loader = import_loader(opt)
104 | net = import_model(opt)
105 | net.eval()
106 | psnr_list = []
107 | logger.info('start testing')
108 | for (img_inp, img_gt, img_name) in test_loader:
109 |
110 | with torch.no_grad():
111 | out = net(img_inp)
112 | mse = ((out - img_gt)**2).mean((2, 3))
113 | psnr = (1 / mse).log10().mean() * 10
114 |
115 | if opt.config['test']['save']:
116 | out_img = (out.clip(0, 1)[0] * 255).permute([1, 2, 0]).cpu().numpy().astype(np.uint8)[..., ::-1]
117 | cv2.imwrite(r'{}/{}.png'.format(opt.save_image_dir, img_name[0]), out_img)
118 |
119 | psnr_list.append(psnr.item())
120 | logger.info('image name: {}, test psnr: {}'.format(img_name[0], psnr))
121 |
122 | logger.info('testing done, overall psnr: {}'.format(sum(psnr_list) / len(psnr_list)))
123 |
124 |
125 | def demo(opt, logger):
126 | demo_loader = import_loader(opt)
127 | net = import_model(opt)
128 | net.eval()
129 | logger.info('start demonstration')
130 | for img_inp, img_name in demo_loader:
131 |
132 | with torch.no_grad():
133 | out = net(img_inp)
134 | out_img = (out.clip(0, 1)[0] * 255).permute([1, 2, 0]).cpu().numpy().astype(np.uint8)[..., ::-1]
135 | cv2.imwrite(r'{}/{}.png'.format(opt.save_image_dir, img_name[0]), out_img)
136 | logger.info('image name: {} output generated'.format(img_name[0]))
137 | logger.info('demonstration done')
138 |
139 |
140 | if __name__ == "__main__":
141 | mp.set_start_method('spawn')
142 |
143 | opt = get_option()
144 | logger = Logger(opt)
145 |
146 | if opt.task == 'train':
147 | train(opt, logger)
148 | elif opt.task == 'test':
149 | test(opt, logger)
150 | elif opt.task == 'demo':
151 | demo(opt, logger)
152 | else:
153 | raise ValueError('unknown task, please choose from [train, test, demo].')
154 |
--------------------------------------------------------------------------------
/onnx2tflite/onnx2tflite/converter.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 | from .components import load_onnx_modelproto, keras_builder, tflite_builder, get_elements_error
4 |
5 | logging.basicConfig(level=logging.INFO)
6 | LOG = logging.getLogger("converter running:")
7 |
8 | def onnx_converter(onnx_model_path:str, output_path:str=None,
9 | input_node_names:list=None, output_node_names:list=None,
10 | need_simplify:bool=True, target_formats:list = ['keras', 'tflite'],
11 | native_groupconv:bool=False,
12 | weight_quant:bool=False, fp16_model:bool=False, int8_model:bool=False, image_root:str=None,
13 | int8_mean:list or float = [123.675, 116.28, 103.53], int8_std:list or float = [58.395, 57.12, 57.375])->float:
14 | """
15 | Converts an ONNX model to various target formats with optional optimizations.
16 |
17 | Parameters:
18 | onnx_model_path (str): Path to the input ONNX model file.
19 | output_path (str, optional): Path to save the converted model(s). If None, the converted model(s) will be saved in the same directory as the input model.
20 | input_node_names (list, optional): List of input node names. If None, the default input nodes of the ONNX model are used.
21 | output_node_names (list, optional): List of output node names. If None, the default output nodes of the ONNX model are used.
22 | need_simplify (bool, optional): If True, the ONNX model will be simplified before conversion. Default is True.
23 | target_formats (list, optional): List of target formats to convert the ONNX model to. Default is ['keras', 'tflite'].
24 | native_groupconv (bool, optional): If True, retains native group convolution operations during conversion. Default is False.
25 | weight_quant (bool, optional): If True, applies weight quantization to the converted model. Default is False.
26 | fp16_model (bool, optional): If True, converts the model to use FP16 precision. Default is False.
27 | int8_model (bool, optional): If True, converts the model to use INT8 precision. Default is False.
28 | image_root (str, optional): Path to the root directory of images for calibration if INT8 quantization is enabled. Default is None.
29 | int8_mean (list or float, optional): Mean values for INT8 quantization. Default is [123.675, 116.28, 103.53].
30 | int8_std (list or float, optional): Standard deviation values for INT8 quantization. Default is [58.395, 57.12, 57.375].
31 |
32 | Returns:
33 | float: Error value.
34 |
35 | Note:
36 | - The function supports multiple target formats for conversion and allows for various optimizations such as simplification, quantization, and precision reduction.
37 | - When INT8 quantization is enabled, 'image_root', 'int8_mean', and 'int8_std' parameters are used for calibration.
38 | """
39 | # 确保 target_formats 是列表类型,并包含 'keras' 或 'tflite'
40 | if not isinstance(target_formats, list) and 'keras' not in target_formats and 'tflite' not in target_formats:
41 | raise KeyError("'keras' or 'tflite' should in list")
42 |
43 | # 1. 加载 ONNX 模型并进行解析
44 | model_proto = load_onnx_modelproto(onnx_model_path, input_node_names, output_node_names, need_simplify)
45 |
46 | # 2. 将 ONNX 模型转换为 Keras 模型
47 | keras_model, input_layout, output_layout = keras_builder(model_proto, native_groupconv)
48 |
49 | # 3. 如果目标格式包括 'tflite',则进一步转换为 TFLite
50 | if 'tflite' in target_formats:
51 | tflite_model = tflite_builder(keras_model, weight_quant, fp16_model, int8_model, image_root, int8_mean, int8_std)
52 |
53 | # 4. 处理输出路径
54 | onnx_path, model_name = os.path.split(onnx_model_path) # 获取 ONNX 模型所在目录
55 | if output_path is None:
56 | output_path = onnx_path # 若未提供 output_path,则默认保存在 ONNX 所在目录
57 | output_path = os.path.join(output_path, model_name.split('.')[0]) # 设置输出文件的基本路径
58 |
59 | # 5. 如果是 FP16 或 INT8 量化模型,修改输出路径
60 | if fp16_model:
61 | output_path = output_path + "_fp16"
62 | elif int8_model:
63 | output_path = output_path + "_int8"
64 |
65 | # 6. 处理 Keras 模型的保存
66 | keras_model_path = None
67 | if 'keras' in target_formats:
68 | keras_model_path = output_path + ".h5"
69 | keras_model.save(keras_model_path)
70 | LOG.info(f"keras model saved in {keras_model_path}")
71 |
72 | # 7. 处理 TFLite 模型的保存
73 | tflite_model_path = None
74 | if 'tflite' in target_formats:
75 | tflite_model_path = output_path + ".tflite"
76 | with open(tflite_model_path, "wb") as fp:
77 | fp.write(tflite_model)
78 |
79 | # 8. 记录转换结果
80 | convert_result = {"keras":keras_model_path, "tflite":tflite_model_path, "keras_error":0, "tflite_error":0}
81 | # ignore quantization model
82 | if int8_model:
83 | return convert_result
84 |
85 | error_dict = {}
86 | try:
87 | error_dict = get_elements_error(model_proto, keras_model_path, tflite_model_path, input_layout, output_layout)
88 | keras_error, tflite_error = error_dict.get("keras", None), error_dict.get("tflite", None)
89 | if keras_error:
90 | if keras_error > 1e-2:
91 | LOG.error("h5 model elements' max error has reached {:^.4E}, but convert is done, please check {} carefully!".format(keras_error, keras_model_path))
92 | elif keras_error > 1e-4:
93 | LOG.warning("h5 model elements' max error is {:^.4E}, pass, h5 saved in {}".format(keras_error, keras_model_path))
94 | else:
95 | LOG.info("h5 model elements' max error is {:^.4E}, pass, h5 saved in {}".format(keras_error, keras_model_path))
96 | if tflite_error:
97 | if tflite_error > 1e-2:
98 | LOG.error("tflite model elements' max error has reached {:^.4E}, but convert is done, please check {} carefully!".format(tflite_error, tflite_model_path))
99 | elif tflite_error > 1e-4:
100 | LOG.warning("tflite model elements' max error is {:^.4E}, pass, tflite saved in {}".format(tflite_error, tflite_model_path))
101 | else:
102 | LOG.info("tflite model elements' max error is {:^.4E}, pass, tflite saved in {}".format(tflite_error, tflite_model_path))
103 | except:
104 | LOG.warning("convert is successed, but model running is failed, please check carefully!")
105 |
106 | convert_result["keras_error"] = error_dict.get("keras", None)
107 | convert_result["tflite_error"] = error_dict.get("tflite", None)
108 | return convert_result
--------------------------------------------------------------------------------
/onnx2tflite/onnx2tflite/layers/activations_layers.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import tensorflow as tf
3 | from tensorflow import keras
4 |
5 | from onnx2tflite.utils.definitions import Layout
6 | from onnx2tflite.utils import OPERATOR, channel_to_last_dimension, tensor_NCD_to_NDC_format
7 |
8 | @OPERATOR.register_operator("Relu")
9 | class TFRelu():
10 | def __init__(self, *args, **kwargs) -> None:
11 | super().__init__()
12 |
13 | def __call__(self, inputs):
14 | return keras.activations.relu(inputs)
15 |
16 | @OPERATOR.register_operator("HardSigmoid")
17 | class TFHardSigmoid():
18 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, *args, **kwargs) -> None:
19 | super().__init__()
20 | self.alpha = node_attribute.get("alpha", 0.2)
21 | self.beta = node_attribute.get("beta", 0.5)
22 |
23 | def __call__(self, inputs):
24 | return tf.clip_by_value(self.alpha*inputs+self.beta, 0, 1)
25 |
26 | @OPERATOR.register_operator("HardSwish")
27 | class TFHardSwish():
28 | def __init__(self, *args, **kwargs) -> None:
29 | super().__init__()
30 |
31 | def __call__(self, inputs):
32 | return inputs*tf.clip_by_value(inputs/6+0.5, 0, 1)
33 |
34 | @OPERATOR.register_operator("Mish")
35 | class TFMish():
36 | def __init__(self, *args, **kwargs) -> None:
37 | super().__init__()
38 |
39 | def __call__(self, inputs):
40 | return inputs*tf.tanh(tf.math.log(tf.math.exp(inputs)+1))
41 |
42 | @OPERATOR.register_operator("Sigmoid")
43 | class TFSigmoid():
44 | def __init__(self, *args, **kwargs) -> None:
45 | super().__init__()
46 |
47 | def __call__(self, inputs):
48 | return keras.activations.sigmoid(inputs)
49 |
50 | @OPERATOR.register_operator("LeakyRelu")
51 | class TFLeakyRelu():
52 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, *args, **kwargs) -> None:
53 | super().__init__()
54 | self.alpha = node_attribute.get('alpha', 0.01)
55 |
56 | def __call__(self, inputs):
57 | return keras.activations.relu(inputs, alpha=self.alpha)
58 |
59 | @OPERATOR.register_operator("PRelu")
60 | class TFPRelu():
61 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, node_outputs, layout_dict, *args, **kwargs) -> None:
62 | super().__init__()
63 | if 'slope' in node_attribute:
64 | self.slope = node_attribute['slope']
65 | elif node_inputs[1] in node_weights:
66 | self.slope = node_weights[node_inputs[1]]
67 | else:
68 | self.slope = tensor_grap[node_inputs[1]]
69 | input_tensor_shape = tensor_grap[node_inputs[0]].shape
70 | channel_last = layout_dict[node_inputs[0]] == Layout.Channel_Last
71 | if isinstance(self.slope, np.ndarray):
72 | while self.slope.ndim < input_tensor_shape.ndims:
73 | self.slope = self.slope[np.newaxis, :]
74 | if channel_last:
75 | self.slope = tensor_NCD_to_NDC_format(self.slope)
76 | if self.slope.ndim > 1:
77 | # remove batchsize
78 | self.slope = self.slope[0]
79 | axes = [i for i in range(1, input_tensor_shape.ndims-1)] if channel_last else [i for i in range(2, input_tensor_shape.ndims)]
80 | self.PRelu = tf.keras.layers.PReLU(weights=[self.slope], shared_axes = axes)
81 |
82 | def __call__(self, inputs):
83 | return self.PRelu(inputs)
84 |
85 | @OPERATOR.register_operator("Sin")
86 | class TFSin():
87 | def __init__(self, *args, **kwargs) -> None:
88 | super().__init__()
89 |
90 | def __call__(self, inputs):
91 | return tf.sin(inputs)
92 |
93 | @OPERATOR.register_operator("Sinh")
94 | class TFSinh():
95 | def __init__(self, *args, **kwargs) -> None:
96 | super().__init__()
97 |
98 | def __call__(self, inputs):
99 | return tf.sinh(inputs)
100 |
101 | @OPERATOR.register_operator("Cos")
102 | class TFCos():
103 | def __init__(self, *args, **kwargs) -> None:
104 | super().__init__()
105 |
106 | def __call__(self, inputs):
107 | return tf.cos(inputs)
108 |
109 | @OPERATOR.register_operator("Cosh")
110 | class TFCosh():
111 | def __init__(self, *args, **kwargs) -> None:
112 | super().__init__()
113 |
114 | def __call__(self, inputs):
115 | return tf.cosh(inputs)
116 |
117 | @OPERATOR.register_operator("Tan")
118 | class TFTan():
119 | def __init__(self, *args, **kwargs) -> None:
120 | super().__init__()
121 |
122 | def __call__(self, inputs):
123 | return tf.tan(inputs)
124 |
125 | @OPERATOR.register_operator("Tanh")
126 | class TFTanh():
127 | def __init__(self, *args, **kwargs) -> None:
128 | super().__init__()
129 |
130 | def __call__(self, inputs):
131 | return tf.tanh(inputs)
132 |
133 | @OPERATOR.register_operator("Softmax")
134 | class TFSoftmax():
135 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, node_outputs, layout_dict, *args, **kwargs) -> None:
136 | super().__init__()
137 | self.axis = node_attribute.get('axis', -1)
138 | if self.axis == -1:
139 | self.axis = len(tensor_grap[node_inputs[0]].shape.as_list()) - 1
140 | if layout_dict[node_inputs[0]] == Layout.Channel_Last:
141 | self.axis = channel_to_last_dimension(self.axis)
142 |
143 | def __call__(self, inputs):
144 | return keras.activations.softmax(inputs, axis=self.axis)
145 |
146 | @OPERATOR.register_operator("Softplus")
147 | class TFSoftplus():
148 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, *args, **kwargs) -> None:
149 | super().__init__()
150 |
151 | def __call__(self, inputs):
152 | return keras.activations.softplus(inputs)
153 |
154 | @OPERATOR.register_operator("Softsign")
155 | class TFSoftsign():
156 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, *args, **kwargs) -> None:
157 | super().__init__()
158 |
159 | def __call__(self, inputs):
160 | return keras.activations.softsign(inputs)
161 |
162 | @OPERATOR.register_operator("Selu")
163 | class TFSelu():
164 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, *args, **kwargs) -> None:
165 | super().__init__()
166 |
167 | def __call__(self, inputs):
168 | return keras.activations.selu(inputs)
169 |
170 | @OPERATOR.register_operator("Elu")
171 | class TFElu():
172 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, *args, **kwargs) -> None:
173 | super().__init__()
174 |
175 | def __call__(self, inputs):
176 | return keras.activations.elu(inputs)
177 |
178 | @OPERATOR.register_operator("Celu")
179 | class TFCelu():
180 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, *args, **kwargs) -> None:
181 | super().__init__()
182 | self.alpha = node_attribute.get("alpha", 1.0)
183 |
184 | def __call__(self, inputs):
185 | return tf.maximum(inputs, 0) + tf.minimum(0, self.alpha*(tf.exp(inputs/self.alpha)-1))
--------------------------------------------------------------------------------
/onnx2tflite/onnx2tflite/components/builder1.py:
--------------------------------------------------------------------------------
1 | import os
2 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
3 |
4 | import tensorflow as tf
5 | from tensorflow import keras
6 | from onnx import numpy_helper
7 | from .dataloader import RandomLoader, ImageLoader
8 |
9 | from onnx2tflite.utils import OPERATOR
10 | from onnx2tflite.layers import conv_layers
11 | from onnx2tflite.utils.definitions import *
12 | from onnx2tflite.utils.graph_tools import build_tf_inputs, decode_node_attribute
13 |
14 | import logging
15 |
16 | # 设置日志
17 | LOG = logging.getLogger("keras_builder")
18 | LOG.setLevel(logging.INFO) # 设定日志级别
19 |
20 | # 添加日志处理器(如果未添加)
21 | if not LOG.hasHandlers():
22 | handler = logging.StreamHandler() # 输出到控制台
23 | formatter = logging.Formatter("%(levelname)s: %(message)s") # 设置日志格式
24 | handler.setFormatter(formatter)
25 | LOG.addHandler(handler)
26 |
27 |
28 | def keras_builder(onnx_model, native_groupconv:bool=False):
29 | """
30 | 将 ONNX 模型转换为 Keras 模型。
31 |
32 | 参数:
33 | - onnx_model (onnx.ModelProto): 需要转换的 ONNX 模型。
34 | - native_groupconv (bool, 可选): 是否保持 ONNX 原生的分组卷积操作。默认 False。
35 |
36 | 返回:
37 | - keras_model (keras.Model): 生成的 Keras 模型。
38 | - input_layout (dict): ONNX 输入张量的布局信息。
39 | - output_layout (dict): ONNX 输出张量的布局信息。
40 | """
41 | # 设置全局变量,控制是否使用原生 ONNX 分组卷积
42 | conv_layers.USE_NATIVE_GROUP_CONV = native_groupconv
43 |
44 | # 解析 ONNX 计算图
45 | model_graph = onnx_model.graph
46 | layout_dict, tf_tensor = {}, {} # 存储 ONNX 层的布局信息 & TensorFlow 层的映射
47 |
48 | '''
49 | 初始化 ONNX 的权重张量
50 | '''
51 |
52 | onnx_weights = dict()
53 |
54 | for initializer in model_graph.initializer:
55 | # 将 ONNX 权重转换为 NumPy 数组
56 | onnx_weights[initializer.name] = numpy_helper.to_array(initializer)
57 |
58 | '''
59 | 解析 ONNX 输入节点并转换为 TensorFlow 输入层
60 | '''
61 | input_nodes = build_tf_inputs(model_graph, layout_dict) # 解析 ONNX 输入
62 | tf_tensor.update(input_nodes) # 更新 TensorFlow 层的映射字典
63 |
64 | '''
65 | 遍历 ONNX 计算图中的所有节点,并转换为 TensorFlow 层
66 | '''
67 |
68 | ###########################################################################################################
69 | for node in model_graph.node:
70 |
71 | op_name, node_inputs, node_outputs = node.op_type, node.input, node.output
72 | op_attr = decode_node_attribute(node) # 解析 ONNX 节点的属性
73 |
74 | # 查找 TensorFlow 对应的操作
75 | tf_operator = OPERATOR.get(op_name)
76 | if tf_operator is None:
77 | raise KeyError(f"{op_name} not implemented yet")
78 |
79 | _inputs = None
80 | if len(node_inputs) > 0: # 如果输入张量已在 `tf_tensor` 中,使用它,否则从 `onnx_weights` 取出
81 | _inputs = tf_tensor[node_inputs[0]] if node_inputs[0] in tf_tensor else onnx_weights[node_inputs[0]]
82 |
83 | # 初始化 layout(数据格式,例如 NHWC)
84 | for index in range(len(node_outputs)):
85 | layout_dict[node_outputs[index]] = layout_dict.get(node_inputs[0], Layout.Default)
86 |
87 | # 执行转换:ONNX 层 -> TensorFlow 层
88 | res = tf_operator(tf_tensor, onnx_weights, node_inputs, op_attr, node_outputs, layout_dict)(_inputs)
89 |
90 | ###########################################################################################################
91 | if isinstance(res, list): # 处理多个输出
92 | for index in range(len(node_outputs)):
93 | tf_tensor[node_outputs[index]] = res[index]
94 | else:
95 | tf_tensor[node_outputs[0]] = res
96 |
97 | '''
98 | 构建 Keras 模型:
99 | INFO: Keras 模型输入形状: (1, 256, 256, 4)
100 | INFO:keras_builder:Keras 模型输入形状: (1, 256, 256, 4)
101 | INFO: Keras 模型输出形状: (1, 512, 512, 3)
102 | INFO:keras_builder:Keras 模型输出形状: (1, 512, 512, 3)
103 | '''
104 | input_nodes = [tf_tensor[x.name] for x in model_graph.input] # 获取 ONNX 输入
105 | outputs_nodes = [tf_tensor[x.name] for x in model_graph.output] # 获取 ONNX 输出
106 | keras_model = keras.Model(inputs=input_nodes, outputs=outputs_nodes) # 构建 Keras 模型
107 | keras_model.trainable = False # 设定为不可训练
108 | # keras_model.summary() # 可选,打印模型结构
109 | # print(layout_dict)
110 |
111 | ####################################################################
112 | '''
113 | 在返回模型之前,检查是否仍然包含动态输入/输出
114 | '''
115 | # 1. 获取 Keras 模型的输入输出形状
116 | input_shape = keras_model.input_shape
117 | output_shape = keras_model.output_shape
118 |
119 | # 2. 检查是否存在动态输入 (None 表示动态形状)
120 | if any(dim is None for dim in input_shape):
121 | LOG.warning(f"Keras 模型仍然包含动态输入: {input_shape}")
122 |
123 | # 3. 检查是否存在动态输出
124 | if any(dim is None for dim in output_shape):
125 | LOG.warning(f"Keras 模型仍然包含动态输出: {output_shape}")
126 |
127 | # 4. 记录信息
128 | LOG.info(f"Keras 模型输入形状: {input_shape}")
129 | LOG.info(f"Keras 模型输出形状: {output_shape}")
130 |
131 | ####################################################################
132 | # 记录 ONNX 的输入和输出布局
133 | input_layout, output_layout = {}, {}
134 | for inp in model_graph.input:
135 | input_layout[inp.name] = layout_dict[inp.name]
136 | for oup in model_graph.output:
137 | output_layout[oup.name] = layout_dict[oup.name]
138 |
139 | return keras_model, input_layout, output_layout # 返回 Keras 模型和布局信息
140 |
141 |
142 | def tflite_builder(keras_model, weight_quant:bool=False, fp16_model=False, int8_model:bool=False, image_root:str=None,
143 | int8_mean:list or float = [123.675, 116.28, 103.53], int8_std:list or float = [58.395, 57.12, 57.375]):
144 |
145 | """
146 | 将 Keras 模型转换为 TFLite 模型,并支持不同的量化模式。
147 |
148 | 参数:
149 | - keras_model (keras.Model): 需要转换的 Keras 模型。
150 | - weight_quant (bool, 可选): 是否进行权重量化。默认 False。
151 | - fp16_model (bool, 可选): 是否转换为 FP16 精度(适用于部分硬件优化)。默认 False。
152 | - int8_model (bool, 可选): 是否转换为 INT8 量化模型(适用于边缘设备)。默认 False。
153 | - image_root (str, 可选): 如果使用 INT8 量化,提供用于校准的图像数据目录。
154 | - int8_mean (list or float, 可选): INT8 量化校准的均值,默认 `[123.675, 116.28, 103.53]`。
155 | - int8_std (list or float, 可选): INT8 量化校准的标准差,默认 `[58.395, 57.12, 57.375]`。
156 |
157 | 返回:
158 | - tflite_model (bytes): 转换后的 TFLite 模型。
159 | """
160 | # 1. 创建 TensorFlow Lite 转换器
161 | converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
162 |
163 | # 2. 设定转换支持的运算 # TensorFlow Lite 内置算子 # 允许使用部分 TensorFlow 原生算子
164 | converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS]
165 |
166 | # 3. 启用量化选项
167 | if weight_quant or int8_model or fp16_model:
168 | converter.experimental_new_converter = True # 使用新的 TFLite 转换器
169 | converter.optimizations = [tf.lite.Optimize.DEFAULT] # 启用优化
170 |
171 | # 4. 处理 FP16 量化(半精度浮点数)
172 | if fp16_model:
173 | converter.target_spec.supported_types = [tf.float16]
174 | converter.inference_input_type = tf.float32
175 | converter.inference_output_type = tf.float32
176 |
177 | # 5. 处理 INT8 量化
178 | elif int8_model:
179 | assert len(keras_model.inputs) == 1, f"help want, only support single input model."
180 | # 获取输入形状
181 | shape = list(keras_model.inputs[0].shape)
182 | # 选择数据集:使用 `image_root` 进行 INT8 量化校准
183 | dataset = RandomLoader(shape) if image_root is None else ImageLoader(image_root, shape, int8_mean, int8_std)
184 | # 设定代表性数据集(TFLite 量化需要一个校准数据集)
185 | converter.representative_dataset = lambda: dataset
186 | # 使用 INT8 计算
187 | converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8, tf.lite.OpsSet.SELECT_TF_OPS]
188 | converter.target_spec.supported_types = []
189 | converter.inference_input_type = tf.uint8
190 | converter.inference_output_type = tf.uint8
191 | converter.experimental_new_converter = True # 启用新的转换器
192 | # 6. 进行 TFLite 转换
193 | tflite_model = converter.convert()
194 | return tflite_model # 返回 TFLite 模型
--------------------------------------------------------------------------------
/onnx2tflite/onnx2tflite.egg-info/PKG-INFO:
--------------------------------------------------------------------------------
1 | Metadata-Version: 2.1
2 | Name: onnx2tflite
3 | Version: 2.0
4 | Summary: onnx to keras/tensorflow lite
5 | Author: MPolaris
6 | License: Apache-2.0
7 | Platform: Windows
8 | Platform: linux
9 | Description-Content-Type: text/markdown
10 | Requires-Dist: onnx
11 | Requires-Dist: onnxruntime
12 | Requires-Dist: onnx-simplifier
13 | Requires-Dist: numpy<=1.24
14 | Requires-Dist: tensorflow<2.13,>=2.5
15 | Requires-Dist: opencv-python
16 |
17 | # ONNX->Keras and ONNX->TFLite tools
18 | ## Welcome
19 | If you have some good ideas, welcome to discuss or give project PRs.
20 |
21 | ## Install
22 | ```cmd
23 | git clone https://github.com/MPolaris/onnx2tflite.git
24 | cd onnx2tflite
25 | python setup.py install
26 | ```
27 | ```python
28 | from onnx2tflite import onnx_converter
29 | res = onnx_converter(
30 | onnx_model_path = "./model.onnx",
31 | need_simplify = True,
32 | output_path = "./models/",
33 | target_formats = ['tflite'],
34 | )
35 | ```
36 | ---
37 | ```cmd
38 | # base
39 | python -m onnx2tflite --weights "./your_model.onnx"
40 |
41 | # give save path
42 | python -m onnx2tflite --weights "./your_model.onnx" --outpath "./save_path"
43 |
44 | # save tflite model
45 | python -m onnx2tflite --weights "./your_model.onnx" --outpath "./save_path" --formats "tflite"
46 |
47 | # save keras and tflite model
48 | python -m onnx2tflite --weights "./your_model.onnx" --outpath "./save_path" --formats "tflite" "keras"
49 |
50 | # cutoff model, redefine inputs and outputs, support middle layers
51 | python -m onnx2tflite --weights "./your_model.onnx" --outpath "./save_path" --formats "tflite" --input-node-names "layer_inputname" --output-node-names "layer_outname1" "layer_outname2"
52 |
53 | # quantify model weight, only weight
54 | python -m onnx2tflite --weights "./your_model.onnx" --formats "tflite" --weigthquant
55 |
56 | # quantify model weight, include input and output
57 | ## fp16
58 | python -m onnx2tflite --weights "./your_model.onnx" --formats "tflite" --fp16
59 | ## recommend
60 | python -m onnx2tflite --weights "./your_model.onnx" --formats "tflite" --int8 --imgroot "./dataset_path" --int8mean 0 0 0 --int8std 255 255 255
61 | ## generate random data, instead of read from image file
62 | python -m onnx2tflite --weights "./your_model.onnx" --formats "tflite" --int8
63 | ```
64 | ---
65 | ## Features
66 | - High Consistency. Compare to ONNX outputs, average error less than 1e-5 per elements.
67 | - More Faster. Output tensorflow-lite model 30% faster than [onnx_tf](https://github.com/onnx/onnx-tensorflow).
68 | - Auto Channel Align. Auto convert pytorch format(NCWH) to tensorflow format(NWHC).
69 | - Deployment Support. Support output quantitative model, include fp16 quantization and uint8 quantization.
70 | - Code Friendly. I've been trying to keep the code structure simple and clear.
71 | ---
72 |
73 | ## Pytorch -> ONNX -> Tensorflow-Keras -> Tensorflow-Lite
74 |
75 | - ### From torchvision to tensorflow-lite
76 | ```python
77 | import torch
78 | import torchvision
79 | _input = torch.randn(1, 3, 224, 224)
80 | model = torchvision.models.mobilenet_v2(True)
81 | # use default settings is ok
82 | torch.onnx.export(model, _input, './mobilenetV2.onnx', opset_version=11)# or opset_version=13
83 |
84 | from converter import onnx_converter
85 | onnx_converter(
86 | onnx_model_path = "./mobilenetV2.onnx",
87 | need_simplify = True,
88 | output_path = "./",
89 | target_formats = ['tflite'], # or ['keras'], ['keras', 'tflite']
90 | weight_quant = False,
91 | fp16_model=False,
92 | int8_model = False,
93 | int8_mean = None,
94 | int8_std = None,
95 | image_root = None
96 | )
97 | ```
98 | - ### From custom pytorch model to tensorflow-lite-int8
99 | ```python
100 | import torch
101 | import torch.nn as nn
102 | import torch.nn.functional as F
103 |
104 | class MyModel(nn.Module):
105 | def __init__(self):
106 | self.conv = nn.Sequential(
107 | nn.Conv2d(3, 64, kernel_size=3, padding=1),
108 | nn.BatchNorm2d(64),
109 | nn.ReLU(inplace=True),
110 | )
111 |
112 | def forward(self, x):
113 | return self.conv(x)
114 |
115 | model = MyModel()
116 | model.load_state_dict(torch.load("model_checkpoint.pth", map_location="cpu"))
117 |
118 | _input = torch.randn(1, 3, 224, 224)
119 | torch.onnx.export(model, _input, './mymodel.onnx', opset_version=11)# or opset_version=13
120 |
121 | from converter import onnx_converter
122 | onnx_converter(
123 | onnx_model_path = "./mymodel.onnx",
124 | need_simplify = True,
125 | output_path = "./",
126 | target_formats = ['tflite'], #or ['keras'], ['keras', 'tflite']
127 | weight_quant = False,
128 | int8_model = True, # do quantification
129 | int8_mean = [123.675, 116.28, 103.53], # give mean of image preprocessing
130 | int8_std = [58.395, 57.12, 57.375], # give std of image preprocessing
131 | image_root = "./dataset/train" # give image folder of train
132 | )
133 | ```
134 | ---
135 | ## Validated models
136 | - [SSD](https://github.com/qfgaohao/pytorch-ssd)
137 | - [HRNet](HRNet-Facial-Landmark-Detection)
138 | - [YOLOX](https://github.com/Megvii-BaseDetection/YOLOX)
139 | - [YOLOV3](https://github.com/ultralytics/yolov3)
140 | - [YOLOV4](https://github.com/Tianxiaomo/pytorch-YOLOv4)
141 | - [YOLOV5](https://github.com/ultralytics/yolov5)
142 | - [YOLOV6](https://github.com/meituan/YOLOv6)
143 | - [YOLOV7](https://github.com/WongKinYiu/yolov7)
144 | - [YOLOV10](https://github.com/THU-MIG/yolov10)
145 | - [MoveNet](https://github.com/fire717/movenet.pytorch)
146 | - [UNet\FPN](https://github.com/bigmb/Unet-Segmentation-Pytorch-Nest-of-Unets)
147 | - ViT(torchvision)
148 | - [SwinTransformerV1](https://github.com/microsoft/Swin-Transformer)
149 | - MLP(custom)
150 | - DCGAN(custom)
151 | - [AutoEncoder/VAE](https://github.com/AntixK/PyTorch-VAE)
152 | - all torchvision classification models
153 | - some segmation models in torchvision
154 | - 1D or 2D CNN without special operators(custom)
155 | ---
156 | ## Add operator by yourself
157 | When you counter unspported operator, you can choose to add it by yourself or make an issue.
158 | It's very simple to implement a new operator parser by following these steps below.
159 | Step 0: Select a corresponding layer code file in [layers folder](./onnx2tflite/layers/), such as activations_layers.py for 'HardSigmoid'.
160 | Step 1: Open it, and edit it:
161 | ```python
162 | # all operators regist through OPERATOR register.
163 | # regist operator's name is onnx operator name.
164 | @OPERATOR.register_operator("HardSigmoid")
165 | class TFHardSigmoid():
166 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, node_outputs, layout_dict, *args, **kwargs) -> None:
167 | '''
168 | :param tensor_grap: dict, key is node name, value is tensorflow-keras node output tensor.
169 | :param node_weights: dict, key is node name, value is static data, such as weight/bias/constant, weight should be transfom by dimension_utils.tensor_NCD_to_NDC_format at most time.
170 | :param node_inputs: List[str], stored node input names, indicates which nodes the input comes from, tensor_grap and node_weights are possible.
171 | :param node_attribute: dict, key is attribute name, such as 'axis' or 'perm'. value type is indeterminate, such as List[int] or int or float. notice that type of 'axis' value should be adjusted form NCHW to NHWC by dimension_utils.channel_to_last_dimension or dimension_utils.shape_NCD_to_NDC_format.
172 | :param node_inputs: List[str], stored node output names.
173 | :param layout_dict: List[Layout], stored all before node's layout.
174 | '''
175 | super().__init__()
176 | self.alpha = node_attribute.get("alpha", 0.2)
177 | self.beta = node_attribute.get("beta", 0.5)
178 |
179 | def __call__(self, inputs):
180 | return tf.clip_by_value(self.alpha*inputs+self.beta, 0, 1)
181 | ```
182 | Step 2: Make it work without error.
183 | Step 3: Convert model to tflite without any quantification.
184 |
185 | ---
186 |
187 | # License
188 | This software is covered by Apache-2.0 license.
189 |
--------------------------------------------------------------------------------
/onnx2tflite/onnx2tflite/layers/mathematics_layers.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import numpy as np
3 | import tensorflow as tf
4 |
5 | from onnx2tflite.utils.definitions import Layout
6 | from onnx2tflite.utils import OPERATOR, dimension_utils, np2tf_type
7 |
8 | LOG = logging.getLogger("calculations_layers :")
9 |
10 | def np2tf(x):
11 | if isinstance(x, np.ndarray):
12 | x = tf.convert_to_tensor(x, dtype=np2tf_type[x.dtype.name])
13 | return x, False
14 | return x, True
15 |
16 | def match_tensor(x1:tf.Tensor or np.ndarray, x2:tf.Tensor or np.ndarray, x1_layout:Layout, x2_layout:Layout):
17 |
18 | x1, f1 = np2tf(x1)
19 | x2, f2 = np2tf(x2)
20 |
21 | # no need to transpose if all var are tensor, we assume tensor are computed by gragh.
22 | if f1 and f2:
23 | if x1_layout != x2_layout:
24 | if x1_layout == Layout.Channel_Last:
25 | x1 = dimension_utils.tensor_NDC_to_NCD_format(x1)
26 | elif x2_layout == Layout.Channel_Last:
27 | x2 = dimension_utils.tensor_NDC_to_NCD_format(x2)
28 | return x1, x2, Layout.Channel_First
29 |
30 | # ensure tensor is set to x1, const weights set to x2
31 | out_layout = x1_layout
32 | if f2:
33 | x1, x2 = x2, x1
34 | out_layout = x2_layout
35 |
36 |
37 | if out_layout == Layout.Channel_Last:
38 | if x1.shape.ndims != x2.shape.ndims:
39 | while x2.shape.ndims < x1.shape.ndims:
40 | x2 = tf.expand_dims(x2, axis=0)
41 | x2 = dimension_utils.tensor_NCD_to_NDC_format(x2)
42 |
43 | x2 = tf.cast(x2, x1.dtype)
44 | return (x2, x1, out_layout) if f2 else (x1, x2, out_layout)
45 |
46 | '''
47 | tensor(NDC) + const
48 | tensor(NCD) + const
49 | tensor(NDC) + tensor(NDC)
50 | tensor(NCD) + tensor(NCD)
51 | '''
52 |
53 | class BaseArithmetic:
54 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, node_outputs, layout_dict, *args, **kwargs):
55 | self.left_val, self.right_val = None, None
56 | left_layout, right_layout = Layout.Default, Layout.Default
57 |
58 | if node_inputs[0] in tensor_grap:
59 | self.left_val = tensor_grap[node_inputs[0]]
60 | left_layout = layout_dict[node_inputs[0]]
61 | else:
62 | self.left_val = node_weights[node_inputs[0]]
63 |
64 | if node_inputs[1] in tensor_grap:
65 | self.right_val = tensor_grap[node_inputs[1]]
66 | right_layout = layout_dict[node_inputs[1]]
67 | else:
68 | self.right_val = node_weights[node_inputs[1]]
69 |
70 | if left_layout == right_layout:
71 | return
72 |
73 | self.left_val, self.right_val, out_layout = match_tensor(self.left_val, self.right_val, left_layout, right_layout)
74 | layout_dict[node_outputs[0]] = out_layout
75 |
76 | @OPERATOR.register_operator("Add")
77 | class TFAdd(BaseArithmetic):
78 | def __init__(self, *args, **kwargs):
79 | super().__init__(*args, **kwargs)
80 |
81 | def __call__(self, *args, **kwargs):
82 | return self.left_val + self.right_val
83 |
84 | @OPERATOR.register_operator("Sub")
85 | class TFSub(BaseArithmetic):
86 | def __init__(self, *args, **kwargs):
87 | super().__init__(*args, **kwargs)
88 |
89 | def __call__(self, *args, **kwargs):
90 | return self.left_val - self.right_val
91 |
92 | @OPERATOR.register_operator("Mul")
93 | class TFMul(BaseArithmetic):
94 | def __init__(self,*args, **kwargs):
95 | super().__init__(*args, **kwargs)
96 |
97 | def __call__(self, *args, **kwargs):
98 | return self.left_val * self.right_val
99 |
100 | @OPERATOR.register_operator("Div")
101 | class TFDiv(BaseArithmetic):
102 | def __init__(self,*args, **kwargs):
103 | super().__init__(*args, **kwargs)
104 |
105 | def __call__(self, *args, **kwargs):
106 | return self.left_val / self.right_val
107 |
108 | @OPERATOR.register_operator("MatMul")
109 | class TFMatMul():
110 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, node_outputs, layout_dict, *args, **kwargs):
111 | super().__init__()
112 | if node_inputs[0] in tensor_grap:
113 | self.A = tensor_grap[node_inputs[0]]
114 | if layout_dict[node_inputs[0]] == Layout.Channel_Last:
115 | self.A = dimension_utils.tensor_NDC_to_NCD_format(self.A)
116 | else:
117 | self.A = node_weights[node_inputs[0]]
118 |
119 | if node_inputs[1] in tensor_grap:
120 | self.B = tensor_grap[node_inputs[1]]
121 | if layout_dict[node_inputs[1]] == Layout.Channel_Last:
122 | self.B = dimension_utils.tensor_NDC_to_NCD_format(self.B)
123 | else:
124 | self.B = node_weights[node_inputs[1]]
125 |
126 | self.dense = tf.keras.layers.Dense(self.B.shape[-1],
127 | weights=[self.B],
128 | use_bias=False)
129 |
130 | layout_dict[node_outputs[0]] = Layout.Channel_First
131 |
132 | def __call__(self, *args, **kwargs):
133 | # out = tf.matmul(self.A, self.B)
134 | try:
135 | out = self.dense(self.A)
136 | except Exception:
137 | out = tf.matmul(self.A, self.B)
138 | return out
139 |
140 | @OPERATOR.register_operator("Mod")
141 | class TFMod():
142 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, *args, **kwargs):
143 | super().__init__()
144 | self.fmod = bool(node_attribute.get("fmod", 0))
145 | self.mod_value = None
146 | if node_inputs[1] in node_weights:
147 | self.mod_value = node_weights[node_inputs[1]]
148 | else:
149 | self.mod_value = tensor_grap[node_inputs[1]]
150 |
151 | def __call__(self, inputs):
152 | if self.fmod:
153 | return tf.math.floormod(inputs, tf.cast(self.mod_value, inputs.dtype))
154 | else:
155 | return tf.math.mod(inputs, tf.cast(self.mod_value, inputs.dtype))
156 |
157 | @OPERATOR.register_operator("Pow")
158 | class TFPow():
159 | def __init__(self, tensor_grap, node_weights, node_inputs, *args, **kwargs):
160 | super().__init__()
161 | self.power_index = node_weights[node_inputs[1]]
162 |
163 | def __call__(self, inputs, *args, **kwargs):
164 | return tf.pow(inputs, self.power_index)
165 |
166 | @OPERATOR.register_operator("Reciprocal")
167 | class TFReciprocal():
168 | def __init__(self, *args, **kwargs):
169 | super().__init__()
170 |
171 | def __call__(self, inputs, *args, **kwargs):
172 | return 1/inputs
173 |
174 | @OPERATOR.register_operator("Sqrt")
175 | class TFSqrt():
176 | def __init__(self, *args, **kwargs):
177 | super().__init__()
178 |
179 | def __call__(self, inputs, *args, **kwargs):
180 | return tf.sqrt(inputs)
181 |
182 | @OPERATOR.register_operator("Exp")
183 | class TFSqrt():
184 | def __init__(self, *args, **kwargs):
185 | super().__init__()
186 |
187 | def __call__(self, inputs, *args, **kwargs):
188 | return tf.exp(inputs)
189 |
190 | @OPERATOR.register_operator("Log")
191 | class TFLog():
192 | def __init__(self, *args, **kwargs):
193 | super().__init__()
194 |
195 | def __call__(self, inputs, *args, **kwargs):
196 | return tf.log(inputs)
197 |
198 | class ReduceBase:
199 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, node_outputs, layout_dict, *args, **kwargs):
200 | self.keep_dims = node_attribute.get("keepdims", 1) == 1
201 | input_shape_len = len(tensor_grap[node_inputs[0]].shape)
202 | if layout_dict[node_inputs[0]] == Layout.Channel_Last:
203 | self.axes = [dimension_utils.channel_to_last_dimension(i) if i >=0 else dimension_utils.channel_to_last_dimension(input_shape_len + i) for i in node_attribute.get("axes", [-1])]
204 | else:
205 | self.axes = [i if i >=0 else input_shape_len + i for i in node_attribute.get("axes", [-1])]
206 |
207 | @OPERATOR.register_operator("ReduceSum")
208 | class TFReduceSum(ReduceBase):
209 | def __init__(self, *args, **kwargs):
210 | super().__init__(*args, **kwargs)
211 |
212 | def __call__(self, inputs, *args, **kwargs):
213 | return tf.math.reduce_sum(inputs, axis=self.axes, keepdims=self.keep_dims)
214 |
215 | @OPERATOR.register_operator("ReduceMean")
216 | class TFReduceMean(ReduceBase):
217 | def __init__(self, *args, **kwargs):
218 | super().__init__(*args, **kwargs)
219 |
220 | def __call__(self, inputs, *args, **kwargs):
221 | return tf.math.reduce_mean(inputs, axis=self.axes, keepdims=self.keep_dims)
222 |
223 | @OPERATOR.register_operator("ReduceMax")
224 | class TFReduceMax(ReduceBase):
225 | def __init__(self, *args, **kwargs):
226 | super().__init__(*args, **kwargs)
227 |
228 | def __call__(self, inputs, *args, **kwargs):
229 | return tf.math.reduce_max(inputs, axis=self.axes, keepdims=self.keep_dims)
230 |
231 | @OPERATOR.register_operator("ReduceMin")
232 | class TFReduceMin(ReduceBase):
233 | def __init__(self, *args, **kwargs):
234 | super().__init__(*args, **kwargs)
235 |
236 | def __call__(self, inputs, *args, **kwargs):
237 | return tf.math.reduce_min(inputs, axis=self.axes, keepdims=self.keep_dims)
238 |
239 | @OPERATOR.register_operator("ArgMax")
240 | class TFArgMax():
241 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, node_outputs, layout_dict, *args, **kwargs):
242 | super().__init__()
243 | self.axis = node_attribute.get('axis', 0)
244 | if layout_dict[node_inputs[0]] == Layout.Channel_Last:
245 | self.axis = dimension_utils.channel_to_last_dimension(self.axis)
246 | self.keepdims = node_attribute.get("keepdims", 1) == 1
247 |
248 | def __call__(self, inputs, *args, **kwargs):
249 | _inputs = tf.argmax(inputs, axis=self.axis)
250 | if self.keepdims:
251 | _inputs = tf.expand_dims(_inputs, axis=self.axis)
252 | return _inputs
253 |
254 | @OPERATOR.register_operator("ArgMin")
255 | class TFArgMin():
256 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, node_outputs, layout_dict, *args, **kwargs):
257 | super().__init__()
258 | self.axis = node_attribute.get('axis', 0)
259 | if layout_dict[node_inputs[0]] == Layout.Channel_Last:
260 | self.axis = dimension_utils.channel_to_last_dimension(self.axis)
261 | self.keepdims = node_attribute.get("keepdims", 1) == 1
262 |
263 | def __call__(self, inputs, *args, **kwargs):
264 | _inputs = tf.argmax(inputs, axis=self.axis)
265 | if self.keepdims:
266 | _inputs = tf.expand_dims(_inputs, axis=self.axis)
267 | return _inputs
268 |
269 | @OPERATOR.register_operator("Erf")
270 | class TFErf():
271 | def __init__(self, *args, **kwargs) -> None:
272 | pass
273 |
274 | def __call__(self, inputs):
275 | inputs = tf.math.erf(inputs)
276 | return inputs
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/model/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | class MBRConv5(nn.Module):
6 | def __init__(self, in_channels, out_channels, rep_scale=4):
7 | super(MBRConv5, self).__init__()
8 | self.in_channels = in_channels
9 | self.out_channels = out_channels
10 | self.conv = nn.Conv2d(in_channels, out_channels * rep_scale, 5, 1, 2)
11 | self.conv_bn = nn.Sequential(
12 | nn.BatchNorm2d(out_channels * rep_scale)
13 | )
14 | self.conv1 = nn.Conv2d(in_channels, out_channels * rep_scale, 1)
15 | self.conv1_bn = nn.Sequential(
16 | nn.BatchNorm2d(out_channels * rep_scale)
17 | )
18 | self.conv2 = nn.Conv2d(in_channels, out_channels * rep_scale, 3, 1, 1)
19 | self.conv2_bn = nn.Sequential(
20 | nn.BatchNorm2d(out_channels * rep_scale)
21 | )
22 | self.conv_crossh = nn.Conv2d(in_channels, out_channels * rep_scale, (3, 1), 1, (1, 0))
23 | self.conv_crossh_bn = nn.Sequential(
24 | nn.BatchNorm2d(out_channels * rep_scale)
25 | )
26 | self.conv_crossv = nn.Conv2d(in_channels, out_channels * rep_scale, (1, 3), 1, (0, 1))
27 | self.conv_crossv_bn = nn.Sequential(
28 | nn.BatchNorm2d(out_channels * rep_scale)
29 | )
30 | self.conv_out = nn.Conv2d(out_channels * rep_scale * 10, out_channels, 1)
31 |
32 | def forward(self, inp):
33 | x1 = self.conv(inp)
34 | x2 = self.conv1(inp)
35 | x3 = self.conv2(inp)
36 | x4 = self.conv_crossh(inp)
37 | x5 = self.conv_crossv(inp)
38 | x = torch.cat(
39 | [x1, x2, x3, x4, x5,
40 | self.conv_bn(x1),
41 | self.conv1_bn(x2),
42 | self.conv2_bn(x3),
43 | self.conv_crossh_bn(x4),
44 | self.conv_crossv_bn(x5)],
45 | 1
46 | )
47 | out = self.conv_out(x)
48 | return out
49 |
50 | def slim(self):
51 | conv_weight = self.conv.weight
52 | conv_bias = self.conv.bias
53 |
54 | conv1_weight = self.conv1.weight
55 | conv1_bias = self.conv1.bias
56 | conv1_weight = nn.functional.pad(conv1_weight, (2, 2, 2, 2))
57 |
58 | conv2_weight = self.conv2.weight
59 | conv2_weight = nn.functional.pad(conv2_weight, (1, 1, 1, 1))
60 | conv2_bias = self.conv2.bias
61 |
62 | conv_crossv_weight = self.conv_crossv.weight
63 | conv_crossv_weight = nn.functional.pad(conv_crossv_weight, (1, 1, 2, 2))
64 | conv_crossv_bias = self.conv_crossv.bias
65 |
66 | conv_crossh_weight = self.conv_crossh.weight
67 | conv_crossh_weight = nn.functional.pad(conv_crossh_weight, (2, 2, 1, 1))
68 | conv_crossh_bias = self.conv_crossh.bias
69 |
70 | conv1_bn_weight = self.conv1.weight
71 | conv1_bn_weight = nn.functional.pad(conv1_bn_weight, (2, 2, 2, 2))
72 |
73 | conv2_bn_weight = self.conv2.weight
74 | conv2_bn_weight = nn.functional.pad(conv2_bn_weight, (1, 1, 1, 1))
75 |
76 | conv_crossv_bn_weight = self.conv_crossv.weight
77 | conv_crossv_bn_weight = nn.functional.pad(conv_crossv_bn_weight, (1, 1, 2, 2))
78 |
79 | conv_crossh_bn_weight = self.conv_crossh.weight
80 | conv_crossh_bn_weight = nn.functional.pad(conv_crossh_bn_weight, (2, 2, 1, 1))
81 |
82 | bn = self.conv_bn[0]
83 | k = 1 / (bn.running_var + bn.eps) ** .5
84 | b = - bn.running_mean / (bn.running_var + bn.eps) ** .5
85 |
86 | conv_bn_weight = self.conv.weight * k.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
87 | conv_bn_weight = conv_bn_weight * bn.weight.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
88 | conv_bn_bias = self.conv.bias * k + b
89 | conv_bn_bias = conv_bn_bias * bn.weight + bn.bias
90 |
91 | bn = self.conv1_bn[0]
92 | k = 1 / (bn.running_var + bn.eps) ** .5
93 | b = - bn.running_mean / (bn.running_var + bn.eps) ** .5
94 | conv1_bn_weight = conv1_bn_weight * k.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
95 | conv1_bn_weight = conv1_bn_weight * bn.weight.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
96 | conv1_bn_bias = self.conv1.bias * k + b
97 | conv1_bn_bias = conv1_bn_bias * bn.weight + bn.bias
98 |
99 | bn = self.conv2_bn[0]
100 | k = 1 / (bn.running_var + bn.eps) ** .5
101 | b = - bn.running_mean / (bn.running_var + bn.eps) ** .5
102 | conv2_bn_weight = conv2_bn_weight * k.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
103 | conv2_bn_weight = conv2_bn_weight * bn.weight.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
104 | conv2_bn_bias = self.conv2.bias * k + b
105 | conv2_bn_bias = conv2_bn_bias * bn.weight + bn.bias
106 |
107 | bn = self.conv_crossv_bn[0]
108 | k = 1 / (bn.running_var + bn.eps) ** .5
109 | b = - bn.running_mean / (bn.running_var + bn.eps) ** .5
110 | conv_crossv_bn_weight = conv_crossv_bn_weight * k.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
111 | conv_crossv_bn_weight = conv_crossv_bn_weight * bn.weight.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
112 | conv_crossv_bn_bias = self.conv_crossv.bias * k + b
113 | conv_crossv_bn_bias = conv_crossv_bn_bias * bn.weight + bn.bias
114 |
115 | bn = self.conv_crossh_bn[0]
116 | k = 1 / (bn.running_var + bn.eps) ** .5
117 | b = - bn.running_mean / (bn.running_var + bn.eps) ** .5
118 | conv_crossh_bn_weight = conv_crossh_bn_weight * k.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
119 | conv_crossh_bn_weight = conv_crossh_bn_weight * bn.weight.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
120 | conv_crossh_bn_bias = self.conv_crossh.bias * k + b
121 | conv_crossh_bn_bias = conv_crossh_bn_bias * bn.weight + bn.bias
122 |
123 | weight = torch.cat(
124 | [conv_weight, conv1_weight, conv2_weight,
125 | conv_crossh_weight, conv_crossv_weight,
126 | conv_bn_weight, conv1_bn_weight, conv2_bn_weight,
127 | conv_crossh_bn_weight, conv_crossv_bn_weight],
128 | 0
129 | )
130 | weight_compress = self.conv_out.weight.squeeze()
131 | weight = torch.matmul(weight_compress, weight.permute([2, 3, 0, 1])).permute([2, 3, 0, 1])
132 | bias_ = torch.cat(
133 | [conv_bias, conv1_bias, conv2_bias,
134 | conv_crossh_bias, conv_crossv_bias,
135 | conv_bn_bias, conv1_bn_bias, conv2_bn_bias,
136 | conv_crossh_bn_bias, conv_crossv_bn_bias],
137 | 0
138 | )
139 | bias = torch.matmul(weight_compress, bias_)
140 | if isinstance(self.conv_out.bias, torch.Tensor):
141 | bias = bias + self.conv_out.bias
142 | return weight, bias
143 |
144 |
145 | ##############################################################################################################
146 | class MBRConv3(nn.Module):
147 | def __init__(self, in_channels, out_channels, rep_scale=4):
148 | super(MBRConv3, self).__init__()
149 |
150 | self.in_channels = in_channels
151 | self.out_channels = out_channels
152 | self.rep_scale = rep_scale
153 |
154 | self.conv = nn.Conv2d(in_channels, out_channels * rep_scale, 3, 1, 1)
155 | self.conv_bn = nn.Sequential(
156 | nn.BatchNorm2d(out_channels * rep_scale)
157 | )
158 | self.conv1 = nn.Conv2d(in_channels, out_channels * rep_scale, 1)
159 | self.conv1_bn = nn.Sequential(
160 | nn.BatchNorm2d(out_channels * rep_scale)
161 | )
162 | self.conv_crossh = nn.Conv2d(in_channels, out_channels * rep_scale, (3, 1), 1, (1, 0))
163 | self.conv_crossh_bn = nn.Sequential(
164 | nn.BatchNorm2d(out_channels * rep_scale)
165 | )
166 | self.conv_crossv = nn.Conv2d(in_channels, out_channels * rep_scale, (1, 3), 1, (0, 1))
167 | self.conv_crossv_bn = nn.Sequential(
168 | nn.BatchNorm2d(out_channels * rep_scale)
169 | )
170 | self.conv_out = nn.Conv2d(out_channels * rep_scale * 8, out_channels, 1)
171 |
172 | def forward(self, inp):
173 | x0 = self.conv(inp)
174 | x1 = self.conv1(inp)
175 | x2 = self.conv_crossh(inp)
176 | x3 = self.conv_crossv(inp)
177 | x = torch.cat(
178 | [ x0,x1,x2,x3,
179 | self.conv_bn(x0),
180 | self.conv1_bn(x1),
181 | self.conv_crossh_bn(x2),
182 | self.conv_crossv_bn(x3)],
183 | 1
184 | )
185 | out = self.conv_out(x)
186 | return out
187 |
188 | def slim(self):
189 | conv_weight = self.conv.weight
190 | conv_bias = self.conv.bias
191 |
192 | conv1_weight = self.conv1.weight
193 | conv1_bias = self.conv1.bias
194 | conv1_weight = F.pad(conv1_weight, (1, 1, 1, 1))
195 |
196 | conv_crossh_weight = self.conv_crossh.weight
197 | conv_crossh_bias = self.conv_crossh.bias
198 | conv_crossh_weight = F.pad(conv_crossh_weight, (1, 1, 0, 0))
199 |
200 | conv_crossv_weight = self.conv_crossv.weight
201 | conv_crossv_bias = self.conv_crossv.bias
202 | conv_crossv_weight = F.pad(conv_crossv_weight, (0, 0, 1, 1))
203 |
204 | # conv_bn
205 | bn = self.conv_bn[0]
206 | k = 1 / torch.sqrt(bn.running_var + bn.eps)
207 | conv_bn_weight = self.conv.weight * k.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
208 | conv_bn_weight = conv_bn_weight * bn.weight.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
209 | conv_bn_bias = self.conv.bias * k + (-bn.running_mean * k)
210 | conv_bn_bias = conv_bn_bias * bn.weight + bn.bias
211 |
212 | # conv1_bn
213 | bn = self.conv1_bn[0]
214 | k = 1 / torch.sqrt(bn.running_var + bn.eps)
215 | conv1_bn_weight = self.conv1.weight * k.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
216 | conv1_bn_weight = conv1_bn_weight * bn.weight.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
217 | conv1_bn_weight = F.pad(conv1_bn_weight, (1, 1, 1, 1))
218 | conv1_bn_bias = self.conv1.bias * k + (-bn.running_mean * k)
219 | conv1_bn_bias = conv1_bn_bias * bn.weight + bn.bias
220 |
221 | # conv_crossh_bn
222 | bn = self.conv_crossh_bn[0]
223 | k = 1 / torch.sqrt(bn.running_var + bn.eps)
224 | conv_crossh_bn_weight = self.conv_crossh.weight * k.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
225 | conv_crossh_bn_weight = conv_crossh_bn_weight * bn.weight.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
226 | conv_crossh_bn_weight = F.pad(conv_crossh_bn_weight, (1, 1, 0, 0))
227 | conv_crossh_bn_bias = self.conv_crossh.bias * k + (-bn.running_mean * k)
228 | conv_crossh_bn_bias = conv_crossh_bn_bias * bn.weight + bn.bias
229 |
230 | # conv_crossv_bn
231 | bn = self.conv_crossv_bn[0]
232 | k = 1 / torch.sqrt(bn.running_var + bn.eps)
233 | conv_crossv_bn_weight = self.conv_crossv.weight * k.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
234 | conv_crossv_bn_weight = conv_crossv_bn_weight * bn.weight.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
235 | conv_crossv_bn_weight = F.pad(conv_crossv_bn_weight, (0, 0, 1, 1))
236 | conv_crossv_bn_bias = self.conv_crossv.bias * k + (-bn.running_mean * k)
237 | conv_crossv_bn_bias = conv_crossv_bn_bias * bn.weight + bn.bias
238 |
239 | weight = torch.cat([
240 | conv_weight,
241 | conv1_weight,
242 | conv_crossh_weight,
243 | conv_crossv_weight,
244 | conv_bn_weight,
245 | conv1_bn_weight,
246 | conv_crossh_bn_weight,
247 | conv_crossv_bn_weight
248 | ], dim=0)
249 |
250 | bias = torch.cat([
251 | conv_bias,
252 | conv1_bias,
253 | conv_crossh_bias,
254 | conv_crossv_bias,
255 | conv_bn_bias,
256 | conv1_bn_bias,
257 | conv_crossh_bn_bias,
258 | conv_crossv_bn_bias
259 | ], dim=0)
260 |
261 | weight_compress = self.conv_out.weight.squeeze()
262 | weight = torch.matmul(weight_compress, weight.view(weight.size(0), -1))
263 | weight = weight.view(self.conv_out.out_channels, self.in_channels, 3, 3)
264 |
265 | bias = torch.matmul(weight_compress, bias.unsqueeze(-1)).squeeze(-1)
266 | if self.conv_out.bias is not None:
267 | bias += self.conv_out.bias
268 |
269 | return weight, bias
270 |
271 | ######################################################################################################
272 | class MBRConv1(nn.Module):
273 | def __init__(self, in_channels, out_channels, rep_scale=4):
274 | super(MBRConv1, self).__init__()
275 |
276 | self.in_channels = in_channels
277 | self.out_channels = out_channels
278 | self.rep_scale = rep_scale
279 |
280 | self.conv = nn.Conv2d(in_channels, out_channels * rep_scale, 1)
281 | self.conv_bn = nn.Sequential(
282 | nn.BatchNorm2d(out_channels * rep_scale)
283 | )
284 | self.conv_out = nn.Conv2d(out_channels * rep_scale * 2, out_channels, 1)
285 |
286 | def forward(self, inp):
287 | x0 = self.conv(inp)
288 | x = torch.cat([x0, self.conv_bn(x0)], 1)
289 | out = self.conv_out(x)
290 | return out
291 |
292 | def slim(self):
293 | conv_weight = self.conv.weight
294 | conv_bias = self.conv.bias
295 |
296 | bn = self.conv_bn[0]
297 | k = 1 / (bn.running_var + bn.eps) ** .5
298 | b = - bn.running_mean / (bn.running_var + bn.eps) ** .5
299 | conv_bn_weight = self.conv.weight * k.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
300 | conv_bn_weight = conv_bn_weight * bn.weight.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
301 | conv_bn_bias = self.conv.bias * k + b
302 | conv_bn_bias = conv_bn_bias * bn.weight + bn.bias
303 |
304 | weight = torch.cat([conv_weight, conv_bn_weight], 0)
305 | weight_compress = self.conv_out.weight.squeeze()
306 | weight = torch.matmul(weight_compress, weight.permute([2, 3, 0, 1])).permute([2, 3, 0, 1])
307 |
308 | bias = torch.cat([conv_bias, conv_bn_bias], 0)
309 | bias = torch.matmul(weight_compress, bias)
310 |
311 | if isinstance(self.conv_out.bias, torch.Tensor):
312 | bias = bias + self.conv_out.bias
313 | return weight, bias
314 |
315 | class FST(nn.Module):
316 | def __init__(self, block1, channels):
317 | super(FST, self).__init__()
318 | self.block1 = block1
319 | self.weight1 = nn.Parameter(torch.randn(1))
320 | self.weight2 = nn.Parameter(torch.randn(1))
321 | self.bias = nn.Parameter(torch.randn((1, channels, 1, 1)))
322 |
323 | def forward(self, x):
324 | x1 = self.block1(x)
325 | weighted_block1 = self.weight1 * x1
326 | weighted_block2 = self.weight2 * x1
327 | return weighted_block1 * weighted_block2 + self.bias
328 |
329 | class FSTS(nn.Module):
330 | def __init__(self, block1, channels):
331 | super(FSTS, self).__init__()
332 | self.block1 = block1
333 | self.weight1 = nn.Parameter(torch.randn(1))
334 | self.weight2 = nn.Parameter(torch.randn(1))
335 | self.bias = nn.Parameter(torch.randn((1, channels, 1, 1)))
336 |
337 | def forward(self, x):
338 | x1 = self.block1(x)
339 | weighted_block1 = self.weight1 * x1
340 | weighted_block2 = self.weight2 * x1
341 | return weighted_block1 * weighted_block2 + self.bias
342 | ##################################################################################
343 | class DropBlock(nn.Module):
344 | def __init__(self, block_size, p=0.5):
345 | super(DropBlock, self).__init__()
346 | self.block_size = block_size
347 | self.p = p / block_size / block_size
348 |
349 | def forward(self, x):
350 | mask = 1 - (torch.rand_like(x[:, :1]) >= self.p).float()
351 | mask = nn.functional.max_pool2d(mask, self.block_size, 1, self.block_size // 2)
352 | return x * (1 - mask)
353 |
--------------------------------------------------------------------------------
/onnx2tflite/onnx2tflite/layers/deformation_layers.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import tensorflow as tf
3 |
4 | from onnx2tflite.utils.definitions import Layout
5 | from onnx2tflite.utils import OPERATOR, dimension_utils
6 |
7 | LOG = logging.getLogger("deformation_layers :")
8 |
9 | @OPERATOR.register_operator("Transpose")
10 | class TFTranspose():
11 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, node_outputs, layout_dict, *args, **kwargs)->None:
12 | super().__init__()
13 | for nop in node_outputs:
14 | layout_dict[nop] = Layout.Channel_First
15 | if kwargs.get("perm_list"):
16 | self.perm_list = kwargs.get("perm_list")
17 | return
18 | self.trans_in = None
19 | self.perm_list = [i for i in node_attribute['perm']]
20 | if layout_dict[node_inputs[0]] == Layout.Channel_Last:
21 | # LOG.info("Transpose will process tensor after change back to NCHW format.")
22 | shape_len = len(tensor_grap[node_inputs[0]].shape)
23 | self.trans_in = [0, shape_len-1] + [n for n in range(1, shape_len-1)]
24 |
25 | def __call__(self, inputs):
26 | if self.trans_in:
27 | inputs = tf.transpose(inputs, perm=self.trans_in)
28 | return tf.transpose(inputs, perm=self.perm_list)
29 |
30 | @OPERATOR.register_operator("Slice")
31 | class TFSlice():
32 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, node_outputs, layout_dict, *args, **kwargs) -> None:
33 | super().__init__()
34 | if len(node_inputs) == 1:
35 | self.starts = node_attribute['starts'][0]
36 | self.ends = node_attribute['ends'][0]
37 | self.axis = node_attribute['axes'][0]
38 | self.steps = 1
39 | else:
40 | self.starts = node_weights[node_inputs[1]][0] if node_inputs[1] in node_weights else tensor_grap[node_inputs[1]][0]
41 | self.axis = node_weights[node_inputs[3]][0] if node_inputs[3] in node_weights else tensor_grap[node_inputs[3]][0]
42 | self.ends = node_weights[node_inputs[2]][0] if node_inputs[2] in node_weights else tensor_grap[node_inputs[2]][0]
43 | self.ends = min(self.ends, tensor_grap[node_inputs[0]].shape[self.axis])
44 | if len(node_inputs) < 5:
45 | self.steps = 1
46 | else:
47 | self.steps = node_weights[node_inputs[4]][0] if node_inputs[4] in node_weights else tensor_grap[node_inputs[4]][0]
48 |
49 | shape = tensor_grap[node_inputs[0]].shape.as_list()
50 | if self.starts < 0:
51 | self.starts = shape[self.axis] + self.starts
52 | if self.ends < 0:
53 | self.ends = shape[self.axis] + self.ends
54 |
55 | if layout_dict[node_inputs[0]] == Layout.Channel_Last:
56 | self.axis = dimension_utils.channel_to_last_dimension(self.axis)
57 |
58 | def __call__(self, inputs):
59 | indices = tf.keras.backend.arange(self.starts, self.ends, step=self.steps)
60 | return tf.gather(inputs, indices, axis=self.axis)
61 |
62 | @OPERATOR.register_operator("Gather")
63 | class TFGather():
64 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, node_outputs, layout_dict, *args, **kwargs) -> None:
65 | super().__init__()
66 | self.axis = node_attribute.get('axis', 0)
67 | self.indices = tensor_grap[node_inputs[1]] if node_inputs[1] in tensor_grap else node_weights[node_inputs[1]]
68 | if layout_dict[node_inputs[0]] == Layout.Channel_Last:
69 | self.axis = dimension_utils.channel_to_last_dimension(self.axis)
70 |
71 | def __call__(self, inputs):
72 | return tf.gather(inputs, self.indices, axis=self.axis)
73 |
74 | @OPERATOR.register_operator("Concat")
75 | class TFConcat():
76 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, node_outputs, layout_dict, *args, **kwargs):
77 | super().__init__()
78 | #TODO can be optimzer by watch after node, if conv to be channel last.
79 | self._axis = node_attribute['axis']
80 | # use `count` to count how much more for channel-last to channel-first
81 | count = 0
82 | for inp in node_inputs:
83 | if inp in node_weights:
84 | count -= 1
85 | elif layout_dict[inp] == Layout.Channel_Last:
86 | count += 1
87 | else:
88 | count -= 1
89 |
90 | self._gather = []
91 | if count < 0:
92 | # align to Channel_First
93 | layout_dict[node_outputs[0]] = Layout.Channel_First
94 | for inp in node_inputs:
95 | if inp in tensor_grap:
96 | if layout_dict[inp] == Layout.Channel_Last:
97 | tensor_grap[inp] = dimension_utils.tensor_NDC_to_NCD_format(tensor_grap[inp])
98 | self._gather.append(tensor_grap[inp])
99 | else:
100 | self._gather.append(node_weights[inp])
101 | else:
102 | # align to Channel_Last
103 | layout_dict[node_outputs[0]] = Layout.Channel_Last
104 | self._axis = dimension_utils.channel_to_last_dimension(self._axis)
105 | for inp in node_inputs:
106 | if inp in tensor_grap:
107 | if layout_dict[inp] != Layout.Channel_Last:
108 | tensor_grap[inp] = dimension_utils.tensor_NCD_to_NDC_format(tensor_grap[inp])
109 | self._gather.append(tensor_grap[inp])
110 | else:
111 | self._gather.append(dimension_utils.tensor_NCD_to_NDC_format(node_weights[inp]))
112 |
113 | def __call__(self, *args, **kwargs):
114 | return tf.concat(self._gather, axis=self._axis)
115 |
116 | @OPERATOR.register_operator("Reshape")
117 | class TFReshape():
118 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, node_outputs, layout_dict, *args, **kwargs):
119 | super().__init__()
120 | self.out_shape = node_weights[node_inputs[1]]
121 | self.trans_in = None
122 | # LOG.info("Reshape will process tensor after change back to NCHW format.")
123 | if layout_dict[node_inputs[0]] == Layout.Channel_Last:
124 | shape_len = len(tensor_grap[node_inputs[0]].shape)
125 | self.trans_in = [0, shape_len-1] + [n for n in range(1, shape_len-1)]
126 | for nop in node_outputs:
127 | layout_dict[nop] = Layout.Channel_First
128 |
129 | def __call__(self, inputs):
130 | if self.trans_in:
131 | inputs = tf.transpose(inputs, perm=self.trans_in)
132 | inputs = tf.reshape(inputs, shape=self.out_shape)
133 | return inputs
134 |
135 | @OPERATOR.register_operator("Flatten")
136 | class TFFlatten():
137 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, node_outputs, layout_dict, *args, **kwargs)->None:
138 | super().__init__()
139 | num_elements = int(tensor_grap[node_inputs[0]].shape.num_elements()/tensor_grap[node_inputs[0]].shape[0])
140 | input_shape = tensor_grap[node_inputs[0]].shape
141 | self.flat = tf.keras.layers.Flatten()
142 | '''
143 | ensure memory order match, for example:
144 | onnx = (B, 2, 3, 4).reshape(B, -1)
145 | tflite = (B, 3, 4, 2).reshape(B, -1)
146 | we can observe that:
147 | onnx.shape == tflite.shape, but np.sum(onnx-tflite) != 0
148 | it's cause memory order of two vars is different, we must make tflite back to onnx by transpose.
149 | generally, this situation is general one, below is just special situation and most appear in cnn.
150 | onnx = (B, 512, 1, 1)
151 | tflite = (B, 1, 1, 512)
152 | or = (B, 1, 512, 1)
153 | these memory order are all same.
154 | '''
155 | self.perm = None
156 | if layout_dict[node_inputs[0]] == Layout.Channel_Last and num_elements != max(input_shape[1:]):
157 | self.perm = [0, len(input_shape)-1]
158 | for i in range(len(input_shape)-2):
159 | self.perm.append(i+1)
160 |
161 | def __call__(self, inputs):
162 | if self.perm:
163 | inputs = tf.transpose(inputs, perm=self.perm)
164 | return self.flat(inputs)
165 |
166 | @OPERATOR.register_operator("Split")
167 | class TFSplit():
168 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, node_outputs, layout_dict, *args, **kwargs)->None:
169 | super().__init__()
170 | self.outputs_nums = len(node_outputs)
171 | self.axis = node_attribute.get("axis", 0)
172 | if layout_dict[node_inputs[0]] == Layout.Channel_Last:
173 | self.axis = dimension_utils.channel_to_last_dimension(self.axis)
174 | split_args = None
175 | if 'split' in node_attribute:
176 | split_args = node_attribute['split']
177 | else:
178 | assert len(node_inputs) == 2 and node_inputs[1] in node_weights
179 | split_args = node_weights[node_inputs[1]]
180 |
181 | self.indices = []
182 | start, end = 0, 0
183 | for i in range(self.outputs_nums):
184 | end = start + int(split_args[i])
185 | self.indices.append(tf.keras.backend.arange(start, end, 1))
186 | start = end
187 |
188 | def __call__(self, inputs):
189 | return [tf.gather(inputs, indices=indice, axis=self.axis) for indice in self.indices]
190 |
191 | @OPERATOR.register_operator("Expand")
192 | class TFExpand():
193 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, node_outputs, layout_dict, *args, **kwargs)->None:
194 | super().__init__()
195 | self.shape = node_weights[node_inputs[1]]
196 | if layout_dict[node_inputs[0]] == Layout.Channel_Last:
197 | self.shape = dimension_utils.shape_NCD_to_NDC_format(self.shape)
198 | def __call__(self, inputs):
199 | for i in range(len(self.shape)):
200 | if int(self.shape[i]//inputs.shape[i]) > 1:
201 | inputs = tf.repeat(inputs, repeats=int(self.shape[i]//inputs.shape[i]), axis=i)
202 | elif self.shape[i] < inputs.shape[i] and self.shape[i] != 1:
203 | inputs = tf.repeat(inputs, repeats=int(self.shape[i]), axis=i)
204 | return inputs
205 |
206 | @OPERATOR.register_operator("GatherElements")
207 | class TFGatherElements():
208 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, node_outputs, layout_dict, *args, **kwargs) -> None:
209 | super().__init__()
210 | self.axis = node_attribute.get("axis", 1)
211 | self.indices = None
212 | if 'indices' in node_attribute:
213 | self.indices = node_attribute['indices']
214 | self.indices = dimension_utils.tensor_NCD_to_NDC_format(self.indices)
215 | elif node_inputs[1] in node_weights:
216 | self.indices = node_weights[node_inputs[1]]
217 | self.indices = dimension_utils.tensor_NCD_to_NDC_format(self.indices)
218 | else:
219 | self.indices = tensor_grap[node_inputs[1]]
220 | if layout_dict[node_inputs[0]] == Layout.Channel_Last:
221 | self.axis = dimension_utils.channel_to_last_dimension(self.axis)
222 | if len(node_inputs) == 1 or layout_dict[node_inputs[1]] != Layout.Channel_Last:
223 | self.indices = dimension_utils.tensor_NCD_to_NDC_format(self.indices)
224 |
225 | def gather_elements(self, input_tensor, indices, axis):
226 | # Get the shape of the input tensor and the indices tensor
227 | input_shape = tf.shape(input_tensor)
228 | indices_shape = tf.shape(indices)
229 |
230 | # Create indices for all dimensions
231 | idx = tf.meshgrid(*[tf.range(s) for s in indices_shape], indexing='ij')
232 | idx = [tf.cast(i, tf.int64) for i in idx]
233 |
234 | # Replace the axis index with the provided indices
235 | idx[axis] = tf.cast(indices, tf.int64)
236 |
237 | # Stack indices to form the final gather indices
238 | gather_indices = tf.stack(idx, axis=-1)
239 |
240 | # Use tf.gather_nd to gather elements
241 | output_tensor = tf.gather_nd(input_tensor, gather_indices)
242 |
243 | return output_tensor
244 |
245 | def __call__(self, inputs):
246 | return self.gather_elements(inputs, self.indices, self.axis)
247 |
248 | @OPERATOR.register_operator("Tile")
249 | class TFTile():
250 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, node_outputs, layout_dict, *args, **kwargs)->None:
251 | super().__init__()
252 | self.repeats = node_attribute['repeats'] if 'repeats' in node_attribute else node_weights[node_inputs[1]]
253 | if layout_dict[node_inputs[0]] == Layout.Channel_Last:
254 | self.repeats = dimension_utils.shape_NCD_to_NDC_format(self.repeats)
255 |
256 | def __call__(self, inputs):
257 | for i in range(len(self.repeats)):
258 | if self.repeats[i] > 1:
259 | inputs = tf.repeat(inputs, self.repeats[i], axis=i)
260 | return inputs
261 |
262 | @OPERATOR.register_operator("Unsqueeze")
263 | class TFUnsqueeze():
264 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, node_outputs, layout_dict, *args, **kwargs)->None:
265 | super().__init__()
266 | self.axis = node_attribute['axes'] if 'axes' in node_attribute else node_weights[node_inputs[1]]
267 | if not isinstance(self.axis, int):
268 | self.axis = int(self.axis[0])
269 | input_shape = tensor_grap[node_inputs[0]].shape
270 | if len(input_shape) == 1:
271 | layout_dict[node_outputs[0]] = Layout.Channel_None
272 | elif len(input_shape) == 2:
273 | layout_dict[node_outputs[0]] = Layout.Channel_First
274 | else:
275 | layout_dict[node_outputs[0]] = layout_dict[node_inputs[0]]
276 | if layout_dict[node_inputs[0]] == Layout.Channel_Last:
277 | self.axis = dimension_utils.channel_to_last_dimension(self.axis)
278 |
279 | def __call__(self, inputs):
280 | return tf.expand_dims(inputs, self.axis)
281 |
282 | @OPERATOR.register_operator("Squeeze")
283 | class TFSqueeze():
284 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, node_outputs, layout_dict, *args, **kwargs)->None:
285 | super().__init__()
286 | self.axis = node_attribute['axes'] if 'axes' in node_attribute else node_weights[node_inputs[1]]
287 | if not isinstance(self.axis, int):
288 | self.axis = int(self.axis[0])
289 | input_shape = tensor_grap[node_inputs[0]].shape
290 | if len(input_shape) <= 3:
291 | layout_dict[node_outputs[0]] = Layout.Channel_None
292 | if len(input_shape) > 2 and layout_dict[node_inputs[0]] == Layout.Channel_Last:
293 | self.axis = dimension_utils.channel_to_last_dimension(self.axis)
294 |
295 | def __call__(self, inputs):
296 | return tf.squeeze(inputs, self.axis)
297 |
298 | @OPERATOR.register_operator("DepthToSpace")
299 | class TFDepthToSpace():
300 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, node_outputs, layout_dict, *args, **kwargs)->None:
301 | super().__init__()
302 | self.block_size = node_attribute.get("blocksize", 2)
303 | self.mode = node_attribute.get("mode", "DCR")
304 | self.channel_last = layout_dict[node_inputs[0]] == Layout.Channel_Last
305 |
306 | def __call__(self, inputs):
307 | if not self.channel_last:
308 | inputs = dimension_utils.tensor_NDC_to_NCD_format(inputs)
309 | if self.mode == "DCR":
310 | return tf.nn.depth_to_space(inputs, self.block_size)
311 | elif self.mode == "CRD":
312 | # help want, native tensorflow is not support CRD mode, this way will generate 5 dims op.
313 | b, h, w, c = inputs.shape
314 | inputs = tf.reshape(inputs, [b, h, w, c//(self.block_size * self.block_size), self.block_size, self.block_size])
315 | inputs = tf.transpose(inputs, perm=[0, 1, 4, 2, 5, 3])
316 | inputs = tf.reshape(inputs, [b, h*self.block_size, w*self.block_size, c//(self.block_size * self.block_size)])
317 | return inputs
318 | else:
319 | raise KeyError(f"For DepthToSpace, mode must be [DCR, CRD], not {self.mode}")
320 |
--------------------------------------------------------------------------------
/onnx2tflite/onnx2tflite/layers/common_layers.py:
--------------------------------------------------------------------------------
1 | import math
2 | import logging
3 | import numpy as np
4 | import tensorflow as tf
5 | from tensorflow import keras
6 |
7 | from onnx2tflite.utils.definitions import Layout
8 | from onnx2tflite.utils import OPERATOR, intfloat_to_list, dimension_utils
9 |
10 | LOG = logging.getLogger("common_layers :")
11 |
12 | @OPERATOR.register_operator("BatchNormalization")
13 | class TFBatchNormalization():
14 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, *args, **kwargs):
15 | super().__init__()
16 | epsilon = node_attribute.get("epsilon", 1e-5)
17 | momentum = node_attribute.get("momentum", 0.9)
18 |
19 | self.bn = keras.layers.BatchNormalization(
20 | gamma_initializer=keras.initializers.Constant(node_weights[node_inputs[1]]),
21 | beta_initializer=keras.initializers.Constant(node_weights[node_inputs[2]]),
22 | moving_mean_initializer=keras.initializers.Constant(node_weights[node_inputs[3]]),
23 | moving_variance_initializer=keras.initializers.Constant(node_weights[node_inputs[4]]),
24 | epsilon=epsilon,
25 | momentum=momentum)
26 |
27 | def __call__(self, inputs):
28 | return self.bn(inputs)
29 |
30 | @OPERATOR.register_operator("InstanceNormalization")
31 | class TFInstanceNormalization():
32 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, node_outputs, layout_dict, *args, **kwargs):
33 | super().__init__()
34 | self.epsilon = node_attribute.get("epsilon", 1e-5)
35 | self.scale = node_weights[node_inputs[1]]
36 | self.bias = node_weights[node_inputs[2]]
37 | self.channel_last = layout_dict[node_inputs[0]] == Layout.Channel_Last
38 |
39 | def __call__(self, inputs):
40 | axes = tuple(range(1, len(inputs.shape)-1)) if self.channel_last else tuple(range(2, len(inputs.shape)))
41 | mean = tf.reduce_mean(inputs, axis=axes, keepdims=True)
42 | var = tf.math.reduce_variance(inputs, axis= axes, keepdims=True)
43 | return self.scale*(inputs - mean)/tf.sqrt(var + self.epsilon) + self.bias
44 |
45 | @OPERATOR.register_operator("Pad")
46 | class TFPad():
47 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, node_outputs, layout_dict, *args, **kwargs):
48 | super().__init__()
49 | if node_attribute.get("pads") is not None:
50 | pads = node_attribute['pads']
51 | elif node_inputs[1] in node_weights:
52 | pads = node_weights[node_inputs[1]]
53 | else:
54 | pads = tensor_grap[node_inputs[1]]
55 | self.pad = [[pads[0], pads[4]], [pads[2], pads[6]], [pads[3], pads[7]], [pads[1], pads[5]]]
56 | self.model = node_attribute.get("mode", "constant").upper()
57 | self.channel_last = layout_dict[node_inputs[0]] == Layout.Channel_Last
58 | layout_dict[node_outputs[0]] = Layout.Channel_Last
59 |
60 | def __call__(self, inputs):
61 | if not self.channel_last:
62 | inputs = dimension_utils.tensor_NCD_to_NDC_format(inputs)
63 | return tf.pad(inputs, self.pad, mode=self.model)
64 |
65 | @OPERATOR.register_operator("Clip")
66 | class TFClip():
67 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, *args, **kwargs):
68 | super().__init__()
69 | if "min" in node_attribute:
70 | self.min = node_attribute.get("min")
71 | else:
72 | self.min = tensor_grap[node_inputs[1]] if node_inputs[1] in tensor_grap else node_weights[node_inputs[1]]
73 | if "max" in node_attribute:
74 | self.max = node_attribute.get("max")
75 | else:
76 | self.max = tensor_grap[node_inputs[2]] if node_inputs[2] in tensor_grap else node_weights[node_inputs[2]]
77 |
78 | def __call__(self, inputs):
79 | if float(self.min) == 0 and float(self.max) == 6:
80 | return tf.nn.relu6(inputs)
81 | return tf.clip_by_value(inputs, self.min, self.max)
82 |
83 | @OPERATOR.register_operator("TFGlobalMaxPool")
84 | class TFGlobalMaxPool():
85 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, node_outputs, layout_dict, *args, **kwargs) -> None:
86 | super().__init__()
87 | self.channel_last = layout_dict[node_inputs[0]] == Layout.Channel_Last
88 |
89 | def __call__(self, inputs):
90 | if self.channel_last:
91 | return tf.reduce_max(inputs, axis=[i for i in range(1, len(inputs.shape)-1)], keepdims=True)
92 | else:
93 | return tf.reduce_max(inputs, axis=[i for i in range(2, len(inputs.shape))], keepdims=True)
94 |
95 | @OPERATOR.register_operator("GlobalAveragePool")
96 | class TFGlobalAveragePool():
97 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, node_outputs, layout_dict, *args, **kwargs) -> None:
98 | super().__init__()
99 | self.channel_last = layout_dict[node_inputs[0]] == Layout.Channel_Last
100 |
101 | def __call__(self, inputs):
102 | if self.channel_last:
103 | return tf.reduce_mean(inputs, axis=[i for i in range(1, len(inputs.shape)-1)], keepdims=True)
104 | else:
105 | return tf.reduce_mean(inputs, axis=[i for i in range(2, len(inputs.shape))], keepdims=True)
106 |
107 | @OPERATOR.register_operator("AveragePool")
108 | class TFAveragePool():
109 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, node_outputs, layout_dict, *args, **kwargs) -> None:
110 | super().__init__()
111 | kernel_shape = intfloat_to_list(node_attribute.get("kernel_shape", [2, 2]), 2)
112 | strides = intfloat_to_list(node_attribute.get("strides", [1, 1]), 2)
113 | dilations = intfloat_to_list(node_attribute.get("dilations", [1, 1]), 2)
114 | ceil_mode = node_attribute.get("ceil_mode", 0)
115 | pads = intfloat_to_list(node_attribute.get("pads", [0, 0, 0, 0]), 4)
116 |
117 | func = math.floor if ceil_mode == 0 else math.ceil
118 |
119 | pad_mode = "SAME"
120 | input_shape = tensor_grap[node_inputs[0]].shape
121 | for i in range(len(input_shape)-2):
122 | pad_shape = pads[i] + pads[i+2]
123 | onnx_output_shape = func((input_shape[1+i]+pad_shape-((kernel_shape[i]-1)*dilations[i]+1))/strides[i]+1)
124 | tf_output_shape = math.floor((input_shape[1+i] - kernel_shape[i]) / strides[i]) + 1
125 | pads[2+i] = max(onnx_output_shape-tf_output_shape, pads[2+i]) # right_down pad
126 | if pad_mode == "SAME" and onnx_output_shape != input_shape[1+i]:
127 | pad_mode = "VALID"
128 | self.avg_pool = keras.layers.AveragePooling2D(pool_size=kernel_shape, strides=strides, padding=pad_mode)
129 |
130 | self.pad = None
131 | if pad_mode == "VALID" and pads is not None and np.sum(pads) > 0:
132 | if np.sum(pads) > 0:
133 | self.pad = keras.layers.ZeroPadding2D(padding=((pads[0], pads[2]), (pads[1], pads[3])))
134 |
135 | self.channel_last = layout_dict[node_inputs[0]] == Layout.Channel_Last
136 | layout_dict[node_outputs[0]] = Layout.Channel_Last
137 |
138 | def __call__(self, inputs):
139 | if not self.channel_last:
140 | inputs = dimension_utils.tensor_NCD_to_NDC_format(inputs)
141 | if self.pad:
142 | inputs = self.pad(inputs)
143 | return self.avg_pool(inputs)
144 |
145 | @OPERATOR.register_operator("MaxPool")
146 | class TFMaxPool():
147 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, node_outputs, layout_dict, *args, **kwargs) -> None:
148 | super().__init__()
149 | kernel_shape = intfloat_to_list(node_attribute.get("kernel_shape", [2, 2]), 2)
150 | strides = intfloat_to_list(node_attribute.get("strides", [1, 1]), 2)
151 | dilations = intfloat_to_list(node_attribute.get("dilations", [1, 1]), 2)
152 | ceil_mode = node_attribute.get("ceil_mode", 0)
153 | pads = intfloat_to_list(node_attribute.get("pads", [0, 0, 0, 0]), 4)
154 |
155 | func = math.floor if ceil_mode == 0 else math.ceil
156 |
157 | pad_mode = "SAME"
158 | input_shape = tensor_grap[node_inputs[0]].shape
159 | for i in range(len(input_shape)-2):
160 | pad_shape = pads[i] + pads[i+2]
161 | onnx_output_shape = func((input_shape[1+i]+pad_shape-((kernel_shape[i]-1)*dilations[i]+1))/strides[i]+1)
162 | tf_output_shape = math.floor((input_shape[1+i] - kernel_shape[i]) / strides[i]) + 1
163 | pads[2+i] = max(onnx_output_shape-tf_output_shape, pads[2+i]) # right_down pad
164 | if pad_mode == "SAME" and onnx_output_shape != input_shape[1+i]:
165 | pad_mode = "VALID"
166 | self.max_pool = keras.layers.MaxPool2D(pool_size=kernel_shape, strides=strides, padding=pad_mode)
167 |
168 | self.pad = None
169 | if pad_mode == "VALID" and pads is not None and np.sum(pads) > 0:
170 | if np.sum(pads) > 0:
171 | self.pad = keras.layers.ZeroPadding2D(padding=((pads[0], pads[2]), (pads[1], pads[3])))
172 |
173 | self.channel_last = layout_dict[node_inputs[0]] == Layout.Channel_Last
174 | layout_dict[node_outputs[0]] = Layout.Channel_Last
175 |
176 | def __call__(self, inputs):
177 | if not self.channel_last:
178 | inputs = dimension_utils.tensor_NCD_to_NDC_format(inputs)
179 | if self.pad:
180 | inputs = self.pad(inputs)
181 | return self.max_pool(inputs)
182 |
183 | @OPERATOR.register_operator("Upsample")
184 | class TFUpsample():
185 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, node_outputs, layout_dict, *args, **kwargs):
186 | super().__init__()
187 | _, h, w, _ = tensor_grap[node_inputs[0]].shape
188 | scale = node_weights[node_inputs[1]]
189 |
190 | self.scale = (int(h*scale[2]), int(w*scale[3]))
191 | if node_attribute.get("mode", "nearest").lower() == 'nearest':
192 | self.method = tf.image.ResizeMethod.NEAREST_NEIGHBOR
193 | else:
194 | self.method = tf.image.ResizeMethod.BILINEAR
195 |
196 | self.channel_last = layout_dict[node_inputs[0]] == Layout.Channel_Last
197 | layout_dict[node_outputs[0]] = Layout.Channel_Last
198 |
199 | def __call__(self, inputs):
200 | if not self.channel_last:
201 | inputs = dimension_utils.tensor_NCD_to_NDC_format(inputs)
202 | return tf.image.resize(inputs, self.scale, method=self.method)
203 |
204 | @OPERATOR.register_operator("Constant")
205 | class TFConstant():
206 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, *args, **kwargs):
207 | super().__init__()
208 | self.val = node_attribute['value']
209 |
210 | def __call__(self, *args, **kwargs):
211 | return self.val
212 |
213 | @OPERATOR.register_operator("ScatterND")
214 | class TFScatterND():
215 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, node_outputs, layout_dict, *args, **kwargs):
216 | super().__init__()
217 | self.indices = node_weights[node_inputs[1]]
218 | self.channle_last = layout_dict[node_inputs[0]] == Layout.Channel_Last
219 | if node_inputs[2] in tensor_grap:
220 | self.updates = tensor_grap[node_inputs[2]]
221 | if self.channle_last:
222 | self.updates = dimension_utils.tensor_NDC_to_NCD_format(self.updates)
223 | else:
224 | self.updates = node_weights[node_inputs[2]]
225 |
226 | layout_dict[node_outputs[0]] = Layout.Channel_First
227 |
228 | def __call__(self, inputs):
229 | if self.channle_last:
230 | inputs = dimension_utils.tensor_NDC_to_NCD_format(inputs)
231 | inputs = tf.tensor_scatter_nd_update(inputs, self.indices, self.updates)
232 | return inputs
233 |
234 | @OPERATOR.register_operator("Resize")
235 | class TFResize():
236 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, node_outputs, layout_dict, *args, **kwargs):
237 | super().__init__()
238 | if node_inputs[-1] in node_weights:
239 | _, _, nh, nw = node_weights[node_inputs[-1]]
240 | if len(node_inputs) != 4:
241 | _, h, w, _ = tensor_grap[node_inputs[0]].shape
242 | nh, nw = int(h*nh), int(w*nw)
243 | self.scale = (nh, nw)
244 | else:
245 | scales = tensor_grap[node_inputs[0]].shape[1:3]*tensor_grap[node_inputs[2]][2:3]
246 | self.scale = scales
247 |
248 | if node_attribute.get("mode", "nearest").lower() == 'nearest':
249 | self.method = tf.image.ResizeMethod.NEAREST_NEIGHBOR
250 | else:
251 | self.method = tf.image.ResizeMethod.BILINEAR
252 |
253 | self.channel_last = layout_dict[node_inputs[0]] == Layout.Channel_Last
254 | layout_dict[node_outputs[0]] = Layout.Channel_Last
255 |
256 | def __call__(self, inputs):
257 | if not self.channel_last:
258 | inputs = dimension_utils.tensor_NCD_to_NDC_format(inputs)
259 | return tf.image.resize(inputs, self.scale, method=self.method)
260 |
261 | @OPERATOR.register_operator("Gemm")
262 | class TFGemm():
263 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, node_outputs, layout_dict, *args, **kwargs) -> None:
264 | super().__init__()
265 | if len(node_inputs) > 2:
266 | weights = [node_weights[node_inputs[1]].T, node_weights[node_inputs[2]]]
267 | else:
268 | weights = [node_weights[node_inputs[1]].T]
269 |
270 | self.dense = keras.layers.Dense(weights[0].shape[1],
271 | weights=weights,
272 | use_bias=len(weights)==2)
273 |
274 | self.channel_last = layout_dict[node_inputs[0]] == Layout.Channel_Last
275 | layout_dict[node_outputs[0]] = Layout.Channel_Last
276 |
277 | def __call__(self, inputs):
278 | if not self.channel_last:
279 | inputs = dimension_utils.tensor_NCD_to_NDC_format(inputs)
280 | return self.dense(inputs)
281 |
282 | @OPERATOR.register_operator("Identity")
283 | class TFIdentity():
284 | def __init__(self, *args, **kwargs):
285 | super().__init__()
286 |
287 | def __call__(self, inputs):
288 | return inputs
289 |
290 | @OPERATOR.register_operator("Dropout")
291 | class TFDropout():
292 | '''
293 | Dropout will be ignored in deployment.
294 | '''
295 | def __init__(self, *args, **kwargs):
296 | super().__init__()
297 |
298 | def __call__(self, inputs):
299 | return inputs
300 |
301 | @OPERATOR.register_operator("TopK")
302 | class TFTopK():
303 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, *args, **kwargs) -> None:
304 |
305 | self.axis = node_attribute.get("axis", -1)
306 | self.largest = node_attribute.get("largest", 1)
307 | self.sorted = bool(node_attribute.get("sorted", 1))
308 | self.K = node_attribute.get('K') if len(node_inputs)==1 else node_weights[node_inputs[1]][0]
309 |
310 | def __call__(self, inputs):
311 | res = tf.math.top_k(inputs, k=self.K, sorted=self.sorted)
312 | return [res[0], res[1]]
313 |
314 | @OPERATOR.register_operator("Cast")
315 | class TFCast():
316 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, *args, **kwargs):
317 | super().__init__()
318 | self.cast_to = int(node_attribute.get("to", 1))
319 | assert self.cast_to > 0 and self.cast_to < 12, f"Unknown cast type [{self.cast_to}]"
320 | self.np_cast_map = {
321 | 1: np.float32,
322 | 2: np.uint8,
323 | 3: np.int8,
324 | 5: np.int16,
325 | 6: np.int32,
326 | 7: np.int64,
327 | 9: np.bool_,
328 | 10: np.float16,
329 | 11: np.double,
330 | }
331 | self.tf_cast_map = {
332 | 1: tf.float32,
333 | 2: tf.uint8,
334 | 3: tf.int8,
335 | 5: tf.int16,
336 | 6: tf.int32,
337 | 7: tf.int64,
338 | 9: tf.bool,
339 | 10: tf.float16,
340 | 11: tf.double,
341 | }
342 |
343 | def __call__(self, inputs):
344 | if isinstance(inputs, list):
345 | for i in range(len(inputs)):
346 | if isinstance(inputs[i], np.ndarray) or isinstance(inputs[i], np.generic):
347 | inputs[i] = self.np_cast_map[self.cast_to](inputs[i])
348 | else:
349 | inputs[i] = tf.cast(input[i], dtype=self.tf_cast_map[self.cast_to])
350 | else:
351 | if isinstance(inputs, np.ndarray) or isinstance(inputs, np.generic):
352 | inputs = self.np_cast_map[self.cast_to](inputs)
353 | else:
354 | inputs = tf.cast(inputs, dtype=self.tf_cast_map[self.cast_to])
355 |
356 | return inputs
357 |
--------------------------------------------------------------------------------
/model/utils_IWO.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | class MBRConv5(nn.Module):
6 | def __init__(self, in_channels, out_channels, rep_scale=4):
7 | super(MBRConv5, self).__init__()
8 | self.in_channels = in_channels
9 | self.out_channels = out_channels
10 | self.conv = nn.Conv2d(in_channels, out_channels * rep_scale, 5, 1, 2)
11 | self.conv_bn = nn.Sequential(
12 | nn.BatchNorm2d(out_channels * rep_scale)
13 | )
14 | self.conv1 = nn.Conv2d(in_channels, out_channels * rep_scale, 1)
15 | self.conv1_bn = nn.Sequential(
16 | nn.BatchNorm2d(out_channels * rep_scale)
17 | )
18 | self.conv2 = nn.Conv2d(in_channels, out_channels * rep_scale, 3, 1, 1)
19 | self.conv2_bn = nn.Sequential(
20 | nn.BatchNorm2d(out_channels * rep_scale)
21 | )
22 | self.conv_crossh = nn.Conv2d(in_channels, out_channels * rep_scale, (3, 1), 1, (1, 0))
23 | self.conv_crossh_bn = nn.Sequential(
24 | nn.BatchNorm2d(out_channels * rep_scale)
25 | )
26 | self.conv_crossv = nn.Conv2d(in_channels, out_channels * rep_scale, (1, 3), 1, (0, 1))
27 | self.conv_crossv_bn = nn.Sequential(
28 | nn.BatchNorm2d(out_channels * rep_scale)
29 | )
30 | self.conv_out = nn.Conv2d(out_channels * rep_scale * 10, out_channels, 1)
31 | self.conv_out.weight.requires_grad = False
32 | self.weight1 = nn.Parameter(torch.zeros_like(self.conv_out.weight))
33 | nn.init.xavier_normal_(self.weight1)
34 |
35 | def forward(self, inp):
36 | x1 = self.conv(inp)
37 | x2 = self.conv1(inp)
38 | x3 = self.conv2(inp)
39 | x4 = self.conv_crossh(inp)
40 | x5 = self.conv_crossv(inp)
41 | x = torch.cat(
42 | [x1, x2, x3, x4, x5,
43 | self.conv_bn(x1),
44 | self.conv1_bn(x2),
45 | self.conv2_bn(x3),
46 | self.conv_crossh_bn(x4),
47 | self.conv_crossv_bn(x5)],
48 | 1
49 | )
50 | final_weight = self.conv_out.weight + self.weight1
51 | out = F.conv2d(x, final_weight, self.conv_out.bias)
52 | return out
53 |
54 | def slim(self):
55 | conv_weight = self.conv.weight
56 | conv_bias = self.conv.bias
57 |
58 | conv1_weight = self.conv1.weight
59 | conv1_bias = self.conv1.bias
60 | conv1_weight = nn.functional.pad(conv1_weight, (2, 2, 2, 2))
61 |
62 | conv2_weight = self.conv2.weight
63 | conv2_weight = nn.functional.pad(conv2_weight, (1, 1, 1, 1))
64 | conv2_bias = self.conv2.bias
65 |
66 | conv_crossv_weight = self.conv_crossv.weight
67 | conv_crossv_weight = nn.functional.pad(conv_crossv_weight, (1, 1, 2, 2))
68 | conv_crossv_bias = self.conv_crossv.bias
69 |
70 | conv_crossh_weight = self.conv_crossh.weight
71 | conv_crossh_weight = nn.functional.pad(conv_crossh_weight, (2, 2, 1, 1))
72 | conv_crossh_bias = self.conv_crossh.bias
73 |
74 | conv1_bn_weight = self.conv1.weight
75 | conv1_bn_weight = nn.functional.pad(conv1_bn_weight, (2, 2, 2, 2))
76 |
77 | conv2_bn_weight = self.conv2.weight
78 | conv2_bn_weight = nn.functional.pad(conv2_bn_weight, (1, 1, 1, 1))
79 |
80 | conv_crossv_bn_weight = self.conv_crossv.weight
81 | conv_crossv_bn_weight = nn.functional.pad(conv_crossv_bn_weight, (1, 1, 2, 2))
82 |
83 | conv_crossh_bn_weight = self.conv_crossh.weight
84 | conv_crossh_bn_weight = nn.functional.pad(conv_crossh_bn_weight, (2, 2, 1, 1))
85 |
86 | bn = self.conv_bn[0]
87 | k = 1 / (bn.running_var + bn.eps) ** .5
88 | b = - bn.running_mean / (bn.running_var + bn.eps) ** .5
89 |
90 | conv_bn_weight = self.conv.weight * k.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
91 | conv_bn_weight = conv_bn_weight * bn.weight.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
92 | conv_bn_bias = self.conv.bias * k + b
93 | conv_bn_bias = conv_bn_bias * bn.weight + bn.bias
94 |
95 | bn = self.conv1_bn[0]
96 | k = 1 / (bn.running_var + bn.eps) ** .5
97 | b = - bn.running_mean / (bn.running_var + bn.eps) ** .5
98 | conv1_bn_weight = conv1_bn_weight * k.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
99 | conv1_bn_weight = conv1_bn_weight * bn.weight.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
100 | conv1_bn_bias = self.conv1.bias * k + b
101 | conv1_bn_bias = conv1_bn_bias * bn.weight + bn.bias
102 |
103 | bn = self.conv2_bn[0]
104 | k = 1 / (bn.running_var + bn.eps) ** .5
105 | b = - bn.running_mean / (bn.running_var + bn.eps) ** .5
106 | conv2_bn_weight = conv2_bn_weight * k.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
107 | conv2_bn_weight = conv2_bn_weight * bn.weight.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
108 | conv2_bn_bias = self.conv2.bias * k + b
109 | conv2_bn_bias = conv2_bn_bias * bn.weight + bn.bias
110 |
111 | bn = self.conv_crossv_bn[0]
112 | k = 1 / (bn.running_var + bn.eps) ** .5
113 | b = - bn.running_mean / (bn.running_var + bn.eps) ** .5
114 | conv_crossv_bn_weight = conv_crossv_bn_weight * k.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
115 | conv_crossv_bn_weight = conv_crossv_bn_weight * bn.weight.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
116 | conv_crossv_bn_bias = self.conv_crossv.bias * k + b
117 | conv_crossv_bn_bias = conv_crossv_bn_bias * bn.weight + bn.bias
118 |
119 | bn = self.conv_crossh_bn[0]
120 | k = 1 / (bn.running_var + bn.eps) ** .5
121 | b = - bn.running_mean / (bn.running_var + bn.eps) ** .5
122 | conv_crossh_bn_weight = conv_crossh_bn_weight * k.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
123 | conv_crossh_bn_weight = conv_crossh_bn_weight * bn.weight.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
124 | conv_crossh_bn_bias = self.conv_crossh.bias * k + b
125 | conv_crossh_bn_bias = conv_crossh_bn_bias * bn.weight + bn.bias
126 |
127 | weight = torch.cat(
128 | [conv_weight, conv1_weight, conv2_weight,
129 | conv_crossh_weight, conv_crossv_weight,
130 | conv_bn_weight, conv1_bn_weight, conv2_bn_weight,
131 | conv_crossh_bn_weight, conv_crossv_bn_weight],
132 | 0
133 | )
134 | #weight_compress = self.conv_out.weight.squeeze()
135 | weight_compress = (self.conv_out.weight + self.weight1).squeeze()
136 | weight = torch.matmul(weight_compress, weight.permute([2, 3, 0, 1])).permute([2, 3, 0, 1])
137 | bias_ = torch.cat(
138 | [conv_bias, conv1_bias, conv2_bias,
139 | conv_crossh_bias, conv_crossv_bias,
140 | conv_bn_bias, conv1_bn_bias, conv2_bn_bias,
141 | conv_crossh_bn_bias, conv_crossv_bn_bias],
142 | 0
143 | )
144 | bias = torch.matmul(weight_compress, bias_)
145 | if isinstance(self.conv_out.bias, torch.Tensor):
146 | bias = bias + self.conv_out.bias
147 | return weight, bias
148 |
149 |
150 | ##############################################################################################################
151 | class MBRConv3(nn.Module):
152 | def __init__(self, in_channels, out_channels, rep_scale=4):
153 | super(MBRConv3, self).__init__()
154 |
155 | self.in_channels = in_channels
156 | self.out_channels = out_channels
157 | self.rep_scale = rep_scale
158 |
159 | self.conv = nn.Conv2d(in_channels, out_channels * rep_scale, 3, 1, 1)
160 | self.conv_bn = nn.Sequential(
161 | nn.BatchNorm2d(out_channels * rep_scale)
162 | )
163 | self.conv1 = nn.Conv2d(in_channels, out_channels * rep_scale, 1)
164 | self.conv1_bn = nn.Sequential(
165 | nn.BatchNorm2d(out_channels * rep_scale)
166 | )
167 | self.conv_crossh = nn.Conv2d(in_channels, out_channels * rep_scale, (3, 1), 1, (1, 0))
168 | self.conv_crossh_bn = nn.Sequential(
169 | nn.BatchNorm2d(out_channels * rep_scale)
170 | )
171 | self.conv_crossv = nn.Conv2d(in_channels, out_channels * rep_scale, (1, 3), 1, (0, 1))
172 | self.conv_crossv_bn = nn.Sequential(
173 | nn.BatchNorm2d(out_channels * rep_scale)
174 | )
175 | self.conv_out = nn.Conv2d(out_channels * rep_scale * 8, out_channels, 1)
176 | self.conv_out.weight.requires_grad = False
177 | self.weight1 = nn.Parameter(torch.zeros_like(self.conv_out.weight))
178 | nn.init.xavier_normal_(self.weight1)
179 |
180 | def forward(self, inp):
181 | x0 = self.conv(inp)
182 | x1 = self.conv1(inp)
183 | x2 = self.conv_crossh(inp)
184 | x3 = self.conv_crossv(inp)
185 | x = torch.cat(
186 | [ x0,x1,x2,x3,
187 | self.conv_bn(x0),
188 | self.conv1_bn(x1),
189 | self.conv_crossh_bn(x2),
190 | self.conv_crossv_bn(x3)],
191 | 1
192 | )
193 | final_weight = self.conv_out.weight + self.weight1
194 | out = F.conv2d(x, final_weight, self.conv_out.bias)
195 | return out
196 |
197 | def slim(self):
198 | conv_weight = self.conv.weight
199 | conv_bias = self.conv.bias
200 |
201 | conv1_weight = self.conv1.weight
202 | conv1_bias = self.conv1.bias
203 | conv1_weight = F.pad(conv1_weight, (1, 1, 1, 1))
204 |
205 | conv_crossh_weight = self.conv_crossh.weight
206 | conv_crossh_bias = self.conv_crossh.bias
207 | conv_crossh_weight = F.pad(conv_crossh_weight, (1, 1, 0, 0))
208 |
209 | conv_crossv_weight = self.conv_crossv.weight
210 | conv_crossv_bias = self.conv_crossv.bias
211 | conv_crossv_weight = F.pad(conv_crossv_weight, (0, 0, 1, 1))
212 |
213 | # conv_bn
214 | bn = self.conv_bn[0]
215 | k = 1 / torch.sqrt(bn.running_var + bn.eps)
216 | conv_bn_weight = self.conv.weight * k.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
217 | conv_bn_weight = conv_bn_weight * bn.weight.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
218 | conv_bn_bias = self.conv.bias * k + (-bn.running_mean * k)
219 | conv_bn_bias = conv_bn_bias * bn.weight + bn.bias
220 |
221 | # conv1_bn
222 | bn = self.conv1_bn[0]
223 | k = 1 / torch.sqrt(bn.running_var + bn.eps)
224 | conv1_bn_weight = self.conv1.weight * k.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
225 | conv1_bn_weight = conv1_bn_weight * bn.weight.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
226 | conv1_bn_weight = F.pad(conv1_bn_weight, (1, 1, 1, 1))
227 | conv1_bn_bias = self.conv1.bias * k + (-bn.running_mean * k)
228 | conv1_bn_bias = conv1_bn_bias * bn.weight + bn.bias
229 |
230 | # conv_crossh_bn
231 | bn = self.conv_crossh_bn[0]
232 | k = 1 / torch.sqrt(bn.running_var + bn.eps)
233 | conv_crossh_bn_weight = self.conv_crossh.weight * k.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
234 | conv_crossh_bn_weight = conv_crossh_bn_weight * bn.weight.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
235 | conv_crossh_bn_weight = F.pad(conv_crossh_bn_weight, (1, 1, 0, 0))
236 | conv_crossh_bn_bias = self.conv_crossh.bias * k + (-bn.running_mean * k)
237 | conv_crossh_bn_bias = conv_crossh_bn_bias * bn.weight + bn.bias
238 |
239 | # conv_crossv_bn
240 | bn = self.conv_crossv_bn[0]
241 | k = 1 / torch.sqrt(bn.running_var + bn.eps)
242 | conv_crossv_bn_weight = self.conv_crossv.weight * k.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
243 | conv_crossv_bn_weight = conv_crossv_bn_weight * bn.weight.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
244 | conv_crossv_bn_weight = F.pad(conv_crossv_bn_weight, (0, 0, 1, 1))
245 | conv_crossv_bn_bias = self.conv_crossv.bias * k + (-bn.running_mean * k)
246 | conv_crossv_bn_bias = conv_crossv_bn_bias * bn.weight + bn.bias
247 |
248 | weight = torch.cat([
249 | conv_weight,
250 | conv1_weight,
251 | conv_crossh_weight,
252 | conv_crossv_weight,
253 | conv_bn_weight,
254 | conv1_bn_weight,
255 | conv_crossh_bn_weight,
256 | conv_crossv_bn_weight
257 | ], dim=0)
258 |
259 | bias = torch.cat([
260 | conv_bias,
261 | conv1_bias,
262 | conv_crossh_bias,
263 | conv_crossv_bias,
264 | conv_bn_bias,
265 | conv1_bn_bias,
266 | conv_crossh_bn_bias,
267 | conv_crossv_bn_bias
268 | ], dim=0)
269 |
270 | #weight_compress = self.conv_out.weight.squeeze()
271 | weight_compress = (self.conv_out.weight + self.weight1).squeeze()
272 | weight = torch.matmul(weight_compress, weight.view(weight.size(0), -1))
273 | weight = weight.view(self.conv_out.out_channels, self.in_channels, 3, 3)
274 |
275 | bias = torch.matmul(weight_compress, bias.unsqueeze(-1)).squeeze(-1)
276 | if self.conv_out.bias is not None:
277 | bias += self.conv_out.bias
278 |
279 | return weight, bias
280 |
281 | ######################################################################################################
282 | class MBRConv1(nn.Module):
283 | def __init__(self, in_channels, out_channels, rep_scale=4):
284 | super(MBRConv1, self).__init__()
285 |
286 | self.in_channels = in_channels
287 | self.out_channels = out_channels
288 | self.rep_scale = rep_scale
289 |
290 | self.conv = nn.Conv2d(in_channels, out_channels * rep_scale, 1)
291 | self.conv_bn = nn.Sequential(
292 | nn.BatchNorm2d(out_channels * rep_scale)
293 | )
294 | self.conv_out = nn.Conv2d(out_channels * rep_scale * 2, out_channels, 1)
295 | self.conv_out.weight.requires_grad = False
296 |
297 | self.weight1 = nn.Parameter(torch.zeros_like(self.conv_out.weight))
298 | nn.init.xavier_normal_(self.weight1)
299 |
300 | def forward(self, inp):
301 | x1 = self.conv(inp)
302 | x = torch.cat([x1, self.conv_bn(x1)], 1)
303 | final_weight = self.conv_out.weight + self.weight1
304 | out = F.conv2d(x, final_weight, self.conv_out.bias)
305 | return out
306 |
307 | def slim(self):
308 | conv_weight = self.conv.weight
309 | conv_bias = self.conv.bias
310 |
311 | bn = self.conv_bn[0]
312 | k = 1 / (bn.running_var + bn.eps) ** .5
313 | b = - bn.running_mean / (bn.running_var + bn.eps) ** .5
314 | conv_bn_weight = self.conv.weight * k.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
315 | conv_bn_weight = conv_bn_weight * bn.weight.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
316 | conv_bn_bias = self.conv.bias * k + b
317 | conv_bn_bias = conv_bn_bias * bn.weight + bn.bias
318 |
319 | weight = torch.cat([conv_weight, conv_bn_weight], 0)
320 | #weight_compress = self.conv_out.weight.squeeze()
321 | weight_compress = (self.conv_out.weight + self.weight1).squeeze()
322 | weight = torch.matmul(weight_compress, weight.permute([2, 3, 0, 1])).permute([2, 3, 0, 1])
323 |
324 | bias = torch.cat([conv_bias, conv_bn_bias], 0)
325 | bias = torch.matmul(weight_compress, bias)
326 |
327 | if isinstance(self.conv_out.bias, torch.Tensor):
328 | bias = bias + self.conv_out.bias
329 | return weight, bias
330 |
331 | class FST(nn.Module):
332 | def __init__(self, block1, channels):
333 | super(FST, self).__init__()
334 | self.block1 = block1
335 | self.weight1 = nn.Parameter(torch.randn(1))
336 | self.weight2 = nn.Parameter(torch.randn(1))
337 | self.bias = nn.Parameter(torch.randn((1, channels, 1, 1)))
338 |
339 | def forward(self, x):
340 | x1 = self.block1(x)
341 | weighted_block1 = self.weight1 * x1
342 | weighted_block2 = self.weight2 * x1
343 | return weighted_block1 * weighted_block2 + self.bias
344 |
345 | class FSTS(nn.Module):
346 | def __init__(self, block1, channels):
347 | super(FSTS, self).__init__()
348 | self.block1 = block1
349 | self.weight1 = nn.Parameter(torch.randn(1))
350 | self.weight2 = nn.Parameter(torch.randn(1))
351 | self.bias = nn.Parameter(torch.randn((1, channels, 1, 1)))
352 |
353 | def forward(self, x):
354 | x1 = self.block1(x)
355 | weighted_block1 = self.weight1 * x1
356 | weighted_block2 = self.weight2 * x1
357 | return weighted_block1 * weighted_block2 + self.bias
358 | ##################################################################################
359 | class DropBlock(nn.Module):
360 | def __init__(self, block_size, p=0.5):
361 | super(DropBlock, self).__init__()
362 | self.block_size = block_size
363 | self.p = p / block_size / block_size
364 |
365 | def forward(self, x):
366 | mask = 1 - (torch.rand_like(x[:, :1]) >= self.p).float()
367 | mask = nn.functional.max_pool2d(mask, self.block_size, 1, self.block_size // 2)
368 | return x * (1 - mask)
369 |
--------------------------------------------------------------------------------