├── nics_fix_pt ├── VERSION ├── nn_fix.py ├── nn_fix_inner.py ├── consts.py ├── __init__.py ├── utils.py ├── quant.py └── fix_modules.py ├── MANIFEST.in ├── examples ├── cifar10 │ ├── train.sh │ ├── finetune.sh │ ├── net.py │ └── main.py └── mnist │ └── train_mnist.py ├── coverage.svg ├── LICENSE ├── .gitignore ├── setup.py ├── tests ├── test_fix_bn.py ├── conftest.py ├── test_quant.py ├── test_utils.py └── test_fix_module.py └── README.md /nics_fix_pt/VERSION: -------------------------------------------------------------------------------- 1 | 0.4.0 2 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include tests/*.py 2 | recursive-include examples *.py *.sh 3 | -------------------------------------------------------------------------------- /nics_fix_pt/nn_fix.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # A wrapper module for all registered fix modules 3 | -------------------------------------------------------------------------------- /nics_fix_pt/nn_fix_inner.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from .fix_modules import register_fix_module 3 | 4 | register_fix_module(nn.Conv2d) 5 | register_fix_module(nn.Linear) 6 | register_fix_module(nn.BatchNorm1d) 7 | register_fix_module(nn.BatchNorm2d) 8 | -------------------------------------------------------------------------------- /examples/cifar10/train.sh: -------------------------------------------------------------------------------- 1 | #!bin/sh 2 | 3 | gpu=${GPU:-0} 4 | result_dir=$(date +%Y%m%d-%H%M%S) 5 | mkdir -p ${result_dir} 6 | 7 | CUDA_VISIBLE_DEVICES=$gpu python main.py \ 8 | --save-dir ${result_dir} \ 9 | $@ 2>&1 | tee ${result_dir}/train.log 10 | 11 | -------------------------------------------------------------------------------- /nics_fix_pt/consts.py: -------------------------------------------------------------------------------- 1 | class RangeMethod: 2 | RANGE_MAX = 0 3 | RANGE_3SIGMA = 1 4 | RANGE_MAX_TENPERCENT = 2 5 | RANGE_SWEEP = 3 6 | 7 | 8 | class QuantizeMethod: 9 | # quantize methods 10 | FIX_NONE = 0 11 | FIX_AUTO = 1 12 | FIX_FIXED = 2 13 | -------------------------------------------------------------------------------- /examples/cifar10/finetune.sh: -------------------------------------------------------------------------------- 1 | #!bin/sh 2 | #arch="vgg11_web" 3 | #arch="vgg11_elegant" 4 | arch="vgg11_ugly" 5 | 6 | gpu=1 7 | data=$(date +%m%d) 8 | method="8_conv1233'44'55'_6" 9 | lr=0.0001 10 | # model="save_fix/checkpoint_auto8_90.730.tar" 11 | model="save_fix/checkpoint_8_conv1233'44'5_6_90.820.tar" 12 | epoches=30 13 | 14 | CUDA_VISIBLE_DEVICES=$gpu python main.py \ 15 | --pretrained $model \ 16 | --arch $arch \ 17 | --lr $lr \ 18 | --epoches $epoches \ 19 | --prefix $method \ 20 | --batch-size 128 \ 21 | --test-batch-size 1000 \ 22 | --save-dir save_fix \ 23 | 2>&1 | tee logs/log-"$method"_1222.log 24 | -------------------------------------------------------------------------------- /coverage.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | coverage 17 | coverage 18 | 73% 19 | 73% 20 | 21 | 22 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Xuefei Ning 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /nics_fix_pt/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | with open(os.path.join(os.path.dirname(__file__), "VERSION")) as f: 4 | __version__ = f.read().strip() 5 | 6 | from nics_fix_pt.consts import QuantizeMethod, RangeMethod 7 | from nics_fix_pt.quant import * 8 | import nics_fix_pt.nn_fix_inner 9 | from nics_fix_pt import nn_fix 10 | from nics_fix_pt.fix_modules import register_fix_module 11 | 12 | FIX_NONE = QuantizeMethod.FIX_NONE 13 | FIX_AUTO = QuantizeMethod.FIX_AUTO 14 | FIX_FIXED = QuantizeMethod.FIX_FIXED 15 | 16 | RANGE_MAX = RangeMethod.RANGE_MAX 17 | RANGE_3SIGMA = RangeMethod.RANGE_3SIGMA 18 | 19 | 20 | class nn_auto_register(object): 21 | """ 22 | An auto register helper that automatically register all not-registered modules 23 | by proxing to modules in torch.nn. 24 | 25 | NOTE: We do not guarantee all auto-registered fixed nn modules will well behave, 26 | as they are not tested. Although, I thought it will work in normal cases. 27 | Use with care! 28 | 29 | Usage: from nics_fix_pt import NAR as nnf 30 | then e.g. `nnf.Bilinear_fix` and `nnf.Bilinear` can all be used as a fixed-point module. 31 | """ 32 | 33 | def __getattr__(self, name): 34 | import torch 35 | 36 | attr = getattr(nn_fix, name, None) 37 | if attr is None: 38 | if name.endswith("_fix"): 39 | ori_name = name[:-4] 40 | else: 41 | ori_name = name 42 | ori_cls = getattr(torch.nn, ori_name) 43 | register_fix_module(ori_cls, register_name=ori_name + "_fix") 44 | return getattr(nn_fix, ori_name + "_fix", None) 45 | return attr 46 | 47 | 48 | NAR = nn_auto_register() 49 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *.cover 46 | .hypothesis/ 47 | 48 | # Translations 49 | *.mo 50 | *.pot 51 | 52 | # Django stuff: 53 | *.log 54 | local_settings.py 55 | 56 | # Flask stuff: 57 | instance/ 58 | .webassets-cache 59 | 60 | # Scrapy stuff: 61 | .scrapy 62 | 63 | # Sphinx documentation 64 | docs/_build/ 65 | 66 | # PyBuilder 67 | target/ 68 | 69 | # Jupyter Notebook 70 | .ipynb_checkpoints 71 | 72 | # pyenv 73 | .python-version 74 | 75 | # celery beat schedule file 76 | celerybeat-schedule 77 | 78 | # SageMath parsed files 79 | *.sage.py 80 | 81 | # Environments 82 | .env 83 | .venv 84 | env/ 85 | venv/ 86 | ENV/ 87 | 88 | # Spyder project settings 89 | .spyderproject 90 | .spyproject 91 | 92 | # Rope project settings 93 | .ropeproject 94 | 95 | # mkdocs documentation 96 | /site 97 | 98 | # mypy 99 | .mypy_cache/ 100 | 101 | examples/data 102 | 103 | tmpdir/ -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | import sys 4 | from setuptools import setup, find_packages 5 | from setuptools.command.test import test as TestCommand 6 | 7 | here = os.path.dirname(os.path.abspath((__file__))) 8 | 9 | # meta infos 10 | NAME = "nics_fix_pytorch" 11 | DESCRIPTION = "Fixed point trainig framework on PyTorch" 12 | with open(os.path.join(os.path.dirname(__file__), "nics_fix_pt/VERSION")) as f: 13 | VERSION = f.read().strip() 14 | 15 | 16 | AUTHOR = "foxfi" 17 | EMAIL = "foxdoraame@gmail.com" 18 | 19 | # package contents 20 | MODULES = [] 21 | PACKAGES = find_packages(exclude=["tests.*", "tests"]) 22 | 23 | # dependencies 24 | INSTALL_REQUIRES = ["six", "numpy"] 25 | TESTS_REQUIRE = ["pytest", "pytest-cov", "PyYAML"] # for mnist example 26 | 27 | # entry points 28 | ENTRY_POINTS = """""" 29 | 30 | 31 | def read_long_description(filename): 32 | path = os.path.join(here, filename) 33 | if os.path.exists(path): 34 | return open(path).read() 35 | return "" 36 | 37 | 38 | class PyTest(TestCommand): 39 | def finalize_options(self): 40 | TestCommand.finalize_options(self) 41 | self.test_args = ["tests/", "--junitxml", "unittest.xml", "--cov"] 42 | self.test_suite = True 43 | 44 | def run_tests(self): 45 | # import here, cause outside the eggs aren"t loaded 46 | import pytest 47 | 48 | errno = pytest.main(self.test_args) 49 | sys.exit(errno) 50 | 51 | 52 | setup( 53 | name=NAME, 54 | version=VERSION, 55 | description=DESCRIPTION, 56 | long_description=read_long_description("README.md"), 57 | author=AUTHOR, 58 | author_email=EMAIL, 59 | py_modules=MODULES, 60 | packages=PACKAGES, 61 | entry_points=ENTRY_POINTS, 62 | zip_safe=False, 63 | install_requires=INSTALL_REQUIRES, 64 | tests_require=TESTS_REQUIRE, 65 | cmdclass={"test": PyTest}, 66 | package_data={ 67 | "nics_fix_pt": ["VERSION"] 68 | } 69 | ) 70 | -------------------------------------------------------------------------------- /tests/test_fix_bn.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.optim as optim 7 | 8 | import nics_fix_pt as nfp 9 | from nics_fix_pt import nn_fix as nnf 10 | 11 | 12 | @pytest.mark.parametrize( 13 | "case", [{"input_num": 3, "momentum": 0.5, "inputs": [[1, 1, 0], [2, 1, 2]]}] 14 | ) 15 | def test_fix_bn_test_auto(case): 16 | # TEST: the first update is the same (not quantized) 17 | bn_fix = nnf.BatchNorm1d_fix( 18 | case["input_num"], 19 | nf_fix_params={ 20 | "running_mean": { 21 | "method": nfp.FIX_AUTO, 22 | "bitwidth": torch.tensor([2]), 23 | "scale": torch.tensor([0]), 24 | }, 25 | "running_var": { 26 | "method": nfp.FIX_AUTO, 27 | "bitwidth": torch.tensor([2]), 28 | "scale": torch.tensor([0]), 29 | }, 30 | }, 31 | affine=False, 32 | momentum=case["momentum"], 33 | ) 34 | bn = nn.BatchNorm1d(case["input_num"], affine=False, momentum=case["momentum"]) 35 | bn_fix.train() 36 | bn.train() 37 | inputs = torch.autograd.Variable( 38 | torch.tensor(case["inputs"]).float(), requires_grad=True 39 | ) 40 | out_fix = bn_fix(inputs) 41 | out = bn(inputs) 42 | assert (bn.running_mean == bn_fix.running_mean).all() # not quantized here 43 | assert (bn.running_var == bn_fix.running_var).all() 44 | assert (out == out_fix).all() 45 | 46 | # TEST: Quantitized on the next forward 47 | bn_fix.train(False) 48 | bn.train(False) 49 | out_fix = bn_fix(inputs) 50 | # Let's explicit quantize the mean/var of the normal BN model for comparison 51 | object.__setattr__( 52 | bn, 53 | "running_mean", 54 | nfp.quant.quantize_cfg( 55 | bn.running_mean, 56 | **{ 57 | "method": nfp.FIX_AUTO, 58 | "bitwidth": torch.tensor([2]), 59 | "scale": torch.tensor([0]), 60 | } 61 | )[0], 62 | ) 63 | object.__setattr__( 64 | bn, 65 | "running_var", 66 | nfp.quant.quantize_cfg( 67 | bn.running_var, 68 | **{ 69 | "method": nfp.FIX_AUTO, 70 | "bitwidth": torch.tensor([2]), 71 | "scale": torch.tensor([0]), 72 | } 73 | )[0], 74 | ) 75 | assert (bn.running_mean == bn_fix.running_mean).all() 76 | assert (bn.running_var == bn_fix.running_var).all() 77 | 78 | out = bn(inputs) 79 | assert (out == out_fix).all() 80 | 81 | # TEST: the running mean/var update is on the quantized running mean 82 | bn_fix.train() 83 | bn.train() 84 | out_fix = bn_fix(inputs) 85 | out = bn(inputs) 86 | assert ( 87 | bn.running_mean == bn_fix.running_mean 88 | ).all() # quantized on the next forward 89 | assert (bn.running_var == bn_fix.running_var).all() 90 | 91 | # runnig_mean_should = np.mean(inputs.detach().numpy(), axis=0) * case["momentum"] 92 | # runnig_var_should = np.var(inputs.detach().numpy(), axis=0) * case["momentum"] + np.ones(case["input_num"]) * (1 - case["momentum"]) 93 | -------------------------------------------------------------------------------- /nics_fix_pt/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from collections import OrderedDict 4 | import copy 5 | import inspect 6 | from contextlib import contextmanager 7 | from functools import wraps 8 | 9 | from six import iteritems 10 | import numpy as np 11 | import torch 12 | 13 | 14 | def try_parse_variable(something): 15 | if isinstance(something, (dict, OrderedDict)): # recursive parse into dict values 16 | return {k: try_parse_variable(v) for k, v in iteritems(something)} 17 | try: 18 | return torch.autograd.Variable( 19 | torch.IntTensor(np.array([something])), requires_grad=False 20 | ) 21 | except ValueError: 22 | return something 23 | 24 | 25 | def get_int(something): 26 | if torch.is_tensor(something): 27 | v = int(something.numpy()[0]) 28 | elif isinstance(something, torch.autograd.Variable): 29 | v = int(something.data.numpy()[0]) 30 | else: 31 | assert isinstance(something, (int, np.int)) 32 | v = int(something) 33 | return v 34 | 35 | 36 | def try_parse_int(something): 37 | if isinstance(something, (dict, OrderedDict)): # recursive parse into dict values 38 | return {k: try_parse_int(v) for k, v in iteritems(something)} 39 | try: 40 | return int(something) 41 | except ValueError: 42 | return something 43 | 44 | 45 | def cache(format_str): 46 | def _cache(func): 47 | _cache_dct = {} 48 | sig = inspect.signature(func) 49 | default_kwargs = { 50 | n: v.default 51 | for n, v in iteritems(sig.parameters) 52 | if v.default != inspect._empty 53 | } 54 | 55 | @wraps(func) 56 | def _func(*args, **kwargs): 57 | args_dct = copy.copy(default_kwargs) 58 | args_dct.update(dict(zip(sig.parameters.keys(), args))) 59 | args_dct.update(kwargs) 60 | cache_str = format_str.format(**args_dct) 61 | if cache_str not in _cache_dct: 62 | _cache_dct[cache_str] = func(*args, **kwargs) 63 | return _cache_dct[cache_str] 64 | 65 | return _func 66 | 67 | return _cache 68 | 69 | 70 | def _generate_default_fix_cfg(names, scale=0, bitwidth=8, method=0): 71 | return { 72 | n: { 73 | "method": torch.autograd.Variable( 74 | torch.IntTensor(np.array([method])), requires_grad=False 75 | ), 76 | "scale": torch.autograd.Variable( 77 | torch.FloatTensor(np.array([scale])), requires_grad=False 78 | ), 79 | "bitwidth": torch.autograd.Variable( 80 | torch.IntTensor(np.array([bitwidth])), requires_grad=False 81 | ), 82 | } 83 | for n in names 84 | } 85 | 86 | 87 | _context = {} 88 | DEFAULT_KWARGS_KEY = "__fix_module_default_kwargs__" 89 | 90 | 91 | def get_kwargs(cls): 92 | return _context.get(DEFAULT_KWARGS_KEY, {}).get(cls.__name__, {}) 93 | 94 | 95 | @contextmanager 96 | def fix_kwargs_scope(_override=False, **kwargs): 97 | old_kwargs = copy.copy(_context.get(DEFAULT_KWARGS_KEY, None)) 98 | if _override or DEFAULT_KWARGS_KEY not in _context: 99 | _context[DEFAULT_KWARGS_KEY] = kwargs 100 | else: 101 | _context[DEFAULT_KWARGS_KEY].update(kwargs) 102 | yield 103 | if old_kwargs is None: 104 | _context.pop(DEFAULT_KWARGS_KEY, None) 105 | else: 106 | _context[DEFAULT_KWARGS_KEY] = old_kwargs 107 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import numpy as np 3 | import torch 4 | from torch.nn import Module 5 | from torch.nn import functional as F 6 | from torch.nn.parameter import Parameter 7 | import nics_fix_pt.nn_fix as nnf 8 | 9 | def _generate_default_fix_cfg(names, scale=0, bitwidth=8, method=0): 10 | return { 11 | n: { 12 | "method": torch.autograd.Variable( 13 | torch.IntTensor(np.array([method])), requires_grad=False 14 | ), 15 | "scale": torch.autograd.Variable( 16 | torch.IntTensor(np.array([scale])), requires_grad=False 17 | ), 18 | "bitwidth": torch.autograd.Variable( 19 | torch.IntTensor(np.array([bitwidth])), requires_grad=False 20 | ), 21 | } 22 | for n in names 23 | } 24 | 25 | class TestNetwork(nnf.FixTopModule): 26 | def __init__(self): 27 | super(TestNetwork, self).__init__() 28 | self.fix_params = {} 29 | for conv_name in ["conv1", "conv2"]: 30 | self.fix_params[conv_name] = _generate_default_fix_cfg( 31 | ["weight", "bias"], method=1, bitwidth=8) 32 | for bn_name in ["bn1", "bn2"]: 33 | self.fix_params[bn_name] = _generate_default_fix_cfg( 34 | ["weight", "bias", "running_mean", "running_var"], method=1, bitwidth=8) 35 | self.conv1 = nnf.Conv2d_fix(3, 64, (3, 3), padding=1, 36 | nf_fix_params=self.fix_params["conv1"]) 37 | self.bn1 = nnf.BatchNorm2d_fix(64, nf_fix_params=self.fix_params["bn1"]) 38 | self.conv2 = nnf.Conv2d_fix(64, 128, (3, 3), padding=1, 39 | nf_fix_params=self.fix_params["conv2"]) 40 | self.bn2 = nnf.BatchNorm2d_fix(128, nf_fix_params=self.fix_params["bn2"]) 41 | 42 | @pytest.fixture 43 | def test_network(): 44 | return TestNetwork() 45 | 46 | # ---- 47 | class TestModule(Module): 48 | def __init__(self, input_num): 49 | super(TestModule, self).__init__() 50 | self.param = Parameter(torch.Tensor(1, input_num)) 51 | self.reset_parameters() 52 | 53 | def reset_parameters(self): 54 | # fake data 55 | with torch.no_grad(): 56 | self.param.fill_(0) 57 | self.param[0, 0] = 0.25111 58 | self.param[0, 1] = 0.5 59 | 60 | def forward(self, input): 61 | # print("input: ", input, "param: ", self.param) 62 | return F.linear(input, self.param, None) 63 | 64 | 65 | @pytest.fixture 66 | def module_cfg(request): 67 | import nics_fix_pt as nfp 68 | import nics_fix_pt.nn_fix as nnf 69 | 70 | nfp.fix_modules.register_fix_module(TestModule) 71 | # default data/grad fix cfg for the parameter `param` of TestModule 72 | data_cfg = nfp.utils._generate_default_fix_cfg( 73 | ["param"], scale=-1, bitwidth=2, method=nfp.FIX_AUTO 74 | ) 75 | grad_cfg = nfp.utils._generate_default_fix_cfg( 76 | ["param"], scale=-1, bitwidth=2, method=nfp.FIX_NONE 77 | ) 78 | # the specified overriding cfgs: input_num, data fix cfg, grad fix cfg 79 | update_cfg = getattr(request, "param", {}) 80 | input_num = update_cfg.pop("input_num", 3) 81 | data_update_cfg = update_cfg.get("data_cfg", {}) 82 | grad_update_cfg = update_cfg.get("grad_cfg", {}) 83 | data_cfg["param"].update(data_update_cfg) 84 | grad_cfg["param"].update(grad_update_cfg) 85 | module = nnf.TestModule_fix( 86 | input_num=input_num, nf_fix_params=data_cfg, nf_fix_params_grad=grad_cfg 87 | ) 88 | return module, data_cfg, grad_cfg 89 | -------------------------------------------------------------------------------- /tests/test_quant.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import numpy as np 3 | import torch 4 | 5 | import nics_fix_pt as nfp 6 | import nics_fix_pt.quant as nfpq 7 | 8 | 9 | @pytest.mark.parametrize( 10 | "case", 11 | [ 12 | { 13 | "data": [0.2513, -0.5, 0], 14 | "scale": 0.5, 15 | "bitwidth": 2, 16 | "method": nfp.FIX_FIXED, 17 | "output": ([0.25, -0.5, 0], 0.25), 18 | }, 19 | { 20 | "data": [0.2513, -0.5, 0], 21 | "scale": 0.5, 22 | "bitwidth": 16, 23 | "method": nfp.FIX_FIXED, 24 | "output": ( 25 | [np.round(0.2513 / (2 ** (-16))) * 2 ** (-16), -0.5, 0], 26 | 2 ** (-16), 27 | ), 28 | }, 29 | { 30 | "data": [0.2513, -0.5, 0], 31 | "scale": 0.5, 32 | "bitwidth": 2, 33 | "method": nfp.FIX_AUTO, 34 | "output": ([0.25, -0.5, 0], 0.25), 35 | }, 36 | { 37 | "data": [0.2513, -0.52, 0], 38 | "scale": 0.5, 39 | "bitwidth": 2, 40 | "method": nfp.FIX_AUTO, 41 | "out_scale": 1, 42 | "output": ([0.5, -0.5, 0], 0.5), 43 | }, 44 | { 45 | "data": [0.2513, -0.52, 0], 46 | "scale": 0.5, 47 | "bitwidth": 4, 48 | "method": nfp.FIX_AUTO, 49 | "range_method": nfp.RANGE_3SIGMA, 50 | "output": ([0.25, -0.5, 0], 0.25), 51 | }, 52 | # test the float scale 53 | { 54 | "data": [0.2513, -0.52, 0.0], 55 | "scale": 0.52, 56 | "bitwidth": 4, 57 | "method": nfp.FIX_AUTO, 58 | "output": ([0.26, -0.52, 0.0], 0.065), 59 | "float_scale": True, 60 | }, 61 | { 62 | "data": [0.2513, -0.52, 0.0], 63 | "scale": 0.5, 64 | "bitwidth": 4, 65 | "method": nfp.FIX_AUTO, 66 | "output": ([(0.2513 + 0.52) * 5 / 16, -0.52, 0.0], (0.2513 + 0.52) / 16), 67 | "float_scale": True, 68 | "zero_point": True, 69 | }, 70 | { 71 | "data": [[[[0.2513]], [[-0.52]]], [[[0.3321]], [[-0.4532]]]], 72 | "scale": [ 73 | 0.2513, 74 | 0.3321, 75 | ], # max_data = data.view(data.shape[0],-1).max(dim=1)[0] 76 | "bitwidth": 4, 77 | "method": nfp.FIX_AUTO, 78 | "output": ( 79 | [[[[4 * 0.52 / 8]], [[-0.52]]], [[[6 * 0.4532 / 8]], [[-0.4532]]]], 80 | [0.52 / 8, 0.4532 / 8], 81 | ), 82 | "float_scale": True, 83 | "group": "batch", 84 | }, 85 | ], 86 | ) 87 | def test_quantize_cfg(case): 88 | scale_tensor = torch.tensor([case["scale"]]) 89 | out = nfpq.quantize_cfg( 90 | torch.tensor(case["data"]), 91 | scale_tensor, 92 | torch.tensor(case["bitwidth"]), 93 | case["method"], 94 | case["range_method"] if "range_method" in case.keys() else nfp.RANGE_MAX, 95 | stochastic=case["stochastic"] if "stochastic" in case.keys() else False, 96 | float_scale=case["float_scale"] if "float_scale" in case.keys() else False, 97 | zero_point=case["zero_point"] if "zero_point" in case.keys() else False, 98 | group=case["group"] if "group" in case.keys() else False, 99 | ) 100 | assert np.isclose(out[0], case["output"][0]).all() 101 | assert np.isclose(out[1].view(-1), case["output"][1]).all() 102 | if "out_scale" in case: 103 | assert bool(scale_tensor == case["out_scale"]) 104 | 105 | 106 | def test_straight_through_gradient(): 107 | inputs = torch.autograd.Variable(torch.tensor([1.1, 0.9]), requires_grad=True) 108 | outputs = nfpq.StraightThroughRound().apply(inputs) 109 | outputs.sum().backward() 110 | assert np.isclose(inputs._grad, [1, 1]).all() 111 | 112 | # when Round is applied without straight through, there is no gradient 113 | inputs.grad.detach_() 114 | inputs.grad.zero_() 115 | output_nost = inputs.round() 116 | assert np.isclose(inputs._grad, [0, 0]).all() 117 | 118 | # Stochastic rounding 119 | inputs = torch.autograd.Variable(torch.Tensor(100).fill_(0.5), requires_grad=True) 120 | outputs = nfpq.StraightThroughStochasticRound().apply(inputs) 121 | assert outputs.max() > 0.9 and outputs.min() < 0.1 122 | 123 | 124 | def test_quantize_gradient(): 125 | quant_grad = nfpq.QuantitizeGradient() 126 | scale = torch.Tensor([0]) 127 | inputs = torch.autograd.Variable(torch.tensor([1.1, 0.9]), requires_grad=True) 128 | quanted = quant_grad.apply(inputs, scale, torch.tensor(2), nfp.FIX_AUTO) 129 | output = ( 130 | quanted * torch.autograd.Variable(torch.tensor([0.5, 0.26]), requires_grad=True) 131 | ).sum() 132 | output.backward() 133 | assert np.isclose(inputs._grad, [0.5, 0.25]).all() 134 | assert scale.item() == 0.5 135 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Fixed Point Training Simulation Framework on PyTorch 2 | 3 | ### Core Functionality 4 | - Parameter/Buffer fix: by using `nics_fix_pt.nn_fix._fix` modules; 5 | parameters are those to be trained, buffers are those to be persistenced but not considered parameter 6 | - Activation fix: by using `ActivationFix` module 7 | - Data fix VS. Gradient fix: by supply `nf_fix_params`/`nf_fix_params_grad` kwargs args 8 | in `nics_fix_pt.nn_fix._fix` or `ActivationFix` module construction 9 | 10 | > NOTE: When constructing a fixed-point module, the dictionary you passed in as `nf_fix_params` argument will be used directly, so if you pass the same configuration dict to two modules. These two module will share the same configurations... If you want each module to be quantized independently, you should construct different configuration dict for each module. 11 | 12 | The underhood quantization of each datum is implemented in `nics_fix.quant._do_quantitize`, the code is easy to understand, if you need different plug-in quantize method other than this default one, or maybe need a configurable behavior of `_do_quantitize`(maybe in simulation for simultaneous computation on different types of device), please contribute for a pluggable and configurable `_do_quantitize` (I believe this will be a simple change), or you can contact me. 13 | 14 | ### Examples 15 | See `examples/mnist/train_mnist.py` for a MNIST fix-point training example, and some demo usage of the utilities. 16 | 17 | ### Usage Explained 18 | 19 | **How to make a module fixed-point** 20 | 21 | If you have implemented your new module type, e.g. a masked convolution module type (a convolution with a mask that is generated by prunning) `MaskedConv2d`, you want a fixed-point version of it, just do: 22 | ```python 23 | from nics_fix_pt import register_fix_module 24 | register_fix_module(MaskedConv2d) 25 | ``` 26 | 27 | Then you can construct the fixed-point module, and use it in the following code. All of it functionalities will stay the same, except the weights and the gradient of weights(optionally) will be converted to fixed-point before each usage. 28 | 29 | ```python 30 | import nics_fix_pt.nn_fix as nnf 31 | masked_conv = nnf.MaksedConv2d_fix(...your paramters, including fixed-point configs...) 32 | ``` 33 | 34 | The internal registered fixed-point modules are already tested in our test examples and our other works: 35 | 36 | * `torch.nn.Linear` 37 | * `torch.nn.Conv2d` 38 | * `torch.nn.BatchNorm1d` 39 | * `torch.nn.BatchNorm2d` 40 | 41 | > NOTE: We expect most `torch.nn` modules will work well without special handling, so you can use the AutoRegister functionality, which will auto register non-registered `torch.nn` modules. However, as we do not test all these modules thoroughly, if you find some modules fail to work normally in some use case, please tell us or contribute to handle that specially. 42 | 43 | **How to inspect the float-point precision/quantized datum** 44 | 45 | Network parameters are saved as float in `module._parameters[]`, and `module.` is the fixed-point/quantized version. You can either use `module._parameters`, or `.nfp_actual_data` (e.g. `masked_conv.weights.nfp_actual_data`) to access the float-point precision datum. 46 | 47 | As `module._parameters` and `module._buffers` are used by `model.state_dict`, when you dump checkpoints onto the disk, the saved network parameters using `model.state_dict` is float precision: 48 | * In your use cases(e.g. fixed-point hardware simultation), together with the saved float-point precision parameters, you might need to dump and then load/modify the fixed configurations of the variables using `model.get_fix_configs` and `model.load_fix_configs`. Check `examples/mnist/train_mnist.py` for an example. 49 | * Or if you want to directly dump the latest used version of the parameters (which might be a quantitized tensor, depend on your latest configuration), use `nnf.FixTopModule.fix_state_dict(module)` for dumping instead. 50 | 51 | **How to config the fixed-point behavior** 52 | 53 | Every fixed-point module need a fixed-point configuration, and an optional fixed-point configuration for the gradients. 54 | 55 | A config for the module should be a dict, keys are the parameter or buffer names, values is a dict includes torch tensors (for current "bitwidth", "method", "scale"), which are modified in place by function calls to `nics_fix_pt.quant.quantitize_cfg`; 56 | 57 | **How to check/manage the configuration** 58 | 59 | 1. For each quantitized datum, you can use `.data_cfg` or `.grad_cfg`, e.g. `masked_conv.weights.data_cfg` or `masked_conv.bias.data_cfg`. 60 | 2. You can also use `FixTopModule.get_fix_configs(module)` to get configs for multiple modulues in one OrderedDict. 61 | 62 | You can modify the config tensors in place to change the behavior. 63 | 64 | ### Utilities 65 | 66 | - FixTopModule: dump/load fix configuration to/from file; print fix configs. 67 | 68 | FixTopModule is just a wrapper that gather config print/load/dump/setting utilities, these utilities will work with nested normal nn.Module as intermediate module containers, e.g. `nn.Sequential` of fixed modules will also work, you do not need to have a subclass multi-inherited from `nn.Sequential` and `nnf.FixTopModule`! 69 | 70 | - AutoRegister: auto register corresponding fixed-point modules for modules in `torch.nn`: Automatically register all not-registered module by proxing to modules in torch.nn. Exampe usage: 71 | ```python 72 | from nics_fix_pt import NAR as nnf_auto 73 | bilinear = nnf_auto.Bilinear_fix(...parameters...) 74 | ``` 75 | 76 | ### Test cases 77 | 78 | Tested with Python 2.7, 3.5, 3.6.1+. 79 | 80 | Pytorch 0.4.1, 1.0.0, 1.1.0, 1.4.0. Note that fixed-point simulation using DataParallel with Pytorch>=1.5.0 versions are not supported now! 81 | 82 | ![coverage percentage](./coverage.svg) 83 | 84 | Run `python setup.py test` to run the pytest test cases. 85 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | 4 | def test_auto_register(): 5 | import torch 6 | from nics_fix_pt import nn_fix as nnf 7 | from nics_fix_pt import NAR as nnf_auto 8 | 9 | torch.nn.Linear_another2 = torch.nn.Linear 10 | 11 | class Linear_another(torch.nn.Linear): # a stub test case 12 | pass 13 | 14 | torch.nn.Linear_another = Linear_another 15 | with pytest.raises(AttributeError) as _: 16 | fix_module = nnf.Linear_another_fix( 17 | 3, 1, nf_fix_params={} 18 | ) # not already registered 19 | fix_module = nnf_auto.Linear_another( 20 | 3, 1, nf_fix_params={} 21 | ) # will automatically register this fix module; `NAR.linear_another_fix` also works. 22 | fix_module = nnf_auto.Linear_another2_fix( 23 | 3, 1, nf_fix_params={} 24 | ) # will automatically register this fix module; `NAR.linear_another2` also works. 25 | fix_module = nnf.Linear_another_fix(3, 1, nf_fix_params={}) 26 | fix_module = nnf.Linear_another2_fix(3, 1, nf_fix_params={}) 27 | 28 | 29 | def test_save_state_dict(tmp_path): 30 | import os 31 | import numpy as np 32 | import torch 33 | from torch import nn 34 | 35 | import nics_fix_pt as nfp 36 | from nics_fix_pt import nn_fix as nnf 37 | from nics_fix_pt.utils import _generate_default_fix_cfg 38 | 39 | class _View(nn.Module): 40 | def __init__(self): 41 | super(_View, self).__init__() 42 | 43 | def forward(self, inputs): 44 | return inputs.view(inputs.shape[0], -1) 45 | data = torch.tensor(np.random.rand(8, 3, 4, 4).astype(np.float32)).cuda() 46 | ckpt = os.path.join(tmp_path, "tmp.pt") 47 | 48 | model = nn.Sequential(*[ 49 | nnf.Conv2d_fix(3, 10, kernel_size=3, padding=1, 50 | nf_fix_params=_generate_default_fix_cfg( 51 | ["weight", "bias"], scale=2.**np.random.randint(low=-10, high=10), 52 | method=nfp.FIX_FIXED)), 53 | nnf.Conv2d_fix(10, 20, kernel_size=3, padding=1, 54 | nf_fix_params=_generate_default_fix_cfg( 55 | ["weight", "bias"], scale=2.**np.random.randint(low=-10, high=10), 56 | method=nfp.FIX_FIXED)), 57 | nnf.Activation_fix( 58 | nf_fix_params=_generate_default_fix_cfg( 59 | ["activation"], scale=2.**np.random.randint(low=-10, high=10), 60 | method=nfp.FIX_FIXED)), 61 | nn.AdaptiveAvgPool2d(1), 62 | _View(), 63 | nnf.Linear_fix(20, 10, nf_fix_params=_generate_default_fix_cfg( 64 | ["weight", "bias"], scale=2.**np.random.randint(low=-10, high=10), 65 | method=nfp.FIX_FIXED)) 66 | ]) 67 | model.cuda() 68 | pre_results = model(data) 69 | torch.save(model.state_dict(), ckpt) 70 | model2 = nn.Sequential(*[ 71 | nnf.Conv2d_fix(3, 10, kernel_size=3, padding=1, 72 | nf_fix_params=_generate_default_fix_cfg( 73 | ["weight", "bias"], scale=2.**np.random.randint(low=-10, high=10), 74 | method=nfp.FIX_FIXED)), 75 | nnf.Conv2d_fix(10, 20, kernel_size=3, padding=1, 76 | nf_fix_params=_generate_default_fix_cfg( 77 | ["weight", "bias"], scale=2.**np.random.randint(low=-10, high=10), 78 | method=nfp.FIX_FIXED)), 79 | nnf.Activation_fix( 80 | nf_fix_params=_generate_default_fix_cfg( 81 | ["activation"], scale=2.**np.random.randint(low=-10, high=10), 82 | method=nfp.FIX_FIXED)), 83 | nn.AdaptiveAvgPool2d(1), 84 | _View(), 85 | nnf.Linear_fix(20, 10, nf_fix_params=_generate_default_fix_cfg( 86 | ["weight", "bias"], scale=2.**np.random.randint(low=-10, high=10), 87 | method=nfp.FIX_FIXED)) 88 | ]) 89 | model2.cuda() 90 | model2.load_state_dict(torch.load(ckpt)) 91 | post_results = model2(data) 92 | assert (post_results - pre_results < 1e-2).all() 93 | 94 | 95 | def test_fix_state_dict(module_cfg): 96 | import torch 97 | from nics_fix_pt.nn_fix import FixTopModule 98 | import nics_fix_pt.quant as nfpq 99 | 100 | module, cfg, _ = module_cfg 101 | dct = FixTopModule.fix_state_dict(module) 102 | assert (dct["param"] == module._parameters["param"]).all() # not already fixed 103 | 104 | # forward the module once 105 | res = module.forward(torch.tensor([0, 0, 0]).float()) 106 | dct = FixTopModule.fix_state_dict(module) 107 | dct_vars = FixTopModule.fix_state_dict(module, keep_vars=True) 108 | quantized, _ = nfpq.quantize_cfg(module._parameters["param"], **cfg["param"]) 109 | dct_vars = FixTopModule.fix_state_dict(module, keep_vars=True) 110 | assert (dct["param"] == quantized).all() # already fixed 111 | assert (dct_vars["param"] == quantized).all() # already fixed 112 | assert ( 113 | dct_vars["param"].nfp_actual_data == module._parameters["param"] 114 | ).all() # underhood float-point data 115 | 116 | 117 | def test_set_fix_method(test_network): 118 | test_network.set_fix_method(method=0, method_by_type={ 119 | "BatchNorm2d_fix": {"running_mean": 2, "running_var": 2, "weight": 1, "bias": 0} 120 | }, method_by_name={"conv1": {"weight": 2}}) 121 | assert int(test_network.bn1.nf_fix_params["weight"]["method"]) == 1 122 | assert int(test_network.bn1.nf_fix_params["bias"]["method"]) == 0 123 | assert int(test_network.bn1.nf_fix_params["running_mean"]["method"]) == 2 124 | assert int(test_network.conv1.nf_fix_params["weight"]["method"]) == 2 125 | assert int(test_network.conv1.nf_fix_params["bias"]["method"]) == 1 126 | assert int(test_network.conv2.nf_fix_params["weight"]["method"]) == 0 127 | assert int(test_network.conv2.nf_fix_params["bias"]["method"]) == 0 128 | -------------------------------------------------------------------------------- /examples/cifar10/net.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | from torch import nn 4 | from torch.nn import functional as F 5 | 6 | import nics_fix_pt as nfp 7 | import nics_fix_pt.nn_fix as nnf 8 | import numpy as np 9 | 10 | def _generate_default_fix_cfg(names, scale=0, bitwidth=8, method=0, 11 | range_method=nfp.RangeMethod.RANGE_MAX): 12 | return { 13 | n: { 14 | "method": torch.autograd.Variable( 15 | torch.IntTensor(np.array([method])), requires_grad=False 16 | ), 17 | "scale": torch.autograd.Variable( 18 | torch.IntTensor(np.array([scale])), requires_grad=False 19 | ), 20 | "bitwidth": torch.autograd.Variable( 21 | torch.IntTensor(np.array([bitwidth])), requires_grad=False 22 | ), 23 | "range_method": range_method 24 | } 25 | for n in names 26 | } 27 | 28 | class FixNet(nnf.FixTopModule): 29 | def __init__(self, fix_bn=True, fix_grad=True, bitwidth_data=8, bitwidth_grad=16, 30 | range_method=nfp.RangeMethod.RANGE_MAX, 31 | grad_range_method=nfp.RangeMethod.RANGE_MAX): 32 | super(FixNet, self).__init__() 33 | 34 | print("fix bn: {}; fix grad: {}; range method: {}; grad range method: {}".format( 35 | fix_bn, fix_grad, range_method, grad_range_method 36 | )) 37 | 38 | # fix configurations (data/grad) for parameters/buffers 39 | self.fix_param_cfgs = {} 40 | self.fix_grad_cfgs = {} 41 | layers = [("conv1_1", 128, 3), ("bn1_1",), ("conv1_2", 128, 3), ("bn1_2",), 42 | ("conv1_3", 128, 3), ("bn1_3",), ("conv2_1", 256, 3), ("bn2_1",), 43 | ("conv2_2", 256, 3), ("bn2_2",), ("conv2_3", 256, 3), ("bn2_3",), 44 | ("conv3_1", 512, 3), ("bn3_1",), ("nin3_2", 256, 1), ("bn3_2",), 45 | ("nin3_3", 128, 1), ("bn3_3",), ("fc4", 10)] 46 | for layer_cfg in layers: 47 | name = layer_cfg[0] 48 | if "bn" in name and not fix_bn: 49 | continue 50 | # data fix config 51 | self.fix_param_cfgs[name] = _generate_default_fix_cfg( 52 | ["weight", "bias", "running_mean", "running_var"] \ 53 | if "bn" in name else ["weight", "bias"], 54 | method=1, bitwidth=bitwidth_data, range_method=range_method 55 | ) 56 | if fix_grad: 57 | # grad fix config 58 | self.fix_grad_cfgs[name] = _generate_default_fix_cfg( 59 | ["weight", "bias"], method=1, bitwidth=bitwidth_grad, 60 | range_method=grad_range_method 61 | ) 62 | 63 | # fix configurations for activations 64 | # data fix config 65 | self.fix_act_cfgs = [ 66 | _generate_default_fix_cfg(["activation"], method=1, bitwidth=bitwidth_data, 67 | range_method=range_method) 68 | for _ in range(20) 69 | ] 70 | if fix_grad: 71 | # grad fix config 72 | self.fix_act_grad_cfgs = [ 73 | _generate_default_fix_cfg(["activation"], method=1, bitwidth=bitwidth_grad, 74 | range_method=grad_range_method) 75 | for _ in range(20) 76 | ] 77 | 78 | # construct layers 79 | cin = 3 80 | for layer_cfg in layers: 81 | name = layer_cfg[0] 82 | if "conv" in name or "nin" in name: 83 | # convolution layers 84 | cout, kernel_size = layer_cfg[1:] 85 | layer = nnf.Conv2d_fix( 86 | cin, cout, 87 | nf_fix_params=self.fix_param_cfgs[name], 88 | nf_fix_params_grad=self.fix_grad_cfgs[name] if fix_grad else None, 89 | kernel_size=kernel_size, 90 | padding=(kernel_size - 1) // 2 if name != "conv3_1" else 0) 91 | cin = cout 92 | elif "bn" in name: 93 | # bn layers 94 | if fix_bn: 95 | layer = nnf.BatchNorm2d_fix( 96 | cin, 97 | nf_fix_params=self.fix_param_cfgs[name], 98 | nf_fix_params_grad=self.fix_grad_cfgs[name] if fix_grad else None) 99 | else: 100 | layer = nn.BatchNorm2d(cin) 101 | elif "fc" in name: 102 | # fully-connected layers 103 | cout = layer_cfg[1] 104 | layer = nnf.Linear_fix( 105 | cin, cout, 106 | nf_fix_params=self.fix_param_cfgs[name], 107 | nf_fix_params_grad=self.fix_grad_cfgs[name] if fix_grad else None) 108 | cin = cout 109 | # call setattr 110 | setattr(self, name, layer) 111 | 112 | for i in range(20): 113 | setattr(self, "fix" + str(i), nnf.Activation_fix( 114 | nf_fix_params=self.fix_act_cfgs[i], 115 | nf_fix_params_grad=self.fix_act_grad_cfgs[i] if fix_grad else None)) 116 | 117 | self.pool1 = nn.MaxPool2d((2, 2)) 118 | self.pool2 = nn.MaxPool2d((2, 2)) 119 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 120 | 121 | def forward(self, x): 122 | x = self.fix0(x) 123 | x = self.fix2(F.relu(self.bn1_1(self.fix1(self.conv1_1(x))))) 124 | x = self.fix4(F.relu(self.bn1_2(self.fix3(self.conv1_2(x))))) 125 | x = self.pool1(self.fix6(F.relu(self.bn1_3(self.fix5(self.conv1_3(x)))))) 126 | x = self.fix8(F.relu(self.bn2_1(self.fix7(self.conv2_1(x))))) 127 | x = self.fix10(F.relu(self.bn2_2(self.fix9(self.conv2_2(x))))) 128 | x = self.pool2(self.fix12(F.relu(self.bn2_3(self.fix11(self.conv2_3(x)))))) 129 | x = self.fix14(F.relu(self.bn3_1(self.fix13(self.conv3_1(x))))) 130 | x = self.fix16(F.relu(self.bn3_2(self.fix15(self.nin3_2(x))))) 131 | x = self.fix18(F.relu(self.bn3_3(self.fix17(self.nin3_3(x))))) 132 | # x = self.fix2(F.relu(self.bn1_1(self.conv1_1(x)))) 133 | # x = self.fix4(F.relu(self.bn1_2(self.conv1_2(x)))) 134 | # x = self.pool1(self.fix6(F.relu(self.bn1_3(self.conv1_3(x))))) 135 | # x = self.fix8(F.relu(self.bn2_1(self.conv2_1(x)))) 136 | # x = self.fix10(F.relu(self.bn2_2(self.conv2_2(x)))) 137 | # x = self.pool2(self.fix12(F.relu(self.bn2_3(self.conv2_3(x))))) 138 | # x = self.fix14(F.relu(self.bn3_1(self.conv3_1(x)))) 139 | # x = self.fix16(F.relu(self.bn3_2(self.nin3_2(x)))) 140 | # x = self.fix18(F.relu(self.bn3_3(self.nin3_3(x)))) 141 | x = self.avg_pool(x) 142 | x = x.view(-1, 128) 143 | x = self.fix19(self.fc4(x)) 144 | 145 | return x 146 | -------------------------------------------------------------------------------- /tests/test_fix_module.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import numpy as np 4 | import torch 5 | from torch import nn 6 | import torch.optim as optim 7 | 8 | import nics_fix_pt as nfp 9 | 10 | # When module_cfg's nf_fix_paramparam is set , it means scale=-1, bitwidth=2, method=FIX_AUTO, see the default config in conftest module_cfg fixture. 11 | @pytest.mark.parametrize( 12 | "module_cfg, case", 13 | [ 14 | ( 15 | {"input_num": 3}, 16 | { 17 | "inputs": [1, 1, 0], 18 | "data": [0.2513, -0.52, 0], 19 | "out_scale": 1, 20 | "result": 0, 21 | "output": [0.5, -0.5, 0], # quantized parameters, step 0.5 22 | }, 23 | ), 24 | ( 25 | {"input_num": 3}, 26 | { 27 | "inputs": [1, 1, 0], 28 | "data": [0.2513, -0.5, 0], 29 | "out_scale": 0.5, 30 | "result": -0.25, 31 | "output": [0.25, -0.5, 0], # quantized parameters, step 0.25 32 | }, 33 | ), 34 | ], 35 | indirect=["module_cfg"], 36 | ) 37 | def test_fix_forward_auto(module_cfg, case): 38 | module, cfg, _ = module_cfg 39 | if "data" in case: 40 | module.param[0, :] = torch.tensor(case["data"]) 41 | with torch.no_grad(): 42 | res = module.forward(torch.tensor(case["inputs"]).float()) 43 | assert np.isclose(res, case["result"]) # calc output 44 | assert np.isclose(module.param, case["output"]).all() # quantized parameter 45 | assert cfg["param"]["scale"] == case["out_scale"] # scale 46 | 47 | @pytest.mark.parametrize( 48 | "module_cfg, case", 49 | [ 50 | ( 51 | {"input_num": 3}, 52 | { 53 | "inputs": [[1, 1, 0], [1, 2, 0]], 54 | "data": [0.2513, -0.52, 0], 55 | "out_scale": 1, 56 | "result": [[0], [-0.5]], 57 | "output": [0.5, -0.5, 0], # quantized parameters, step 0.5 58 | }, 59 | ), 60 | ( 61 | {"input_num": 3}, 62 | { 63 | "inputs": [[1, 1, 0], [1, 1, 0]], 64 | "data": [0.2513, -0.52, 0], 65 | "out_scale": 1, 66 | "result": [[0], [0]], 67 | "output": [0.5, -0.5, 0], # quantized parameters, step 0.5 68 | }, 69 | ), 70 | ( 71 | {"input_num": 3}, 72 | { 73 | "inputs": [[1, 1, 0], [1, 1, 0]], 74 | "data": [0.2513, -0.5, 0], 75 | "out_scale": 0.5, 76 | "result": [[-0.25], [-0.25]], 77 | "output": [0.25, -0.5, 0], # quantized parameters, step 0.25 78 | }, 79 | ), 80 | ], 81 | indirect=["module_cfg"], 82 | ) 83 | def test_fix_forward_parallel_gpu(module_cfg, case): 84 | module, cfg, _ = module_cfg 85 | if "data" in case: 86 | module.param.data[0, :] = torch.tensor(case["data"]) 87 | model = nn.DataParallel(module.cuda(), [0, 1]) 88 | with torch.no_grad(): 89 | res = model(torch.tensor(case["inputs"]).float().cuda()) 90 | assert cfg["param"]["scale"] == case["out_scale"] # scale 91 | assert np.isclose(res.cpu(), case["result"]).all() # calc output 92 | # assert np.isclose(module.param.cpu(), case["output"]).all() # quantized parameter 93 | # this will not change, 94 | # but the gradient will still be accumulated in module_parameters[name].grad 95 | 96 | @pytest.mark.parametrize( 97 | "module_cfg, case", 98 | [ 99 | ( 100 | {"input_num": 3, "grad_cfg": {"method": nfp.FIX_AUTO}}, 101 | { 102 | "inputs": [0.52, -0.27, 0], 103 | "data": [0, 0, 0], 104 | "grad_scale": 1, 105 | "output": [0.5, -0.5, 0], 106 | }, 107 | ), 108 | ( 109 | {"input_num": 3, "grad_cfg": {"method": nfp.FIX_AUTO}}, 110 | { 111 | "inputs": [0.5, -0.27, 0], 112 | "data": [0, 0, 0], 113 | "grad_scale": 0.5, 114 | "output": [0.5, -0.25, 0], # quantized gradients 115 | }, 116 | ), 117 | ], 118 | indirect=["module_cfg"], 119 | ) 120 | def test_fix_backward_auto(module_cfg, case): 121 | module, _, cfg = module_cfg 122 | if "data" in case: 123 | module.param.data[0, :] = torch.tensor(case["data"]) 124 | res = module.forward(torch.tensor(case["inputs"]).float()) 125 | res.backward() 126 | assert np.isclose( 127 | module._parameters["param"].grad, case["output"] 128 | ).all() # quantized gradient 129 | assert cfg["param"]["scale"] == case["grad_scale"] # scale 130 | 131 | @pytest.mark.parametrize( 132 | "module_cfg, case", 133 | [ 134 | ( 135 | {"input_num": 3, "data_cfg": {"method": nfp.FIX_NONE}, 136 | "grad_cfg": {"method": nfp.FIX_AUTO}}, 137 | { 138 | "inputs": [[0.52, -0.27, 0], [0.52, -0.27, 0]], 139 | "data": [0, 0, 0], 140 | "grad_scale": 1, 141 | "output": [0.5, -0.5, 0], 142 | }, 143 | ), 144 | ( 145 | {"input_num": 3, "grad_cfg": {"method": nfp.FIX_AUTO}}, 146 | { 147 | "inputs": [[0.5, -0.27, 0], [0.5, -0.27, 0]], 148 | "data": [0, 0, 0], 149 | "grad_scale": 0.5, 150 | "output": [0.5, -0.25, 0], # quantized gradients 151 | }, 152 | ), 153 | ], 154 | indirect=["module_cfg"], 155 | ) 156 | def test_fix_backward_parallel_gpu(module_cfg, case): 157 | module, _, cfg = module_cfg 158 | if "data" in case: 159 | module.param.data[0, :] = torch.tensor(case["data"]) 160 | model = nn.DataParallel(module.cuda(), [0, 1]) 161 | res = torch.sum(model(torch.tensor(case["inputs"]).float().cuda())) 162 | res.backward() 163 | assert np.isclose( 164 | module._parameters["param"].grad.cpu(), 2 * np.array(case["output"]) 165 | ).all() # quantized gradient, 2 batch, grad x 2 166 | assert cfg["param"]["scale"] == case["grad_scale"] # scale 167 | 168 | @pytest.mark.parametrize( 169 | "module_cfg, case", 170 | [ 171 | ( 172 | {"input_num": 3, "grad_cfg": {"method": nfp.FIX_AUTO}}, 173 | { 174 | "inputs": [0.52, -0.27, 0], 175 | "data": [0, 0, 0], 176 | "grad_scale": 1, 177 | "output": [0.5, -0.5, 0], 178 | }, 179 | ), 180 | ( 181 | {"input_num": 3, "grad_cfg": {"method": nfp.FIX_AUTO}}, 182 | { 183 | "inputs": [0.5, -0.27, 0], 184 | "data": [0, 0, 0], 185 | "grad_scale": 0.5, 186 | "output": [0.5, -0.25, 0], # quantized gradients 187 | }, 188 | ), 189 | ], 190 | indirect=["module_cfg"], 191 | ) 192 | def test_fix_update_auto(module_cfg, case): 193 | module, _, cfg = module_cfg 194 | if "data" in case: 195 | module.param.data[0, :] = torch.tensor(case["data"]) 196 | optimizer = optim.SGD(module.parameters(), lr=1.0, momentum=0) 197 | res = module.forward(torch.tensor(case["inputs"]).float()) 198 | res.backward() 199 | optimizer.step() 200 | assert np.isclose( 201 | -module._parameters["param"].detach(), case["output"] 202 | ).all() # updated parameter should be - lr * gradient 203 | assert cfg["param"]["scale"] == case["grad_scale"] # scale 204 | 205 | def test_ConvBN_fix(): 206 | from nics_fix_pt.nn_fix import ConvBN_fix 207 | # float forward and combine forward get the same results 208 | module = ConvBN_fix(3, 32, nf_fix_params={}).cuda() 209 | module.train() 210 | data = torch.tensor(np.random.rand(128, 3, 32, 32).astype(np.float32)).cuda() 211 | comb_out = module(data) 212 | float_out = module.bn(module.conv(data)) 213 | assert (float_out - comb_out < 1e-3).all() 214 | 215 | module.eval() 216 | module = ConvBN_fix(3, 32, nf_fix_params={}).cuda() 217 | data = torch.tensor(np.random.rand(128, 3, 32, 32).astype(np.float32)).cuda() 218 | comb_out = module(data) 219 | float_out = module.bn(module.conv(data)) 220 | assert (float_out - comb_out < 1e-3).all() 221 | 222 | -------------------------------------------------------------------------------- /examples/mnist/train_mnist.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from __future__ import print_function 4 | 5 | import argparse 6 | import yaml 7 | 8 | import numpy as np 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | import torch.optim as optim 13 | from torchvision import datasets, transforms 14 | from torch.autograd import Variable 15 | 16 | import nics_fix_pt as nfp 17 | import nics_fix_pt.nn_fix as nnf 18 | 19 | # Training settings 20 | parser = argparse.ArgumentParser(description="PyTorch MNIST Example") 21 | parser.add_argument( 22 | "--batch-size", 23 | type=int, 24 | default=64, 25 | metavar="N", 26 | help="input batch size for training (default: 64)", 27 | ) 28 | parser.add_argument( 29 | "--test-batch-size", 30 | type=int, 31 | default=1000, 32 | metavar="N", 33 | help="input batch size for testing (default: 1000)", 34 | ) 35 | parser.add_argument( 36 | "--epochs", 37 | type=int, 38 | default=1, 39 | metavar="N", 40 | help="number of epochs to train (default: 1)", 41 | ) 42 | parser.add_argument( 43 | "--lr", type=float, default=0.01, metavar="LR", help="learning rate (default: 0.01)" 44 | ) 45 | parser.add_argument( 46 | "--momentum", 47 | type=float, 48 | default=0.5, 49 | metavar="M", 50 | help="SGD momentum (default: 0.5)", 51 | ) 52 | parser.add_argument( 53 | "--no-cuda", action="store_true", default=False, help="disables CUDA training" 54 | ) 55 | parser.add_argument( 56 | "--seed", type=int, default=1, metavar="S", help="random seed (default: 1)" 57 | ) 58 | parser.add_argument( 59 | "--log-interval", 60 | type=int, 61 | default=10, 62 | metavar="N", 63 | help="how many batches to wait before logging training status", 64 | ) 65 | parser.add_argument( 66 | "--float", 67 | action="store_true", 68 | default=False, 69 | help="use float point training/testing", 70 | ) 71 | parser.add_argument("--save", default=None, help="save fixed-point paramters to file") 72 | args = parser.parse_args() 73 | args.cuda = not args.no_cuda and torch.cuda.is_available() 74 | 75 | torch.manual_seed(args.seed) 76 | if args.cuda: 77 | torch.cuda.manual_seed(args.seed) 78 | 79 | 80 | kwargs = {"num_workers": 1, "pin_memory": True} if args.cuda else {} 81 | train_loader = torch.utils.data.DataLoader( 82 | datasets.MNIST( 83 | "../data", 84 | train=True, 85 | download=True, 86 | transform=transforms.Compose( 87 | [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] 88 | ), 89 | ), 90 | batch_size=args.batch_size, 91 | shuffle=True, 92 | **kwargs 93 | ) 94 | test_loader = torch.utils.data.DataLoader( 95 | datasets.MNIST( 96 | "../data", 97 | train=False, 98 | transform=transforms.Compose( 99 | [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] 100 | ), 101 | ), 102 | batch_size=args.test_batch_size, 103 | shuffle=True, 104 | **kwargs 105 | ) 106 | 107 | 108 | def _generate_default_fix_cfg(names, scale=0, bitwidth=8, method=0): 109 | return { 110 | n: { 111 | "method": torch.autograd.Variable( 112 | torch.IntTensor(np.array([method])), requires_grad=False 113 | ), 114 | "scale": torch.autograd.Variable( 115 | torch.IntTensor(np.array([scale])), requires_grad=False 116 | ), 117 | "bitwidth": torch.autograd.Variable( 118 | torch.IntTensor(np.array([bitwidth])), requires_grad=False 119 | ), 120 | } 121 | for n in names 122 | } 123 | 124 | 125 | BITWIDTH = 4 126 | 127 | 128 | class Net(nnf.FixTopModule): 129 | def __init__(self): 130 | super(Net, self).__init__() 131 | # initialize some fix configurations 132 | self.fc1_fix_params = _generate_default_fix_cfg( 133 | ["weight", "bias"], method=1, bitwidth=BITWIDTH 134 | ) 135 | self.bn_fc1_params = _generate_default_fix_cfg( 136 | ["weight", "bias", "running_mean", "running_var"], 137 | method=1, 138 | bitwidth=BITWIDTH, 139 | ) 140 | self.fc2_fix_params = _generate_default_fix_cfg( 141 | ["weight", "bias"], method=1, bitwidth=BITWIDTH 142 | ) 143 | self.fix_params = [ 144 | _generate_default_fix_cfg(["activation"], method=1, bitwidth=BITWIDTH) 145 | for _ in range(4) 146 | ] 147 | # initialize modules 148 | self.fc1 = nnf.Linear_fix(784, 100, nf_fix_params=self.fc1_fix_params) 149 | # self.bn_fc1 = nnf.BatchNorm1d_fix(100, nf_fix_params=self.bn_fc1_params) 150 | self.fc2 = nnf.Linear_fix(100, 10, nf_fix_params=self.fc2_fix_params) 151 | self.fix0 = nnf.Activation_fix(nf_fix_params=self.fix_params[0]) 152 | # self.fix0_bn = nnf.Activation_fix(nf_fix_params=self.fix_params[1]) 153 | self.fix1 = nnf.Activation_fix(nf_fix_params=self.fix_params[2]) 154 | self.fix2 = nnf.Activation_fix(nf_fix_params=self.fix_params[3]) 155 | 156 | def forward(self, x): 157 | x = self.fix0(x.view(-1, 784)) 158 | x = F.relu(self.fix1(self.fc1(x))) 159 | # x = F.relu(self.fix0_bn(self.bn_fc1(self.fix1(self.fc1(x))))) 160 | self.logits = x = self.fix2(self.fc2(x)) 161 | return F.log_softmax(x, dim=-1) 162 | 163 | 164 | model = Net() 165 | if args.cuda: 166 | model.cuda() 167 | 168 | optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum) 169 | 170 | 171 | def train(epoch, fix_method=nfp.FIX_AUTO): 172 | model.set_fix_method(fix_method) 173 | model.train() 174 | for batch_idx, (data, target) in enumerate(train_loader): 175 | if args.cuda: 176 | data, target = data.cuda(), target.cuda() 177 | data, target = Variable(data), Variable(target) 178 | optimizer.zero_grad() 179 | output = model(data) 180 | loss = F.nll_loss(output, target) 181 | loss.backward() 182 | optimizer.step() 183 | if batch_idx % args.log_interval == 0: 184 | print( 185 | "\rTrain Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format( 186 | epoch, 187 | batch_idx * len(data), 188 | len(train_loader.dataset), 189 | 100.0 * batch_idx / len(train_loader), 190 | loss.data.item(), 191 | ), 192 | end="", 193 | ) 194 | print("") 195 | 196 | 197 | def test(fix_method=nfp.FIX_FIXED): 198 | model.set_fix_method(fix_method) 199 | model.eval() 200 | test_loss = 0 201 | correct = 0 202 | with torch.no_grad(): 203 | for data, target in test_loader: 204 | if args.cuda: 205 | data, target = data.cuda(), target.cuda() 206 | data, target = Variable(data), Variable(target) 207 | output = model(data) 208 | test_loss += F.nll_loss( 209 | output, target, size_average=False 210 | ).data.item() # sum up batch loss 211 | pred = output.data.max(1, keepdim=True)[ 212 | 1 213 | ] # get the index of the max log-probability 214 | correct += pred.eq(target.data.view_as(pred)).sum().item() 215 | 216 | test_loss /= len(test_loader.dataset) 217 | print( 218 | "\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format( 219 | test_loss, 220 | correct, 221 | len(test_loader.dataset), 222 | 100.0 * correct / len(test_loader.dataset), 223 | ) 224 | ) 225 | 226 | 227 | for epoch in range(1, args.epochs + 1): 228 | train(epoch, nfp.FIX_NONE if args.float else nfp.FIX_AUTO) 229 | test(nfp.FIX_NONE if args.float else nfp.FIX_FIXED) 230 | 231 | model.print_fix_configs() 232 | fix_cfg = { 233 | "data": model.get_fix_configs(data_only=True), 234 | "grad": model.get_fix_configs(grad=True, data_only=True), 235 | } 236 | with open("mnist_fix_config.yaml", "w") as wf: 237 | yaml.dump(fix_cfg, wf, default_flow_style=False) 238 | 239 | if args.save: 240 | state = {"model": model.fix_state_dict(), "epoch": args.epochs} 241 | torch.save(state, args.save) 242 | print("Saving fixed state dict to", args.save) 243 | 244 | # Let's try float test 245 | print("test float: ", end="") 246 | test(nfp.FIX_NONE) # after 1 epoch: ~ 92% 247 | 248 | if not args.float: 249 | # Let's load the fix config again, and test it using FIX_FIXED 250 | print("load from the yaml config and test fixed again: ", end="") 251 | with open("mnist_fix_config.yaml", "r") as rf: 252 | fix_cfg = yaml.load(rf) 253 | model.load_fix_configs(fix_cfg["data"]) 254 | model.load_fix_configs(fix_cfg["grad"], grad=True) 255 | test(nfp.FIX_FIXED) # after 1 epoch: ~ 89% 256 | -------------------------------------------------------------------------------- /examples/cifar10/main.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import argparse 4 | import os 5 | import shutil 6 | import time 7 | import math 8 | 9 | import numpy as np 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.parallel 13 | import torch.backends.cudnn as cudnn 14 | import torch.nn.functional as F 15 | import torch.optim 16 | import torch.utils.data 17 | import torchvision.transforms as transforms 18 | import torchvision.datasets as datasets 19 | 20 | import nics_fix_pt as nfp 21 | import nics_fix_pt.nn_fix as nnf 22 | from net import FixNet 23 | 24 | parser = argparse.ArgumentParser(description="PyTorch Cifar10 Fixed-Point Training") 25 | parser.add_argument( 26 | "--save-dir", 27 | required=True, 28 | help="The directory used to save the trained models", 29 | type=str, 30 | ) 31 | parser.add_argument( 32 | "--gpu", 33 | metavar="GPUs", 34 | default="0", 35 | help="The gpu devices to use" 36 | ) 37 | parser.add_argument( 38 | "--epoch", 39 | default=100, 40 | type=int, 41 | metavar="N", 42 | help="number of total epochs to run", 43 | ) 44 | parser.add_argument( 45 | "--start-epoch", 46 | default=0, 47 | type=int, 48 | metavar="N", 49 | help="manual epoch number (useful on restarts)", 50 | ) 51 | parser.add_argument( 52 | "-b", 53 | "--batch-size", 54 | default=128, 55 | type=int, 56 | metavar="N", 57 | help="mini-batch size (default: 128)", 58 | ) 59 | parser.add_argument( 60 | "--test-batch-size", 61 | type=int, 62 | default=32, 63 | metavar="N", 64 | help="input batch size for testing (default: 1000)", 65 | ) 66 | parser.add_argument( 67 | "--lr", 68 | "--learning-rate", 69 | default=0.05, 70 | type=float, 71 | metavar="LR", 72 | help="initial learning rate", 73 | ) 74 | parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum") 75 | parser.add_argument( 76 | "--weight-decay", 77 | "--wd", 78 | default=5e-4, 79 | type=float, 80 | metavar="W", 81 | help="weight decay (default: 5e-4)", 82 | ) 83 | parser.add_argument( 84 | "--print-freq", 85 | "-p", 86 | default=40, 87 | type=int, 88 | metavar="N", 89 | help="print frequency (default: 40)", 90 | ) 91 | parser.add_argument( 92 | "--resume", 93 | default="", 94 | type=str, 95 | metavar="PATH", 96 | help="path to latest checkpoint (default: none)", 97 | ) 98 | parser.add_argument( 99 | "--prefix", 100 | default="", 101 | type=str, 102 | metavar="PREFIX", 103 | help="checkpoint prefix (default: none)", 104 | ) 105 | parser.add_argument( 106 | "--float-bn", 107 | default=False, 108 | action="store_true", 109 | help="quantize the bn layer" 110 | ) 111 | parser.add_argument( 112 | "--fix-grad", 113 | default=False, 114 | action="store_true", 115 | help="quantize the gradients" 116 | ) 117 | parser.add_argument( 118 | "--range-method", 119 | default=0, 120 | choices=[0, 1, 3], 121 | help=("range methods of data (including parameters, buffers, activations). " 122 | "0: RANGE_MAX, 1: RANGE_3SIGMA, 3: RANGE_SWEEP") 123 | ) 124 | parser.add_argument( 125 | "--grad-range-method", 126 | default=0, 127 | choices=[0, 1, 3], 128 | help=("range methods of gradients (including parameters, activations)." 129 | " 0: RANGE_MAX, 1: RANGE_3SIGMA, 3: RANGE_SWEEP") 130 | ) 131 | parser.add_argument( 132 | "-e", 133 | "--evaluate", 134 | action="store_true", 135 | help="evaluate model on validation set", 136 | ) 137 | parser.add_argument( 138 | "--pretrained", default="", type=str, metavar="PATH", help="use pre-trained model" 139 | ) 140 | parser.add_argument( 141 | "--bitwidth-data", default=8, type=int, help="the bitwidth of parameters/buffers/activations" 142 | ) 143 | parser.add_argument( 144 | "--bitwidth-grad", default=16, type=int, help="the bitwidth of gradients of parameters/activations" 145 | ) 146 | 147 | best_prec1 = 90 148 | start = time.time() 149 | 150 | def _set_fix_method_train_ori(model): 151 | model.set_fix_method(nfp.FIX_AUTO) 152 | 153 | def _set_fix_method_eval_ori(model): 154 | model.set_fix_method(nfp.FIX_FIXED) 155 | 156 | ## -------- 157 | ## When bitwidth is small, bn fix would prevent the model from learning. 158 | ## Could use this following config: 159 | ## Note that batchnorm2d_fix buffers (running_mean, running_var) are handled specially here. 160 | ## The running_mean and running_var are not quantized during training forward process, 161 | ## only quantized during test process. This could help avoid the buffer accumulation problem 162 | ## when the bitwidth is too small. 163 | def _set_fix_method_train(model): 164 | model.set_fix_method( 165 | nfp.FIX_AUTO, 166 | method_by_type={ 167 | "BatchNorm2d_fix": { 168 | "weight": nfp.FIX_AUTO, 169 | "bias": nfp.FIX_AUTO, 170 | "running_mean": nfp.FIX_NONE, 171 | "running_var": nfp.FIX_NONE} 172 | }) 173 | 174 | def _set_fix_method_eval(model): 175 | model.set_fix_method( 176 | nfp.FIX_FIXED, 177 | method_by_type={ 178 | "BatchNorm2d_fix": { 179 | "weight": nfp.FIX_FIXED, 180 | "bias": nfp.FIX_FIXED, 181 | "running_mean": nfp.FIX_AUTO, 182 | "running_var": nfp.FIX_AUTO} 183 | }) 184 | ## -------- 185 | 186 | 187 | def main(): 188 | global args, best_prec1 189 | args = parser.parse_args() 190 | print("cmd line arguments: ", args) 191 | 192 | gpus = [int(d) for d in args.gpu.split(",")] 193 | torch.cuda.set_device(gpus[0]) 194 | 195 | # Check the save_dir exists or not 196 | if not os.path.exists(args.save_dir): 197 | os.makedirs(args.save_dir) 198 | 199 | model = FixNet( 200 | fix_bn=not args.float_bn, 201 | fix_grad=args.fix_grad, 202 | range_method=args.range_method, 203 | grad_range_method=args.grad_range_method, 204 | bitwidth_data=args.bitwidth_data, 205 | bitwidth_grad=args.bitwidth_grad 206 | ) 207 | model.print_fix_configs() 208 | 209 | model.cuda() 210 | if len(gpus) > 1: 211 | parallel_model = torch.nn.DataParallel(model, gpus) 212 | else: 213 | parallel_model = model 214 | 215 | # optionally resume from a checkpoint 216 | if args.resume: 217 | if os.path.isfile(args.resume): 218 | print("=> loading checkpoint '{}'".format(args.resume)) 219 | checkpoint = torch.load(args.resume) 220 | args.start_epoch = checkpoint["epoch"] 221 | best_prec1 = checkpoint["best_prec1"] 222 | model.load_state_dict(checkpoint["state_dict"]) 223 | print( 224 | "=> loaded checkpoint '{}' (epoch {})".format( 225 | args.evaluate, checkpoint["epoch"] 226 | ) 227 | ) 228 | else: 229 | print("=> no checkpoint found at '{}'".format(args.resume)) 230 | assert os.path.isfile(args.resume) 231 | 232 | if args.pretrained: 233 | if os.path.isfile(args.pretrained): 234 | print("=> fintune from checkpoint '{}'".format(args.pretrained)) 235 | checkpoint = torch.load(args.pretrained) 236 | # args.start_epoch = checkpoint['epoch'] 237 | # best_prec1 = checkpoint['best_prec1'] 238 | model.load_state_dict(checkpoint["state_dict"]) 239 | else: 240 | print("=> no checkpoint found at '{}'".format(args.resume)) 241 | assert os.path.isfile(args.pretrained) 242 | 243 | # cudnn.benchmark = True 244 | 245 | normalize = transforms.Normalize( 246 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 247 | ) 248 | 249 | train_loader = torch.utils.data.DataLoader( 250 | datasets.CIFAR10( 251 | root="../data/cifar10", 252 | train=True, 253 | transform=transforms.Compose( 254 | [ 255 | transforms.RandomHorizontalFlip(), 256 | transforms.RandomCrop(32, 4), 257 | transforms.ToTensor(), 258 | normalize, 259 | ] 260 | ), 261 | download=True, 262 | ), 263 | batch_size=args.batch_size, 264 | shuffle=True, 265 | num_workers=4, 266 | pin_memory=True, 267 | ) 268 | 269 | val_loader = torch.utils.data.DataLoader( 270 | datasets.CIFAR10( 271 | root="../data/cifar10", 272 | train=False, 273 | transform=transforms.Compose([transforms.ToTensor(), normalize]), 274 | ), 275 | batch_size=args.test_batch_size, 276 | shuffle=False, 277 | num_workers=2, 278 | pin_memory=True, 279 | ) 280 | 281 | # define loss function (criterion) and pptimizer 282 | criterion = nn.CrossEntropyLoss().cuda() 283 | 284 | optimizer = torch.optim.SGD( 285 | model.parameters(), 286 | lr=args.lr, 287 | momentum=args.momentum, 288 | weight_decay=args.weight_decay, 289 | ) 290 | 291 | if args.evaluate: 292 | validate(val_loader, model, parallel_model, criterion) 293 | return 294 | 295 | for epoch in range(args.start_epoch, args.epoch): 296 | adjust_learning_rate(optimizer, epoch) 297 | 298 | # train for one epoch 299 | train(train_loader, model, parallel_model, criterion, optimizer, epoch) 300 | 301 | # evaluate on validation set 302 | prec1 = validate(val_loader, model, parallel_model, criterion) 303 | 304 | # remember best prec@1 and save checkpoint 305 | is_best = prec1 > best_prec1 306 | best_prec1 = max(prec1, best_prec1) 307 | if best_prec1 > 90 and is_best: 308 | save_checkpoint( 309 | { 310 | "epoch": epoch + 1, 311 | "state_dict": model.state_dict(), 312 | "best_prec1": best_prec1, 313 | }, 314 | is_best, 315 | filename=os.path.join( 316 | args.save_dir, 317 | "checkpoint_{}_{:.3f}.tar".format(args.prefix, best_prec1), 318 | ), 319 | ) 320 | model.print_fix_configs() 321 | 322 | print("Best acc: {}".format(best_prec1)) 323 | 324 | 325 | def train(train_loader, model, p_model, criterion, optimizer, epoch): 326 | """ 327 | Run one train epoch 328 | """ 329 | losses = AverageMeter() 330 | top1 = AverageMeter() 331 | 332 | # switch to train mode 333 | _set_fix_method_train_ori(model) 334 | model.train() 335 | 336 | for i, (input, target) in enumerate(train_loader): 337 | target = target.cuda(async=True) 338 | input_var = torch.autograd.Variable(input).cuda() 339 | target_var = torch.autograd.Variable(target) 340 | 341 | # compute output 342 | output = p_model(input_var) 343 | loss = criterion(output, target_var) 344 | 345 | # compute gradient and do SGD step 346 | optimizer.zero_grad() 347 | loss.backward() 348 | optimizer.step() 349 | 350 | output = output.float() 351 | loss = loss.float() 352 | # measure accuracy and record loss 353 | prec1 = accuracy(output.data, target)[0] 354 | losses.update(loss.item(), input.size(0)) 355 | top1.update(prec1.item(), input.size(0)) 356 | 357 | if i % args.print_freq == 0: 358 | print( 359 | "\rEpoch: [{0}][{1}/{2}]\t" 360 | "Time {t}\t" 361 | "Loss {loss.val:.4f} ({loss.avg:.4f})\t" 362 | "Prec@1 {top1.val:.3f}% ({top1.avg:.3f}%)".format( 363 | epoch, 364 | i, 365 | len(train_loader), 366 | t=time.time() - start, 367 | loss=losses, 368 | top1=top1, 369 | ), 370 | end="", 371 | ) 372 | 373 | 374 | def validate(val_loader, model, p_model, criterion): 375 | """ 376 | Run evaluation 377 | """ 378 | losses = AverageMeter() 379 | top1 = AverageMeter() 380 | 381 | # switch to evaluate mode 382 | _set_fix_method_eval_ori(model) 383 | model.eval() 384 | 385 | with torch.no_grad(): 386 | for i, (input, target) in enumerate(val_loader): 387 | target = target.cuda(async=True) 388 | input_var = torch.autograd.Variable(input).cuda() 389 | target_var = torch.autograd.Variable(target) 390 | 391 | # compute output 392 | output = p_model(input_var) 393 | loss = criterion(output, target_var) 394 | 395 | output = output.float() 396 | loss = loss.float() 397 | 398 | # measure accuracy and record loss 399 | prec1 = accuracy(output.data, target)[0] 400 | losses.update(loss.item(), input.size(0)) 401 | top1.update(prec1.item(), input.size(0)) 402 | 403 | if i % args.print_freq == 0: 404 | print( 405 | "Test: [{0}/{1}]\t" 406 | "Time {t}\t" 407 | "Loss {loss.val:.4f} ({loss.avg:.4f})\t" 408 | "Prec@1 {top1.val:.3f}% ({top1.avg:.3f}%)".format( 409 | i, len(val_loader), t=time.time() - start, loss=losses, top1=top1 410 | ) 411 | ) 412 | 413 | print( 414 | " * Prec@1 {top1.avg:.3f}%\tBest Prec@1 {best_prec1:.3f}%".format( 415 | top1=top1, best_prec1=best_prec1 416 | ) 417 | ) 418 | 419 | return top1.avg 420 | 421 | 422 | def save_checkpoint(state, is_best, filename="checkpoint.pth.tar"): 423 | """ 424 | Save the training model 425 | """ 426 | torch.save(state, filename) 427 | 428 | 429 | class AverageMeter(object): 430 | """Computes and stores the average and current value""" 431 | 432 | def __init__(self): 433 | self.reset() 434 | 435 | def reset(self): 436 | self.val = 0 437 | self.avg = 0 438 | self.sum = 0 439 | self.count = 0 440 | 441 | def update(self, val, n=1): 442 | self.val = val 443 | self.sum += val * n 444 | self.count += n 445 | self.avg = self.sum / self.count 446 | 447 | 448 | def adjust_learning_rate(optimizer, epoch): 449 | """Sets the learning rate to the initial LR decayed by 0.5 every 10 epochs""" 450 | lr = args.lr * (0.5 ** (epoch // 10)) 451 | for param_group in optimizer.param_groups: 452 | param_group["lr"] = lr 453 | print("Epoch {}: lr: {}".format(epoch, lr)) 454 | 455 | def accuracy(output, target, topk=(1,)): 456 | """Computes the precision@k for the specified values of k""" 457 | maxk = max(topk) 458 | batch_size = target.size(0) 459 | 460 | _, pred = output.topk(maxk, 1, True, True) 461 | pred = pred.t() 462 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 463 | 464 | res = [] 465 | for k in topk: 466 | correct_k = correct[:k].view(-1).float().sum(0) 467 | res.append(correct_k.mul_(100.0 / batch_size)) 468 | return res 469 | 470 | 471 | if __name__ == "__main__": 472 | main() 473 | -------------------------------------------------------------------------------- /nics_fix_pt/quant.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from __future__ import print_function 4 | 5 | import copy 6 | 7 | import numpy as np 8 | import torch 9 | 10 | from nics_fix_pt.utils import get_int 11 | from nics_fix_pt.consts import QuantizeMethod, RangeMethod 12 | 13 | __all__ = ["quantize"] 14 | 15 | 16 | def _do_quantize(data, scale, bit_width, symmetric=True, stochastic=False, group=False): 17 | """ 18 | when sym is not true, the input scale will be 2-value tensor [min,max] 19 | by defalutm the data_to_devise.device, bit_width.devicescale is a single fp-value, denoting range [-max, max] 20 | """ 21 | # The grouping are only applied for conv/fc weight/data, with the shape of [x,x,x,x] 22 | # bn params & bias with the shape of [x,x] 23 | if len(data.shape) < 4: 24 | group = False 25 | 26 | scale = scale.to(data.device) 27 | bit_width = bit_width.to(data.device) 28 | tensor_2 = torch.autograd.Variable( 29 | torch.FloatTensor([2.0]), requires_grad=False 30 | ).to(data.device) 31 | if symmetric: 32 | dynamic_range = 2 * scale 33 | maxs = scale 34 | mins = scale * (-1) 35 | else: 36 | """ 37 | actually in real hardware implmention, the asymmetric quantization 38 | will be implemented through 1-fp scale and 1-int zero-point to 39 | constraint the range, here for simplicity of software simulation, 40 | we simply define its range 41 | """ 42 | dynamic_range = scale[1] - scale[0] 43 | maxs = scale[1] 44 | mins = scale[0] 45 | 46 | if group == "batch": 47 | maxs = maxs.reshape(-1, 1, 1, 1).expand_as(data) 48 | mins = mins.reshape(-1, 1, 1, 1).expand_as(data) 49 | data_to_devise = dynamic_range.reshape(-1, 1, 1, 1) 50 | # MAYBE we should split activation/weight also the grad for this 51 | elif group == "channel": 52 | maxs = maxs.reshape(1, -1, 1, 1).expand_as(data) 53 | mins = mins.reshape(1, -1, 1, 1).expand_as(data) 54 | data_to_devise = dynamic_range.reshape(1, -1, 1, 1) 55 | else: 56 | data_to_devise = dynamic_range 57 | 58 | step = data_to_devise / torch.pow(2, bit_width).float() 59 | 60 | if stochastic: 61 | output = StraightThroughStochasticRound.apply(data / step) * step 62 | else: 63 | output = StraightThroughRound.apply(data / step) * step 64 | # torch.clamp dose not support clamp with multiple value, so using max(min) as alternative 65 | if group is not False: 66 | output = torch.min(torch.max(mins, output), maxs) 67 | else: 68 | output = torch.clamp(output, float(mins), float(maxs)) 69 | 70 | return ( 71 | output, 72 | step, 73 | ) 74 | 75 | 76 | def quantize_cfg( 77 | data, 78 | scale, 79 | bitwidth, 80 | method, 81 | range_method=RangeMethod.RANGE_MAX, 82 | stochastic=False, 83 | float_scale=False, 84 | zero_point=False, 85 | group=False, 86 | ): 87 | """ 88 | stochastic - stochastic rounding 89 | range_method - how to decide dynamic range 90 | float_scale - whether the scale is chosen to be 2^K 91 | zero_point - symm/asymm quantize 92 | """ 93 | if ( 94 | not isinstance(method, torch.autograd.Variable) 95 | and not torch.is_tensor(method) 96 | and method == QuantizeMethod.FIX_NONE 97 | ): 98 | return data, None 99 | 100 | if group == "batch" and len(data.shape) == 4: 101 | # only applied for conv units 102 | max_data = data.view(data.shape[0], -1).max(dim=1)[0] 103 | min_data = data.view(data.shape[0], -1).min(dim=1)[0] 104 | elif group == "channel" and len(data.shape) == 4: 105 | max_data = data.view(data.shape[1], -1).max(dim=1)[0] 106 | min_data = data.view(data.shape[1], -1).min(dim=1)[0] 107 | else: 108 | max_data = data.max() 109 | min_data = data.min() 110 | 111 | # Avoid extreme value 112 | EPS = torch.FloatTensor(max_data.shape).fill_(1e-5) 113 | EPS = EPS.to(max_data.device) 114 | if not zero_point: 115 | max_data = torch.max(torch.max(max_data.abs(), min_data.abs()), EPS) 116 | 117 | method_v = get_int(method) 118 | 119 | if method_v == QuantizeMethod.FIX_NONE: 120 | return data, None 121 | elif method_v == QuantizeMethod.FIX_AUTO: 122 | range_method_v = get_int(range_method) 123 | if float_scale and range_method_v != RangeMethod.RANGE_MAX: 124 | raise NotImplementedError("Now Only Support Float_Scale with Range-Max") 125 | if range_method_v == RangeMethod.RANGE_MAX: 126 | if float_scale: 127 | if zero_point: 128 | new_scale = torch.stack([min_data, max_data]) 129 | scale.data = new_scale 130 | return _do_quantize( 131 | data, 132 | scale, 133 | bitwidth, 134 | stochastic=stochastic, 135 | symmetric=False, 136 | group=group, 137 | ) 138 | else: 139 | scale.data = max_data 140 | return _do_quantize( 141 | data, 142 | scale, 143 | bitwidth, 144 | stochastic=stochastic, 145 | symmetric=True, 146 | group=group, 147 | ) 148 | else: 149 | new_scale = torch.pow( 150 | 2, 151 | torch.ceil( 152 | torch.log(max_data) 153 | / torch.FloatTensor([1]).fill_(np.log(2.0)).to(max_data.device) 154 | ), 155 | ) 156 | 157 | scale.data = new_scale 158 | return _do_quantize( 159 | data, scale, bitwidth, stochastic=stochastic, group=group 160 | ) 161 | 162 | elif range_method_v == RangeMethod.RANGE_MAX_TENPERCENT: 163 | # FIXME: Too slow 164 | scale = torch.pow( 165 | 2, 166 | torch.ceil( 167 | torch.log( 168 | torch.max( 169 | # torch.kthvalue(torch.abs(data.view(-1)), 9 * (data.nelement() // 10))[0], 170 | torch.topk(torch.abs(data.view(-1)), data.nelement() // 10)[ 171 | 0 172 | ][-1], 173 | # torch.tensor(EPS).float().to(data.device)) 174 | torch.FloatTensor(1).fill_(EPS).to(data.device), 175 | ) 176 | ) 177 | / torch.cuda.FloatTensor([1]).fill_(np.log(2.0)) 178 | ), 179 | ) 180 | return _do_quantize(data, scale, bitwidth, stochastic=stochastic) 181 | 182 | elif range_method_v == RangeMethod.RANGE_3SIGMA: 183 | new_boundary = torch.max( 184 | 3 * torch.std(data) + torch.abs(torch.mean(data)), 185 | torch.tensor(EPS).float().to(data.device), 186 | ) 187 | new_scale = torch.pow(2, torch.ceil(torch.log(new_boundary) / np.log(2.0))) 188 | scale.data = new_scale 189 | return _do_quantize( 190 | data, scale, bitwidth, stochastic=stochastic, symmetric=not zero_point 191 | ) 192 | 193 | elif range_method_v == RangeMethod.RANGE_SWEEP: 194 | # Iterat through other scale to find the proper scale to minimize error 195 | # Noted that the scale is [(MAX - SWEEP),MAX] 196 | SWEEP = 3 197 | temp_scale = torch.ceil( 198 | torch.log( 199 | torch.max( 200 | torch.max(abs(data)), torch.tensor(EPS).float().to(data.device) 201 | ) 202 | ) 203 | / np.log(2.0) 204 | ) 205 | for i in range(SWEEP): 206 | errors[i] = torch.abs( 207 | _do_quantize(data, temp_scale - i, bitwidth)[0] - data 208 | ).sum() 209 | new_scale = torch.pow(2, temp_scale - errors.argmin()) 210 | scale.data = new_scale 211 | return _do_quantize(data, scale, bitwidth, stochastic=stochastic) 212 | 213 | else: 214 | raise NotImplementedError() 215 | 216 | elif method_v == QuantizeMethod.FIX_FIXED: 217 | 218 | if group == "batch" and len(data.shape) == 4: 219 | max_data = data.view(data.shape[0], -1).max(dim=1)[0] 220 | min_data = data.view(data.shape[0], -1).min(dim=1)[0] 221 | if group == "channel" and len(data.shape) == 4: 222 | max_data = data.view(data.shape[1], -1).max(dim=1)[0] 223 | min_data = data.view(data.shape[1], -1).min(dim=1)[0] 224 | else: 225 | max_data = data.max() 226 | min_data = data.min() 227 | 228 | EPS = torch.FloatTensor(max_data.shape).fill_(1e-5) 229 | EPS = EPS.to(max_data.device) 230 | if not zero_point: 231 | max_data = torch.max(torch.max(max_data.abs(), min_data.abs()), EPS) 232 | 233 | # TODO: Check whether float_scale automatically adjust through inference 234 | # If float_scale, do as FIX_AUTO does 235 | if float_scale: 236 | if zero_point: 237 | new_scale = torch.stack([min_data, max_data]) 238 | scale.data = new_scale 239 | return _do_quantize( 240 | data, scale, bitwidth, stochastic=stochastic, symmetric=False 241 | ) 242 | else: 243 | scale.data = max_data 244 | return _do_quantize( 245 | data, scale, bitwidth, stochastic=stochastic, symmetric=True 246 | ) 247 | else: 248 | # donot use new_scale when using power-of-2 scale 249 | return _do_quantize( 250 | data, 251 | scale, 252 | bitwidth, 253 | stochastic=stochastic, 254 | symmetric=not zero_point, 255 | group=group, 256 | ) 257 | 258 | raise Exception("Quantitize method not legal: {}".format(method_v)) 259 | 260 | 261 | # https://discuss.pytorch.org/t/how-to-override-the-gradients-for-parameters/3417/6 262 | class StraightThroughRound(torch.autograd.Function): 263 | @staticmethod 264 | def forward(ctx, x): 265 | return x.round() 266 | 267 | @staticmethod 268 | def backward(ctx, g): 269 | return g 270 | 271 | 272 | class StraightThroughStochasticRound(torch.autograd.Function): 273 | @staticmethod 274 | def forward(ctx, x): 275 | # FIXME: Wrong Stochatsic Method, independent stochastic for each element, could lead to even worse perf. 276 | # The Binary tensor denoting whether ceil or not, closer to ceil means for probabily choose ceil 277 | # return x.floor() + (torch.rand(x.shape).to(x.device) > x.ceil() - x)*torch.ones(x.shape).to(x.device) 278 | # out = x.floor() + (torch.cuda.FloatTensor(x.shape).uniform_() > x.ceil() - x)*torch.cuda.FloatTensor(x.shape).fill_(1.) 279 | # out = x.floor() + ((x.ceil() - x) < torch.cuda.FloatTensor([1]).fill_(np.random.uniform()))*torch.cuda.FloatTensor(x.shape).fill_(1.) 280 | noise = torch.FloatTensor(x.shape).uniform_(-0.5, 0.5).to(x.device) 281 | x.add_(noise) 282 | return x 283 | 284 | @staticmethod 285 | def backward(ctx, g): 286 | return g 287 | 288 | 289 | class QuantitizeGradient(torch.autograd.Function): 290 | @staticmethod 291 | def forward( 292 | ctx, 293 | x, 294 | scale, 295 | bitwidth, 296 | method, 297 | range_method=RangeMethod.RANGE_MAX, 298 | stochastic=False, 299 | float_scale=False, 300 | zero_point=False, 301 | group=False, 302 | ): 303 | # FIXME: save the tensor/variables for backward, 304 | # maybe should use `ctx.save_for_backward` for standard practice 305 | # but `save_for_backward` requires scale/bitwidth/method all being of type `Variable`... 306 | ctx.saved = ( 307 | scale, 308 | bitwidth, 309 | method, 310 | range_method, 311 | stochastic, 312 | float_scale, 313 | zero_point, 314 | group, 315 | ) 316 | return x 317 | 318 | @staticmethod 319 | def backward(ctx, g): 320 | return ( 321 | quantize_cfg(g, *ctx.saved)[0], 322 | None, 323 | None, 324 | None, 325 | None, 326 | None, 327 | None, 328 | None, 329 | None, 330 | ) 331 | 332 | 333 | def quantize(param, fix_cfg={}, fix_grad_cfg={}, kwarg_cfg={}, name=""): 334 | # fix_cfg/fix_grad_cfg is the configuration saved; 335 | # kwarg_cfg is the overriding configuration supplied for each `forward` call 336 | data_cfg = copy.copy(fix_cfg) 337 | data_cfg.update(kwarg_cfg.get(name + "_fix", {})) 338 | grad_cfg = copy.copy(fix_grad_cfg) 339 | grad_cfg.update(kwarg_cfg.get(name + "_grad_fix", {})) 340 | method = data_cfg.get("method", QuantizeMethod.FIX_NONE) 341 | 342 | step = 0 343 | # quantize data 344 | out_param = param 345 | 346 | if ( 347 | isinstance(method, torch.autograd.Variable) 348 | or torch.is_tensor(method) 349 | or method != QuantizeMethod.FIX_NONE 350 | ): 351 | out_param, stepp = quantize_cfg( 352 | out_param, 353 | data_cfg["scale"], 354 | data_cfg["bitwidth"], 355 | data_cfg["method"], 356 | data_cfg.get("range_method", RangeMethod.RANGE_MAX), 357 | data_cfg.get("stochastic", False), 358 | data_cfg.get("float_scale", False), 359 | data_cfg.get("zero_point", False), 360 | data_cfg.get("group", False), 361 | ) 362 | 363 | # quantize gradient 364 | method = grad_cfg.get("method", QuantizeMethod.FIX_NONE) 365 | if ( 366 | isinstance(method, torch.autograd.Variable) 367 | or torch.is_tensor(method) 368 | or method != QuantizeMethod.FIX_NONE 369 | ): 370 | out_param = QuantitizeGradient().apply( 371 | out_param, 372 | grad_cfg["scale"], 373 | grad_cfg["bitwidth"], 374 | grad_cfg["method"], 375 | grad_cfg.get("range_method", RangeMethod.RANGE_MAX), 376 | grad_cfg.get("stochastic", False), 377 | grad_cfg.get("float_scale", False), 378 | grad_cfg.get("zero_point", False), 379 | data_cfg.get("group", False), 380 | ) 381 | 382 | out_param.data_cfg = data_cfg 383 | out_param.grad_cfg = grad_cfg 384 | if param is not out_param: 385 | # avoid memory leaking: old `buffer` tensors could remain referenced unexpectedly 386 | if hasattr(param, "nfp_actual_data"): 387 | del param.nfp_actual_data 388 | del param.data_cfg 389 | del param.grad_cfg 390 | out_param.nfp_actual_data = param # avoid loop ref 391 | # NOTE: the returned step is data fix stepsize, not gradient fix step size; 392 | return out_param, step 393 | -------------------------------------------------------------------------------- /nics_fix_pt/fix_modules.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from __future__ import print_function 4 | 5 | import warnings 6 | from collections import OrderedDict 7 | 8 | import six 9 | 10 | import torch 11 | from torch.nn import Module 12 | import torch.nn.functional as F 13 | from . import nn_fix, utils, quant 14 | 15 | 16 | # ---- helpers ---- 17 | def _get_kwargs(self, true_kwargs): 18 | default_kwargs = utils.get_kwargs(self.__class__) 19 | if not default_kwargs: 20 | return true_kwargs 21 | # NOTE: here we do not deep copy the default values, 22 | # so non-atom type default value such as dict/list/tensor will be shared 23 | kwargs = {k: v for k, v in six.iteritems(default_kwargs)} 24 | kwargs.update(true_kwargs) 25 | return kwargs 26 | 27 | 28 | def _get_fix_cfg(self, name, grad=False): 29 | if not grad: 30 | cfg = self.nf_fix_params.get(name, {}) 31 | if "scale" in cfg: 32 | cfg["scale"] = self._buffers["{}_fp_scale".format(name)] 33 | else: 34 | cfg = self.nf_fix_params_grad.get(name, {}) 35 | if "scale" in cfg: 36 | cfg["scale"] = self._buffers["{}_grad_fp_scale".format(name)] 37 | return cfg 38 | 39 | 40 | def _register_fix_buffers(self, patch_register=True): 41 | # register scale tensors as buffers, for correct use in multi-gpu data parallel model 42 | avail_keys = list(self._parameters.keys()) + list(self._buffers.keys()) 43 | for name, cfg in six.iteritems(self.nf_fix_params): 44 | if patch_register and name not in avail_keys: 45 | warnings.warn( 46 | ( 47 | "{} not available in {}, this specific fixed config " 48 | "will not have effects" 49 | ).format(name, self) 50 | ) 51 | if "scale" in cfg: 52 | self.register_buffer("{}_fp_scale".format(name), cfg["scale"]) 53 | 54 | avail_keys = list(self._parameters.keys()) 55 | for name, cfg in six.iteritems(self.nf_fix_params_grad): 56 | if patch_register and name not in avail_keys: 57 | warnings.warn( 58 | ( 59 | "{} not available in {}, this specific grads fixed config " 60 | "will not have effects" 61 | ).format(name, self) 62 | ) 63 | if "scale" in cfg: 64 | self.register_buffer("{}_grad_fp_scale".format(name), cfg["scale"]) 65 | 66 | 67 | # -------- 68 | 69 | 70 | def get_fix_forward(cur_cls): 71 | # pylint: disable=protected-access 72 | def fix_forward(self, inputs, **kwargs): 73 | if not isinstance(inputs, dict): 74 | inputs = {"inputs": inputs} 75 | for n, param in six.iteritems(self._parameters): 76 | # NOTE: Since Pytorch>=1.5.0, parameters in DataParallel replica are no longer 77 | # registered in the _parameters dict, so this mechanism will no longer work. 78 | # Thus for now, only Pytorch<1.5.0 versions are supported if DataParallel is used! 79 | if not isinstance(param, (torch.Tensor, torch.autograd.Variable)): 80 | continue 81 | fix_cfg = _get_fix_cfg(self, n) 82 | fix_grad_cfg = _get_fix_cfg(self, n, grad=True) 83 | set_n, _ = quant.quantize( 84 | param, fix_cfg, fix_grad_cfg, kwarg_cfg=inputs, name=n 85 | ) 86 | object.__setattr__(self, n, set_n) 87 | for n, param in six.iteritems(self._buffers): 88 | if not isinstance(param, (torch.Tensor, torch.autograd.Variable)): 89 | continue 90 | fix_cfg = _get_fix_cfg(self, n) 91 | fix_grad_cfg = _get_fix_cfg(self, n, grad=True) 92 | set_n, _ = quant.quantize( 93 | param, fix_cfg, fix_grad_cfg, kwarg_cfg=inputs, name=n 94 | ) 95 | object.__setattr__(self, n, set_n) 96 | res = super(cur_cls, self).forward(inputs["inputs"], **kwargs) 97 | for n, param in six.iteritems(self._buffers): 98 | # set buffer back, as there will be no gradient, just in-place modification 99 | # FIXME: For fixed-point batch norm, 100 | # the running mean/var accumulattion is on quantized mean/var, 101 | # which means it might fail to update the running mean/var 102 | # if the updating momentum is too small 103 | updated_buffer = getattr(self, n) 104 | if updated_buffer is not self._buffers[n]: 105 | self._buffers[n].copy_(updated_buffer) 106 | return res 107 | 108 | return fix_forward 109 | 110 | 111 | class FixMeta(type): 112 | def __new__(mcs, name, bases, attrs): 113 | # Construct class name 114 | if not attrs.get("__register_name__", None): 115 | attrs["__register_name__"] = bases[0].__name__ + "_fix" 116 | name = attrs["__register_name__"] 117 | cls = super(FixMeta, mcs).__new__(mcs, name, bases, attrs) 118 | # if already subclass 119 | if not isinstance(bases[0], FixMeta): 120 | cls.forward = get_fix_forward(cur_cls=cls) 121 | setattr(nn_fix, name, cls) 122 | return cls 123 | 124 | 125 | def register_fix_module(cls, register_name=None): 126 | @six.add_metaclass(FixMeta) 127 | class __a_not_use_name(cls): 128 | __register_name__ = register_name 129 | 130 | def __init__(self, *args, **kwargs): 131 | kwargs = _get_kwargs(self, kwargs) 132 | # Pop and parse fix configuration from kwargs 133 | assert "nf_fix_params" in kwargs and isinstance( 134 | kwargs["nf_fix_params"], dict 135 | ), ( 136 | "Must specifiy `nf_fix_params` keyword arguments, " 137 | "and `nf_fix_params_grad` is optional." 138 | ) 139 | self.nf_fix_params = kwargs.pop("nf_fix_params") 140 | self.nf_fix_params_grad = kwargs.pop("nf_fix_params_grad", {}) or {} 141 | cls.__init__(self, *args, **kwargs) 142 | _register_fix_buffers(self, patch_register=True) 143 | # avail_keys = list(self._parameters.keys()) + list(self._buffers.keys()) 144 | # self.nf_fix_params = {k: self.nf_fix_params[k] 145 | # for k in avail_keys if k in self.nf_fix_params} 146 | # self.nf_fix_params_grad = {k: self.nf_fix_params_grad[k] 147 | # for k in avail_keys if k in self.nf_fix_params_grad} 148 | 149 | 150 | class Activation_fix(Module): 151 | def __init__(self, **kwargs): 152 | super(Activation_fix, self).__init__() 153 | kwargs = _get_kwargs(self, kwargs) 154 | assert "nf_fix_params" in kwargs and isinstance( 155 | kwargs["nf_fix_params"], dict 156 | ), "Must specifiy `nf_fix_params` keyword arguments, and `nf_fix_params_grad` is optional." 157 | self.nf_fix_params = kwargs.pop("nf_fix_params") 158 | self.nf_fix_params_grad = kwargs.pop("nf_fix_params_grad", {}) or {} 159 | self.activation = None 160 | 161 | # register scale as buffers 162 | _register_fix_buffers(self, patch_register=False) 163 | 164 | def forward(self, inputs): 165 | if not isinstance(inputs, dict): 166 | inputs = {"inputs": inputs} 167 | name = "activation" 168 | fix_cfg = self.nf_fix_params.get(name, {}) 169 | fix_grad_cfg = self.nf_fix_params_grad.get(name, {}) 170 | self.activation, _ = quant.quantize( 171 | inputs["inputs"], fix_cfg, fix_grad_cfg, kwarg_cfg=inputs, name=name 172 | ) 173 | return self.activation 174 | 175 | 176 | class ConvBN_fix(Module): 177 | def __init__( 178 | self, 179 | in_channels, 180 | out_channels, 181 | kernel_size=3, 182 | stride=1, 183 | padding=0, 184 | dilation=1, 185 | groups=1, 186 | eps=1e-05, 187 | momentum=0.1, 188 | affine=True, 189 | track_running_stats=True, 190 | **kwargs 191 | ): 192 | super(ConvBN_fix, self).__init__() 193 | kwargs = _get_kwargs(self, kwargs) 194 | assert "nf_fix_params" in kwargs and isinstance( 195 | kwargs["nf_fix_params"], dict 196 | ), "Must specifiy `nf_fix_params` keyword arguments, and `nf_fix_params_grad` is optional." 197 | self.nf_fix_params = kwargs.pop("nf_fix_params") 198 | self.nf_fix_params_grad = kwargs.pop("nf_fix_params_grad", {}) or {} 199 | if self.nf_fix_params_grad: 200 | warnings.warn( 201 | "Gradient fixed-point cfgs will NOT take effect! Because, " 202 | "Gradient quantization is usually used to simulate training on hardware. " 203 | "However, merged ConvBN is designed for mitigating the discrepancy between " 204 | "training and behaviour on deploy-only hardware; " 205 | "and enable relatively more accurate running mean/var accumulation " 206 | "during software training. Use these two together might not make sense." 207 | ) 208 | 209 | # init the two floating-point sub-modules 210 | self.conv = torch.nn.Conv2d( 211 | in_channels, 212 | out_channels, 213 | kernel_size, 214 | stride, 215 | padding, 216 | dilation, 217 | groups, 218 | bias=False, 219 | ) 220 | self.bn = torch.nn.BatchNorm2d( 221 | out_channels, eps, momentum, affine, track_running_stats 222 | ) 223 | 224 | # conv and bn attributes 225 | self.stride = stride 226 | self.padding = padding 227 | self.dilation = dilation 228 | self.groups = groups 229 | self.kernel_size = self.conv.kernel_size 230 | self.in_channels = in_channels 231 | self.out_channels = out_channels 232 | self.eps = eps 233 | self.momentum = momentum 234 | self.affine = affine 235 | self.track_running_stats = track_running_stats 236 | 237 | # the quantized combined weights and bias 238 | self.weight = self.conv.weight 239 | self.bias = self.bn.bias 240 | 241 | # register scale as buffers 242 | _register_fix_buffers(self, patch_register=False) 243 | 244 | def forward(self, inputs): 245 | if self.training: 246 | out = self.conv(inputs) 247 | # dummy output, just to accumulate running mean and running var (floating-point) 248 | _ = self.bn(out) 249 | 250 | # calculate batch var/mean 251 | mean = torch.mean(out, dim=[0, 2, 3]) 252 | var = torch.var(out, dim=[0, 2, 3]) 253 | else: 254 | mean = self.bn.running_mean 255 | var = self.bn.running_var 256 | 257 | inputs = {"inputs": inputs} 258 | # parameters/buffers to be combined 259 | bn_scale = self.bn.weight 260 | bn_bias = self.bn.bias 261 | bn_eps = self.bn.eps 262 | conv_weight = self.conv.weight 263 | conv_bias = self.conv.bias or 0.0 # could be None 264 | 265 | # combine new weights/bias 266 | comb_weight = conv_weight * (bn_scale / torch.sqrt(var + bn_eps)).view( 267 | -1, 1, 1, 1 268 | ) 269 | comb_bias = bn_bias + (conv_bias - mean) * bn_scale / torch.sqrt(var + bn_eps) 270 | 271 | # quantize the combined weights/bias (as what would be done in hardware deploy scenario) 272 | comb_weight, _ = quant.quantize( 273 | comb_weight, 274 | self.nf_fix_params.get("weight", {}), 275 | {}, 276 | kwarg_cfg=inputs, 277 | name="weight", 278 | ) 279 | comb_bias, _ = quant.quantize( 280 | comb_bias, 281 | self.nf_fix_params.get("bias", {}), 282 | {}, 283 | kwarg_cfg=inputs, 284 | name="bias", 285 | ) 286 | 287 | # run the fixed-point combined convbn 288 | convbn_out = F.conv2d( 289 | inputs["inputs"], 290 | comb_weight, 291 | comb_bias, 292 | self.stride, 293 | self.padding, 294 | self.dilation, 295 | self.groups, 296 | ) 297 | object.__setattr__(self, "weight", comb_weight) 298 | object.__setattr__(self, "bias", comb_bias) 299 | return convbn_out 300 | 301 | 302 | class FixTopModule(Module): 303 | """ 304 | A module with some simple fix configuration manage utilities. 305 | """ 306 | 307 | def __init__(self, *args, **kwargs): 308 | super(FixTopModule, self).__init__(*args, **kwargs) 309 | 310 | # To be portable between python2/3, use staticmethod for these utility methods, 311 | # and patch instance method here. 312 | # As Python2 do not support binding instance method to a class that is not a FixTopModule 313 | self.fix_state_dict = FixTopModule.fix_state_dict.__get__(self) 314 | self.load_fix_configs = FixTopModule.load_fix_configs.__get__(self) 315 | self.get_fix_configs = FixTopModule.get_fix_configs.__get__(self) 316 | self.print_fix_configs = FixTopModule.print_fix_configs.__get__(self) 317 | self.set_fix_method = FixTopModule.set_fix_method.__get__(self) 318 | 319 | @staticmethod 320 | def fix_state_dict(self, destination=None, prefix="", keep_vars=False): 321 | r"""FIXME: maybe do another quantization to make sure all vars are quantized? 322 | 323 | Returns a dictionary containing a whole fixed-point state of the module. 324 | 325 | Both parameters and persistent buffers (e.g. running averages) are 326 | included. Keys are corresponding parameter and buffer names. 327 | 328 | Returns: 329 | dict: 330 | a dictionary containing a whole state of the module 331 | 332 | Example:: 333 | 334 | >>> module.state_dict().keys() 335 | ['bias', 'weight'] 336 | 337 | """ 338 | if destination is None: 339 | destination = OrderedDict() 340 | destination._metadata = OrderedDict() 341 | destination._metadata[prefix[:-1]] = local_metadata = dict( 342 | version=self._version 343 | ) 344 | for name, param in self._parameters.items(): 345 | if param is not None: 346 | if isinstance(self.__class__, FixMeta): # A fixed-point module 347 | # Get the last used version of the parameters 348 | thevar = getattr(self, name) 349 | else: 350 | thevar = param 351 | destination[prefix + name] = thevar if keep_vars else thevar.data 352 | for name, buf in self._buffers.items(): 353 | if buf is not None: 354 | if isinstance(self.__class__, FixMeta): # A fixed-point module 355 | # Get the last saved version of the buffers, 356 | # which can be of float precision 357 | # (as buffers will be turned into fixed-point precision on the next forward) 358 | thevar = getattr(self, name) 359 | else: 360 | thevar = buf 361 | destination[prefix + name] = thevar if keep_vars else thevar.data 362 | for name, module in self._modules.items(): 363 | if module is not None: 364 | FixTopModule.fix_state_dict( 365 | module, destination, prefix + name + ".", keep_vars=keep_vars 366 | ) 367 | for hook in self._state_dict_hooks.values(): 368 | hook_result = hook(self, destination, prefix, local_metadata) 369 | if hook_result is not None: 370 | destination = hook_result 371 | return destination 372 | 373 | @staticmethod 374 | def load_fix_configs(self, cfgs, grad=False): 375 | assert isinstance(cfgs, (OrderedDict, dict)) 376 | for name, module in six.iteritems(self._modules): 377 | if isinstance(module.__class__, FixMeta) or isinstance( 378 | module, Activation_fix 379 | ): 380 | if name not in cfgs: 381 | print( 382 | ( 383 | "WARNING: Fix configuration for {} not found in the configuration! " 384 | "Make sure you know why this happened or " 385 | "there might be some subtle error!" 386 | ).format(name) 387 | ) 388 | else: 389 | setattr( 390 | module, 391 | "nf_fix_params" if not grad else "nf_fix_params_grad", 392 | utils.try_parse_variable(cfgs[name]), 393 | ) 394 | elif isinstance(module, FixTopModule): 395 | module.load_fix_configs(cfgs[name], grad=grad) 396 | else: 397 | FixTopModule.load_fix_configs(module, cfgs[name], grad=grad) 398 | 399 | @staticmethod 400 | def get_fix_configs(self, grad=False, data_only=False): 401 | """ 402 | get_fix_configs: 403 | 404 | Parameters: 405 | grad: BOOLEAN(default False), whether or not to get the gradient configs 406 | instead of data configs. 407 | data_only: BOOLEAN(default False), whether or not to get the numbers instead 408 | of `torch.Tensor` (which can be modified in place). 409 | """ 410 | cfg_dct = OrderedDict() 411 | for name, module in six.iteritems(self._modules): 412 | if isinstance(module.__class__, FixMeta) or isinstance( 413 | module, Activation_fix 414 | ): 415 | cfg_dct[name] = getattr( 416 | module, "nf_fix_params" if not grad else "nf_fix_params_grad" 417 | ) 418 | if data_only: 419 | cfg_dct[name] = utils.try_parse_int(cfg_dct[name]) 420 | elif isinstance(module, FixTopModule): 421 | cfg_dct[name] = module.get_fix_configs(grad=grad, data_only=data_only) 422 | else: 423 | cfg_dct[name] = FixTopModule.get_fix_configs( 424 | module, grad=grad, data_only=data_only 425 | ) 426 | return cfg_dct 427 | 428 | @staticmethod 429 | def print_fix_configs(self, data_fix_cfg=None, grad_fix_cfg=None, prefix_spaces=0): 430 | if data_fix_cfg is None: 431 | data_fix_cfg = self.get_fix_configs(grad=False) 432 | if grad_fix_cfg is None: 433 | grad_fix_cfg = self.get_fix_configs(grad=True) 434 | 435 | def _print(string, **kwargs): 436 | print( 437 | "\n".join([" " * prefix_spaces + line for line in string.split("\n")]) 438 | + "\n", 439 | **kwargs 440 | ) 441 | 442 | for key in data_fix_cfg: 443 | _print(key) 444 | d_cfg = data_fix_cfg[key] 445 | g_cfg = grad_fix_cfg[key] 446 | if isinstance(d_cfg, OrderedDict): 447 | self.print_fix_configs(d_cfg, g_cfg, prefix_spaces=2) 448 | else: 449 | # a dict of configs 450 | keys = set(d_cfg.keys()).union(g_cfg.keys()) 451 | for param_name in keys: 452 | d_bw = utils.try_parse_int( 453 | d_cfg.get(param_name, {}).get("bitwidth", "f") 454 | ) 455 | g_bw = utils.try_parse_int( 456 | g_cfg.get(param_name, {}).get("bitwidth", "f") 457 | ) 458 | d_sc = utils.try_parse_int( 459 | d_cfg.get(param_name, {}).get("scale", "f") 460 | ) 461 | g_sc = utils.try_parse_int( 462 | g_cfg.get(param_name, {}).get("scale", "f") 463 | ) 464 | d_mt = utils.try_parse_int( 465 | d_cfg.get(param_name, {}).get("method", 0) 466 | ) 467 | g_mt = utils.try_parse_int( 468 | g_cfg.get(param_name, {}).get("method", 0) 469 | ) 470 | _print( 471 | ( 472 | " {param_name:10}: d: bitwidth: {d_bw:3}; " 473 | "scale: {d_sc:3}; method: {d_mt:3}\n" 474 | + " " * 14 475 | + "g: bitwidth: {g_bw:3}; scale: {g_sc:3}; method: {g_mt:3}" 476 | ).format( 477 | param_name=param_name, 478 | d_bw=d_bw, 479 | g_bw=g_bw, 480 | d_sc=d_sc, 481 | g_sc=g_sc, 482 | d_mt=d_mt, 483 | g_mt=g_mt, 484 | ) 485 | ) 486 | 487 | @staticmethod 488 | def set_fix_method( 489 | self, method=None, method_by_type=None, method_by_name=None, grad=False 490 | ): 491 | for module_name, module in six.iteritems(self._modules): 492 | if isinstance(module.__class__, FixMeta) or isinstance( 493 | module, Activation_fix 494 | ): 495 | fix_params = getattr( 496 | module, "nf_fix_params" if not grad else "nf_fix_params_grad" 497 | ) 498 | if method_by_name is not None and module_name in method_by_name: 499 | for param_n, param_method in six.iteritems( 500 | method_by_name[module_name] 501 | ): 502 | assert ( 503 | param_n in fix_params 504 | ), "{} is not a quantized parameter of module {}".format( 505 | param_n, module_name 506 | ) 507 | _set_method(fix_params[param_n], param_method) 508 | else: 509 | if method_by_type is not None: 510 | param_method_cfg = method_by_type.get( 511 | type(module).__name__, 512 | method_by_type.get(type(module), None), 513 | ) 514 | else: 515 | param_method_cfg = None 516 | if param_method_cfg is not None: 517 | # specifiedd by method_by_type 518 | for param_n, param_method in six.iteritems(param_method_cfg): 519 | assert ( 520 | param_n in fix_params 521 | ), "{} is not a quantized parameter of module {}".format( 522 | param_n, module_name 523 | ) 524 | _set_method(fix_params[param_n], param_method) 525 | elif method is not None: 526 | for param_n in fix_params: 527 | _set_method(fix_params[param_n], method) 528 | elif isinstance(module, FixTopModule): 529 | module.set_fix_method(method, grad=grad) 530 | else: 531 | FixTopModule.set_fix_method(module, method, grad=grad) 532 | 533 | 534 | # helpers 535 | def _set_method(param_cfg, new_method): 536 | if new_method is None: 537 | return 538 | if "method" in param_cfg: 539 | ori_method = param_cfg["method"] 540 | if isinstance(ori_method, torch.autograd.Variable): 541 | ori_method.data.numpy()[0] = new_method 542 | elif torch.is_tensor(ori_method): 543 | ori_method.numpy()[0] = new_method 544 | else: 545 | param_cfg["method"] = new_method 546 | 547 | 548 | nn_fix.Activation_fix = Activation_fix 549 | nn_fix.FixTopModule = FixTopModule 550 | nn_fix.ConvBN_fix = ConvBN_fix 551 | --------------------------------------------------------------------------------