├── nics_fix_pt
├── VERSION
├── nn_fix.py
├── nn_fix_inner.py
├── consts.py
├── __init__.py
├── utils.py
├── quant.py
└── fix_modules.py
├── MANIFEST.in
├── examples
├── cifar10
│ ├── train.sh
│ ├── finetune.sh
│ ├── net.py
│ └── main.py
└── mnist
│ └── train_mnist.py
├── coverage.svg
├── LICENSE
├── .gitignore
├── setup.py
├── tests
├── test_fix_bn.py
├── conftest.py
├── test_quant.py
├── test_utils.py
└── test_fix_module.py
└── README.md
/nics_fix_pt/VERSION:
--------------------------------------------------------------------------------
1 | 0.4.0
2 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include tests/*.py
2 | recursive-include examples *.py *.sh
3 |
--------------------------------------------------------------------------------
/nics_fix_pt/nn_fix.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # A wrapper module for all registered fix modules
3 |
--------------------------------------------------------------------------------
/nics_fix_pt/nn_fix_inner.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | from .fix_modules import register_fix_module
3 |
4 | register_fix_module(nn.Conv2d)
5 | register_fix_module(nn.Linear)
6 | register_fix_module(nn.BatchNorm1d)
7 | register_fix_module(nn.BatchNorm2d)
8 |
--------------------------------------------------------------------------------
/examples/cifar10/train.sh:
--------------------------------------------------------------------------------
1 | #!bin/sh
2 |
3 | gpu=${GPU:-0}
4 | result_dir=$(date +%Y%m%d-%H%M%S)
5 | mkdir -p ${result_dir}
6 |
7 | CUDA_VISIBLE_DEVICES=$gpu python main.py \
8 | --save-dir ${result_dir} \
9 | $@ 2>&1 | tee ${result_dir}/train.log
10 |
11 |
--------------------------------------------------------------------------------
/nics_fix_pt/consts.py:
--------------------------------------------------------------------------------
1 | class RangeMethod:
2 | RANGE_MAX = 0
3 | RANGE_3SIGMA = 1
4 | RANGE_MAX_TENPERCENT = 2
5 | RANGE_SWEEP = 3
6 |
7 |
8 | class QuantizeMethod:
9 | # quantize methods
10 | FIX_NONE = 0
11 | FIX_AUTO = 1
12 | FIX_FIXED = 2
13 |
--------------------------------------------------------------------------------
/examples/cifar10/finetune.sh:
--------------------------------------------------------------------------------
1 | #!bin/sh
2 | #arch="vgg11_web"
3 | #arch="vgg11_elegant"
4 | arch="vgg11_ugly"
5 |
6 | gpu=1
7 | data=$(date +%m%d)
8 | method="8_conv1233'44'55'_6"
9 | lr=0.0001
10 | # model="save_fix/checkpoint_auto8_90.730.tar"
11 | model="save_fix/checkpoint_8_conv1233'44'5_6_90.820.tar"
12 | epoches=30
13 |
14 | CUDA_VISIBLE_DEVICES=$gpu python main.py \
15 | --pretrained $model \
16 | --arch $arch \
17 | --lr $lr \
18 | --epoches $epoches \
19 | --prefix $method \
20 | --batch-size 128 \
21 | --test-batch-size 1000 \
22 | --save-dir save_fix \
23 | 2>&1 | tee logs/log-"$method"_1222.log
24 |
--------------------------------------------------------------------------------
/coverage.svg:
--------------------------------------------------------------------------------
1 |
2 |
22 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 Xuefei Ning
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/nics_fix_pt/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | with open(os.path.join(os.path.dirname(__file__), "VERSION")) as f:
4 | __version__ = f.read().strip()
5 |
6 | from nics_fix_pt.consts import QuantizeMethod, RangeMethod
7 | from nics_fix_pt.quant import *
8 | import nics_fix_pt.nn_fix_inner
9 | from nics_fix_pt import nn_fix
10 | from nics_fix_pt.fix_modules import register_fix_module
11 |
12 | FIX_NONE = QuantizeMethod.FIX_NONE
13 | FIX_AUTO = QuantizeMethod.FIX_AUTO
14 | FIX_FIXED = QuantizeMethod.FIX_FIXED
15 |
16 | RANGE_MAX = RangeMethod.RANGE_MAX
17 | RANGE_3SIGMA = RangeMethod.RANGE_3SIGMA
18 |
19 |
20 | class nn_auto_register(object):
21 | """
22 | An auto register helper that automatically register all not-registered modules
23 | by proxing to modules in torch.nn.
24 |
25 | NOTE: We do not guarantee all auto-registered fixed nn modules will well behave,
26 | as they are not tested. Although, I thought it will work in normal cases.
27 | Use with care!
28 |
29 | Usage: from nics_fix_pt import NAR as nnf
30 | then e.g. `nnf.Bilinear_fix` and `nnf.Bilinear` can all be used as a fixed-point module.
31 | """
32 |
33 | def __getattr__(self, name):
34 | import torch
35 |
36 | attr = getattr(nn_fix, name, None)
37 | if attr is None:
38 | if name.endswith("_fix"):
39 | ori_name = name[:-4]
40 | else:
41 | ori_name = name
42 | ori_cls = getattr(torch.nn, ori_name)
43 | register_fix_module(ori_cls, register_name=ori_name + "_fix")
44 | return getattr(nn_fix, ori_name + "_fix", None)
45 | return attr
46 |
47 |
48 | NAR = nn_auto_register()
49 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 |
27 | # PyInstaller
28 | # Usually these files are written by a python script from a template
29 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
30 | *.manifest
31 | *.spec
32 |
33 | # Installer logs
34 | pip-log.txt
35 | pip-delete-this-directory.txt
36 |
37 | # Unit test / coverage reports
38 | htmlcov/
39 | .tox/
40 | .coverage
41 | .coverage.*
42 | .cache
43 | nosetests.xml
44 | coverage.xml
45 | *.cover
46 | .hypothesis/
47 |
48 | # Translations
49 | *.mo
50 | *.pot
51 |
52 | # Django stuff:
53 | *.log
54 | local_settings.py
55 |
56 | # Flask stuff:
57 | instance/
58 | .webassets-cache
59 |
60 | # Scrapy stuff:
61 | .scrapy
62 |
63 | # Sphinx documentation
64 | docs/_build/
65 |
66 | # PyBuilder
67 | target/
68 |
69 | # Jupyter Notebook
70 | .ipynb_checkpoints
71 |
72 | # pyenv
73 | .python-version
74 |
75 | # celery beat schedule file
76 | celerybeat-schedule
77 |
78 | # SageMath parsed files
79 | *.sage.py
80 |
81 | # Environments
82 | .env
83 | .venv
84 | env/
85 | venv/
86 | ENV/
87 |
88 | # Spyder project settings
89 | .spyderproject
90 | .spyproject
91 |
92 | # Rope project settings
93 | .ropeproject
94 |
95 | # mkdocs documentation
96 | /site
97 |
98 | # mypy
99 | .mypy_cache/
100 |
101 | examples/data
102 |
103 | tmpdir/
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import os
3 | import sys
4 | from setuptools import setup, find_packages
5 | from setuptools.command.test import test as TestCommand
6 |
7 | here = os.path.dirname(os.path.abspath((__file__)))
8 |
9 | # meta infos
10 | NAME = "nics_fix_pytorch"
11 | DESCRIPTION = "Fixed point trainig framework on PyTorch"
12 | with open(os.path.join(os.path.dirname(__file__), "nics_fix_pt/VERSION")) as f:
13 | VERSION = f.read().strip()
14 |
15 |
16 | AUTHOR = "foxfi"
17 | EMAIL = "foxdoraame@gmail.com"
18 |
19 | # package contents
20 | MODULES = []
21 | PACKAGES = find_packages(exclude=["tests.*", "tests"])
22 |
23 | # dependencies
24 | INSTALL_REQUIRES = ["six", "numpy"]
25 | TESTS_REQUIRE = ["pytest", "pytest-cov", "PyYAML"] # for mnist example
26 |
27 | # entry points
28 | ENTRY_POINTS = """"""
29 |
30 |
31 | def read_long_description(filename):
32 | path = os.path.join(here, filename)
33 | if os.path.exists(path):
34 | return open(path).read()
35 | return ""
36 |
37 |
38 | class PyTest(TestCommand):
39 | def finalize_options(self):
40 | TestCommand.finalize_options(self)
41 | self.test_args = ["tests/", "--junitxml", "unittest.xml", "--cov"]
42 | self.test_suite = True
43 |
44 | def run_tests(self):
45 | # import here, cause outside the eggs aren"t loaded
46 | import pytest
47 |
48 | errno = pytest.main(self.test_args)
49 | sys.exit(errno)
50 |
51 |
52 | setup(
53 | name=NAME,
54 | version=VERSION,
55 | description=DESCRIPTION,
56 | long_description=read_long_description("README.md"),
57 | author=AUTHOR,
58 | author_email=EMAIL,
59 | py_modules=MODULES,
60 | packages=PACKAGES,
61 | entry_points=ENTRY_POINTS,
62 | zip_safe=False,
63 | install_requires=INSTALL_REQUIRES,
64 | tests_require=TESTS_REQUIRE,
65 | cmdclass={"test": PyTest},
66 | package_data={
67 | "nics_fix_pt": ["VERSION"]
68 | }
69 | )
70 |
--------------------------------------------------------------------------------
/tests/test_fix_bn.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | import torch.optim as optim
7 |
8 | import nics_fix_pt as nfp
9 | from nics_fix_pt import nn_fix as nnf
10 |
11 |
12 | @pytest.mark.parametrize(
13 | "case", [{"input_num": 3, "momentum": 0.5, "inputs": [[1, 1, 0], [2, 1, 2]]}]
14 | )
15 | def test_fix_bn_test_auto(case):
16 | # TEST: the first update is the same (not quantized)
17 | bn_fix = nnf.BatchNorm1d_fix(
18 | case["input_num"],
19 | nf_fix_params={
20 | "running_mean": {
21 | "method": nfp.FIX_AUTO,
22 | "bitwidth": torch.tensor([2]),
23 | "scale": torch.tensor([0]),
24 | },
25 | "running_var": {
26 | "method": nfp.FIX_AUTO,
27 | "bitwidth": torch.tensor([2]),
28 | "scale": torch.tensor([0]),
29 | },
30 | },
31 | affine=False,
32 | momentum=case["momentum"],
33 | )
34 | bn = nn.BatchNorm1d(case["input_num"], affine=False, momentum=case["momentum"])
35 | bn_fix.train()
36 | bn.train()
37 | inputs = torch.autograd.Variable(
38 | torch.tensor(case["inputs"]).float(), requires_grad=True
39 | )
40 | out_fix = bn_fix(inputs)
41 | out = bn(inputs)
42 | assert (bn.running_mean == bn_fix.running_mean).all() # not quantized here
43 | assert (bn.running_var == bn_fix.running_var).all()
44 | assert (out == out_fix).all()
45 |
46 | # TEST: Quantitized on the next forward
47 | bn_fix.train(False)
48 | bn.train(False)
49 | out_fix = bn_fix(inputs)
50 | # Let's explicit quantize the mean/var of the normal BN model for comparison
51 | object.__setattr__(
52 | bn,
53 | "running_mean",
54 | nfp.quant.quantize_cfg(
55 | bn.running_mean,
56 | **{
57 | "method": nfp.FIX_AUTO,
58 | "bitwidth": torch.tensor([2]),
59 | "scale": torch.tensor([0]),
60 | }
61 | )[0],
62 | )
63 | object.__setattr__(
64 | bn,
65 | "running_var",
66 | nfp.quant.quantize_cfg(
67 | bn.running_var,
68 | **{
69 | "method": nfp.FIX_AUTO,
70 | "bitwidth": torch.tensor([2]),
71 | "scale": torch.tensor([0]),
72 | }
73 | )[0],
74 | )
75 | assert (bn.running_mean == bn_fix.running_mean).all()
76 | assert (bn.running_var == bn_fix.running_var).all()
77 |
78 | out = bn(inputs)
79 | assert (out == out_fix).all()
80 |
81 | # TEST: the running mean/var update is on the quantized running mean
82 | bn_fix.train()
83 | bn.train()
84 | out_fix = bn_fix(inputs)
85 | out = bn(inputs)
86 | assert (
87 | bn.running_mean == bn_fix.running_mean
88 | ).all() # quantized on the next forward
89 | assert (bn.running_var == bn_fix.running_var).all()
90 |
91 | # runnig_mean_should = np.mean(inputs.detach().numpy(), axis=0) * case["momentum"]
92 | # runnig_var_should = np.var(inputs.detach().numpy(), axis=0) * case["momentum"] + np.ones(case["input_num"]) * (1 - case["momentum"])
93 |
--------------------------------------------------------------------------------
/nics_fix_pt/utils.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | from collections import OrderedDict
4 | import copy
5 | import inspect
6 | from contextlib import contextmanager
7 | from functools import wraps
8 |
9 | from six import iteritems
10 | import numpy as np
11 | import torch
12 |
13 |
14 | def try_parse_variable(something):
15 | if isinstance(something, (dict, OrderedDict)): # recursive parse into dict values
16 | return {k: try_parse_variable(v) for k, v in iteritems(something)}
17 | try:
18 | return torch.autograd.Variable(
19 | torch.IntTensor(np.array([something])), requires_grad=False
20 | )
21 | except ValueError:
22 | return something
23 |
24 |
25 | def get_int(something):
26 | if torch.is_tensor(something):
27 | v = int(something.numpy()[0])
28 | elif isinstance(something, torch.autograd.Variable):
29 | v = int(something.data.numpy()[0])
30 | else:
31 | assert isinstance(something, (int, np.int))
32 | v = int(something)
33 | return v
34 |
35 |
36 | def try_parse_int(something):
37 | if isinstance(something, (dict, OrderedDict)): # recursive parse into dict values
38 | return {k: try_parse_int(v) for k, v in iteritems(something)}
39 | try:
40 | return int(something)
41 | except ValueError:
42 | return something
43 |
44 |
45 | def cache(format_str):
46 | def _cache(func):
47 | _cache_dct = {}
48 | sig = inspect.signature(func)
49 | default_kwargs = {
50 | n: v.default
51 | for n, v in iteritems(sig.parameters)
52 | if v.default != inspect._empty
53 | }
54 |
55 | @wraps(func)
56 | def _func(*args, **kwargs):
57 | args_dct = copy.copy(default_kwargs)
58 | args_dct.update(dict(zip(sig.parameters.keys(), args)))
59 | args_dct.update(kwargs)
60 | cache_str = format_str.format(**args_dct)
61 | if cache_str not in _cache_dct:
62 | _cache_dct[cache_str] = func(*args, **kwargs)
63 | return _cache_dct[cache_str]
64 |
65 | return _func
66 |
67 | return _cache
68 |
69 |
70 | def _generate_default_fix_cfg(names, scale=0, bitwidth=8, method=0):
71 | return {
72 | n: {
73 | "method": torch.autograd.Variable(
74 | torch.IntTensor(np.array([method])), requires_grad=False
75 | ),
76 | "scale": torch.autograd.Variable(
77 | torch.FloatTensor(np.array([scale])), requires_grad=False
78 | ),
79 | "bitwidth": torch.autograd.Variable(
80 | torch.IntTensor(np.array([bitwidth])), requires_grad=False
81 | ),
82 | }
83 | for n in names
84 | }
85 |
86 |
87 | _context = {}
88 | DEFAULT_KWARGS_KEY = "__fix_module_default_kwargs__"
89 |
90 |
91 | def get_kwargs(cls):
92 | return _context.get(DEFAULT_KWARGS_KEY, {}).get(cls.__name__, {})
93 |
94 |
95 | @contextmanager
96 | def fix_kwargs_scope(_override=False, **kwargs):
97 | old_kwargs = copy.copy(_context.get(DEFAULT_KWARGS_KEY, None))
98 | if _override or DEFAULT_KWARGS_KEY not in _context:
99 | _context[DEFAULT_KWARGS_KEY] = kwargs
100 | else:
101 | _context[DEFAULT_KWARGS_KEY].update(kwargs)
102 | yield
103 | if old_kwargs is None:
104 | _context.pop(DEFAULT_KWARGS_KEY, None)
105 | else:
106 | _context[DEFAULT_KWARGS_KEY] = old_kwargs
107 |
--------------------------------------------------------------------------------
/tests/conftest.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import numpy as np
3 | import torch
4 | from torch.nn import Module
5 | from torch.nn import functional as F
6 | from torch.nn.parameter import Parameter
7 | import nics_fix_pt.nn_fix as nnf
8 |
9 | def _generate_default_fix_cfg(names, scale=0, bitwidth=8, method=0):
10 | return {
11 | n: {
12 | "method": torch.autograd.Variable(
13 | torch.IntTensor(np.array([method])), requires_grad=False
14 | ),
15 | "scale": torch.autograd.Variable(
16 | torch.IntTensor(np.array([scale])), requires_grad=False
17 | ),
18 | "bitwidth": torch.autograd.Variable(
19 | torch.IntTensor(np.array([bitwidth])), requires_grad=False
20 | ),
21 | }
22 | for n in names
23 | }
24 |
25 | class TestNetwork(nnf.FixTopModule):
26 | def __init__(self):
27 | super(TestNetwork, self).__init__()
28 | self.fix_params = {}
29 | for conv_name in ["conv1", "conv2"]:
30 | self.fix_params[conv_name] = _generate_default_fix_cfg(
31 | ["weight", "bias"], method=1, bitwidth=8)
32 | for bn_name in ["bn1", "bn2"]:
33 | self.fix_params[bn_name] = _generate_default_fix_cfg(
34 | ["weight", "bias", "running_mean", "running_var"], method=1, bitwidth=8)
35 | self.conv1 = nnf.Conv2d_fix(3, 64, (3, 3), padding=1,
36 | nf_fix_params=self.fix_params["conv1"])
37 | self.bn1 = nnf.BatchNorm2d_fix(64, nf_fix_params=self.fix_params["bn1"])
38 | self.conv2 = nnf.Conv2d_fix(64, 128, (3, 3), padding=1,
39 | nf_fix_params=self.fix_params["conv2"])
40 | self.bn2 = nnf.BatchNorm2d_fix(128, nf_fix_params=self.fix_params["bn2"])
41 |
42 | @pytest.fixture
43 | def test_network():
44 | return TestNetwork()
45 |
46 | # ----
47 | class TestModule(Module):
48 | def __init__(self, input_num):
49 | super(TestModule, self).__init__()
50 | self.param = Parameter(torch.Tensor(1, input_num))
51 | self.reset_parameters()
52 |
53 | def reset_parameters(self):
54 | # fake data
55 | with torch.no_grad():
56 | self.param.fill_(0)
57 | self.param[0, 0] = 0.25111
58 | self.param[0, 1] = 0.5
59 |
60 | def forward(self, input):
61 | # print("input: ", input, "param: ", self.param)
62 | return F.linear(input, self.param, None)
63 |
64 |
65 | @pytest.fixture
66 | def module_cfg(request):
67 | import nics_fix_pt as nfp
68 | import nics_fix_pt.nn_fix as nnf
69 |
70 | nfp.fix_modules.register_fix_module(TestModule)
71 | # default data/grad fix cfg for the parameter `param` of TestModule
72 | data_cfg = nfp.utils._generate_default_fix_cfg(
73 | ["param"], scale=-1, bitwidth=2, method=nfp.FIX_AUTO
74 | )
75 | grad_cfg = nfp.utils._generate_default_fix_cfg(
76 | ["param"], scale=-1, bitwidth=2, method=nfp.FIX_NONE
77 | )
78 | # the specified overriding cfgs: input_num, data fix cfg, grad fix cfg
79 | update_cfg = getattr(request, "param", {})
80 | input_num = update_cfg.pop("input_num", 3)
81 | data_update_cfg = update_cfg.get("data_cfg", {})
82 | grad_update_cfg = update_cfg.get("grad_cfg", {})
83 | data_cfg["param"].update(data_update_cfg)
84 | grad_cfg["param"].update(grad_update_cfg)
85 | module = nnf.TestModule_fix(
86 | input_num=input_num, nf_fix_params=data_cfg, nf_fix_params_grad=grad_cfg
87 | )
88 | return module, data_cfg, grad_cfg
89 |
--------------------------------------------------------------------------------
/tests/test_quant.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import numpy as np
3 | import torch
4 |
5 | import nics_fix_pt as nfp
6 | import nics_fix_pt.quant as nfpq
7 |
8 |
9 | @pytest.mark.parametrize(
10 | "case",
11 | [
12 | {
13 | "data": [0.2513, -0.5, 0],
14 | "scale": 0.5,
15 | "bitwidth": 2,
16 | "method": nfp.FIX_FIXED,
17 | "output": ([0.25, -0.5, 0], 0.25),
18 | },
19 | {
20 | "data": [0.2513, -0.5, 0],
21 | "scale": 0.5,
22 | "bitwidth": 16,
23 | "method": nfp.FIX_FIXED,
24 | "output": (
25 | [np.round(0.2513 / (2 ** (-16))) * 2 ** (-16), -0.5, 0],
26 | 2 ** (-16),
27 | ),
28 | },
29 | {
30 | "data": [0.2513, -0.5, 0],
31 | "scale": 0.5,
32 | "bitwidth": 2,
33 | "method": nfp.FIX_AUTO,
34 | "output": ([0.25, -0.5, 0], 0.25),
35 | },
36 | {
37 | "data": [0.2513, -0.52, 0],
38 | "scale": 0.5,
39 | "bitwidth": 2,
40 | "method": nfp.FIX_AUTO,
41 | "out_scale": 1,
42 | "output": ([0.5, -0.5, 0], 0.5),
43 | },
44 | {
45 | "data": [0.2513, -0.52, 0],
46 | "scale": 0.5,
47 | "bitwidth": 4,
48 | "method": nfp.FIX_AUTO,
49 | "range_method": nfp.RANGE_3SIGMA,
50 | "output": ([0.25, -0.5, 0], 0.25),
51 | },
52 | # test the float scale
53 | {
54 | "data": [0.2513, -0.52, 0.0],
55 | "scale": 0.52,
56 | "bitwidth": 4,
57 | "method": nfp.FIX_AUTO,
58 | "output": ([0.26, -0.52, 0.0], 0.065),
59 | "float_scale": True,
60 | },
61 | {
62 | "data": [0.2513, -0.52, 0.0],
63 | "scale": 0.5,
64 | "bitwidth": 4,
65 | "method": nfp.FIX_AUTO,
66 | "output": ([(0.2513 + 0.52) * 5 / 16, -0.52, 0.0], (0.2513 + 0.52) / 16),
67 | "float_scale": True,
68 | "zero_point": True,
69 | },
70 | {
71 | "data": [[[[0.2513]], [[-0.52]]], [[[0.3321]], [[-0.4532]]]],
72 | "scale": [
73 | 0.2513,
74 | 0.3321,
75 | ], # max_data = data.view(data.shape[0],-1).max(dim=1)[0]
76 | "bitwidth": 4,
77 | "method": nfp.FIX_AUTO,
78 | "output": (
79 | [[[[4 * 0.52 / 8]], [[-0.52]]], [[[6 * 0.4532 / 8]], [[-0.4532]]]],
80 | [0.52 / 8, 0.4532 / 8],
81 | ),
82 | "float_scale": True,
83 | "group": "batch",
84 | },
85 | ],
86 | )
87 | def test_quantize_cfg(case):
88 | scale_tensor = torch.tensor([case["scale"]])
89 | out = nfpq.quantize_cfg(
90 | torch.tensor(case["data"]),
91 | scale_tensor,
92 | torch.tensor(case["bitwidth"]),
93 | case["method"],
94 | case["range_method"] if "range_method" in case.keys() else nfp.RANGE_MAX,
95 | stochastic=case["stochastic"] if "stochastic" in case.keys() else False,
96 | float_scale=case["float_scale"] if "float_scale" in case.keys() else False,
97 | zero_point=case["zero_point"] if "zero_point" in case.keys() else False,
98 | group=case["group"] if "group" in case.keys() else False,
99 | )
100 | assert np.isclose(out[0], case["output"][0]).all()
101 | assert np.isclose(out[1].view(-1), case["output"][1]).all()
102 | if "out_scale" in case:
103 | assert bool(scale_tensor == case["out_scale"])
104 |
105 |
106 | def test_straight_through_gradient():
107 | inputs = torch.autograd.Variable(torch.tensor([1.1, 0.9]), requires_grad=True)
108 | outputs = nfpq.StraightThroughRound().apply(inputs)
109 | outputs.sum().backward()
110 | assert np.isclose(inputs._grad, [1, 1]).all()
111 |
112 | # when Round is applied without straight through, there is no gradient
113 | inputs.grad.detach_()
114 | inputs.grad.zero_()
115 | output_nost = inputs.round()
116 | assert np.isclose(inputs._grad, [0, 0]).all()
117 |
118 | # Stochastic rounding
119 | inputs = torch.autograd.Variable(torch.Tensor(100).fill_(0.5), requires_grad=True)
120 | outputs = nfpq.StraightThroughStochasticRound().apply(inputs)
121 | assert outputs.max() > 0.9 and outputs.min() < 0.1
122 |
123 |
124 | def test_quantize_gradient():
125 | quant_grad = nfpq.QuantitizeGradient()
126 | scale = torch.Tensor([0])
127 | inputs = torch.autograd.Variable(torch.tensor([1.1, 0.9]), requires_grad=True)
128 | quanted = quant_grad.apply(inputs, scale, torch.tensor(2), nfp.FIX_AUTO)
129 | output = (
130 | quanted * torch.autograd.Variable(torch.tensor([0.5, 0.26]), requires_grad=True)
131 | ).sum()
132 | output.backward()
133 | assert np.isclose(inputs._grad, [0.5, 0.25]).all()
134 | assert scale.item() == 0.5
135 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## Fixed Point Training Simulation Framework on PyTorch
2 |
3 | ### Core Functionality
4 | - Parameter/Buffer fix: by using `nics_fix_pt.nn_fix._fix` modules;
5 | parameters are those to be trained, buffers are those to be persistenced but not considered parameter
6 | - Activation fix: by using `ActivationFix` module
7 | - Data fix VS. Gradient fix: by supply `nf_fix_params`/`nf_fix_params_grad` kwargs args
8 | in `nics_fix_pt.nn_fix._fix` or `ActivationFix` module construction
9 |
10 | > NOTE: When constructing a fixed-point module, the dictionary you passed in as `nf_fix_params` argument will be used directly, so if you pass the same configuration dict to two modules. These two module will share the same configurations... If you want each module to be quantized independently, you should construct different configuration dict for each module.
11 |
12 | The underhood quantization of each datum is implemented in `nics_fix.quant._do_quantitize`, the code is easy to understand, if you need different plug-in quantize method other than this default one, or maybe need a configurable behavior of `_do_quantitize`(maybe in simulation for simultaneous computation on different types of device), please contribute for a pluggable and configurable `_do_quantitize` (I believe this will be a simple change), or you can contact me.
13 |
14 | ### Examples
15 | See `examples/mnist/train_mnist.py` for a MNIST fix-point training example, and some demo usage of the utilities.
16 |
17 | ### Usage Explained
18 |
19 | **How to make a module fixed-point**
20 |
21 | If you have implemented your new module type, e.g. a masked convolution module type (a convolution with a mask that is generated by prunning) `MaskedConv2d`, you want a fixed-point version of it, just do:
22 | ```python
23 | from nics_fix_pt import register_fix_module
24 | register_fix_module(MaskedConv2d)
25 | ```
26 |
27 | Then you can construct the fixed-point module, and use it in the following code. All of it functionalities will stay the same, except the weights and the gradient of weights(optionally) will be converted to fixed-point before each usage.
28 |
29 | ```python
30 | import nics_fix_pt.nn_fix as nnf
31 | masked_conv = nnf.MaksedConv2d_fix(...your paramters, including fixed-point configs...)
32 | ```
33 |
34 | The internal registered fixed-point modules are already tested in our test examples and our other works:
35 |
36 | * `torch.nn.Linear`
37 | * `torch.nn.Conv2d`
38 | * `torch.nn.BatchNorm1d`
39 | * `torch.nn.BatchNorm2d`
40 |
41 | > NOTE: We expect most `torch.nn` modules will work well without special handling, so you can use the AutoRegister functionality, which will auto register non-registered `torch.nn` modules. However, as we do not test all these modules thoroughly, if you find some modules fail to work normally in some use case, please tell us or contribute to handle that specially.
42 |
43 | **How to inspect the float-point precision/quantized datum**
44 |
45 | Network parameters are saved as float in `module._parameters[]`, and `module.` is the fixed-point/quantized version. You can either use `module._parameters`, or `.nfp_actual_data` (e.g. `masked_conv.weights.nfp_actual_data`) to access the float-point precision datum.
46 |
47 | As `module._parameters` and `module._buffers` are used by `model.state_dict`, when you dump checkpoints onto the disk, the saved network parameters using `model.state_dict` is float precision:
48 | * In your use cases(e.g. fixed-point hardware simultation), together with the saved float-point precision parameters, you might need to dump and then load/modify the fixed configurations of the variables using `model.get_fix_configs` and `model.load_fix_configs`. Check `examples/mnist/train_mnist.py` for an example.
49 | * Or if you want to directly dump the latest used version of the parameters (which might be a quantitized tensor, depend on your latest configuration), use `nnf.FixTopModule.fix_state_dict(module)` for dumping instead.
50 |
51 | **How to config the fixed-point behavior**
52 |
53 | Every fixed-point module need a fixed-point configuration, and an optional fixed-point configuration for the gradients.
54 |
55 | A config for the module should be a dict, keys are the parameter or buffer names, values is a dict includes torch tensors (for current "bitwidth", "method", "scale"), which are modified in place by function calls to `nics_fix_pt.quant.quantitize_cfg`;
56 |
57 | **How to check/manage the configuration**
58 |
59 | 1. For each quantitized datum, you can use `.data_cfg` or `.grad_cfg`, e.g. `masked_conv.weights.data_cfg` or `masked_conv.bias.data_cfg`.
60 | 2. You can also use `FixTopModule.get_fix_configs(module)` to get configs for multiple modulues in one OrderedDict.
61 |
62 | You can modify the config tensors in place to change the behavior.
63 |
64 | ### Utilities
65 |
66 | - FixTopModule: dump/load fix configuration to/from file; print fix configs.
67 |
68 | FixTopModule is just a wrapper that gather config print/load/dump/setting utilities, these utilities will work with nested normal nn.Module as intermediate module containers, e.g. `nn.Sequential` of fixed modules will also work, you do not need to have a subclass multi-inherited from `nn.Sequential` and `nnf.FixTopModule`!
69 |
70 | - AutoRegister: auto register corresponding fixed-point modules for modules in `torch.nn`: Automatically register all not-registered module by proxing to modules in torch.nn. Exampe usage:
71 | ```python
72 | from nics_fix_pt import NAR as nnf_auto
73 | bilinear = nnf_auto.Bilinear_fix(...parameters...)
74 | ```
75 |
76 | ### Test cases
77 |
78 | Tested with Python 2.7, 3.5, 3.6.1+.
79 |
80 | Pytorch 0.4.1, 1.0.0, 1.1.0, 1.4.0. Note that fixed-point simulation using DataParallel with Pytorch>=1.5.0 versions are not supported now!
81 |
82 | 
83 |
84 | Run `python setup.py test` to run the pytest test cases.
85 |
--------------------------------------------------------------------------------
/tests/test_utils.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 |
4 | def test_auto_register():
5 | import torch
6 | from nics_fix_pt import nn_fix as nnf
7 | from nics_fix_pt import NAR as nnf_auto
8 |
9 | torch.nn.Linear_another2 = torch.nn.Linear
10 |
11 | class Linear_another(torch.nn.Linear): # a stub test case
12 | pass
13 |
14 | torch.nn.Linear_another = Linear_another
15 | with pytest.raises(AttributeError) as _:
16 | fix_module = nnf.Linear_another_fix(
17 | 3, 1, nf_fix_params={}
18 | ) # not already registered
19 | fix_module = nnf_auto.Linear_another(
20 | 3, 1, nf_fix_params={}
21 | ) # will automatically register this fix module; `NAR.linear_another_fix` also works.
22 | fix_module = nnf_auto.Linear_another2_fix(
23 | 3, 1, nf_fix_params={}
24 | ) # will automatically register this fix module; `NAR.linear_another2` also works.
25 | fix_module = nnf.Linear_another_fix(3, 1, nf_fix_params={})
26 | fix_module = nnf.Linear_another2_fix(3, 1, nf_fix_params={})
27 |
28 |
29 | def test_save_state_dict(tmp_path):
30 | import os
31 | import numpy as np
32 | import torch
33 | from torch import nn
34 |
35 | import nics_fix_pt as nfp
36 | from nics_fix_pt import nn_fix as nnf
37 | from nics_fix_pt.utils import _generate_default_fix_cfg
38 |
39 | class _View(nn.Module):
40 | def __init__(self):
41 | super(_View, self).__init__()
42 |
43 | def forward(self, inputs):
44 | return inputs.view(inputs.shape[0], -1)
45 | data = torch.tensor(np.random.rand(8, 3, 4, 4).astype(np.float32)).cuda()
46 | ckpt = os.path.join(tmp_path, "tmp.pt")
47 |
48 | model = nn.Sequential(*[
49 | nnf.Conv2d_fix(3, 10, kernel_size=3, padding=1,
50 | nf_fix_params=_generate_default_fix_cfg(
51 | ["weight", "bias"], scale=2.**np.random.randint(low=-10, high=10),
52 | method=nfp.FIX_FIXED)),
53 | nnf.Conv2d_fix(10, 20, kernel_size=3, padding=1,
54 | nf_fix_params=_generate_default_fix_cfg(
55 | ["weight", "bias"], scale=2.**np.random.randint(low=-10, high=10),
56 | method=nfp.FIX_FIXED)),
57 | nnf.Activation_fix(
58 | nf_fix_params=_generate_default_fix_cfg(
59 | ["activation"], scale=2.**np.random.randint(low=-10, high=10),
60 | method=nfp.FIX_FIXED)),
61 | nn.AdaptiveAvgPool2d(1),
62 | _View(),
63 | nnf.Linear_fix(20, 10, nf_fix_params=_generate_default_fix_cfg(
64 | ["weight", "bias"], scale=2.**np.random.randint(low=-10, high=10),
65 | method=nfp.FIX_FIXED))
66 | ])
67 | model.cuda()
68 | pre_results = model(data)
69 | torch.save(model.state_dict(), ckpt)
70 | model2 = nn.Sequential(*[
71 | nnf.Conv2d_fix(3, 10, kernel_size=3, padding=1,
72 | nf_fix_params=_generate_default_fix_cfg(
73 | ["weight", "bias"], scale=2.**np.random.randint(low=-10, high=10),
74 | method=nfp.FIX_FIXED)),
75 | nnf.Conv2d_fix(10, 20, kernel_size=3, padding=1,
76 | nf_fix_params=_generate_default_fix_cfg(
77 | ["weight", "bias"], scale=2.**np.random.randint(low=-10, high=10),
78 | method=nfp.FIX_FIXED)),
79 | nnf.Activation_fix(
80 | nf_fix_params=_generate_default_fix_cfg(
81 | ["activation"], scale=2.**np.random.randint(low=-10, high=10),
82 | method=nfp.FIX_FIXED)),
83 | nn.AdaptiveAvgPool2d(1),
84 | _View(),
85 | nnf.Linear_fix(20, 10, nf_fix_params=_generate_default_fix_cfg(
86 | ["weight", "bias"], scale=2.**np.random.randint(low=-10, high=10),
87 | method=nfp.FIX_FIXED))
88 | ])
89 | model2.cuda()
90 | model2.load_state_dict(torch.load(ckpt))
91 | post_results = model2(data)
92 | assert (post_results - pre_results < 1e-2).all()
93 |
94 |
95 | def test_fix_state_dict(module_cfg):
96 | import torch
97 | from nics_fix_pt.nn_fix import FixTopModule
98 | import nics_fix_pt.quant as nfpq
99 |
100 | module, cfg, _ = module_cfg
101 | dct = FixTopModule.fix_state_dict(module)
102 | assert (dct["param"] == module._parameters["param"]).all() # not already fixed
103 |
104 | # forward the module once
105 | res = module.forward(torch.tensor([0, 0, 0]).float())
106 | dct = FixTopModule.fix_state_dict(module)
107 | dct_vars = FixTopModule.fix_state_dict(module, keep_vars=True)
108 | quantized, _ = nfpq.quantize_cfg(module._parameters["param"], **cfg["param"])
109 | dct_vars = FixTopModule.fix_state_dict(module, keep_vars=True)
110 | assert (dct["param"] == quantized).all() # already fixed
111 | assert (dct_vars["param"] == quantized).all() # already fixed
112 | assert (
113 | dct_vars["param"].nfp_actual_data == module._parameters["param"]
114 | ).all() # underhood float-point data
115 |
116 |
117 | def test_set_fix_method(test_network):
118 | test_network.set_fix_method(method=0, method_by_type={
119 | "BatchNorm2d_fix": {"running_mean": 2, "running_var": 2, "weight": 1, "bias": 0}
120 | }, method_by_name={"conv1": {"weight": 2}})
121 | assert int(test_network.bn1.nf_fix_params["weight"]["method"]) == 1
122 | assert int(test_network.bn1.nf_fix_params["bias"]["method"]) == 0
123 | assert int(test_network.bn1.nf_fix_params["running_mean"]["method"]) == 2
124 | assert int(test_network.conv1.nf_fix_params["weight"]["method"]) == 2
125 | assert int(test_network.conv1.nf_fix_params["bias"]["method"]) == 1
126 | assert int(test_network.conv2.nf_fix_params["weight"]["method"]) == 0
127 | assert int(test_network.conv2.nf_fix_params["bias"]["method"]) == 0
128 |
--------------------------------------------------------------------------------
/examples/cifar10/net.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import torch
3 | from torch import nn
4 | from torch.nn import functional as F
5 |
6 | import nics_fix_pt as nfp
7 | import nics_fix_pt.nn_fix as nnf
8 | import numpy as np
9 |
10 | def _generate_default_fix_cfg(names, scale=0, bitwidth=8, method=0,
11 | range_method=nfp.RangeMethod.RANGE_MAX):
12 | return {
13 | n: {
14 | "method": torch.autograd.Variable(
15 | torch.IntTensor(np.array([method])), requires_grad=False
16 | ),
17 | "scale": torch.autograd.Variable(
18 | torch.IntTensor(np.array([scale])), requires_grad=False
19 | ),
20 | "bitwidth": torch.autograd.Variable(
21 | torch.IntTensor(np.array([bitwidth])), requires_grad=False
22 | ),
23 | "range_method": range_method
24 | }
25 | for n in names
26 | }
27 |
28 | class FixNet(nnf.FixTopModule):
29 | def __init__(self, fix_bn=True, fix_grad=True, bitwidth_data=8, bitwidth_grad=16,
30 | range_method=nfp.RangeMethod.RANGE_MAX,
31 | grad_range_method=nfp.RangeMethod.RANGE_MAX):
32 | super(FixNet, self).__init__()
33 |
34 | print("fix bn: {}; fix grad: {}; range method: {}; grad range method: {}".format(
35 | fix_bn, fix_grad, range_method, grad_range_method
36 | ))
37 |
38 | # fix configurations (data/grad) for parameters/buffers
39 | self.fix_param_cfgs = {}
40 | self.fix_grad_cfgs = {}
41 | layers = [("conv1_1", 128, 3), ("bn1_1",), ("conv1_2", 128, 3), ("bn1_2",),
42 | ("conv1_3", 128, 3), ("bn1_3",), ("conv2_1", 256, 3), ("bn2_1",),
43 | ("conv2_2", 256, 3), ("bn2_2",), ("conv2_3", 256, 3), ("bn2_3",),
44 | ("conv3_1", 512, 3), ("bn3_1",), ("nin3_2", 256, 1), ("bn3_2",),
45 | ("nin3_3", 128, 1), ("bn3_3",), ("fc4", 10)]
46 | for layer_cfg in layers:
47 | name = layer_cfg[0]
48 | if "bn" in name and not fix_bn:
49 | continue
50 | # data fix config
51 | self.fix_param_cfgs[name] = _generate_default_fix_cfg(
52 | ["weight", "bias", "running_mean", "running_var"] \
53 | if "bn" in name else ["weight", "bias"],
54 | method=1, bitwidth=bitwidth_data, range_method=range_method
55 | )
56 | if fix_grad:
57 | # grad fix config
58 | self.fix_grad_cfgs[name] = _generate_default_fix_cfg(
59 | ["weight", "bias"], method=1, bitwidth=bitwidth_grad,
60 | range_method=grad_range_method
61 | )
62 |
63 | # fix configurations for activations
64 | # data fix config
65 | self.fix_act_cfgs = [
66 | _generate_default_fix_cfg(["activation"], method=1, bitwidth=bitwidth_data,
67 | range_method=range_method)
68 | for _ in range(20)
69 | ]
70 | if fix_grad:
71 | # grad fix config
72 | self.fix_act_grad_cfgs = [
73 | _generate_default_fix_cfg(["activation"], method=1, bitwidth=bitwidth_grad,
74 | range_method=grad_range_method)
75 | for _ in range(20)
76 | ]
77 |
78 | # construct layers
79 | cin = 3
80 | for layer_cfg in layers:
81 | name = layer_cfg[0]
82 | if "conv" in name or "nin" in name:
83 | # convolution layers
84 | cout, kernel_size = layer_cfg[1:]
85 | layer = nnf.Conv2d_fix(
86 | cin, cout,
87 | nf_fix_params=self.fix_param_cfgs[name],
88 | nf_fix_params_grad=self.fix_grad_cfgs[name] if fix_grad else None,
89 | kernel_size=kernel_size,
90 | padding=(kernel_size - 1) // 2 if name != "conv3_1" else 0)
91 | cin = cout
92 | elif "bn" in name:
93 | # bn layers
94 | if fix_bn:
95 | layer = nnf.BatchNorm2d_fix(
96 | cin,
97 | nf_fix_params=self.fix_param_cfgs[name],
98 | nf_fix_params_grad=self.fix_grad_cfgs[name] if fix_grad else None)
99 | else:
100 | layer = nn.BatchNorm2d(cin)
101 | elif "fc" in name:
102 | # fully-connected layers
103 | cout = layer_cfg[1]
104 | layer = nnf.Linear_fix(
105 | cin, cout,
106 | nf_fix_params=self.fix_param_cfgs[name],
107 | nf_fix_params_grad=self.fix_grad_cfgs[name] if fix_grad else None)
108 | cin = cout
109 | # call setattr
110 | setattr(self, name, layer)
111 |
112 | for i in range(20):
113 | setattr(self, "fix" + str(i), nnf.Activation_fix(
114 | nf_fix_params=self.fix_act_cfgs[i],
115 | nf_fix_params_grad=self.fix_act_grad_cfgs[i] if fix_grad else None))
116 |
117 | self.pool1 = nn.MaxPool2d((2, 2))
118 | self.pool2 = nn.MaxPool2d((2, 2))
119 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
120 |
121 | def forward(self, x):
122 | x = self.fix0(x)
123 | x = self.fix2(F.relu(self.bn1_1(self.fix1(self.conv1_1(x)))))
124 | x = self.fix4(F.relu(self.bn1_2(self.fix3(self.conv1_2(x)))))
125 | x = self.pool1(self.fix6(F.relu(self.bn1_3(self.fix5(self.conv1_3(x))))))
126 | x = self.fix8(F.relu(self.bn2_1(self.fix7(self.conv2_1(x)))))
127 | x = self.fix10(F.relu(self.bn2_2(self.fix9(self.conv2_2(x)))))
128 | x = self.pool2(self.fix12(F.relu(self.bn2_3(self.fix11(self.conv2_3(x))))))
129 | x = self.fix14(F.relu(self.bn3_1(self.fix13(self.conv3_1(x)))))
130 | x = self.fix16(F.relu(self.bn3_2(self.fix15(self.nin3_2(x)))))
131 | x = self.fix18(F.relu(self.bn3_3(self.fix17(self.nin3_3(x)))))
132 | # x = self.fix2(F.relu(self.bn1_1(self.conv1_1(x))))
133 | # x = self.fix4(F.relu(self.bn1_2(self.conv1_2(x))))
134 | # x = self.pool1(self.fix6(F.relu(self.bn1_3(self.conv1_3(x)))))
135 | # x = self.fix8(F.relu(self.bn2_1(self.conv2_1(x))))
136 | # x = self.fix10(F.relu(self.bn2_2(self.conv2_2(x))))
137 | # x = self.pool2(self.fix12(F.relu(self.bn2_3(self.conv2_3(x)))))
138 | # x = self.fix14(F.relu(self.bn3_1(self.conv3_1(x))))
139 | # x = self.fix16(F.relu(self.bn3_2(self.nin3_2(x))))
140 | # x = self.fix18(F.relu(self.bn3_3(self.nin3_3(x))))
141 | x = self.avg_pool(x)
142 | x = x.view(-1, 128)
143 | x = self.fix19(self.fc4(x))
144 |
145 | return x
146 |
--------------------------------------------------------------------------------
/tests/test_fix_module.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | import numpy as np
4 | import torch
5 | from torch import nn
6 | import torch.optim as optim
7 |
8 | import nics_fix_pt as nfp
9 |
10 | # When module_cfg's nf_fix_paramparam is set , it means scale=-1, bitwidth=2, method=FIX_AUTO, see the default config in conftest module_cfg fixture.
11 | @pytest.mark.parametrize(
12 | "module_cfg, case",
13 | [
14 | (
15 | {"input_num": 3},
16 | {
17 | "inputs": [1, 1, 0],
18 | "data": [0.2513, -0.52, 0],
19 | "out_scale": 1,
20 | "result": 0,
21 | "output": [0.5, -0.5, 0], # quantized parameters, step 0.5
22 | },
23 | ),
24 | (
25 | {"input_num": 3},
26 | {
27 | "inputs": [1, 1, 0],
28 | "data": [0.2513, -0.5, 0],
29 | "out_scale": 0.5,
30 | "result": -0.25,
31 | "output": [0.25, -0.5, 0], # quantized parameters, step 0.25
32 | },
33 | ),
34 | ],
35 | indirect=["module_cfg"],
36 | )
37 | def test_fix_forward_auto(module_cfg, case):
38 | module, cfg, _ = module_cfg
39 | if "data" in case:
40 | module.param[0, :] = torch.tensor(case["data"])
41 | with torch.no_grad():
42 | res = module.forward(torch.tensor(case["inputs"]).float())
43 | assert np.isclose(res, case["result"]) # calc output
44 | assert np.isclose(module.param, case["output"]).all() # quantized parameter
45 | assert cfg["param"]["scale"] == case["out_scale"] # scale
46 |
47 | @pytest.mark.parametrize(
48 | "module_cfg, case",
49 | [
50 | (
51 | {"input_num": 3},
52 | {
53 | "inputs": [[1, 1, 0], [1, 2, 0]],
54 | "data": [0.2513, -0.52, 0],
55 | "out_scale": 1,
56 | "result": [[0], [-0.5]],
57 | "output": [0.5, -0.5, 0], # quantized parameters, step 0.5
58 | },
59 | ),
60 | (
61 | {"input_num": 3},
62 | {
63 | "inputs": [[1, 1, 0], [1, 1, 0]],
64 | "data": [0.2513, -0.52, 0],
65 | "out_scale": 1,
66 | "result": [[0], [0]],
67 | "output": [0.5, -0.5, 0], # quantized parameters, step 0.5
68 | },
69 | ),
70 | (
71 | {"input_num": 3},
72 | {
73 | "inputs": [[1, 1, 0], [1, 1, 0]],
74 | "data": [0.2513, -0.5, 0],
75 | "out_scale": 0.5,
76 | "result": [[-0.25], [-0.25]],
77 | "output": [0.25, -0.5, 0], # quantized parameters, step 0.25
78 | },
79 | ),
80 | ],
81 | indirect=["module_cfg"],
82 | )
83 | def test_fix_forward_parallel_gpu(module_cfg, case):
84 | module, cfg, _ = module_cfg
85 | if "data" in case:
86 | module.param.data[0, :] = torch.tensor(case["data"])
87 | model = nn.DataParallel(module.cuda(), [0, 1])
88 | with torch.no_grad():
89 | res = model(torch.tensor(case["inputs"]).float().cuda())
90 | assert cfg["param"]["scale"] == case["out_scale"] # scale
91 | assert np.isclose(res.cpu(), case["result"]).all() # calc output
92 | # assert np.isclose(module.param.cpu(), case["output"]).all() # quantized parameter
93 | # this will not change,
94 | # but the gradient will still be accumulated in module_parameters[name].grad
95 |
96 | @pytest.mark.parametrize(
97 | "module_cfg, case",
98 | [
99 | (
100 | {"input_num": 3, "grad_cfg": {"method": nfp.FIX_AUTO}},
101 | {
102 | "inputs": [0.52, -0.27, 0],
103 | "data": [0, 0, 0],
104 | "grad_scale": 1,
105 | "output": [0.5, -0.5, 0],
106 | },
107 | ),
108 | (
109 | {"input_num": 3, "grad_cfg": {"method": nfp.FIX_AUTO}},
110 | {
111 | "inputs": [0.5, -0.27, 0],
112 | "data": [0, 0, 0],
113 | "grad_scale": 0.5,
114 | "output": [0.5, -0.25, 0], # quantized gradients
115 | },
116 | ),
117 | ],
118 | indirect=["module_cfg"],
119 | )
120 | def test_fix_backward_auto(module_cfg, case):
121 | module, _, cfg = module_cfg
122 | if "data" in case:
123 | module.param.data[0, :] = torch.tensor(case["data"])
124 | res = module.forward(torch.tensor(case["inputs"]).float())
125 | res.backward()
126 | assert np.isclose(
127 | module._parameters["param"].grad, case["output"]
128 | ).all() # quantized gradient
129 | assert cfg["param"]["scale"] == case["grad_scale"] # scale
130 |
131 | @pytest.mark.parametrize(
132 | "module_cfg, case",
133 | [
134 | (
135 | {"input_num": 3, "data_cfg": {"method": nfp.FIX_NONE},
136 | "grad_cfg": {"method": nfp.FIX_AUTO}},
137 | {
138 | "inputs": [[0.52, -0.27, 0], [0.52, -0.27, 0]],
139 | "data": [0, 0, 0],
140 | "grad_scale": 1,
141 | "output": [0.5, -0.5, 0],
142 | },
143 | ),
144 | (
145 | {"input_num": 3, "grad_cfg": {"method": nfp.FIX_AUTO}},
146 | {
147 | "inputs": [[0.5, -0.27, 0], [0.5, -0.27, 0]],
148 | "data": [0, 0, 0],
149 | "grad_scale": 0.5,
150 | "output": [0.5, -0.25, 0], # quantized gradients
151 | },
152 | ),
153 | ],
154 | indirect=["module_cfg"],
155 | )
156 | def test_fix_backward_parallel_gpu(module_cfg, case):
157 | module, _, cfg = module_cfg
158 | if "data" in case:
159 | module.param.data[0, :] = torch.tensor(case["data"])
160 | model = nn.DataParallel(module.cuda(), [0, 1])
161 | res = torch.sum(model(torch.tensor(case["inputs"]).float().cuda()))
162 | res.backward()
163 | assert np.isclose(
164 | module._parameters["param"].grad.cpu(), 2 * np.array(case["output"])
165 | ).all() # quantized gradient, 2 batch, grad x 2
166 | assert cfg["param"]["scale"] == case["grad_scale"] # scale
167 |
168 | @pytest.mark.parametrize(
169 | "module_cfg, case",
170 | [
171 | (
172 | {"input_num": 3, "grad_cfg": {"method": nfp.FIX_AUTO}},
173 | {
174 | "inputs": [0.52, -0.27, 0],
175 | "data": [0, 0, 0],
176 | "grad_scale": 1,
177 | "output": [0.5, -0.5, 0],
178 | },
179 | ),
180 | (
181 | {"input_num": 3, "grad_cfg": {"method": nfp.FIX_AUTO}},
182 | {
183 | "inputs": [0.5, -0.27, 0],
184 | "data": [0, 0, 0],
185 | "grad_scale": 0.5,
186 | "output": [0.5, -0.25, 0], # quantized gradients
187 | },
188 | ),
189 | ],
190 | indirect=["module_cfg"],
191 | )
192 | def test_fix_update_auto(module_cfg, case):
193 | module, _, cfg = module_cfg
194 | if "data" in case:
195 | module.param.data[0, :] = torch.tensor(case["data"])
196 | optimizer = optim.SGD(module.parameters(), lr=1.0, momentum=0)
197 | res = module.forward(torch.tensor(case["inputs"]).float())
198 | res.backward()
199 | optimizer.step()
200 | assert np.isclose(
201 | -module._parameters["param"].detach(), case["output"]
202 | ).all() # updated parameter should be - lr * gradient
203 | assert cfg["param"]["scale"] == case["grad_scale"] # scale
204 |
205 | def test_ConvBN_fix():
206 | from nics_fix_pt.nn_fix import ConvBN_fix
207 | # float forward and combine forward get the same results
208 | module = ConvBN_fix(3, 32, nf_fix_params={}).cuda()
209 | module.train()
210 | data = torch.tensor(np.random.rand(128, 3, 32, 32).astype(np.float32)).cuda()
211 | comb_out = module(data)
212 | float_out = module.bn(module.conv(data))
213 | assert (float_out - comb_out < 1e-3).all()
214 |
215 | module.eval()
216 | module = ConvBN_fix(3, 32, nf_fix_params={}).cuda()
217 | data = torch.tensor(np.random.rand(128, 3, 32, 32).astype(np.float32)).cuda()
218 | comb_out = module(data)
219 | float_out = module.bn(module.conv(data))
220 | assert (float_out - comb_out < 1e-3).all()
221 |
222 |
--------------------------------------------------------------------------------
/examples/mnist/train_mnist.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | from __future__ import print_function
4 |
5 | import argparse
6 | import yaml
7 |
8 | import numpy as np
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 | import torch.optim as optim
13 | from torchvision import datasets, transforms
14 | from torch.autograd import Variable
15 |
16 | import nics_fix_pt as nfp
17 | import nics_fix_pt.nn_fix as nnf
18 |
19 | # Training settings
20 | parser = argparse.ArgumentParser(description="PyTorch MNIST Example")
21 | parser.add_argument(
22 | "--batch-size",
23 | type=int,
24 | default=64,
25 | metavar="N",
26 | help="input batch size for training (default: 64)",
27 | )
28 | parser.add_argument(
29 | "--test-batch-size",
30 | type=int,
31 | default=1000,
32 | metavar="N",
33 | help="input batch size for testing (default: 1000)",
34 | )
35 | parser.add_argument(
36 | "--epochs",
37 | type=int,
38 | default=1,
39 | metavar="N",
40 | help="number of epochs to train (default: 1)",
41 | )
42 | parser.add_argument(
43 | "--lr", type=float, default=0.01, metavar="LR", help="learning rate (default: 0.01)"
44 | )
45 | parser.add_argument(
46 | "--momentum",
47 | type=float,
48 | default=0.5,
49 | metavar="M",
50 | help="SGD momentum (default: 0.5)",
51 | )
52 | parser.add_argument(
53 | "--no-cuda", action="store_true", default=False, help="disables CUDA training"
54 | )
55 | parser.add_argument(
56 | "--seed", type=int, default=1, metavar="S", help="random seed (default: 1)"
57 | )
58 | parser.add_argument(
59 | "--log-interval",
60 | type=int,
61 | default=10,
62 | metavar="N",
63 | help="how many batches to wait before logging training status",
64 | )
65 | parser.add_argument(
66 | "--float",
67 | action="store_true",
68 | default=False,
69 | help="use float point training/testing",
70 | )
71 | parser.add_argument("--save", default=None, help="save fixed-point paramters to file")
72 | args = parser.parse_args()
73 | args.cuda = not args.no_cuda and torch.cuda.is_available()
74 |
75 | torch.manual_seed(args.seed)
76 | if args.cuda:
77 | torch.cuda.manual_seed(args.seed)
78 |
79 |
80 | kwargs = {"num_workers": 1, "pin_memory": True} if args.cuda else {}
81 | train_loader = torch.utils.data.DataLoader(
82 | datasets.MNIST(
83 | "../data",
84 | train=True,
85 | download=True,
86 | transform=transforms.Compose(
87 | [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
88 | ),
89 | ),
90 | batch_size=args.batch_size,
91 | shuffle=True,
92 | **kwargs
93 | )
94 | test_loader = torch.utils.data.DataLoader(
95 | datasets.MNIST(
96 | "../data",
97 | train=False,
98 | transform=transforms.Compose(
99 | [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
100 | ),
101 | ),
102 | batch_size=args.test_batch_size,
103 | shuffle=True,
104 | **kwargs
105 | )
106 |
107 |
108 | def _generate_default_fix_cfg(names, scale=0, bitwidth=8, method=0):
109 | return {
110 | n: {
111 | "method": torch.autograd.Variable(
112 | torch.IntTensor(np.array([method])), requires_grad=False
113 | ),
114 | "scale": torch.autograd.Variable(
115 | torch.IntTensor(np.array([scale])), requires_grad=False
116 | ),
117 | "bitwidth": torch.autograd.Variable(
118 | torch.IntTensor(np.array([bitwidth])), requires_grad=False
119 | ),
120 | }
121 | for n in names
122 | }
123 |
124 |
125 | BITWIDTH = 4
126 |
127 |
128 | class Net(nnf.FixTopModule):
129 | def __init__(self):
130 | super(Net, self).__init__()
131 | # initialize some fix configurations
132 | self.fc1_fix_params = _generate_default_fix_cfg(
133 | ["weight", "bias"], method=1, bitwidth=BITWIDTH
134 | )
135 | self.bn_fc1_params = _generate_default_fix_cfg(
136 | ["weight", "bias", "running_mean", "running_var"],
137 | method=1,
138 | bitwidth=BITWIDTH,
139 | )
140 | self.fc2_fix_params = _generate_default_fix_cfg(
141 | ["weight", "bias"], method=1, bitwidth=BITWIDTH
142 | )
143 | self.fix_params = [
144 | _generate_default_fix_cfg(["activation"], method=1, bitwidth=BITWIDTH)
145 | for _ in range(4)
146 | ]
147 | # initialize modules
148 | self.fc1 = nnf.Linear_fix(784, 100, nf_fix_params=self.fc1_fix_params)
149 | # self.bn_fc1 = nnf.BatchNorm1d_fix(100, nf_fix_params=self.bn_fc1_params)
150 | self.fc2 = nnf.Linear_fix(100, 10, nf_fix_params=self.fc2_fix_params)
151 | self.fix0 = nnf.Activation_fix(nf_fix_params=self.fix_params[0])
152 | # self.fix0_bn = nnf.Activation_fix(nf_fix_params=self.fix_params[1])
153 | self.fix1 = nnf.Activation_fix(nf_fix_params=self.fix_params[2])
154 | self.fix2 = nnf.Activation_fix(nf_fix_params=self.fix_params[3])
155 |
156 | def forward(self, x):
157 | x = self.fix0(x.view(-1, 784))
158 | x = F.relu(self.fix1(self.fc1(x)))
159 | # x = F.relu(self.fix0_bn(self.bn_fc1(self.fix1(self.fc1(x)))))
160 | self.logits = x = self.fix2(self.fc2(x))
161 | return F.log_softmax(x, dim=-1)
162 |
163 |
164 | model = Net()
165 | if args.cuda:
166 | model.cuda()
167 |
168 | optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)
169 |
170 |
171 | def train(epoch, fix_method=nfp.FIX_AUTO):
172 | model.set_fix_method(fix_method)
173 | model.train()
174 | for batch_idx, (data, target) in enumerate(train_loader):
175 | if args.cuda:
176 | data, target = data.cuda(), target.cuda()
177 | data, target = Variable(data), Variable(target)
178 | optimizer.zero_grad()
179 | output = model(data)
180 | loss = F.nll_loss(output, target)
181 | loss.backward()
182 | optimizer.step()
183 | if batch_idx % args.log_interval == 0:
184 | print(
185 | "\rTrain Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
186 | epoch,
187 | batch_idx * len(data),
188 | len(train_loader.dataset),
189 | 100.0 * batch_idx / len(train_loader),
190 | loss.data.item(),
191 | ),
192 | end="",
193 | )
194 | print("")
195 |
196 |
197 | def test(fix_method=nfp.FIX_FIXED):
198 | model.set_fix_method(fix_method)
199 | model.eval()
200 | test_loss = 0
201 | correct = 0
202 | with torch.no_grad():
203 | for data, target in test_loader:
204 | if args.cuda:
205 | data, target = data.cuda(), target.cuda()
206 | data, target = Variable(data), Variable(target)
207 | output = model(data)
208 | test_loss += F.nll_loss(
209 | output, target, size_average=False
210 | ).data.item() # sum up batch loss
211 | pred = output.data.max(1, keepdim=True)[
212 | 1
213 | ] # get the index of the max log-probability
214 | correct += pred.eq(target.data.view_as(pred)).sum().item()
215 |
216 | test_loss /= len(test_loader.dataset)
217 | print(
218 | "\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format(
219 | test_loss,
220 | correct,
221 | len(test_loader.dataset),
222 | 100.0 * correct / len(test_loader.dataset),
223 | )
224 | )
225 |
226 |
227 | for epoch in range(1, args.epochs + 1):
228 | train(epoch, nfp.FIX_NONE if args.float else nfp.FIX_AUTO)
229 | test(nfp.FIX_NONE if args.float else nfp.FIX_FIXED)
230 |
231 | model.print_fix_configs()
232 | fix_cfg = {
233 | "data": model.get_fix_configs(data_only=True),
234 | "grad": model.get_fix_configs(grad=True, data_only=True),
235 | }
236 | with open("mnist_fix_config.yaml", "w") as wf:
237 | yaml.dump(fix_cfg, wf, default_flow_style=False)
238 |
239 | if args.save:
240 | state = {"model": model.fix_state_dict(), "epoch": args.epochs}
241 | torch.save(state, args.save)
242 | print("Saving fixed state dict to", args.save)
243 |
244 | # Let's try float test
245 | print("test float: ", end="")
246 | test(nfp.FIX_NONE) # after 1 epoch: ~ 92%
247 |
248 | if not args.float:
249 | # Let's load the fix config again, and test it using FIX_FIXED
250 | print("load from the yaml config and test fixed again: ", end="")
251 | with open("mnist_fix_config.yaml", "r") as rf:
252 | fix_cfg = yaml.load(rf)
253 | model.load_fix_configs(fix_cfg["data"])
254 | model.load_fix_configs(fix_cfg["grad"], grad=True)
255 | test(nfp.FIX_FIXED) # after 1 epoch: ~ 89%
256 |
--------------------------------------------------------------------------------
/examples/cifar10/main.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import argparse
4 | import os
5 | import shutil
6 | import time
7 | import math
8 |
9 | import numpy as np
10 | import torch
11 | import torch.nn as nn
12 | import torch.nn.parallel
13 | import torch.backends.cudnn as cudnn
14 | import torch.nn.functional as F
15 | import torch.optim
16 | import torch.utils.data
17 | import torchvision.transforms as transforms
18 | import torchvision.datasets as datasets
19 |
20 | import nics_fix_pt as nfp
21 | import nics_fix_pt.nn_fix as nnf
22 | from net import FixNet
23 |
24 | parser = argparse.ArgumentParser(description="PyTorch Cifar10 Fixed-Point Training")
25 | parser.add_argument(
26 | "--save-dir",
27 | required=True,
28 | help="The directory used to save the trained models",
29 | type=str,
30 | )
31 | parser.add_argument(
32 | "--gpu",
33 | metavar="GPUs",
34 | default="0",
35 | help="The gpu devices to use"
36 | )
37 | parser.add_argument(
38 | "--epoch",
39 | default=100,
40 | type=int,
41 | metavar="N",
42 | help="number of total epochs to run",
43 | )
44 | parser.add_argument(
45 | "--start-epoch",
46 | default=0,
47 | type=int,
48 | metavar="N",
49 | help="manual epoch number (useful on restarts)",
50 | )
51 | parser.add_argument(
52 | "-b",
53 | "--batch-size",
54 | default=128,
55 | type=int,
56 | metavar="N",
57 | help="mini-batch size (default: 128)",
58 | )
59 | parser.add_argument(
60 | "--test-batch-size",
61 | type=int,
62 | default=32,
63 | metavar="N",
64 | help="input batch size for testing (default: 1000)",
65 | )
66 | parser.add_argument(
67 | "--lr",
68 | "--learning-rate",
69 | default=0.05,
70 | type=float,
71 | metavar="LR",
72 | help="initial learning rate",
73 | )
74 | parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")
75 | parser.add_argument(
76 | "--weight-decay",
77 | "--wd",
78 | default=5e-4,
79 | type=float,
80 | metavar="W",
81 | help="weight decay (default: 5e-4)",
82 | )
83 | parser.add_argument(
84 | "--print-freq",
85 | "-p",
86 | default=40,
87 | type=int,
88 | metavar="N",
89 | help="print frequency (default: 40)",
90 | )
91 | parser.add_argument(
92 | "--resume",
93 | default="",
94 | type=str,
95 | metavar="PATH",
96 | help="path to latest checkpoint (default: none)",
97 | )
98 | parser.add_argument(
99 | "--prefix",
100 | default="",
101 | type=str,
102 | metavar="PREFIX",
103 | help="checkpoint prefix (default: none)",
104 | )
105 | parser.add_argument(
106 | "--float-bn",
107 | default=False,
108 | action="store_true",
109 | help="quantize the bn layer"
110 | )
111 | parser.add_argument(
112 | "--fix-grad",
113 | default=False,
114 | action="store_true",
115 | help="quantize the gradients"
116 | )
117 | parser.add_argument(
118 | "--range-method",
119 | default=0,
120 | choices=[0, 1, 3],
121 | help=("range methods of data (including parameters, buffers, activations). "
122 | "0: RANGE_MAX, 1: RANGE_3SIGMA, 3: RANGE_SWEEP")
123 | )
124 | parser.add_argument(
125 | "--grad-range-method",
126 | default=0,
127 | choices=[0, 1, 3],
128 | help=("range methods of gradients (including parameters, activations)."
129 | " 0: RANGE_MAX, 1: RANGE_3SIGMA, 3: RANGE_SWEEP")
130 | )
131 | parser.add_argument(
132 | "-e",
133 | "--evaluate",
134 | action="store_true",
135 | help="evaluate model on validation set",
136 | )
137 | parser.add_argument(
138 | "--pretrained", default="", type=str, metavar="PATH", help="use pre-trained model"
139 | )
140 | parser.add_argument(
141 | "--bitwidth-data", default=8, type=int, help="the bitwidth of parameters/buffers/activations"
142 | )
143 | parser.add_argument(
144 | "--bitwidth-grad", default=16, type=int, help="the bitwidth of gradients of parameters/activations"
145 | )
146 |
147 | best_prec1 = 90
148 | start = time.time()
149 |
150 | def _set_fix_method_train_ori(model):
151 | model.set_fix_method(nfp.FIX_AUTO)
152 |
153 | def _set_fix_method_eval_ori(model):
154 | model.set_fix_method(nfp.FIX_FIXED)
155 |
156 | ## --------
157 | ## When bitwidth is small, bn fix would prevent the model from learning.
158 | ## Could use this following config:
159 | ## Note that batchnorm2d_fix buffers (running_mean, running_var) are handled specially here.
160 | ## The running_mean and running_var are not quantized during training forward process,
161 | ## only quantized during test process. This could help avoid the buffer accumulation problem
162 | ## when the bitwidth is too small.
163 | def _set_fix_method_train(model):
164 | model.set_fix_method(
165 | nfp.FIX_AUTO,
166 | method_by_type={
167 | "BatchNorm2d_fix": {
168 | "weight": nfp.FIX_AUTO,
169 | "bias": nfp.FIX_AUTO,
170 | "running_mean": nfp.FIX_NONE,
171 | "running_var": nfp.FIX_NONE}
172 | })
173 |
174 | def _set_fix_method_eval(model):
175 | model.set_fix_method(
176 | nfp.FIX_FIXED,
177 | method_by_type={
178 | "BatchNorm2d_fix": {
179 | "weight": nfp.FIX_FIXED,
180 | "bias": nfp.FIX_FIXED,
181 | "running_mean": nfp.FIX_AUTO,
182 | "running_var": nfp.FIX_AUTO}
183 | })
184 | ## --------
185 |
186 |
187 | def main():
188 | global args, best_prec1
189 | args = parser.parse_args()
190 | print("cmd line arguments: ", args)
191 |
192 | gpus = [int(d) for d in args.gpu.split(",")]
193 | torch.cuda.set_device(gpus[0])
194 |
195 | # Check the save_dir exists or not
196 | if not os.path.exists(args.save_dir):
197 | os.makedirs(args.save_dir)
198 |
199 | model = FixNet(
200 | fix_bn=not args.float_bn,
201 | fix_grad=args.fix_grad,
202 | range_method=args.range_method,
203 | grad_range_method=args.grad_range_method,
204 | bitwidth_data=args.bitwidth_data,
205 | bitwidth_grad=args.bitwidth_grad
206 | )
207 | model.print_fix_configs()
208 |
209 | model.cuda()
210 | if len(gpus) > 1:
211 | parallel_model = torch.nn.DataParallel(model, gpus)
212 | else:
213 | parallel_model = model
214 |
215 | # optionally resume from a checkpoint
216 | if args.resume:
217 | if os.path.isfile(args.resume):
218 | print("=> loading checkpoint '{}'".format(args.resume))
219 | checkpoint = torch.load(args.resume)
220 | args.start_epoch = checkpoint["epoch"]
221 | best_prec1 = checkpoint["best_prec1"]
222 | model.load_state_dict(checkpoint["state_dict"])
223 | print(
224 | "=> loaded checkpoint '{}' (epoch {})".format(
225 | args.evaluate, checkpoint["epoch"]
226 | )
227 | )
228 | else:
229 | print("=> no checkpoint found at '{}'".format(args.resume))
230 | assert os.path.isfile(args.resume)
231 |
232 | if args.pretrained:
233 | if os.path.isfile(args.pretrained):
234 | print("=> fintune from checkpoint '{}'".format(args.pretrained))
235 | checkpoint = torch.load(args.pretrained)
236 | # args.start_epoch = checkpoint['epoch']
237 | # best_prec1 = checkpoint['best_prec1']
238 | model.load_state_dict(checkpoint["state_dict"])
239 | else:
240 | print("=> no checkpoint found at '{}'".format(args.resume))
241 | assert os.path.isfile(args.pretrained)
242 |
243 | # cudnn.benchmark = True
244 |
245 | normalize = transforms.Normalize(
246 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
247 | )
248 |
249 | train_loader = torch.utils.data.DataLoader(
250 | datasets.CIFAR10(
251 | root="../data/cifar10",
252 | train=True,
253 | transform=transforms.Compose(
254 | [
255 | transforms.RandomHorizontalFlip(),
256 | transforms.RandomCrop(32, 4),
257 | transforms.ToTensor(),
258 | normalize,
259 | ]
260 | ),
261 | download=True,
262 | ),
263 | batch_size=args.batch_size,
264 | shuffle=True,
265 | num_workers=4,
266 | pin_memory=True,
267 | )
268 |
269 | val_loader = torch.utils.data.DataLoader(
270 | datasets.CIFAR10(
271 | root="../data/cifar10",
272 | train=False,
273 | transform=transforms.Compose([transforms.ToTensor(), normalize]),
274 | ),
275 | batch_size=args.test_batch_size,
276 | shuffle=False,
277 | num_workers=2,
278 | pin_memory=True,
279 | )
280 |
281 | # define loss function (criterion) and pptimizer
282 | criterion = nn.CrossEntropyLoss().cuda()
283 |
284 | optimizer = torch.optim.SGD(
285 | model.parameters(),
286 | lr=args.lr,
287 | momentum=args.momentum,
288 | weight_decay=args.weight_decay,
289 | )
290 |
291 | if args.evaluate:
292 | validate(val_loader, model, parallel_model, criterion)
293 | return
294 |
295 | for epoch in range(args.start_epoch, args.epoch):
296 | adjust_learning_rate(optimizer, epoch)
297 |
298 | # train for one epoch
299 | train(train_loader, model, parallel_model, criterion, optimizer, epoch)
300 |
301 | # evaluate on validation set
302 | prec1 = validate(val_loader, model, parallel_model, criterion)
303 |
304 | # remember best prec@1 and save checkpoint
305 | is_best = prec1 > best_prec1
306 | best_prec1 = max(prec1, best_prec1)
307 | if best_prec1 > 90 and is_best:
308 | save_checkpoint(
309 | {
310 | "epoch": epoch + 1,
311 | "state_dict": model.state_dict(),
312 | "best_prec1": best_prec1,
313 | },
314 | is_best,
315 | filename=os.path.join(
316 | args.save_dir,
317 | "checkpoint_{}_{:.3f}.tar".format(args.prefix, best_prec1),
318 | ),
319 | )
320 | model.print_fix_configs()
321 |
322 | print("Best acc: {}".format(best_prec1))
323 |
324 |
325 | def train(train_loader, model, p_model, criterion, optimizer, epoch):
326 | """
327 | Run one train epoch
328 | """
329 | losses = AverageMeter()
330 | top1 = AverageMeter()
331 |
332 | # switch to train mode
333 | _set_fix_method_train_ori(model)
334 | model.train()
335 |
336 | for i, (input, target) in enumerate(train_loader):
337 | target = target.cuda(async=True)
338 | input_var = torch.autograd.Variable(input).cuda()
339 | target_var = torch.autograd.Variable(target)
340 |
341 | # compute output
342 | output = p_model(input_var)
343 | loss = criterion(output, target_var)
344 |
345 | # compute gradient and do SGD step
346 | optimizer.zero_grad()
347 | loss.backward()
348 | optimizer.step()
349 |
350 | output = output.float()
351 | loss = loss.float()
352 | # measure accuracy and record loss
353 | prec1 = accuracy(output.data, target)[0]
354 | losses.update(loss.item(), input.size(0))
355 | top1.update(prec1.item(), input.size(0))
356 |
357 | if i % args.print_freq == 0:
358 | print(
359 | "\rEpoch: [{0}][{1}/{2}]\t"
360 | "Time {t}\t"
361 | "Loss {loss.val:.4f} ({loss.avg:.4f})\t"
362 | "Prec@1 {top1.val:.3f}% ({top1.avg:.3f}%)".format(
363 | epoch,
364 | i,
365 | len(train_loader),
366 | t=time.time() - start,
367 | loss=losses,
368 | top1=top1,
369 | ),
370 | end="",
371 | )
372 |
373 |
374 | def validate(val_loader, model, p_model, criterion):
375 | """
376 | Run evaluation
377 | """
378 | losses = AverageMeter()
379 | top1 = AverageMeter()
380 |
381 | # switch to evaluate mode
382 | _set_fix_method_eval_ori(model)
383 | model.eval()
384 |
385 | with torch.no_grad():
386 | for i, (input, target) in enumerate(val_loader):
387 | target = target.cuda(async=True)
388 | input_var = torch.autograd.Variable(input).cuda()
389 | target_var = torch.autograd.Variable(target)
390 |
391 | # compute output
392 | output = p_model(input_var)
393 | loss = criterion(output, target_var)
394 |
395 | output = output.float()
396 | loss = loss.float()
397 |
398 | # measure accuracy and record loss
399 | prec1 = accuracy(output.data, target)[0]
400 | losses.update(loss.item(), input.size(0))
401 | top1.update(prec1.item(), input.size(0))
402 |
403 | if i % args.print_freq == 0:
404 | print(
405 | "Test: [{0}/{1}]\t"
406 | "Time {t}\t"
407 | "Loss {loss.val:.4f} ({loss.avg:.4f})\t"
408 | "Prec@1 {top1.val:.3f}% ({top1.avg:.3f}%)".format(
409 | i, len(val_loader), t=time.time() - start, loss=losses, top1=top1
410 | )
411 | )
412 |
413 | print(
414 | " * Prec@1 {top1.avg:.3f}%\tBest Prec@1 {best_prec1:.3f}%".format(
415 | top1=top1, best_prec1=best_prec1
416 | )
417 | )
418 |
419 | return top1.avg
420 |
421 |
422 | def save_checkpoint(state, is_best, filename="checkpoint.pth.tar"):
423 | """
424 | Save the training model
425 | """
426 | torch.save(state, filename)
427 |
428 |
429 | class AverageMeter(object):
430 | """Computes and stores the average and current value"""
431 |
432 | def __init__(self):
433 | self.reset()
434 |
435 | def reset(self):
436 | self.val = 0
437 | self.avg = 0
438 | self.sum = 0
439 | self.count = 0
440 |
441 | def update(self, val, n=1):
442 | self.val = val
443 | self.sum += val * n
444 | self.count += n
445 | self.avg = self.sum / self.count
446 |
447 |
448 | def adjust_learning_rate(optimizer, epoch):
449 | """Sets the learning rate to the initial LR decayed by 0.5 every 10 epochs"""
450 | lr = args.lr * (0.5 ** (epoch // 10))
451 | for param_group in optimizer.param_groups:
452 | param_group["lr"] = lr
453 | print("Epoch {}: lr: {}".format(epoch, lr))
454 |
455 | def accuracy(output, target, topk=(1,)):
456 | """Computes the precision@k for the specified values of k"""
457 | maxk = max(topk)
458 | batch_size = target.size(0)
459 |
460 | _, pred = output.topk(maxk, 1, True, True)
461 | pred = pred.t()
462 | correct = pred.eq(target.view(1, -1).expand_as(pred))
463 |
464 | res = []
465 | for k in topk:
466 | correct_k = correct[:k].view(-1).float().sum(0)
467 | res.append(correct_k.mul_(100.0 / batch_size))
468 | return res
469 |
470 |
471 | if __name__ == "__main__":
472 | main()
473 |
--------------------------------------------------------------------------------
/nics_fix_pt/quant.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | from __future__ import print_function
4 |
5 | import copy
6 |
7 | import numpy as np
8 | import torch
9 |
10 | from nics_fix_pt.utils import get_int
11 | from nics_fix_pt.consts import QuantizeMethod, RangeMethod
12 |
13 | __all__ = ["quantize"]
14 |
15 |
16 | def _do_quantize(data, scale, bit_width, symmetric=True, stochastic=False, group=False):
17 | """
18 | when sym is not true, the input scale will be 2-value tensor [min,max]
19 | by defalutm the data_to_devise.device, bit_width.devicescale is a single fp-value, denoting range [-max, max]
20 | """
21 | # The grouping are only applied for conv/fc weight/data, with the shape of [x,x,x,x]
22 | # bn params & bias with the shape of [x,x]
23 | if len(data.shape) < 4:
24 | group = False
25 |
26 | scale = scale.to(data.device)
27 | bit_width = bit_width.to(data.device)
28 | tensor_2 = torch.autograd.Variable(
29 | torch.FloatTensor([2.0]), requires_grad=False
30 | ).to(data.device)
31 | if symmetric:
32 | dynamic_range = 2 * scale
33 | maxs = scale
34 | mins = scale * (-1)
35 | else:
36 | """
37 | actually in real hardware implmention, the asymmetric quantization
38 | will be implemented through 1-fp scale and 1-int zero-point to
39 | constraint the range, here for simplicity of software simulation,
40 | we simply define its range
41 | """
42 | dynamic_range = scale[1] - scale[0]
43 | maxs = scale[1]
44 | mins = scale[0]
45 |
46 | if group == "batch":
47 | maxs = maxs.reshape(-1, 1, 1, 1).expand_as(data)
48 | mins = mins.reshape(-1, 1, 1, 1).expand_as(data)
49 | data_to_devise = dynamic_range.reshape(-1, 1, 1, 1)
50 | # MAYBE we should split activation/weight also the grad for this
51 | elif group == "channel":
52 | maxs = maxs.reshape(1, -1, 1, 1).expand_as(data)
53 | mins = mins.reshape(1, -1, 1, 1).expand_as(data)
54 | data_to_devise = dynamic_range.reshape(1, -1, 1, 1)
55 | else:
56 | data_to_devise = dynamic_range
57 |
58 | step = data_to_devise / torch.pow(2, bit_width).float()
59 |
60 | if stochastic:
61 | output = StraightThroughStochasticRound.apply(data / step) * step
62 | else:
63 | output = StraightThroughRound.apply(data / step) * step
64 | # torch.clamp dose not support clamp with multiple value, so using max(min) as alternative
65 | if group is not False:
66 | output = torch.min(torch.max(mins, output), maxs)
67 | else:
68 | output = torch.clamp(output, float(mins), float(maxs))
69 |
70 | return (
71 | output,
72 | step,
73 | )
74 |
75 |
76 | def quantize_cfg(
77 | data,
78 | scale,
79 | bitwidth,
80 | method,
81 | range_method=RangeMethod.RANGE_MAX,
82 | stochastic=False,
83 | float_scale=False,
84 | zero_point=False,
85 | group=False,
86 | ):
87 | """
88 | stochastic - stochastic rounding
89 | range_method - how to decide dynamic range
90 | float_scale - whether the scale is chosen to be 2^K
91 | zero_point - symm/asymm quantize
92 | """
93 | if (
94 | not isinstance(method, torch.autograd.Variable)
95 | and not torch.is_tensor(method)
96 | and method == QuantizeMethod.FIX_NONE
97 | ):
98 | return data, None
99 |
100 | if group == "batch" and len(data.shape) == 4:
101 | # only applied for conv units
102 | max_data = data.view(data.shape[0], -1).max(dim=1)[0]
103 | min_data = data.view(data.shape[0], -1).min(dim=1)[0]
104 | elif group == "channel" and len(data.shape) == 4:
105 | max_data = data.view(data.shape[1], -1).max(dim=1)[0]
106 | min_data = data.view(data.shape[1], -1).min(dim=1)[0]
107 | else:
108 | max_data = data.max()
109 | min_data = data.min()
110 |
111 | # Avoid extreme value
112 | EPS = torch.FloatTensor(max_data.shape).fill_(1e-5)
113 | EPS = EPS.to(max_data.device)
114 | if not zero_point:
115 | max_data = torch.max(torch.max(max_data.abs(), min_data.abs()), EPS)
116 |
117 | method_v = get_int(method)
118 |
119 | if method_v == QuantizeMethod.FIX_NONE:
120 | return data, None
121 | elif method_v == QuantizeMethod.FIX_AUTO:
122 | range_method_v = get_int(range_method)
123 | if float_scale and range_method_v != RangeMethod.RANGE_MAX:
124 | raise NotImplementedError("Now Only Support Float_Scale with Range-Max")
125 | if range_method_v == RangeMethod.RANGE_MAX:
126 | if float_scale:
127 | if zero_point:
128 | new_scale = torch.stack([min_data, max_data])
129 | scale.data = new_scale
130 | return _do_quantize(
131 | data,
132 | scale,
133 | bitwidth,
134 | stochastic=stochastic,
135 | symmetric=False,
136 | group=group,
137 | )
138 | else:
139 | scale.data = max_data
140 | return _do_quantize(
141 | data,
142 | scale,
143 | bitwidth,
144 | stochastic=stochastic,
145 | symmetric=True,
146 | group=group,
147 | )
148 | else:
149 | new_scale = torch.pow(
150 | 2,
151 | torch.ceil(
152 | torch.log(max_data)
153 | / torch.FloatTensor([1]).fill_(np.log(2.0)).to(max_data.device)
154 | ),
155 | )
156 |
157 | scale.data = new_scale
158 | return _do_quantize(
159 | data, scale, bitwidth, stochastic=stochastic, group=group
160 | )
161 |
162 | elif range_method_v == RangeMethod.RANGE_MAX_TENPERCENT:
163 | # FIXME: Too slow
164 | scale = torch.pow(
165 | 2,
166 | torch.ceil(
167 | torch.log(
168 | torch.max(
169 | # torch.kthvalue(torch.abs(data.view(-1)), 9 * (data.nelement() // 10))[0],
170 | torch.topk(torch.abs(data.view(-1)), data.nelement() // 10)[
171 | 0
172 | ][-1],
173 | # torch.tensor(EPS).float().to(data.device))
174 | torch.FloatTensor(1).fill_(EPS).to(data.device),
175 | )
176 | )
177 | / torch.cuda.FloatTensor([1]).fill_(np.log(2.0))
178 | ),
179 | )
180 | return _do_quantize(data, scale, bitwidth, stochastic=stochastic)
181 |
182 | elif range_method_v == RangeMethod.RANGE_3SIGMA:
183 | new_boundary = torch.max(
184 | 3 * torch.std(data) + torch.abs(torch.mean(data)),
185 | torch.tensor(EPS).float().to(data.device),
186 | )
187 | new_scale = torch.pow(2, torch.ceil(torch.log(new_boundary) / np.log(2.0)))
188 | scale.data = new_scale
189 | return _do_quantize(
190 | data, scale, bitwidth, stochastic=stochastic, symmetric=not zero_point
191 | )
192 |
193 | elif range_method_v == RangeMethod.RANGE_SWEEP:
194 | # Iterat through other scale to find the proper scale to minimize error
195 | # Noted that the scale is [(MAX - SWEEP),MAX]
196 | SWEEP = 3
197 | temp_scale = torch.ceil(
198 | torch.log(
199 | torch.max(
200 | torch.max(abs(data)), torch.tensor(EPS).float().to(data.device)
201 | )
202 | )
203 | / np.log(2.0)
204 | )
205 | for i in range(SWEEP):
206 | errors[i] = torch.abs(
207 | _do_quantize(data, temp_scale - i, bitwidth)[0] - data
208 | ).sum()
209 | new_scale = torch.pow(2, temp_scale - errors.argmin())
210 | scale.data = new_scale
211 | return _do_quantize(data, scale, bitwidth, stochastic=stochastic)
212 |
213 | else:
214 | raise NotImplementedError()
215 |
216 | elif method_v == QuantizeMethod.FIX_FIXED:
217 |
218 | if group == "batch" and len(data.shape) == 4:
219 | max_data = data.view(data.shape[0], -1).max(dim=1)[0]
220 | min_data = data.view(data.shape[0], -1).min(dim=1)[0]
221 | if group == "channel" and len(data.shape) == 4:
222 | max_data = data.view(data.shape[1], -1).max(dim=1)[0]
223 | min_data = data.view(data.shape[1], -1).min(dim=1)[0]
224 | else:
225 | max_data = data.max()
226 | min_data = data.min()
227 |
228 | EPS = torch.FloatTensor(max_data.shape).fill_(1e-5)
229 | EPS = EPS.to(max_data.device)
230 | if not zero_point:
231 | max_data = torch.max(torch.max(max_data.abs(), min_data.abs()), EPS)
232 |
233 | # TODO: Check whether float_scale automatically adjust through inference
234 | # If float_scale, do as FIX_AUTO does
235 | if float_scale:
236 | if zero_point:
237 | new_scale = torch.stack([min_data, max_data])
238 | scale.data = new_scale
239 | return _do_quantize(
240 | data, scale, bitwidth, stochastic=stochastic, symmetric=False
241 | )
242 | else:
243 | scale.data = max_data
244 | return _do_quantize(
245 | data, scale, bitwidth, stochastic=stochastic, symmetric=True
246 | )
247 | else:
248 | # donot use new_scale when using power-of-2 scale
249 | return _do_quantize(
250 | data,
251 | scale,
252 | bitwidth,
253 | stochastic=stochastic,
254 | symmetric=not zero_point,
255 | group=group,
256 | )
257 |
258 | raise Exception("Quantitize method not legal: {}".format(method_v))
259 |
260 |
261 | # https://discuss.pytorch.org/t/how-to-override-the-gradients-for-parameters/3417/6
262 | class StraightThroughRound(torch.autograd.Function):
263 | @staticmethod
264 | def forward(ctx, x):
265 | return x.round()
266 |
267 | @staticmethod
268 | def backward(ctx, g):
269 | return g
270 |
271 |
272 | class StraightThroughStochasticRound(torch.autograd.Function):
273 | @staticmethod
274 | def forward(ctx, x):
275 | # FIXME: Wrong Stochatsic Method, independent stochastic for each element, could lead to even worse perf.
276 | # The Binary tensor denoting whether ceil or not, closer to ceil means for probabily choose ceil
277 | # return x.floor() + (torch.rand(x.shape).to(x.device) > x.ceil() - x)*torch.ones(x.shape).to(x.device)
278 | # out = x.floor() + (torch.cuda.FloatTensor(x.shape).uniform_() > x.ceil() - x)*torch.cuda.FloatTensor(x.shape).fill_(1.)
279 | # out = x.floor() + ((x.ceil() - x) < torch.cuda.FloatTensor([1]).fill_(np.random.uniform()))*torch.cuda.FloatTensor(x.shape).fill_(1.)
280 | noise = torch.FloatTensor(x.shape).uniform_(-0.5, 0.5).to(x.device)
281 | x.add_(noise)
282 | return x
283 |
284 | @staticmethod
285 | def backward(ctx, g):
286 | return g
287 |
288 |
289 | class QuantitizeGradient(torch.autograd.Function):
290 | @staticmethod
291 | def forward(
292 | ctx,
293 | x,
294 | scale,
295 | bitwidth,
296 | method,
297 | range_method=RangeMethod.RANGE_MAX,
298 | stochastic=False,
299 | float_scale=False,
300 | zero_point=False,
301 | group=False,
302 | ):
303 | # FIXME: save the tensor/variables for backward,
304 | # maybe should use `ctx.save_for_backward` for standard practice
305 | # but `save_for_backward` requires scale/bitwidth/method all being of type `Variable`...
306 | ctx.saved = (
307 | scale,
308 | bitwidth,
309 | method,
310 | range_method,
311 | stochastic,
312 | float_scale,
313 | zero_point,
314 | group,
315 | )
316 | return x
317 |
318 | @staticmethod
319 | def backward(ctx, g):
320 | return (
321 | quantize_cfg(g, *ctx.saved)[0],
322 | None,
323 | None,
324 | None,
325 | None,
326 | None,
327 | None,
328 | None,
329 | None,
330 | )
331 |
332 |
333 | def quantize(param, fix_cfg={}, fix_grad_cfg={}, kwarg_cfg={}, name=""):
334 | # fix_cfg/fix_grad_cfg is the configuration saved;
335 | # kwarg_cfg is the overriding configuration supplied for each `forward` call
336 | data_cfg = copy.copy(fix_cfg)
337 | data_cfg.update(kwarg_cfg.get(name + "_fix", {}))
338 | grad_cfg = copy.copy(fix_grad_cfg)
339 | grad_cfg.update(kwarg_cfg.get(name + "_grad_fix", {}))
340 | method = data_cfg.get("method", QuantizeMethod.FIX_NONE)
341 |
342 | step = 0
343 | # quantize data
344 | out_param = param
345 |
346 | if (
347 | isinstance(method, torch.autograd.Variable)
348 | or torch.is_tensor(method)
349 | or method != QuantizeMethod.FIX_NONE
350 | ):
351 | out_param, stepp = quantize_cfg(
352 | out_param,
353 | data_cfg["scale"],
354 | data_cfg["bitwidth"],
355 | data_cfg["method"],
356 | data_cfg.get("range_method", RangeMethod.RANGE_MAX),
357 | data_cfg.get("stochastic", False),
358 | data_cfg.get("float_scale", False),
359 | data_cfg.get("zero_point", False),
360 | data_cfg.get("group", False),
361 | )
362 |
363 | # quantize gradient
364 | method = grad_cfg.get("method", QuantizeMethod.FIX_NONE)
365 | if (
366 | isinstance(method, torch.autograd.Variable)
367 | or torch.is_tensor(method)
368 | or method != QuantizeMethod.FIX_NONE
369 | ):
370 | out_param = QuantitizeGradient().apply(
371 | out_param,
372 | grad_cfg["scale"],
373 | grad_cfg["bitwidth"],
374 | grad_cfg["method"],
375 | grad_cfg.get("range_method", RangeMethod.RANGE_MAX),
376 | grad_cfg.get("stochastic", False),
377 | grad_cfg.get("float_scale", False),
378 | grad_cfg.get("zero_point", False),
379 | data_cfg.get("group", False),
380 | )
381 |
382 | out_param.data_cfg = data_cfg
383 | out_param.grad_cfg = grad_cfg
384 | if param is not out_param:
385 | # avoid memory leaking: old `buffer` tensors could remain referenced unexpectedly
386 | if hasattr(param, "nfp_actual_data"):
387 | del param.nfp_actual_data
388 | del param.data_cfg
389 | del param.grad_cfg
390 | out_param.nfp_actual_data = param # avoid loop ref
391 | # NOTE: the returned step is data fix stepsize, not gradient fix step size;
392 | return out_param, step
393 |
--------------------------------------------------------------------------------
/nics_fix_pt/fix_modules.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | from __future__ import print_function
4 |
5 | import warnings
6 | from collections import OrderedDict
7 |
8 | import six
9 |
10 | import torch
11 | from torch.nn import Module
12 | import torch.nn.functional as F
13 | from . import nn_fix, utils, quant
14 |
15 |
16 | # ---- helpers ----
17 | def _get_kwargs(self, true_kwargs):
18 | default_kwargs = utils.get_kwargs(self.__class__)
19 | if not default_kwargs:
20 | return true_kwargs
21 | # NOTE: here we do not deep copy the default values,
22 | # so non-atom type default value such as dict/list/tensor will be shared
23 | kwargs = {k: v for k, v in six.iteritems(default_kwargs)}
24 | kwargs.update(true_kwargs)
25 | return kwargs
26 |
27 |
28 | def _get_fix_cfg(self, name, grad=False):
29 | if not grad:
30 | cfg = self.nf_fix_params.get(name, {})
31 | if "scale" in cfg:
32 | cfg["scale"] = self._buffers["{}_fp_scale".format(name)]
33 | else:
34 | cfg = self.nf_fix_params_grad.get(name, {})
35 | if "scale" in cfg:
36 | cfg["scale"] = self._buffers["{}_grad_fp_scale".format(name)]
37 | return cfg
38 |
39 |
40 | def _register_fix_buffers(self, patch_register=True):
41 | # register scale tensors as buffers, for correct use in multi-gpu data parallel model
42 | avail_keys = list(self._parameters.keys()) + list(self._buffers.keys())
43 | for name, cfg in six.iteritems(self.nf_fix_params):
44 | if patch_register and name not in avail_keys:
45 | warnings.warn(
46 | (
47 | "{} not available in {}, this specific fixed config "
48 | "will not have effects"
49 | ).format(name, self)
50 | )
51 | if "scale" in cfg:
52 | self.register_buffer("{}_fp_scale".format(name), cfg["scale"])
53 |
54 | avail_keys = list(self._parameters.keys())
55 | for name, cfg in six.iteritems(self.nf_fix_params_grad):
56 | if patch_register and name not in avail_keys:
57 | warnings.warn(
58 | (
59 | "{} not available in {}, this specific grads fixed config "
60 | "will not have effects"
61 | ).format(name, self)
62 | )
63 | if "scale" in cfg:
64 | self.register_buffer("{}_grad_fp_scale".format(name), cfg["scale"])
65 |
66 |
67 | # --------
68 |
69 |
70 | def get_fix_forward(cur_cls):
71 | # pylint: disable=protected-access
72 | def fix_forward(self, inputs, **kwargs):
73 | if not isinstance(inputs, dict):
74 | inputs = {"inputs": inputs}
75 | for n, param in six.iteritems(self._parameters):
76 | # NOTE: Since Pytorch>=1.5.0, parameters in DataParallel replica are no longer
77 | # registered in the _parameters dict, so this mechanism will no longer work.
78 | # Thus for now, only Pytorch<1.5.0 versions are supported if DataParallel is used!
79 | if not isinstance(param, (torch.Tensor, torch.autograd.Variable)):
80 | continue
81 | fix_cfg = _get_fix_cfg(self, n)
82 | fix_grad_cfg = _get_fix_cfg(self, n, grad=True)
83 | set_n, _ = quant.quantize(
84 | param, fix_cfg, fix_grad_cfg, kwarg_cfg=inputs, name=n
85 | )
86 | object.__setattr__(self, n, set_n)
87 | for n, param in six.iteritems(self._buffers):
88 | if not isinstance(param, (torch.Tensor, torch.autograd.Variable)):
89 | continue
90 | fix_cfg = _get_fix_cfg(self, n)
91 | fix_grad_cfg = _get_fix_cfg(self, n, grad=True)
92 | set_n, _ = quant.quantize(
93 | param, fix_cfg, fix_grad_cfg, kwarg_cfg=inputs, name=n
94 | )
95 | object.__setattr__(self, n, set_n)
96 | res = super(cur_cls, self).forward(inputs["inputs"], **kwargs)
97 | for n, param in six.iteritems(self._buffers):
98 | # set buffer back, as there will be no gradient, just in-place modification
99 | # FIXME: For fixed-point batch norm,
100 | # the running mean/var accumulattion is on quantized mean/var,
101 | # which means it might fail to update the running mean/var
102 | # if the updating momentum is too small
103 | updated_buffer = getattr(self, n)
104 | if updated_buffer is not self._buffers[n]:
105 | self._buffers[n].copy_(updated_buffer)
106 | return res
107 |
108 | return fix_forward
109 |
110 |
111 | class FixMeta(type):
112 | def __new__(mcs, name, bases, attrs):
113 | # Construct class name
114 | if not attrs.get("__register_name__", None):
115 | attrs["__register_name__"] = bases[0].__name__ + "_fix"
116 | name = attrs["__register_name__"]
117 | cls = super(FixMeta, mcs).__new__(mcs, name, bases, attrs)
118 | # if already subclass
119 | if not isinstance(bases[0], FixMeta):
120 | cls.forward = get_fix_forward(cur_cls=cls)
121 | setattr(nn_fix, name, cls)
122 | return cls
123 |
124 |
125 | def register_fix_module(cls, register_name=None):
126 | @six.add_metaclass(FixMeta)
127 | class __a_not_use_name(cls):
128 | __register_name__ = register_name
129 |
130 | def __init__(self, *args, **kwargs):
131 | kwargs = _get_kwargs(self, kwargs)
132 | # Pop and parse fix configuration from kwargs
133 | assert "nf_fix_params" in kwargs and isinstance(
134 | kwargs["nf_fix_params"], dict
135 | ), (
136 | "Must specifiy `nf_fix_params` keyword arguments, "
137 | "and `nf_fix_params_grad` is optional."
138 | )
139 | self.nf_fix_params = kwargs.pop("nf_fix_params")
140 | self.nf_fix_params_grad = kwargs.pop("nf_fix_params_grad", {}) or {}
141 | cls.__init__(self, *args, **kwargs)
142 | _register_fix_buffers(self, patch_register=True)
143 | # avail_keys = list(self._parameters.keys()) + list(self._buffers.keys())
144 | # self.nf_fix_params = {k: self.nf_fix_params[k]
145 | # for k in avail_keys if k in self.nf_fix_params}
146 | # self.nf_fix_params_grad = {k: self.nf_fix_params_grad[k]
147 | # for k in avail_keys if k in self.nf_fix_params_grad}
148 |
149 |
150 | class Activation_fix(Module):
151 | def __init__(self, **kwargs):
152 | super(Activation_fix, self).__init__()
153 | kwargs = _get_kwargs(self, kwargs)
154 | assert "nf_fix_params" in kwargs and isinstance(
155 | kwargs["nf_fix_params"], dict
156 | ), "Must specifiy `nf_fix_params` keyword arguments, and `nf_fix_params_grad` is optional."
157 | self.nf_fix_params = kwargs.pop("nf_fix_params")
158 | self.nf_fix_params_grad = kwargs.pop("nf_fix_params_grad", {}) or {}
159 | self.activation = None
160 |
161 | # register scale as buffers
162 | _register_fix_buffers(self, patch_register=False)
163 |
164 | def forward(self, inputs):
165 | if not isinstance(inputs, dict):
166 | inputs = {"inputs": inputs}
167 | name = "activation"
168 | fix_cfg = self.nf_fix_params.get(name, {})
169 | fix_grad_cfg = self.nf_fix_params_grad.get(name, {})
170 | self.activation, _ = quant.quantize(
171 | inputs["inputs"], fix_cfg, fix_grad_cfg, kwarg_cfg=inputs, name=name
172 | )
173 | return self.activation
174 |
175 |
176 | class ConvBN_fix(Module):
177 | def __init__(
178 | self,
179 | in_channels,
180 | out_channels,
181 | kernel_size=3,
182 | stride=1,
183 | padding=0,
184 | dilation=1,
185 | groups=1,
186 | eps=1e-05,
187 | momentum=0.1,
188 | affine=True,
189 | track_running_stats=True,
190 | **kwargs
191 | ):
192 | super(ConvBN_fix, self).__init__()
193 | kwargs = _get_kwargs(self, kwargs)
194 | assert "nf_fix_params" in kwargs and isinstance(
195 | kwargs["nf_fix_params"], dict
196 | ), "Must specifiy `nf_fix_params` keyword arguments, and `nf_fix_params_grad` is optional."
197 | self.nf_fix_params = kwargs.pop("nf_fix_params")
198 | self.nf_fix_params_grad = kwargs.pop("nf_fix_params_grad", {}) or {}
199 | if self.nf_fix_params_grad:
200 | warnings.warn(
201 | "Gradient fixed-point cfgs will NOT take effect! Because, "
202 | "Gradient quantization is usually used to simulate training on hardware. "
203 | "However, merged ConvBN is designed for mitigating the discrepancy between "
204 | "training and behaviour on deploy-only hardware; "
205 | "and enable relatively more accurate running mean/var accumulation "
206 | "during software training. Use these two together might not make sense."
207 | )
208 |
209 | # init the two floating-point sub-modules
210 | self.conv = torch.nn.Conv2d(
211 | in_channels,
212 | out_channels,
213 | kernel_size,
214 | stride,
215 | padding,
216 | dilation,
217 | groups,
218 | bias=False,
219 | )
220 | self.bn = torch.nn.BatchNorm2d(
221 | out_channels, eps, momentum, affine, track_running_stats
222 | )
223 |
224 | # conv and bn attributes
225 | self.stride = stride
226 | self.padding = padding
227 | self.dilation = dilation
228 | self.groups = groups
229 | self.kernel_size = self.conv.kernel_size
230 | self.in_channels = in_channels
231 | self.out_channels = out_channels
232 | self.eps = eps
233 | self.momentum = momentum
234 | self.affine = affine
235 | self.track_running_stats = track_running_stats
236 |
237 | # the quantized combined weights and bias
238 | self.weight = self.conv.weight
239 | self.bias = self.bn.bias
240 |
241 | # register scale as buffers
242 | _register_fix_buffers(self, patch_register=False)
243 |
244 | def forward(self, inputs):
245 | if self.training:
246 | out = self.conv(inputs)
247 | # dummy output, just to accumulate running mean and running var (floating-point)
248 | _ = self.bn(out)
249 |
250 | # calculate batch var/mean
251 | mean = torch.mean(out, dim=[0, 2, 3])
252 | var = torch.var(out, dim=[0, 2, 3])
253 | else:
254 | mean = self.bn.running_mean
255 | var = self.bn.running_var
256 |
257 | inputs = {"inputs": inputs}
258 | # parameters/buffers to be combined
259 | bn_scale = self.bn.weight
260 | bn_bias = self.bn.bias
261 | bn_eps = self.bn.eps
262 | conv_weight = self.conv.weight
263 | conv_bias = self.conv.bias or 0.0 # could be None
264 |
265 | # combine new weights/bias
266 | comb_weight = conv_weight * (bn_scale / torch.sqrt(var + bn_eps)).view(
267 | -1, 1, 1, 1
268 | )
269 | comb_bias = bn_bias + (conv_bias - mean) * bn_scale / torch.sqrt(var + bn_eps)
270 |
271 | # quantize the combined weights/bias (as what would be done in hardware deploy scenario)
272 | comb_weight, _ = quant.quantize(
273 | comb_weight,
274 | self.nf_fix_params.get("weight", {}),
275 | {},
276 | kwarg_cfg=inputs,
277 | name="weight",
278 | )
279 | comb_bias, _ = quant.quantize(
280 | comb_bias,
281 | self.nf_fix_params.get("bias", {}),
282 | {},
283 | kwarg_cfg=inputs,
284 | name="bias",
285 | )
286 |
287 | # run the fixed-point combined convbn
288 | convbn_out = F.conv2d(
289 | inputs["inputs"],
290 | comb_weight,
291 | comb_bias,
292 | self.stride,
293 | self.padding,
294 | self.dilation,
295 | self.groups,
296 | )
297 | object.__setattr__(self, "weight", comb_weight)
298 | object.__setattr__(self, "bias", comb_bias)
299 | return convbn_out
300 |
301 |
302 | class FixTopModule(Module):
303 | """
304 | A module with some simple fix configuration manage utilities.
305 | """
306 |
307 | def __init__(self, *args, **kwargs):
308 | super(FixTopModule, self).__init__(*args, **kwargs)
309 |
310 | # To be portable between python2/3, use staticmethod for these utility methods,
311 | # and patch instance method here.
312 | # As Python2 do not support binding instance method to a class that is not a FixTopModule
313 | self.fix_state_dict = FixTopModule.fix_state_dict.__get__(self)
314 | self.load_fix_configs = FixTopModule.load_fix_configs.__get__(self)
315 | self.get_fix_configs = FixTopModule.get_fix_configs.__get__(self)
316 | self.print_fix_configs = FixTopModule.print_fix_configs.__get__(self)
317 | self.set_fix_method = FixTopModule.set_fix_method.__get__(self)
318 |
319 | @staticmethod
320 | def fix_state_dict(self, destination=None, prefix="", keep_vars=False):
321 | r"""FIXME: maybe do another quantization to make sure all vars are quantized?
322 |
323 | Returns a dictionary containing a whole fixed-point state of the module.
324 |
325 | Both parameters and persistent buffers (e.g. running averages) are
326 | included. Keys are corresponding parameter and buffer names.
327 |
328 | Returns:
329 | dict:
330 | a dictionary containing a whole state of the module
331 |
332 | Example::
333 |
334 | >>> module.state_dict().keys()
335 | ['bias', 'weight']
336 |
337 | """
338 | if destination is None:
339 | destination = OrderedDict()
340 | destination._metadata = OrderedDict()
341 | destination._metadata[prefix[:-1]] = local_metadata = dict(
342 | version=self._version
343 | )
344 | for name, param in self._parameters.items():
345 | if param is not None:
346 | if isinstance(self.__class__, FixMeta): # A fixed-point module
347 | # Get the last used version of the parameters
348 | thevar = getattr(self, name)
349 | else:
350 | thevar = param
351 | destination[prefix + name] = thevar if keep_vars else thevar.data
352 | for name, buf in self._buffers.items():
353 | if buf is not None:
354 | if isinstance(self.__class__, FixMeta): # A fixed-point module
355 | # Get the last saved version of the buffers,
356 | # which can be of float precision
357 | # (as buffers will be turned into fixed-point precision on the next forward)
358 | thevar = getattr(self, name)
359 | else:
360 | thevar = buf
361 | destination[prefix + name] = thevar if keep_vars else thevar.data
362 | for name, module in self._modules.items():
363 | if module is not None:
364 | FixTopModule.fix_state_dict(
365 | module, destination, prefix + name + ".", keep_vars=keep_vars
366 | )
367 | for hook in self._state_dict_hooks.values():
368 | hook_result = hook(self, destination, prefix, local_metadata)
369 | if hook_result is not None:
370 | destination = hook_result
371 | return destination
372 |
373 | @staticmethod
374 | def load_fix_configs(self, cfgs, grad=False):
375 | assert isinstance(cfgs, (OrderedDict, dict))
376 | for name, module in six.iteritems(self._modules):
377 | if isinstance(module.__class__, FixMeta) or isinstance(
378 | module, Activation_fix
379 | ):
380 | if name not in cfgs:
381 | print(
382 | (
383 | "WARNING: Fix configuration for {} not found in the configuration! "
384 | "Make sure you know why this happened or "
385 | "there might be some subtle error!"
386 | ).format(name)
387 | )
388 | else:
389 | setattr(
390 | module,
391 | "nf_fix_params" if not grad else "nf_fix_params_grad",
392 | utils.try_parse_variable(cfgs[name]),
393 | )
394 | elif isinstance(module, FixTopModule):
395 | module.load_fix_configs(cfgs[name], grad=grad)
396 | else:
397 | FixTopModule.load_fix_configs(module, cfgs[name], grad=grad)
398 |
399 | @staticmethod
400 | def get_fix_configs(self, grad=False, data_only=False):
401 | """
402 | get_fix_configs:
403 |
404 | Parameters:
405 | grad: BOOLEAN(default False), whether or not to get the gradient configs
406 | instead of data configs.
407 | data_only: BOOLEAN(default False), whether or not to get the numbers instead
408 | of `torch.Tensor` (which can be modified in place).
409 | """
410 | cfg_dct = OrderedDict()
411 | for name, module in six.iteritems(self._modules):
412 | if isinstance(module.__class__, FixMeta) or isinstance(
413 | module, Activation_fix
414 | ):
415 | cfg_dct[name] = getattr(
416 | module, "nf_fix_params" if not grad else "nf_fix_params_grad"
417 | )
418 | if data_only:
419 | cfg_dct[name] = utils.try_parse_int(cfg_dct[name])
420 | elif isinstance(module, FixTopModule):
421 | cfg_dct[name] = module.get_fix_configs(grad=grad, data_only=data_only)
422 | else:
423 | cfg_dct[name] = FixTopModule.get_fix_configs(
424 | module, grad=grad, data_only=data_only
425 | )
426 | return cfg_dct
427 |
428 | @staticmethod
429 | def print_fix_configs(self, data_fix_cfg=None, grad_fix_cfg=None, prefix_spaces=0):
430 | if data_fix_cfg is None:
431 | data_fix_cfg = self.get_fix_configs(grad=False)
432 | if grad_fix_cfg is None:
433 | grad_fix_cfg = self.get_fix_configs(grad=True)
434 |
435 | def _print(string, **kwargs):
436 | print(
437 | "\n".join([" " * prefix_spaces + line for line in string.split("\n")])
438 | + "\n",
439 | **kwargs
440 | )
441 |
442 | for key in data_fix_cfg:
443 | _print(key)
444 | d_cfg = data_fix_cfg[key]
445 | g_cfg = grad_fix_cfg[key]
446 | if isinstance(d_cfg, OrderedDict):
447 | self.print_fix_configs(d_cfg, g_cfg, prefix_spaces=2)
448 | else:
449 | # a dict of configs
450 | keys = set(d_cfg.keys()).union(g_cfg.keys())
451 | for param_name in keys:
452 | d_bw = utils.try_parse_int(
453 | d_cfg.get(param_name, {}).get("bitwidth", "f")
454 | )
455 | g_bw = utils.try_parse_int(
456 | g_cfg.get(param_name, {}).get("bitwidth", "f")
457 | )
458 | d_sc = utils.try_parse_int(
459 | d_cfg.get(param_name, {}).get("scale", "f")
460 | )
461 | g_sc = utils.try_parse_int(
462 | g_cfg.get(param_name, {}).get("scale", "f")
463 | )
464 | d_mt = utils.try_parse_int(
465 | d_cfg.get(param_name, {}).get("method", 0)
466 | )
467 | g_mt = utils.try_parse_int(
468 | g_cfg.get(param_name, {}).get("method", 0)
469 | )
470 | _print(
471 | (
472 | " {param_name:10}: d: bitwidth: {d_bw:3}; "
473 | "scale: {d_sc:3}; method: {d_mt:3}\n"
474 | + " " * 14
475 | + "g: bitwidth: {g_bw:3}; scale: {g_sc:3}; method: {g_mt:3}"
476 | ).format(
477 | param_name=param_name,
478 | d_bw=d_bw,
479 | g_bw=g_bw,
480 | d_sc=d_sc,
481 | g_sc=g_sc,
482 | d_mt=d_mt,
483 | g_mt=g_mt,
484 | )
485 | )
486 |
487 | @staticmethod
488 | def set_fix_method(
489 | self, method=None, method_by_type=None, method_by_name=None, grad=False
490 | ):
491 | for module_name, module in six.iteritems(self._modules):
492 | if isinstance(module.__class__, FixMeta) or isinstance(
493 | module, Activation_fix
494 | ):
495 | fix_params = getattr(
496 | module, "nf_fix_params" if not grad else "nf_fix_params_grad"
497 | )
498 | if method_by_name is not None and module_name in method_by_name:
499 | for param_n, param_method in six.iteritems(
500 | method_by_name[module_name]
501 | ):
502 | assert (
503 | param_n in fix_params
504 | ), "{} is not a quantized parameter of module {}".format(
505 | param_n, module_name
506 | )
507 | _set_method(fix_params[param_n], param_method)
508 | else:
509 | if method_by_type is not None:
510 | param_method_cfg = method_by_type.get(
511 | type(module).__name__,
512 | method_by_type.get(type(module), None),
513 | )
514 | else:
515 | param_method_cfg = None
516 | if param_method_cfg is not None:
517 | # specifiedd by method_by_type
518 | for param_n, param_method in six.iteritems(param_method_cfg):
519 | assert (
520 | param_n in fix_params
521 | ), "{} is not a quantized parameter of module {}".format(
522 | param_n, module_name
523 | )
524 | _set_method(fix_params[param_n], param_method)
525 | elif method is not None:
526 | for param_n in fix_params:
527 | _set_method(fix_params[param_n], method)
528 | elif isinstance(module, FixTopModule):
529 | module.set_fix_method(method, grad=grad)
530 | else:
531 | FixTopModule.set_fix_method(module, method, grad=grad)
532 |
533 |
534 | # helpers
535 | def _set_method(param_cfg, new_method):
536 | if new_method is None:
537 | return
538 | if "method" in param_cfg:
539 | ori_method = param_cfg["method"]
540 | if isinstance(ori_method, torch.autograd.Variable):
541 | ori_method.data.numpy()[0] = new_method
542 | elif torch.is_tensor(ori_method):
543 | ori_method.numpy()[0] = new_method
544 | else:
545 | param_cfg["method"] = new_method
546 |
547 |
548 | nn_fix.Activation_fix = Activation_fix
549 | nn_fix.FixTopModule = FixTopModule
550 | nn_fix.ConvBN_fix = ConvBN_fix
551 |
--------------------------------------------------------------------------------