├── hnn ├── __init__.py ├── hu │ ├── __init__.py │ ├── sampler.py │ ├── a2s_precision_convert.py │ ├── a2s_learnable_coding_sign_convert.py │ ├── s2a_global_rate_coding.py │ ├── precision_convert.py │ ├── a2s_poisson_coding_sign_convert.py │ ├── learnable_sampler.py │ ├── poisson_sampler.py │ ├── a2s_learnable_coding.py │ ├── model.py │ ├── s2a_rate_coding.py │ ├── window_set.py │ ├── global_average_window_conv.py │ ├── a2s_rate_coding.py │ ├── average_window_conv.py │ ├── rate_coding_sampler.py │ ├── s2ahu.py │ ├── s2a_learnable_rate_coding.py │ ├── learnable_window_conv.py │ ├── a2shu.py │ ├── hu.py │ └── window_conv.py ├── ann │ ├── __init__.py │ ├── flatten3d.py │ ├── q_add.py │ ├── q_module.py │ ├── q_adaptive_avgpool2d.py │ ├── q_linear.py │ └── q_conv2d.py ├── snn │ ├── __init__.py │ ├── surrogate │ │ ├── __init__.py │ │ └── rectangle.py │ ├── output_rate_coding.py │ ├── reset_mode.py │ ├── q_dynamics.py │ ├── fire.py │ ├── recorder.py │ ├── refractory.py │ ├── neuron.py │ ├── model.py │ ├── lif_recorder.py │ ├── q_module.py │ ├── threshold_accumulate_with_saturate.py │ ├── threshold_dynamics.py │ ├── lif.py │ ├── fire_with_constant_threshold.py │ ├── lif_neuron.py │ ├── saturate.py │ ├── hard_update_after_spike.py │ ├── accumulate.py │ ├── integrate_and_fire.py │ ├── q_integrate.py │ ├── threshold_accumulate.py │ ├── soft_update_after_spike.py │ ├── reset_after_spike.py │ ├── lif_with_tensor_threshold_and_reset_mode_and_refractory.py │ ├── if_with_tensor_threshold_and_reset_mode_and_refractory.py │ ├── accumulate_with_refractory.py │ ├── leaky.py │ ├── q_linear.py │ ├── q_conv2d.py │ └── extended_lif.py ├── network_type.py ├── utils.py ├── onnx_export_pass.py ├── grad.py └── fuse_bn.py ├── examples ├── __init__.py ├── ann │ ├── __init__.py │ ├── alexnet.py │ ├── lenet.py │ ├── vgg16.py │ ├── googlenet.py │ ├── vgg19.py │ ├── squeezenet.py │ └── small_squeezenet.py ├── hnn │ ├── __init__.py │ ├── s2ahnn_lenet.py │ ├── a2shnn_lenet.py │ └── as2hnn_lenet.py ├── snn │ ├── __init__.py │ ├── snn_mlp.py │ ├── snn_lenet.py │ ├── snn_vgg16.py │ └── snn_vgg19.py └── train_resnet50.md ├── unit_tests ├── __init__.py ├── test_q_conv.py ├── test_hnn.py ├── test_snn.py ├── test_fpmodel.py ├── test_restrict.py ├── test_aware.py ├── test_quantize.py ├── test_collect_q_params.py └── test_lif.py ├── .vscode ├── .env ├── settings.json └── launch.json ├── pytest.ini ├── docs └── sphinx │ ├── source │ ├── _static │ │ ├── hu.png │ │ ├── hnn1.png │ │ ├── hnn2.png │ │ ├── hnn3.png │ │ ├── encoder_snn.png │ │ ├── quantization1.png │ │ ├── snn_framework.png │ │ └── spike_input_snn.png │ ├── index.rst │ ├── SNN编程与量化框架.rst │ ├── HNN介绍.rst │ ├── HNN编程与量化框架.rst │ └── conf.py │ ├── Makefile │ └── make.bat ├── .gitignore ├── requirement.txt ├── setup.py └── README.md /hnn/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /hnn/hu/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /examples/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /hnn/ann/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /hnn/snn/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /unit_tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /examples/ann/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /examples/hnn/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /examples/snn/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /hnn/snn/surrogate/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.vscode/.env: -------------------------------------------------------------------------------- 1 | PYTHONPATH=.:${PYTHONPATH} 2 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | testpathes=unit_tests 3 | addopts=-n8 4 | junit_family=xunit1 5 | -------------------------------------------------------------------------------- /docs/sphinx/source/_static/hu.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openBII/HNN/HEAD/docs/sphinx/source/_static/hu.png -------------------------------------------------------------------------------- /docs/sphinx/source/_static/hnn1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openBII/HNN/HEAD/docs/sphinx/source/_static/hnn1.png -------------------------------------------------------------------------------- /docs/sphinx/source/_static/hnn2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openBII/HNN/HEAD/docs/sphinx/source/_static/hnn2.png -------------------------------------------------------------------------------- /docs/sphinx/source/_static/hnn3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openBII/HNN/HEAD/docs/sphinx/source/_static/hnn3.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | **/temp/ 2 | **/__pycache__ 3 | data 4 | .data 5 | *.onnx 6 | *.dat 7 | docs/**/build 8 | dist 9 | *.egg-info/ -------------------------------------------------------------------------------- /docs/sphinx/source/_static/encoder_snn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openBII/HNN/HEAD/docs/sphinx/source/_static/encoder_snn.png -------------------------------------------------------------------------------- /docs/sphinx/source/_static/quantization1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openBII/HNN/HEAD/docs/sphinx/source/_static/quantization1.png -------------------------------------------------------------------------------- /docs/sphinx/source/_static/snn_framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openBII/HNN/HEAD/docs/sphinx/source/_static/snn_framework.png -------------------------------------------------------------------------------- /docs/sphinx/source/_static/spike_input_snn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openBII/HNN/HEAD/docs/sphinx/source/_static/spike_input_snn.png -------------------------------------------------------------------------------- /requirement.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | torch==1.11.0 3 | torchvision 4 | onnx 5 | onnx-simplifier 6 | spikingjelly 7 | pytest 8 | pytest-html 9 | pytest-xdist 10 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "python.envFile": "${workspaceFolder}/.vscode/.env", 3 | "cmake.configureOnOpen": false, 4 | "python.testing.pytestEnabled": true, 5 | "python.testing.pytestArgs": [ 6 | "unit_tests" 7 | ] 8 | } -------------------------------------------------------------------------------- /hnn/network_type.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) OpenBII 2 | # Team: CBICR 3 | # SPDX-License-Identifier: Apache-2.0 4 | # See: https://spdx.org/licenses/ 5 | 6 | from enum import Enum 7 | 8 | 9 | class NetworkType(Enum): 10 | ANN = 1 11 | SNN = 2 12 | HNN = 3 13 | -------------------------------------------------------------------------------- /hnn/snn/output_rate_coding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class OutputRateCoding(torch.nn.Module): 5 | def __init__(self, dim=0) -> None: 6 | super().__init__() 7 | self.dim = dim 8 | 9 | def forward(self, x: torch.Tensor): 10 | x = torch.stack(x, dim=0) 11 | return x.mean(dim=self.dim) -------------------------------------------------------------------------------- /hnn/snn/reset_mode.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) OpenBII 2 | # Team: CBICR 3 | # SPDX-License-Identifier: Apache-2.0 4 | # See: https://spdx.org/licenses/ 5 | 6 | from enum import Enum 7 | 8 | 9 | class ResetMode(Enum): 10 | '''脉冲神经元的膜电位的复位模式 11 | 12 | HARD: 膜电位复位到固定值 13 | SOFT: 将膜电位减去一个变量 14 | SOFT_CONSTANT: 将膜电位减去一个固定值 15 | ''' 16 | HARD = 0 17 | SOFT = 1 18 | SOFT_CONSTANT = 2 -------------------------------------------------------------------------------- /hnn/hu/sampler.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) OpenBII 2 | # Team: CBICR 3 | # SPDX-License-Identifier: Apache-2.0 4 | # See: https://spdx.org/licenses/ 5 | 6 | import torch 7 | 8 | 9 | class Sampler(torch.nn.Module): 10 | '''采样器 11 | 12 | 抽象类, 一般用于ANN到SNN的转换使用, 将静态的数据采样成时间序列 13 | 14 | Args: 15 | window_size: 采样后的时间序列的长度 16 | ''' 17 | def __init__(self, window_size) -> None: 18 | super(Sampler, self).__init__() 19 | self.window_size = window_size -------------------------------------------------------------------------------- /hnn/hu/a2s_precision_convert.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) OpenBII 2 | # Team: CBICR 3 | # SPDX-License-Identifier: Apache-2.0 4 | # See: https://spdx.org/licenses/ 5 | 6 | import torch 7 | from typing import Callable 8 | from hnn.hu.precision_convert import PrecisionConvert 9 | 10 | 11 | class A2SPrecisionConvert(PrecisionConvert): 12 | '''用于ANN到SNN的精度转换单元 13 | 14 | Args: 15 | converter: 精度转换函数 16 | ''' 17 | def __init__(self, converter: Callable[[torch.Tensor], torch.Tensor]) -> None: 18 | super(A2SPrecisionConvert, self).__init__(converter=converter) -------------------------------------------------------------------------------- /hnn/hu/a2s_learnable_coding_sign_convert.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) OpenBII 2 | # Team: CBICR 3 | # SPDX-License-Identifier: Apache-2.0 4 | # See: https://spdx.org/licenses/ 5 | 6 | import torch 7 | from hnn.hu.a2s_learnable_coding import A2SLearnableCoding 8 | 9 | 10 | class A2SLearnableCodingSignConvert(A2SLearnableCoding): 11 | '''使用可学习采样器并且精度转换函数为符号函数的A2SHU 12 | ''' 13 | def __init__(self, window_size: int, non_linear: torch.nn.Module = None) -> None: 14 | super(A2SLearnableCodingSignConvert, self).__init__( 15 | window_size, torch.sign, non_linear) -------------------------------------------------------------------------------- /.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | // Use IntelliSense to learn about possible attributes. 3 | // Hover to view descriptions of existing attributes. 4 | // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 5 | "version": "0.2.0", 6 | "configurations": [ 7 | { 8 | "name": "Python: Current File", 9 | "type": "python", 10 | "request": "launch", 11 | "program": "${file}", 12 | "console": "integratedTerminal", 13 | "cwd": "${workspaceFolder}" 14 | }, 15 | ] 16 | } -------------------------------------------------------------------------------- /hnn/ann/flatten3d.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) OpenBII 2 | # Team: CBICR 3 | # SPDX-License-Identifier: Apache-2.0 4 | # See: https://spdx.org/licenses/ 5 | 6 | import torch 7 | 8 | 9 | class Flatten3d(torch.nn.Module): 10 | '''将[N, C, H, W]排布的张量转换成[N, H, W, C]并按照C-order展开到一维 11 | ''' 12 | def __init__(self) -> None: 13 | super(Flatten3d, self).__init__() 14 | 15 | def forward(self, x: torch.Tensor): 16 | x = x.permute(0, 2, 3, 1) # [N, C, H, W] -> [N, H, W, C] 17 | # [N, H, W, C] -> [N, H * W * C] 18 | x = x.contiguous().view(x.size(0), -1) 19 | return x -------------------------------------------------------------------------------- /hnn/snn/q_dynamics.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) OpenBII 2 | # Team: CBICR 3 | # SPDX-License-Identifier: Apache-2.0 4 | # See: https://spdx.org/licenses/ 5 | 6 | from hnn.snn.q_module import QModule 7 | 8 | 9 | class QDynamics(QModule): 10 | '''类似于torch.nn.Module, QModel和其他支持量化的算子都继承于QModule类 11 | 12 | Attributes: 13 | scale: 量化参数, 用于对脉冲神经元参数进行放缩 14 | first_time: 只有初次执行时才会对输入膜电位进行量化, 后续时间步执行时的输入膜电位已经被量化过不需要再被量化 15 | freeze: 脉冲神经元参数是否处于冻结状态 16 | ''' 17 | def __init__(self): 18 | self.scale = None 19 | self.first_time = True 20 | self.freeze = False -------------------------------------------------------------------------------- /hnn/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) OpenBII 2 | # Team: CBICR 3 | # SPDX-License-Identifier: Apache-2.0 4 | # See: https://spdx.org/licenses/ 5 | 6 | import random 7 | 8 | import numpy 9 | import torch 10 | 11 | 12 | def setup_random_seed(seed): 13 | torch.manual_seed(seed) 14 | torch.cuda.manual_seed_all(seed) 15 | numpy.random.seed(seed) 16 | random.seed(seed) 17 | torch.backends.cudnn.deterministic = True 18 | 19 | 20 | def get_int8_tensor(shape): 21 | assert type(shape) is tuple 22 | x = torch.rand(shape) * 2 - 1 23 | x = x.mul(128).round().clamp(-128, 127) 24 | return x 25 | -------------------------------------------------------------------------------- /hnn/hu/s2a_global_rate_coding.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) OpenBII 2 | # Team: CBICR 3 | # SPDX-License-Identifier: Apache-2.0 4 | # See: https://spdx.org/licenses/ 5 | 6 | import torch 7 | from hnn.hu.s2ahu import S2AHU 8 | from hnn.hu.global_average_window_conv import GlobalAverageWindowConv 9 | 10 | 11 | class S2AGlobalRateCoding(S2AHU): 12 | '''使用全局平均时间窗卷积的S2AHU 13 | ''' 14 | def __init__(self, window_size: int, non_linear: torch.nn.Module = None) -> None: 15 | super(S2AGlobalRateCoding, self).__init__(window_size, non_linear) 16 | self.window_conv = GlobalAverageWindowConv(window_size=window_size) 17 | self.check() -------------------------------------------------------------------------------- /hnn/snn/fire.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) OpenBII 2 | # Team: CBICR 3 | # SPDX-License-Identifier: Apache-2.0 4 | # See: https://spdx.org/licenses/ 5 | 6 | import torch 7 | 8 | 9 | class Fire(torch.nn.Module): 10 | '''脉冲发放 11 | 12 | 膜电位和阈值均为张量, 即每个神经元都可以有不同的阈值 13 | 14 | Args: 15 | surrogate_function: 梯度替代函数, 可使用的函数见hnn/snn/surrogate 16 | ''' 17 | def __init__(self, surrogate_function) -> None: 18 | super(Fire, self).__init__() 19 | self.surrogate_function = surrogate_function 20 | 21 | def forward(self, v, v_th) -> torch.Tensor: 22 | spike = self.surrogate_function.apply(v, v_th) 23 | return spike -------------------------------------------------------------------------------- /hnn/hu/precision_convert.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) OpenBII 2 | # Team: CBICR 3 | # SPDX-License-Identifier: Apache-2.0 4 | # See: https://spdx.org/licenses/ 5 | 6 | import torch 7 | from typing import Callable 8 | 9 | 10 | class PrecisionConvert(torch.nn.Module): 11 | '''精度转换单元 12 | 13 | 一般用于负责从ANN转换到SNN的HU中 14 | 15 | Args: 16 | converter: 精度转换函数 17 | ''' 18 | def __init__(self, converter: Callable[[torch.Tensor], torch.Tensor]) -> None: 19 | super(PrecisionConvert, self).__init__() 20 | self.converter = converter 21 | 22 | def forward(self, x: torch.Tensor) -> torch.Tensor: 23 | x = self.converter(x) 24 | return x -------------------------------------------------------------------------------- /docs/sphinx/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | SOURCEDIR = source 8 | BUILDDIR = build 9 | 10 | # Put it first so that "make" without argument is like "make help". 11 | help: 12 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 13 | 14 | .PHONY: help Makefile 15 | 16 | # Catch-all target: route all unknown targets to Sphinx using the new 17 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 18 | %: Makefile 19 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) -------------------------------------------------------------------------------- /hnn/hu/a2s_poisson_coding_sign_convert.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) OpenBII 2 | # Team: CBICR 3 | # SPDX-License-Identifier: Apache-2.0 4 | # See: https://spdx.org/licenses/ 5 | 6 | import torch 7 | from hnn.hu.a2shu import A2SHU 8 | from hnn.hu.poisson_sampler import PoissonSampler 9 | 10 | 11 | class A2SPoissonCodingSignConvert(A2SHU): 12 | '''采样器为泊松采样器、精度转换函数为符号函数的A2SHU 13 | ''' 14 | def __init__(self, window_size: int, non_linear: torch.nn.Module = None) -> None: 15 | super(A2SPoissonCodingSignConvert, self).__init__( 16 | window_size, torch.sign, non_linear) 17 | self.sampler = PoissonSampler(window_size=self.window_size) 18 | self.check() -------------------------------------------------------------------------------- /hnn/hu/learnable_sampler.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) OpenBII 2 | # Team: CBICR 3 | # SPDX-License-Identifier: Apache-2.0 4 | # See: https://spdx.org/licenses/ 5 | 6 | import torch 7 | from hnn.hu.sampler import Sampler 8 | 9 | 10 | class LearnableSampler(Sampler): 11 | '''可学习的采样器 12 | 13 | 将输入数据看作一帧数据, 通过线性变换采样到window_size帧 14 | ''' 15 | def __init__(self, window_size: int) -> None: 16 | super(LearnableSampler, self).__init__(window_size=window_size) 17 | self.linear = torch.nn.Linear(1, self.window_size) 18 | 19 | def forward(self, x: torch.Tensor) -> torch.Tensor: 20 | x = x.unsqueeze(-1) 21 | x = self.linear(x) 22 | return x -------------------------------------------------------------------------------- /hnn/hu/poisson_sampler.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) OpenBII 2 | # Team: CBICR 3 | # SPDX-License-Identifier: Apache-2.0 4 | # See: https://spdx.org/licenses/ 5 | 6 | import torch 7 | from hnn.hu.rate_coding_sampler import RateCodingSampler 8 | from spikingjelly.clock_driven import encoding 9 | 10 | 11 | class PoissonSampler(RateCodingSampler): 12 | '''泊松采样器 13 | 14 | 首先通过正则化将数据转换到0到1区间, 然后通过泊松采样器采样, 这里复用了SpikingJelly中的PoissonEncoder 15 | ''' 16 | def __init__(self, window_size: int) -> None: 17 | super().__init__(window_size, encoding.PoissonEncoder()) 18 | 19 | def normalize(self, x: torch.Tensor) -> torch.Tensor: 20 | return (x - x.min()) / (x.max() - x.min()) -------------------------------------------------------------------------------- /docs/sphinx/source/index.rst: -------------------------------------------------------------------------------- 1 | .. TianjicX_Compiler documentation master file, created by 2 | sphinx-quickstart on Jan 12 2023. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | 7 | HNN编程框架 8 | =============== 9 | 10 | .. Documentation Overview 11 | .. ====================== 12 | 13 | HNN编程框架整体上可以分成三部分:ANN的自动量化框架,SNN的编程和自动量化框架,HNN的编程和自动量化框架。 14 | 15 | .. toctree:: 16 | :maxdepth: 4 17 | :caption: 目录: 18 | 19 | ANN量化框架 20 | SNN编程与量化框架 21 | HNN介绍 22 | HNN编程与量化框架 23 | 24 | .. Indices and tables 25 | .. ================== 26 | 27 | .. * :ref:`genindex` 28 | .. * :ref:`modindex` 29 | .. * :ref:`search` 30 | -------------------------------------------------------------------------------- /examples/train_resnet50.md: -------------------------------------------------------------------------------- 1 | # 加载ResNet50预训练模型并进行量化感知训练 2 | 3 | 1. 加载预训练模型, 测试替换后的模型是否一致 `python train_resnet50.py --pretrain --eval --test_batch_size=256 --env_gpu=0` 4 | 2. BN融合 `python train_resnet50.py --pretrain --fuse_bn --checkpoint --env_gpu=0` 5 | 3. 约束训练 `python train_resnet50.py --pretrain --train --checkpoint --restrict --lr=1e-5 -b64 --test_batch_size=256 --env_gpu=0` 6 | 4. 后训练静态量化 `python train_resnet50.py --pretrain --collect --quantize --eval --checkpoint`, checkpoint包括模型的state dict和量化参数的字典(这个步骤可以跳过) 7 | 5. 量化感知训练 `python train_resnet.py --pretrain --collect --aware --train --checkpoint --lr=1e-5 -b64 --test_batch_size=256 --env_gpu=0` 8 | 6. 测试保存的量化模型 `python train_resnet50.py --eval --quantized_pretrain --test_batch_size=256 --env_gpu=0` -------------------------------------------------------------------------------- /hnn/hu/a2s_learnable_coding.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) OpenBII 2 | # Team: CBICR 3 | # SPDX-License-Identifier: Apache-2.0 4 | # See: https://spdx.org/licenses/ 5 | 6 | import torch 7 | from hnn.hu.a2shu import A2SHU 8 | from typing import Callable 9 | from hnn.hu.learnable_sampler import LearnableSampler 10 | 11 | 12 | class A2SLearnableCoding(A2SHU): 13 | '''使用可学习采样器的A2SHU 14 | ''' 15 | def __init__(self, window_size: int, converter: Callable[[torch.Tensor], torch.Tensor], 16 | non_linear: torch.nn.Module = None) -> None: 17 | super(A2SLearnableCoding, self).__init__( 18 | window_size, converter, non_linear) 19 | self.sampler = LearnableSampler(window_size=self.window_size) 20 | self.check() 21 | -------------------------------------------------------------------------------- /hnn/hu/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from hnn.hu.a2shu import A2SHU 3 | from hnn.snn.model import Model 4 | 5 | 6 | class A2SModel(torch.nn.Module): 7 | def __init__(self, T) -> None: 8 | super().__init__() 9 | self.T = T 10 | self.ann: torch.nn.Module = None 11 | self.a2shu: A2SHU = None 12 | self.snn: Model = None 13 | self.encode: torch.nn.Module = None 14 | 15 | def reshape(self, x: torch.Tensor): 16 | return x 17 | 18 | def forward(self, x, *args): 19 | x = self.ann(x) 20 | x = self.a2shu(x) # [N, C, H, W] -> [N, C, H, W, T] 21 | x = self.reshape(x) # [N, C, H, W, T] -> [T, ...] 22 | x = self.snn.multi_step_forward(x, *args) 23 | return self.encode(x) -------------------------------------------------------------------------------- /hnn/snn/recorder.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) OpenBII 2 | # Team: CBICR 3 | # SPDX-License-Identifier: Apache-2.0 4 | # See: https://spdx.org/licenses/ 5 | 6 | import torch 7 | 8 | 9 | class Recorder(torch.autograd.Function): 10 | '''记录脉冲神经元的各种参数信息和标识脉冲神经元 11 | 12 | 抽象类, 脉冲神经元需要根据需求继承Recorder类, 主要用于记录脉冲神经元的各种参数信息以及在计算图中起到标识脉冲神经元的作用 13 | ''' 14 | @staticmethod 15 | def forward(ctx, input): 16 | return input 17 | 18 | @staticmethod 19 | def backward(ctx, grad_output): 20 | return grad_output 21 | 22 | @staticmethod 23 | def symbolic(g: torch._C.Graph, input: torch._C.Value): 24 | # FIXME(huanyu): 这里有个pytorch的bug没有修复, 正常应该通过setType()设置形状, 但shape inference还是会missing 25 | # 这issue好几个月前就提了pytorch还没有修复烦死了😡 26 | return g.op("snn::Record", input) -------------------------------------------------------------------------------- /hnn/hu/s2a_rate_coding.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) OpenBII 2 | # Team: CBICR 3 | # SPDX-License-Identifier: Apache-2.0 4 | # See: https://spdx.org/licenses/ 5 | 6 | import torch 7 | from hnn.hu.s2ahu import S2AHU 8 | from hnn.hu.average_window_conv import AverageWindowConv 9 | 10 | 11 | class S2ARateCoding(S2AHU): 12 | '''使用平均时间窗卷积的S2AHU 13 | 14 | Args: 15 | window_size: 时间窗大小 16 | kernel_size: 卷积窗大小 17 | stride: 卷积窗滑动步长 18 | non_linear: 非线性函数 19 | ''' 20 | def __init__(self, window_size: int, kernel_size: int, stride: int, non_linear: torch.nn.Module = None) -> None: 21 | super(S2ARateCoding, self).__init__(window_size, non_linear) 22 | self.window_conv = AverageWindowConv( 23 | kernel_size=kernel_size, stride=stride) 24 | self.check() -------------------------------------------------------------------------------- /hnn/hu/window_set.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) OpenBII 2 | # Team: CBICR 3 | # SPDX-License-Identifier: Apache-2.0 4 | # See: https://spdx.org/licenses/ 5 | 6 | import torch 7 | 8 | 9 | class WindowSet(torch.nn.Module): 10 | '''设置时间窗 11 | 12 | 非抽象类, 已给出具体实现 13 | 假设输入数据维度为[batch_size, .., T], 最后一个维度为时间, 时间窗长度为t, 输出数据维度变为[batch_size, ..., T / t, t] 14 | 15 | Args: 16 | size: 时间窗长度 17 | ''' 18 | def __init__(self, size: int) -> None: 19 | super(WindowSet, self).__init__() 20 | self.size = size 21 | 22 | def forward(self, x: torch.Tensor) -> torch.Tensor: 23 | t = x.size(-1) 24 | num = t // self.size 25 | shape = list(x.size()) 26 | shape[-1] = num 27 | shape.append(self.size) 28 | x = x.unsqueeze(-2).reshape(shape) 29 | return x -------------------------------------------------------------------------------- /hnn/snn/refractory.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) OpenBII 2 | # Team: CBICR 3 | # SPDX-License-Identifier: Apache-2.0 4 | # See: https://spdx.org/licenses/ 5 | 6 | import torch 7 | from hnn.snn.hard_update_after_spike import HardUpdateAfterSpike 8 | 9 | 10 | class Refractory(torch.nn.Module): 11 | '''不应期 12 | 13 | 不应期减计数, 发放脉冲后不应期复位 14 | 15 | Args: 16 | reset.value = ref_len, 不应期长度 17 | ''' 18 | def __init__(self, ref_len) -> None: 19 | super(Refractory, self).__init__() 20 | self.reset = HardUpdateAfterSpike(value=ref_len) 21 | 22 | def forward(self, ref_cnt: torch.Tensor, spike: torch.Tensor) -> torch.Tensor: 23 | with torch.no_grad(): 24 | ref_cnt[ref_cnt > 0] = ref_cnt[ref_cnt > 0] - 1 25 | ref_cnt = self.reset(ref_cnt, spike) 26 | return ref_cnt 27 | -------------------------------------------------------------------------------- /hnn/snn/neuron.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) OpenBII 2 | # Team: CBICR 3 | # SPDX-License-Identifier: Apache-2.0 4 | # See: https://spdx.org/licenses/ 5 | 6 | import torch 7 | from hnn.snn.q_model import QModel 8 | 9 | 10 | class Neuron(QModel): 11 | '''神经元基类 12 | 13 | 需要导出成ONNX模型的SNN中的神经元需要继承Neuron类 14 | 主要功能为在进行实际的神经元计算前插入Recorder结点, 通过Recorder结点标识神经元并且记录神经元参数, 然后完成正常的神经元计算 15 | 16 | Args: 17 | recorder: 继承自Recorder类, 不完成实际计算, 主要用于标识神经元和记录各种参数 18 | T: SNN的时间步 19 | ''' 20 | def __init__(self, recorder, T): 21 | super(Neuron, self).__init__() 22 | self.recorder = recorder.apply 23 | self.T = T 24 | self.neuron = None 25 | 26 | def record(self, x: torch.Tensor) -> torch.Tensor: 27 | return self.recorder(x) 28 | 29 | def forward(self, *args): 30 | return self.neuron(self.record(args[0]), *args[1:]) -------------------------------------------------------------------------------- /hnn/hu/global_average_window_conv.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) OpenBII 2 | # Team: CBICR 3 | # SPDX-License-Identifier: Apache-2.0 4 | # See: https://spdx.org/licenses/ 5 | 6 | import torch 7 | from hnn.hu.average_window_conv import AverageWindowConv 8 | from typing import List 9 | 10 | 11 | class GlobalAverageWindowConv(AverageWindowConv): 12 | '''全局平均时间窗卷积 13 | 14 | 平均时间窗卷积的退化情况, 在时间窗内计算平均值 15 | 16 | Args: 17 | window_size: 时间窗大小 18 | ''' 19 | def __init__(self, window_size: int) -> None: 20 | super(GlobalAverageWindowConv, self).__init__( 21 | kernel_size=window_size, stride=window_size) 22 | 23 | def reshape(self, x: torch.Tensor, prefix_shape: List) -> torch.Tensor: 24 | '''重载的卷积后reshape方法 25 | 26 | 和父类的reshape方法效果相同, 但更简洁 27 | ''' 28 | prefix_shape.append(x.size(-2)) 29 | x = x.squeeze(-1).reshape(prefix_shape) 30 | return x -------------------------------------------------------------------------------- /docs/sphinx/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /hnn/hu/a2s_rate_coding.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) OpenBII 2 | # Team: CBICR 3 | # SPDX-License-Identifier: Apache-2.0 4 | # See: https://spdx.org/licenses/ 5 | 6 | import torch 7 | from hnn.hu.a2shu import A2SHU 8 | from hnn.hu.rate_coding_sampler import RateCodingSampler 9 | from typing import Callable 10 | 11 | 12 | class A2SRateCoding(A2SHU): 13 | '''采样器为基于rate coding的采样器 14 | 15 | Args: 16 | window_size: 转换后的时间序列长度 17 | non_linear: 非线性函数 18 | precision_convert.converter: 精度转换函数 19 | sampler.encoder: 用于采样的编码器 20 | ''' 21 | def __init__(self, window_size: int, encoder: torch.nn.Module, 22 | converter: Callable[[torch.Tensor], torch.Tensor], non_linear: torch.nn.Module = None) -> None: 23 | super(A2SRateCoding, self).__init__(window_size, converter, non_linear) 24 | self.sampler = RateCodingSampler( 25 | window_size=self.window_size, encoder=encoder) 26 | self.check() -------------------------------------------------------------------------------- /hnn/hu/average_window_conv.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) OpenBII 2 | # Team: CBICR 3 | # SPDX-License-Identifier: Apache-2.0 4 | # See: https://spdx.org/licenses/ 5 | 6 | import torch 7 | from hnn.hu.window_conv import WindowConv 8 | 9 | 10 | class AverageWindowConv(WindowConv): 11 | '''平均时间窗卷积 12 | 13 | 时间窗卷积的退化情况, 在卷积窗内计算平均值 14 | 15 | Args: 16 | kernel_size: 卷积窗大小 17 | stride: 滑窗的步长 18 | ''' 19 | def __init__(self, kernel_size: int, stride: int) -> None: 20 | super(AverageWindowConv, self).__init__() 21 | self.avgpool = torch.nn.AvgPool1d( 22 | kernel_size=kernel_size, stride=stride) 23 | 24 | def forward(self, x: torch.Tensor) -> torch.Tensor: 25 | # 输入特征图形状 [N, ..., num, size] 26 | # 输出特征图形状 [N, ..., T] 27 | x, prefix_shape = self.reshape1d(x) # [N * ..., num, size], [N, ...] 28 | x = self.avgpool(x) 29 | x = self.reshape(x, prefix_shape) 30 | return x -------------------------------------------------------------------------------- /hnn/hu/rate_coding_sampler.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) OpenBII 2 | # Team: CBICR 3 | # SPDX-License-Identifier: Apache-2.0 4 | # See: https://spdx.org/licenses/ 5 | 6 | import torch 7 | from hnn.hu.sampler import Sampler 8 | 9 | 10 | class RateCodingSampler(Sampler): 11 | '''基于RateCoding的采样器 12 | 13 | 通过encoder将一帧数据采样到window_size帧, 采样前可能会对数据进行正则化 14 | ''' 15 | def __init__(self, window_size: int, encoder: torch.nn.Module) -> None: 16 | super(RateCodingSampler, self).__init__(window_size=window_size) 17 | self.encoder = encoder 18 | 19 | def normalize(self, x: torch.Tensor) -> torch.Tensor: 20 | return x 21 | 22 | def forward(self, x: torch.Tensor) -> torch.Tensor: 23 | shape = list(x.size()) 24 | shape.append(self.window_size) 25 | out = torch.zeros(shape) 26 | x = self.normalize(x) 27 | for i in range(self.window_size): 28 | out[..., i] = self.encoder(x) 29 | return out -------------------------------------------------------------------------------- /hnn/snn/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from enum import Enum 3 | 4 | 5 | class InputMode(Enum): 6 | STATIC = 'static' 7 | SEQUENTIAL = 'sequential' 8 | 9 | 10 | class Model(torch.nn.Module): 11 | def __init__(self, time_interval: int, mode: InputMode) -> None: 12 | super(Model, self).__init__() 13 | self.time_interval = time_interval 14 | self.mode = mode 15 | 16 | def multi_step_forward(self, x, *args): 17 | outputs = [] 18 | if self.mode == InputMode.STATIC: 19 | for i in range(self.time_interval): 20 | output, *args = self.forward(x, *args) 21 | outputs.append(output) 22 | elif self.mode == InputMode.SEQUENTIAL: 23 | for i in range(self.time_interval): 24 | output, *args = self.forward(x[i], *args) 25 | outputs.append(output) 26 | else: 27 | raise ValueError('Unsupported input mode') 28 | return outputs 29 | -------------------------------------------------------------------------------- /hnn/snn/surrogate/rectangle.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) OpenBII 2 | # Team: CBICR 3 | # SPDX-License-Identifier: Apache-2.0 4 | # See: https://spdx.org/licenses/ 5 | 6 | import torch 7 | 8 | 9 | class Rectangle(torch.autograd.Function): 10 | '''SNN中脉冲发放的代理函数, 通过矩形窗函数来代替梯度 11 | ''' 12 | 13 | window_size = 1 14 | 15 | @staticmethod 16 | def forward(ctx, v3: torch.Tensor, v_th) -> torch.Tensor: 17 | ctx.save_for_backward(v3, torch.as_tensor(v_th, device=v3.device)) 18 | out = (v3 > v_th).float() 19 | return out 20 | 21 | @staticmethod 22 | def backward(ctx, grad_output): 23 | v3, v_th = ctx.saved_tensors 24 | mask = torch.abs(v3 - v_th) < Rectangle.window_size / 2 25 | return grad_output * mask.float() * 1 / Rectangle.window_size, None 26 | 27 | @staticmethod 28 | def symbolic(g: torch._C.Graph, input: torch._C.Value, v_th0: float) -> torch._C.Value: 29 | return g.op("snn::RectangleFire", input, v_th0_f=v_th0) -------------------------------------------------------------------------------- /hnn/hu/s2ahu.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) OpenBII 2 | # Team: CBICR 3 | # SPDX-License-Identifier: Apache-2.0 4 | # See: https://spdx.org/licenses/ 5 | 6 | import torch 7 | from hnn.hu.hu import HU 8 | from hnn.hu.window_set import WindowSet 9 | 10 | 11 | class S2AHU(HU): 12 | '''SNN到ANN转换的HU 13 | 14 | 抽象类, 子类必须具有window_set和window_conv两个部分, 不能包括sampler和precision_convert, 一般包括non_linear操作 15 | 16 | Args: 17 | window_size: 时间窗大小 18 | non_linear: 非线性函数 19 | window_set: 设置时间窗 20 | ''' 21 | def __init__(self, window_size: int, non_linear: torch.nn.Module = None) -> None: 22 | super(S2AHU, self).__init__(window_size, non_linear) 23 | self.window_set = WindowSet(size=window_size) 24 | 25 | def check(self): 26 | '''检查S2AHU是否符合基本要求 27 | 28 | 子类应该在构造函数最后调用check方法来检查合法性 29 | ''' 30 | assert (self.window_set is not None and self.window_conv is not None and 31 | self.sampler is None and self.precision_convert is None) -------------------------------------------------------------------------------- /hnn/hu/s2a_learnable_rate_coding.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) OpenBII 2 | # Team: CBICR 3 | # SPDX-License-Identifier: Apache-2.0 4 | # See: https://spdx.org/licenses/ 5 | 6 | import torch 7 | from hnn.hu.s2ahu import S2AHU 8 | from hnn.hu.learnable_window_conv import LearnableWindowConv 9 | from typing import Union, Tuple 10 | 11 | 12 | class S2ALearnableRateCoding(S2AHU): 13 | '''使用可学习时间窗卷积的S2AHU 14 | 15 | Args: 16 | window_size: 时间窗大小 17 | num_windows: 时间窗数量 18 | kernel_size: 卷积窗大小 19 | stride: 卷积步长 20 | padding: 卷积补零 21 | non_linear: 非线性函数 22 | ''' 23 | def __init__(self, window_size: int, num_windows: int, kernel_size: int, 24 | stride: int, padding: Union[int, Tuple[int]], non_linear: torch.nn.Module = None) -> None: 25 | super(S2ALearnableRateCoding, self).__init__(window_size, non_linear) 26 | self.window_conv = LearnableWindowConv( 27 | in_channels=num_windows, kernel_size=kernel_size, stride=stride, padding=padding) 28 | self.check() -------------------------------------------------------------------------------- /hnn/snn/lif_recorder.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) OpenBII 2 | # Team: CBICR 3 | # SPDX-License-Identifier: Apache-2.0 4 | # See: https://spdx.org/licenses/ 5 | 6 | import torch 7 | from hnn.snn.recorder import Recorder 8 | 9 | 10 | class LIFRecorder(Recorder): 11 | '''LIF神经元的Recorder 12 | 13 | 需要重载forward方法和symbolic方法, forward方法需要传入输入和神经元的各种参数, symbolic方法构建Recorder结点并记录参数 14 | ''' 15 | @staticmethod 16 | def forward(ctx, input, v_th, v_leaky_alpha, v_leaky_beta, v_reset, v_leaky_adpt_en, v_init, time_window_size): 17 | return input 18 | 19 | @staticmethod 20 | def symbolic(g: torch._C.Graph, input: torch._C.Value, v_th, v_leaky_alpha, v_leaky_beta, v_reset, v_leaky_adpt_en, v_init, time_window_size): 21 | return g.op("snn::LIFRecorder", input, 22 | v_th_f=v_th, v_leaky_alpha_f=v_leaky_alpha, 23 | v_leaky_beta_f=v_leaky_beta, v_reset_f=v_reset, 24 | v_leaky_adpt_en_i=v_leaky_adpt_en, v_init_f=v_init, 25 | time_window_size_f=time_window_size) -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | from os import path 3 | 4 | 5 | DIR = path.dirname(path.abspath(__file__)) 6 | with open(path.join(DIR, 'README.md')) as f: 7 | README = f.read() 8 | 9 | setup( 10 | name="hnn", 11 | version="0.0.3.5", 12 | packages=find_packages(exclude=["examples", "examples.*", "unit_tests", "unit_tests.*"]), 13 | keywords=["hybrid neural networks", "spiking neural networks", "quantization"], 14 | description="A programming framework based on PyTorch for hybrid neural networks with automatic quantization", 15 | long_description=README, 16 | long_description_content_type='text/markdown', 17 | license="Apache License 2.0", 18 | url="https://github.com/openBII/HNN", 19 | author="Huanyu", 20 | author_email="huanyu.qu@hotmail.com", 21 | include_package_data=True, 22 | platforms="any", 23 | install_requires=['numpy', 'torch==1.11.0', 'torchvision', 'onnx', 'onnx-simplifier', 'spikingjelly'], 24 | tests_require=['pytest', 'pytest-html', 'pytest-xdist'], 25 | ) 26 | -------------------------------------------------------------------------------- /unit_tests/test_q_conv.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) OpenBII 2 | # Team: CBICR 3 | # SPDX-License-Identifier: Apache-2.0 4 | # See: https://spdx.org/licenses/ 5 | 6 | import torch 7 | import unittest 8 | import logging 9 | from hnn.ann.q_conv2d import QConv2d 10 | 11 | 12 | class TestQConv2d(unittest.TestCase): 13 | def test_q_conv2d(self): 14 | x = torch.randn((1, 3, 3, 3)) 15 | qconv = QConv2d(3, 8, 3) 16 | conv = torch.nn.Conv2d(3, 8, 3) 17 | logging.debug(conv.weight.data.abs().max()) 18 | qconv.load_state_dict(conv.state_dict()) 19 | qconv.collect_q_params(2) 20 | logging.debug(qconv.weight_scale) 21 | logging.debug(qconv.bit_shift) 22 | qx = x / x.abs().max() * 128 23 | qx = qx.round().clamp(-128, 127) 24 | qconv.quantize() 25 | qy = qconv(qx) 26 | logging.debug(qy) 27 | qconv.aware() 28 | y = qconv(x / x.abs().max()) 29 | logging.debug(y * 128) 30 | 31 | 32 | if __name__ == '__main__': 33 | logging.basicConfig(level=logging.DEBUG) 34 | t1 = TestQConv2d() 35 | t1.test_q_conv2d() -------------------------------------------------------------------------------- /hnn/hu/learnable_window_conv.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) OpenBII 2 | # Team: CBICR 3 | # SPDX-License-Identifier: Apache-2.0 4 | # See: https://spdx.org/licenses/ 5 | 6 | import torch 7 | from hnn.hu.window_conv import WindowConv 8 | from typing import Union, Tuple 9 | 10 | 11 | class LearnableWindowConv(WindowConv): 12 | '''可学习的时间窗卷积 13 | 14 | 进行groups = in_channels = out_channels的1D卷积, 类似于1D的在时间维度上的深度可分离卷积 15 | 16 | Args: 17 | in_channels: 卷积的输入通道数, 等于时间窗的数量 18 | kernel_size: 卷积核大小 19 | stride: 卷积的步长 20 | padding: 卷积的补零 21 | ''' 22 | def __init__(self, in_channels: int, kernel_size: int, stride: int, padding: Union[int, Tuple[int]]) -> None: 23 | super(LearnableWindowConv, self).__init__() 24 | self.conv = torch.nn.Conv1d( 25 | in_channels=in_channels, out_channels=in_channels, 26 | kernel_size=kernel_size, stride=stride, 27 | padding=padding, groups=in_channels) 28 | 29 | def forward(self, x: torch.Tensor) -> torch.Tensor: 30 | x, prefix_shape = self.reshape1d(x) 31 | x = self.conv(x) 32 | x = self.reshape(x, prefix_shape) 33 | return x -------------------------------------------------------------------------------- /hnn/hu/a2shu.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) OpenBII 2 | # Team: CBICR 3 | # SPDX-License-Identifier: Apache-2.0 4 | # See: https://spdx.org/licenses/ 5 | 6 | import torch 7 | from hnn.hu.hu import HU 8 | from hnn.hu.a2s_precision_convert import A2SPrecisionConvert 9 | from typing import Callable 10 | 11 | 12 | class A2SHU(HU): 13 | '''ANN到SNN转换的HU 14 | 15 | 抽象类, 子类必须具有sampler和precision_convert两个部分, 不能包括window_set和window_conv, 一般包括non_linear操作 16 | 17 | Args: 18 | window_size: 转换后的时间序列长度 19 | non_linear: 非线性函数 20 | precision_convert.converter: 精度转换函数 21 | ''' 22 | def __init__(self, window_size: int, converter: Callable[[torch.Tensor], torch.Tensor], 23 | non_linear: torch.nn.Module = None) -> None: 24 | super(A2SHU, self).__init__(window_size, non_linear) 25 | self.precision_convert = A2SPrecisionConvert(converter=converter) 26 | 27 | def check(self): 28 | '''检查A2SHU是否符合基本要求 29 | 30 | 子类应该在构造函数最后调用check方法来检查合法性 31 | ''' 32 | assert (self.window_set is None and self.window_conv is None and 33 | self.sampler is not None and self.precision_convert is not None) -------------------------------------------------------------------------------- /unit_tests/test_hnn.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) OpenBII 2 | # Team: CBICR 3 | # SPDX-License-Identifier: Apache-2.0 4 | # See: https://spdx.org/licenses/ 5 | 6 | import torch 7 | import unittest 8 | from hnn.hu.a2s_poisson_coding_sign_convert import A2SPoissonCodingSignConvert 9 | from hnn.hu.s2a_global_rate_coding import S2AGlobalRateCoding 10 | from hnn.hu.s2a_learnable_rate_coding import S2ALearnableRateCoding 11 | 12 | 13 | class TestHNN(unittest.TestCase): 14 | def test_hnn(self): 15 | a2s = A2SPoissonCodingSignConvert( 16 | window_size=5, non_linear=torch.nn.ReLU()) 17 | x = torch.randn(3, 4, 5, 6) 18 | y = a2s(x) 19 | self.assertTrue(y.shape == torch.Size([3, 4, 5, 6, 5])) 20 | 21 | s2a = S2AGlobalRateCoding(window_size=9, non_linear=torch.nn.ReLU()) 22 | x = torch.randn(3, 4, 5, 6, 9).le(0.5).to(torch.float) 23 | y = s2a(x) 24 | self.assertTrue(y.shape == torch.Size([3, 4, 5, 6, 1])) 25 | 26 | s2a = S2ALearnableRateCoding( 27 | window_size=3, num_windows=3, kernel_size=3, stride=1, padding=0) 28 | y = s2a(x) 29 | self.assertTrue(y.shape == torch.Size([3, 4, 5, 6, 3])) 30 | 31 | 32 | if __name__ == '__main__': 33 | t1 = TestHNN() 34 | t1.test_hnn() -------------------------------------------------------------------------------- /hnn/snn/q_module.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) OpenBII 2 | # Team: CBICR 3 | # SPDX-License-Identifier: Apache-2.0 4 | # See: https://spdx.org/licenses/ 5 | 6 | from abc import ABC, abstractmethod 7 | 8 | 9 | class QModule(ABC): 10 | '''类似于torch.nn.Module, QModel和其他支持量化的算子都继承于QModule类 11 | 12 | Attributes: 13 | quantization_mode: 表示QModule处于量化模式 14 | aware_mode: 表示QModule处于量化感知模式 15 | pretrained: 表示QModule已经加载过预训练模型 16 | ''' 17 | def __init__(self): 18 | self.quantization_mode = False 19 | self.aware_mode = False 20 | self.pretrained = False 21 | 22 | def quantize(self): 23 | '''抽象方法, 用于对模型进行量化 24 | 25 | 继承QModule的子类中的此方法需要先调用父类的此方法将模型置于正确的状态 26 | 27 | Raises: 28 | AssertionError: 如果模型已经处于量化状态则调用此方法会报错 29 | ''' 30 | assert not(self.quantization_mode), 'Model has been quantized' 31 | self.quantization_mode = True 32 | self.aware_mode = False 33 | 34 | def aware(self): 35 | '''抽象方法, 用于对模型进行量化感知训练 36 | 37 | 继承QModule的子类中的此方法需要先调用父类的此方法将模型置于正确的状态 38 | ''' 39 | self.aware_mode = True 40 | 41 | @abstractmethod 42 | def dequantize(self): 43 | '''抽象方法, 用于对模型进行反量化, 将量化模型转换成浮点数模型 44 | 45 | 继承QModule的子类中的此方法需要先调用父类的此方法将模型置于正确的状态 46 | ''' 47 | self.quantization_mode = False -------------------------------------------------------------------------------- /hnn/hu/hu.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) OpenBII 2 | # Team: CBICR 3 | # SPDX-License-Identifier: Apache-2.0 4 | # See: https://spdx.org/licenses/ 5 | 6 | import torch 7 | 8 | 9 | class HU(torch.nn.Module): 10 | '''Hybrid Unit类 11 | 12 | Hybrid Unit抽象类, 包含五个部分: 13 | - window_set: 设置时间窗 14 | - window_conv: 时间窗内进行时间维度上的卷积 15 | - sampler: 采样器 16 | - non_linear: 非线性函数 17 | - precision_convert: 精度转换单元 18 | 19 | 所有Hybrid Unit实例均继承于HU类, 不需要实现forward函数 20 | ''' 21 | def __init__(self, window_size: int, non_linear: torch.nn.Module = None) -> None: 22 | '''HU构造函数 23 | 24 | Args: 25 | window_size: 时间窗大小 26 | non_linear: 非线性变换 27 | ''' 28 | super(HU, self).__init__() 29 | self.window_size = window_size 30 | self.window_set = None 31 | self.window_conv = None 32 | self.sampler = None 33 | self.non_linear = non_linear 34 | self.precision_convert = None 35 | 36 | def forward(self, x: torch.Tensor) -> torch.Tensor: 37 | if self.window_set is not None: 38 | x = self.window_set(x) 39 | if self.window_conv is not None: 40 | x = self.window_conv(x) 41 | if self.sampler is not None: 42 | x = self.sampler(x) 43 | if self.non_linear is not None: 44 | x = self.non_linear(x) 45 | if self.precision_convert is not None: 46 | x = self.precision_convert(x) 47 | return x -------------------------------------------------------------------------------- /unit_tests/test_snn.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) OpenBII 2 | # Team: CBICR 3 | # SPDX-License-Identifier: Apache-2.0 4 | # See: https://spdx.org/licenses/ 5 | 6 | import torch 7 | import unittest 8 | from hnn.snn.q_model import QModel 9 | from hnn.snn.q_linear import QLinear 10 | from hnn.snn.lif import QLIF 11 | from spikingjelly.clock_driven import encoding 12 | 13 | 14 | class SNN(QModel): 15 | def __init__(self): 16 | super(SNN, self).__init__() 17 | self.linear = QLinear(28 * 28, 10, bias=False) 18 | self.lif = QLIF(v_th=1, v_leaky_alpha=0.5, 19 | v_leaky_beta=0, v_reset=0) 20 | 21 | def forward(self, x: torch.Tensor, v: torch.Tensor = None): 22 | x = x.view(x.size(0), -1) 23 | x, q_param = self.linear(x) 24 | out, v = self.lif(x, q_param, v) 25 | return out, v 26 | 27 | 28 | class TestSNN(unittest.TestCase): 29 | def test_snn(self): 30 | x = torch.rand(1, 1, 28, 28) 31 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 32 | x = x.to(device) 33 | encoder = encoding.PoissonEncoder().to(device) 34 | snn = SNN().to(device) 35 | snn.collect_q_params() 36 | snn.quantize() 37 | snn.aware(x) 38 | length = 2 39 | v = None 40 | for _ in range(length): 41 | x = encoder(x) 42 | spike, v = snn(x, v) 43 | 44 | 45 | if __name__ == '__main__': 46 | t1 = TestSNN() 47 | t1.test_snn() -------------------------------------------------------------------------------- /hnn/hu/window_conv.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) OpenBII 2 | # Team: CBICR 3 | # SPDX-License-Identifier: Apache-2.0 4 | # See: https://spdx.org/licenses/ 5 | 6 | import torch 7 | from typing import List 8 | 9 | 10 | class WindowConv(torch.nn.Module): 11 | '''时间窗内进行时间维度上的卷积 12 | 13 | 抽象类 14 | ''' 15 | def __init__(self) -> None: 16 | super(WindowConv, self).__init__() 17 | 18 | def reshape1d(self, x: torch.Tensor): 19 | '''将形状为[batch_size, ..., num_of_windows, window_size]的数据转换成[batch_size * ..., num_of_windows, window_size] 20 | 21 | 此方法应用在进行卷积之前, 由于要进行1D卷积, 所以需要先将输入数据变成三维张量 22 | 23 | Returns: 24 | 第一个返回值: reshape之后的三维数据 25 | 第二个返回值: 原始数据的部分维度 26 | ''' 27 | shape = list(x.size()) 28 | batch_size = 1 29 | for i in range(0, len(shape) - 2): 30 | batch_size *= shape[i] 31 | num = x.size(-2) 32 | size = x.size(-1) 33 | x = x.reshape(batch_size, num, size) 34 | return x, shape[:-2] 35 | 36 | def reshape(self, x: torch.Tensor, prefix_shape: List) -> torch.Tensor: 37 | '''用于在完成时间窗卷积之后的形状变换 38 | 39 | 假设输入特征图形状为[N, num_of_windows, window_size], 此方法将数据形状转换成[*prefix_shape, num_of_windows * window_size], 最后一个维度为时间维度 40 | 41 | Args: 42 | prefix_shape: reshape1d的第二个返回值, prod(prefix_shape) = N 43 | ''' 44 | num = x.size(-2) 45 | size = x.size(-1) 46 | t = num * size 47 | prefix_shape.append(t) 48 | x = x.reshape(prefix_shape) 49 | return x -------------------------------------------------------------------------------- /hnn/snn/threshold_accumulate_with_saturate.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) OpenBII 2 | # Team: CBICR 3 | # SPDX-License-Identifier: Apache-2.0 4 | # See: https://spdx.org/licenses/ 5 | 6 | import torch 7 | from hnn.snn.q_model import QModel 8 | from hnn.snn.threshold_accumulate import ThresholdAccumulate, QThresholdAccumulate 9 | from hnn.snn.saturate import Saturate, QSaturate 10 | 11 | 12 | class ThresholdAccumulateWithSaturate(torch.nn.Module): 13 | '''带有下限饱和的膜电位阈值累加 14 | 15 | 包括膜电位阈值累加和下限饱和两个步骤 16 | 17 | Args: 18 | accumulate.vth0 = v_th0 19 | saturate.v_l = v_l 20 | ''' 21 | def __init__(self, v_th0, v_l) -> None: 22 | super(ThresholdAccumulateWithSaturate, self).__init__() 23 | self.accumulate = ThresholdAccumulate(v_th0=v_th0) 24 | self.saturate = Saturate(v_l=v_l) 25 | 26 | def forward(self, v_th_adpt) -> torch.Tensor: 27 | with torch.no_grad(): 28 | v_th = self.accumulate(v_th_adpt) 29 | v_th = self.saturate(v_th) 30 | return v_th 31 | 32 | 33 | class QThresholdAccumulateWithSaturate(QModel): 34 | '''支持量化的带有下限饱和的膜电位阈值累加 35 | ''' 36 | def __init__(self, v_th0, v_l) -> None: 37 | QModel.__init__(self) 38 | self.accumulate = QThresholdAccumulate(v_th0=v_th0) 39 | self.saturate = QSaturate(v_l=v_l) 40 | 41 | def forward(self, v_th_adpt, scale) -> torch.Tensor: 42 | with torch.no_grad(): 43 | v_th = self.accumulate.forward(v_th_adpt=v_th_adpt, scale=scale) 44 | v_th = self.saturate.forward(x=v_th_adpt, scale=scale) 45 | return v_th -------------------------------------------------------------------------------- /hnn/snn/threshold_dynamics.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) OpenBII 2 | # Team: CBICR 3 | # SPDX-License-Identifier: Apache-2.0 4 | # See: https://spdx.org/licenses/ 5 | 6 | import torch 7 | from hnn.snn.q_model import QModel 8 | from hnn.snn.leaky import Leaky, QLeaky 9 | from hnn.snn.soft_update_after_spike import SoftUpdateAfterSpike, QSoftUpdateAfterSpike 10 | 11 | 12 | class ThresholdDynamics(torch.nn.Module): 13 | '''膜电位阈值的动力学 14 | 15 | 包括阈值自适应分量的指数衰减和发放后导致的阈值增加 16 | 17 | Args: 18 | decay.alpha = v_th_alpha 19 | decay.beta = v_th_beta 20 | decay.adpt_en = v_th_adpt_en 21 | update.value = v_th_incre 22 | ''' 23 | def __init__(self, v_th_alpha, v_th_beta, v_th_incre, v_th_adpt_en=True) -> None: 24 | super(ThresholdDynamics, self).__init__() 25 | self.decay = Leaky(alpha=v_th_alpha, beta=v_th_beta, 26 | adpt_en=v_th_adpt_en) 27 | self.update = SoftUpdateAfterSpike(value=v_th_incre) 28 | 29 | def forward(self, v_th_adpt: torch.Tensor, spike: torch.Tensor) -> torch.Tensor: 30 | v_th_adpt = self.decay(v_th_adpt) 31 | v_th_adpt = self.update(v_th_adpt, spike) 32 | return v_th_adpt 33 | 34 | 35 | class QThresholdDynamics(QModel): 36 | '''支持量化的膜电位阈值的动力学 37 | ''' 38 | def __init__(self, v_th_alpha, v_th_beta, v_th_incre, v_th_adpt_en=True) -> None: 39 | QModel.__init__(self) 40 | self.decay = QLeaky( 41 | alpha=v_th_alpha, beta=v_th_beta, adpt_en=v_th_adpt_en) 42 | self.update = QSoftUpdateAfterSpike(value=v_th_incre) 43 | 44 | def forward(self, v_th_adpt: torch.Tensor, spike: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: 45 | v_th_adpt = self.decay.forward(x=v_th_adpt, weight_scale=scale) 46 | v_th_adpt = self.update.forward( 47 | x=v_th_adpt, spike=spike, weight_scale=scale) 48 | return v_th_adpt -------------------------------------------------------------------------------- /examples/snn/snn_mlp.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) OpenBII 2 | # Team: CBICR 3 | # SPDX-License-Identifier: Apache-2.0 4 | # See: https://spdx.org/licenses/ 5 | 6 | import torch 7 | 8 | from hnn.snn.lif import QLIF 9 | from hnn.snn.output_rate_coding import OutputRateCoding 10 | from hnn.snn.q_linear import QLinear 11 | from hnn.snn.q_model import QModel 12 | 13 | 14 | class SNNMLP(QModel): 15 | def __init__(self, in_channels, T, num_classes=10): 16 | super(SNNMLP, self).__init__(time_window_size=T) 17 | self.linear1 = QLinear( 18 | in_features=in_channels, out_features=num_classes, bias=False, is_encoder=True) 19 | self.lif1 = QLIF(v_th=1, v_leaky_alpha=0.9, 20 | v_leaky_beta=0, v_reset=0) 21 | 22 | self.linear2 = QLinear( 23 | in_features=num_classes, out_features=num_classes, bias=False) 24 | self.lif2 = QLIF(v_th=1, v_leaky_alpha=0.9, 25 | v_leaky_beta=0, v_reset=0) 26 | 27 | self.linear3 = QLinear( 28 | in_features=num_classes, out_features=num_classes, bias=False) 29 | self.lif3 = QLIF(v_th=1, v_leaky_alpha=0.9, 30 | v_leaky_beta=0, v_reset=0) 31 | 32 | self.coding = OutputRateCoding() 33 | 34 | def forward(self, x: torch.Tensor): 35 | spike = torch.zeros((self.T, x.shape[0], x.shape[1])) 36 | v1 = None 37 | v2 = None 38 | v3 = None 39 | for i in range(self.T): 40 | x, q = self.linear1(x) 41 | out, v1 = self.lif1(x, q, v1) 42 | 43 | x, q = self.linear2(out) 44 | out, v2 = self.lif2(x, q, v2) 45 | 46 | x, q = self.linear3(out) 47 | out, v3 = self.lif3(x, q, v3) 48 | spike[i] = out 49 | return self.coding(spike) 50 | 51 | 52 | if __name__ == '__main__': 53 | x = torch.randn((2, 10)) 54 | model = SNNMLP(10, 10, 10) 55 | y = model(x) 56 | 57 | torch.onnx.export(model, x, 'temp/SNNMLP.onnx', 58 | custom_opsets={'snn': 1}, opset_version=11) 59 | -------------------------------------------------------------------------------- /hnn/snn/lif.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) OpenBII 2 | # Team: CBICR 3 | # SPDX-License-Identifier: Apache-2.0 4 | # See: https://spdx.org/licenses/ 5 | 6 | import torch 7 | from hnn.snn.q_model import QModel 8 | from hnn.snn.integrate_and_fire import IF, QIF 9 | from hnn.snn.leaky import Leaky, QLeaky 10 | 11 | 12 | class LIF(torch.nn.Module): 13 | '''Leaky-Integrate-and-Fire神经元 14 | 15 | Integrate阶段通过其他算子完成 16 | 由IF神经元和Leaky操作组成 17 | 18 | Args: 19 | if_node.reset.value = v_reset 20 | if_node.accumulate.v_init = v_init 21 | if_node.fire.v_th = v_th 22 | if_node.fire.surrogate_function: 默认为Rectangle 23 | v_leaky.alpha = v_leaky_alpha 24 | v_leaky.beta = v_leaky_beta 25 | v_leaky.adpt_en = v_leaky_adpt_en 26 | window_size: Rectangle的矩形窗宽度, default = 1 27 | ''' 28 | def __init__(self, v_th, v_leaky_alpha, v_leaky_beta, v_reset=0, v_leaky_adpt_en=False, v_init=None, window_size=1): 29 | super(LIF, self).__init__() 30 | self.if_node = IF(v_th=v_th, v_reset=v_reset, 31 | v_init=v_init, window_size=window_size) 32 | self.v_leaky = Leaky(alpha=v_leaky_alpha, 33 | beta=v_leaky_beta, adpt_en=v_leaky_adpt_en) 34 | 35 | def forward(self, u_in: torch.Tensor, v=None): 36 | spike, v = self.if_node(u_in, v) 37 | v = self.v_leaky(v) 38 | return spike, v 39 | 40 | 41 | class QLIF(QModel): 42 | '''支持量化的LIF神经元 43 | ''' 44 | def __init__(self, v_th, v_leaky_alpha, v_leaky_beta, v_reset=0, v_leaky_adpt_en=False, v_init=None, window_size=1): 45 | QModel.__init__(self) 46 | self.if_node = QIF(v_th=v_th, v_reset=v_reset, 47 | v_init=v_init, window_size=window_size) 48 | self.v_leaky = QLeaky(alpha=v_leaky_alpha, 49 | beta=v_leaky_beta, adpt_en=v_leaky_adpt_en) 50 | 51 | def forward(self, u_in: torch.Tensor, scale: torch.Tensor, v=None): 52 | spike, v = self.if_node.forward(u_in, scale, v) 53 | v = self.v_leaky.forward(v, scale) 54 | return spike, v -------------------------------------------------------------------------------- /examples/ann/alexnet.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) OpenBII 2 | # Team: CBICR 3 | # SPDX-License-Identifier: Apache-2.0 4 | # See: https://spdx.org/licenses/ 5 | 6 | import torch 7 | import torch.nn as nn 8 | from hnn.ann.q_conv2d import QConv2d 9 | from hnn.ann.q_linear import QLinear 10 | from hnn.ann.q_model import QModel 11 | 12 | 13 | class QAlexNet(QModel): 14 | def __init__(self, num_classes=1000): 15 | super(QAlexNet, self).__init__() 16 | self.conv0 = QConv2d(3, 64, kernel_size=11, stride=4, padding=2) 17 | self.relu = nn.ReLU(inplace=True) 18 | self.maxpool0 = nn.MaxPool2d(3, stride=2) 19 | 20 | self.conv1 = QConv2d(64, 192, kernel_size=5, padding=2) 21 | self.maxpool1 = nn.MaxPool2d(3, stride=2) 22 | self.conv2 = QConv2d(192, 384, kernel_size=3, stride=1, padding=1) 23 | self.conv3 = QConv2d(384, 256, kernel_size=3, stride=1, padding=1) 24 | self.conv4 = QConv2d(256, 256, kernel_size=3, stride=1, padding=1) 25 | self.maxpool2 = nn.MaxPool2d(3, stride=2) 26 | 27 | self.fc0 = QLinear(6 * 6 * 256, 4096) 28 | self.fc1 = QLinear(4096, 4096) 29 | self.fc2 = QLinear(4096, num_classes) 30 | 31 | self.model_name = 'QAlexNet' 32 | self.input_shape = (1, 3, 224, 224) 33 | 34 | def forward(self, x) -> torch.Tensor: 35 | x = self.conv0(x) 36 | x = self.relu(x) 37 | x = self.maxpool0(x) 38 | x = self.conv1(x) 39 | x = self.relu(x) 40 | x = self.maxpool1(x) 41 | x = self.conv2(x) 42 | x = self.relu(x) 43 | x = self.conv3(x) 44 | x = self.relu(x) 45 | x = self.conv4(x) 46 | x = self.relu(x) 47 | x = self.maxpool2(x) 48 | 49 | x = torch.flatten(x, 1) 50 | x = self.fc0(x) 51 | x = self.relu(x) 52 | x = self.fc1(x) 53 | x = self.relu(x) 54 | x = self.fc2(x) 55 | return x 56 | 57 | 58 | if __name__ == '__main__': 59 | model = QAlexNet() 60 | model.execute(is_random_input=True, fix_random_seed=True, 61 | result_path='temp/QAlexNet/o_0_0_0.dat', export_onnx_path='temp/QAlexNet/QAlexNet.onnx') 62 | -------------------------------------------------------------------------------- /hnn/ann/q_add.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) OpenBII 2 | # Team: CBICR 3 | # SPDX-License-Identifier: Apache-2.0 4 | # See: https://spdx.org/licenses/ 5 | 6 | import torch 7 | from hnn.ann.q_module import QModule 8 | from hnn.grad import FakeQuantizeFloor 9 | 10 | 11 | class QAdd(QModule, torch.nn.Module): 12 | '''支持量化的张量加算子 13 | 14 | Args: 15 | bit_shift: 完成定点数计算后需要的量化参数 16 | is_last_node: 是否是最后一个算子 17 | ''' 18 | def __init__(self, is_last_node=False): 19 | torch.nn.Module.__init__(self) 20 | QModule.__init__(self) 21 | self.bit_shift = None 22 | self.is_last_node = is_last_node 23 | 24 | def collect_q_params(self): 25 | '''计算张量加的量化参数 26 | 27 | 如果采用先限制激活再量化的方法, 则量化参数为固定值, 在完成定点数张量加之后不需要进行特殊处理 28 | ''' 29 | QModule.collect_q_params(self) 30 | self.bit_shift = 0 31 | 32 | def forward(self, x: torch.Tensor, y: torch.Tensor): 33 | if self.restricted: 34 | x = x.clamp(-QModule.activation_absmax, QModule.activation_absmax) 35 | y = y.clamp(-QModule.activation_absmax, QModule.activation_absmax) 36 | if self.aware_mode: 37 | assert not( 38 | self.quantization_mode), 'Quantization mode and QAT mode are mutual exclusive' 39 | x = FakeQuantizeFloor.apply(x, 128 / QModule.activation_absmax) 40 | y = FakeQuantizeFloor.apply(y, 128 / QModule.activation_absmax) 41 | out = x + y 42 | if self.quantization_mode and not(self.is_last_node): 43 | assert not( 44 | self.aware_mode), 'Quantization mode and QAT mode are mutual exclusive' 45 | out = out.clamp(-2147483648, 46 | 2147483647).div(2 ** self.bit_shift).floor().clamp(-128, 127) 47 | if self.is_last_node: 48 | out = out.clamp(-2147483648, 2147483647) 49 | return out 50 | 51 | def quantize(self): 52 | QModule.quantize(self) 53 | 54 | def dequantize(self): 55 | QModule.dequantize(self) 56 | 57 | def aware(self): 58 | if self.quantization_mode: 59 | self.dequantize() 60 | QModule.aware(self) 61 | -------------------------------------------------------------------------------- /hnn/snn/fire_with_constant_threshold.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) OpenBII 2 | # Team: CBICR 3 | # SPDX-License-Identifier: Apache-2.0 4 | # See: https://spdx.org/licenses/ 5 | 6 | import torch 7 | from hnn.snn.q_module import QModule 8 | 9 | 10 | class FireWithConstantThreshold(torch.nn.Module): 11 | '''固定阈值的脉冲发放 12 | 13 | 一层神经元共享相同的阈值 14 | 15 | Args: 16 | surrogate_function: 梯度替代函数, 可使用的函数见hnn/snn/surrogate 17 | v_th: 固定的阈值常量 18 | ''' 19 | def __init__(self, surrogate_function, v_th) -> None: 20 | super(FireWithConstantThreshold, self).__init__() 21 | self.surrogate_function = surrogate_function 22 | self.v_th = v_th 23 | 24 | def forward(self, v) -> torch.Tensor: 25 | spike = self.surrogate_function.apply(v, self.v_th) 26 | return spike 27 | 28 | 29 | class QFireWithConstantThreshold(QModule, FireWithConstantThreshold): 30 | '''支持量化的固定阈值的脉冲发放 31 | 32 | 其他说明类似于hnn/snn/accumulate.py 33 | ''' 34 | def __init__(self, surrogate_function, v_th) -> None: 35 | QModule.__init__(self) 36 | FireWithConstantThreshold.__init__(self, surrogate_function, v_th) 37 | self.weight_scale = None 38 | self.first_time = True 39 | self.pretrained = False 40 | self.freeze = False 41 | 42 | def forward(self, v: torch.Tensor, weight_scale: torch.Tensor): 43 | self.weight_scale = weight_scale 44 | if self.quantization_mode: 45 | self._quantize() 46 | spike = FireWithConstantThreshold.forward(self, v) 47 | return spike 48 | 49 | def _quantize(self): 50 | if self.first_time: 51 | self.first_time = False 52 | if not self.pretrained and not self.freeze: 53 | self.v_th = round(self.v_th * self.weight_scale.item()) 54 | 55 | def dequantize(self): 56 | QModule.dequantize(self) 57 | self.v_th = self.v_th / self.weight_scale.item() 58 | 59 | def aware(self): 60 | if self.quantization_mode: 61 | self.dequantize() 62 | QModule.aware(self) 63 | self.v_th = round(self.v_th * self.weight_scale.item() 64 | ) / self.weight_scale.item() -------------------------------------------------------------------------------- /hnn/snn/lif_neuron.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) OpenBII 2 | # Team: CBICR 3 | # SPDX-License-Identifier: Apache-2.0 4 | # See: https://spdx.org/licenses/ 5 | 6 | import torch 7 | from hnn.snn.neuron import Neuron 8 | from hnn.snn.lif_recorder import LIFRecorder 9 | from hnn.snn.lif import LIF, QLIF 10 | 11 | 12 | class LIFNeuron(Neuron): 13 | '''含有Recorder的LIF神经元 14 | 15 | 包含recorder和LIF神经元两部分, 需要重载record方法和_forward方法 16 | record方法用来记录神经元参数, _forward方法直接调用神经元的前向推理过程 17 | 基类Neuron中的forward方法会自动调用record方法和_forward方法 18 | ''' 19 | def __init__(self, T, v_th, v_leaky_alpha, v_leaky_beta, v_reset=0, v_leaky_adpt_en=False, v_init=None, window_size=1): 20 | Neuron.__init__(self, LIFRecorder, T) 21 | self.neuron = LIF(v_th=v_th, v_leaky_alpha=v_leaky_alpha, v_leaky_beta=v_leaky_beta, v_reset=v_reset, 22 | v_leaky_adpt_en=v_leaky_adpt_en, v_init=v_init, window_size=window_size) 23 | 24 | def record(self, x: torch.Tensor): 25 | return self.recorder( 26 | x, 27 | self.neuron.if_node.fire.v_th, 28 | self.neuron.v_leaky.alpha, 29 | self.neuron.v_leaky.beta, 30 | self.neuron.if_node.reset.value, 31 | self.neuron.v_leaky.adpt_en, 32 | self.neuron.if_node.accumulate.v_init, 33 | self.T 34 | ) 35 | 36 | 37 | class QLIFNeuron(Neuron): 38 | '''支持量化的含有Recorder的LIF神经元 39 | ''' 40 | def __init__(self, T, v_th, v_leaky_alpha, v_leaky_beta, v_reset=0, v_leaky_adpt_en=False, v_init=None, window_size=1): 41 | Neuron.__init__(self, LIFRecorder, T) 42 | self.neuron = QLIF(v_th=v_th, v_leaky_alpha=v_leaky_alpha, v_leaky_beta=v_leaky_beta, v_reset=v_reset, 43 | v_leaky_adpt_en=v_leaky_adpt_en, v_init=v_init, window_size=window_size) 44 | 45 | def record(self, x: torch.Tensor): 46 | return self.recorder( 47 | x, 48 | self.neuron.if_node.fire.v_th, 49 | self.neuron.v_leaky.alpha, 50 | self.neuron.v_leaky.beta, 51 | self.neuron.if_node.reset.value, 52 | self.neuron.v_leaky.adpt_en, 53 | self.neuron.if_node.accumulate.v_init, 54 | self.T 55 | ) -------------------------------------------------------------------------------- /hnn/snn/saturate.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) OpenBII 2 | # Team: CBICR 3 | # SPDX-License-Identifier: Apache-2.0 4 | # See: https://spdx.org/licenses/ 5 | 6 | import torch 7 | from hnn.grad import FakeQuantizeINT28 8 | from hnn.snn.q_module import QModule 9 | 10 | 11 | class Saturate(torch.nn.Module): 12 | '''下限饱和 13 | 14 | 当输入低于阈值时, 输入取阈值 15 | 16 | Args: 17 | v_l: 下限饱和阈值 18 | ''' 19 | def __init__(self, v_l): 20 | super(Saturate, self).__init__() 21 | self.v_l = v_l 22 | 23 | def forward(self, x: torch.Tensor) -> torch.Tensor: 24 | out = x.clamp(min=self.v_l) 25 | return out 26 | 27 | 28 | class QSaturate(QModule, Saturate): 29 | '''支持量化的下限饱和 30 | 31 | 其他说明类似于hnn/snn/accumulate.py 32 | ''' 33 | def __init__(self, v_l): 34 | QModule.__init__(self) 35 | Saturate.__init__(self, v_l) 36 | self.scale = None 37 | self.first_time = True 38 | self.pretrained = False 39 | self.freeze = False 40 | 41 | def forward(self, x: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: 42 | self.scale = scale 43 | if self.quantization_mode: 44 | self._quantize() 45 | if self.aware_mode: 46 | assert not( 47 | self.quantization_mode), 'Quantization mode and QAT mode are mutual exclusive' 48 | x = FakeQuantizeINT28.apply(x, scale) 49 | # forward 50 | x = Saturate.forward(self, x) 51 | if self.quantization_mode: 52 | assert not( 53 | self.aware_mode), 'Quantization mode and QAT mode are mutual exclusive' 54 | x = x.clamp(-134217728, 134217727) # INT28 55 | return x 56 | 57 | def _quantize(self): 58 | if self.first_time: 59 | self.first_time = False 60 | if not self.pretrained and not self.freeze: 61 | self.v_l = round(self.v_l * self.scale.item()) 62 | 63 | def dequantize(self): 64 | QModule.dequantize(self) 65 | self.v_l = self.v_l / self.scale.item() 66 | 67 | def aware(self): 68 | if self.quantization_mode: 69 | self.dequantize() 70 | QModule.aware(self) 71 | self.v_l = round(self.v_l * self.scale.item()) / self.scale.item() -------------------------------------------------------------------------------- /hnn/snn/hard_update_after_spike.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) OpenBII 2 | # Team: CBICR 3 | # SPDX-License-Identifier: Apache-2.0 4 | # See: https://spdx.org/licenses/ 5 | 6 | import torch 7 | from hnn.snn.q_module import QModule 8 | 9 | 10 | class HardUpdateAfterSpike(torch.nn.Module): 11 | '''发放脉冲后对输入进行hard reset 12 | 13 | Attributes: 14 | value: 重置后的值 15 | ''' 16 | def __init__(self, value: float) -> None: 17 | super(HardUpdateAfterSpike, self).__init__() 18 | self.value = value 19 | 20 | def forward(self, x: torch.Tensor, spike: torch.Tensor) -> torch.Tensor: 21 | out = spike * self.value + (1 - spike) * x # 避免inplace操作同时保证可导 22 | return out 23 | 24 | 25 | class QHardUpdateAfterSpike(QModule, HardUpdateAfterSpike): 26 | '''支持量化的发放脉冲后对输入进行hard reset操作 27 | 28 | 其他说明类似于hnn/snn/q_accumulate.py 29 | ''' 30 | def __init__(self, value) -> None: 31 | QModule.__init__(self) 32 | HardUpdateAfterSpike.__init__(self, value) 33 | self.weight_scale = None 34 | self.first_time = True 35 | self.pretrained = False 36 | self.freeze = False 37 | 38 | def forward(self, x: torch.Tensor, spike: torch.Tensor, weight_scale: torch.Tensor): 39 | self.weight_scale = weight_scale 40 | if self.quantization_mode: 41 | self._quantize() 42 | x = HardUpdateAfterSpike.forward(self, x, spike) 43 | if self.quantization_mode: 44 | assert not( 45 | self.aware_mode), 'Quantization mode and QAT mode are mutual exclusive' 46 | x = x.clamp(-134217728, 134217727) # INT28 47 | return x 48 | 49 | def _quantize(self): 50 | if self.first_time: 51 | self.first_time = False 52 | if not self.pretrained and not self.freeze: 53 | self.value = round(self.value * self.weight_scale.item()) 54 | 55 | def dequantize(self): 56 | QModule.dequantize(self) 57 | self.value = self.value / self.weight_scale.item() 58 | 59 | def aware(self): 60 | if self.quantization_mode: 61 | self.dequantize() 62 | QModule.aware(self) 63 | self.value = round(self.value * self.weight_scale.item() 64 | ) / self.weight_scale.item() 65 | -------------------------------------------------------------------------------- /hnn/snn/accumulate.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) OpenBII 2 | # Team: CBICR 3 | # SPDX-License-Identifier: Apache-2.0 4 | # See: https://spdx.org/licenses/ 5 | 6 | import torch 7 | from hnn.snn.q_dynamics import QDynamics 8 | from hnn.grad import FakeQuantizeINT28 9 | 10 | 11 | class Accumulate(torch.nn.Module): 12 | '''膜电位累加 13 | 14 | Args: 15 | v_init: 如果输入膜电位为None, 则输入膜电位默认为固定初始值 16 | ''' 17 | def __init__(self, v_init) -> None: 18 | super(Accumulate, self).__init__() 19 | self.v_init = v_init 20 | 21 | def forward(self, u_in, v=None) -> torch.Tensor: 22 | if v is None: 23 | v = torch.full_like(u_in, self.v_init) 24 | return u_in + v 25 | 26 | 27 | class QAccumulate(QDynamics, Accumulate): 28 | '''支持量化的膜电位累加操作 29 | ''' 30 | def __init__(self, v_init) -> None: 31 | QDynamics.__init__(self) 32 | Accumulate.__init__(self, v_init) 33 | 34 | def forward(self, x, scale: torch.Tensor, v=None): 35 | self.scale = scale 36 | if self.quantization_mode: 37 | v = self._quantize(v) 38 | if self.aware_mode: 39 | v = self._aware(v) 40 | v = Accumulate.forward(self, x, v) 41 | if self.quantization_mode: 42 | v = v.clamp(-134217728, 134217727) # INT28 43 | return v 44 | 45 | def _quantize(self, v: torch.Tensor) -> torch.Tensor: 46 | '''运行时量化方法 47 | 48 | 由于SNN脉冲神经元的量化参数需要Integrate阶段的算子给出, 所以此方法在运行时被调用 49 | 只有初次被调用会对输入膜电位进行量化 50 | 两次推理过程中调用refresh方法可以冻结神经元参数, 只有没有加载预训练模型且不处于冻结状态时会对脉冲神经元参数进行量化 51 | ''' 52 | if self.first_time: 53 | self.first_time = False 54 | if not self.freeze: 55 | self.v_init = round(self.v_init * self.scale.item()) 56 | if v is not None: 57 | v = v.mul(self.scale).round( 58 | ).clamp(-134217728, 134217727) 59 | return v 60 | 61 | def dequantize(self): 62 | QDynamics.dequantize(self) 63 | self.v_init = self.v_init / self.scale.item() 64 | 65 | def _aware(self, v: torch.Tensor): 66 | if self.quantization_mode: 67 | self.dequantize() 68 | self.v_init = round( 69 | self.v_init * self.scale.item()) / self.scale.item() 70 | if v is not None: 71 | v = FakeQuantizeINT28.apply(v, self.scale) 72 | return v -------------------------------------------------------------------------------- /hnn/snn/integrate_and_fire.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) OpenBII 2 | # Team: CBICR 3 | # SPDX-License-Identifier: Apache-2.0 4 | # See: https://spdx.org/licenses/ 5 | 6 | import torch 7 | from hnn.snn.q_model import QModel 8 | from hnn.snn.hard_update_after_spike import HardUpdateAfterSpike, QHardUpdateAfterSpike 9 | from hnn.snn.surrogate.rectangle import Rectangle 10 | from hnn.snn.accumulate import Accumulate, QAccumulate 11 | from hnn.snn.fire_with_constant_threshold import FireWithConstantThreshold, QFireWithConstantThreshold 12 | 13 | 14 | class IF(torch.nn.Module): 15 | '''Integrate-and-Fire神经元 16 | 17 | Integrate阶段通过其他算子完成 18 | 包括膜电位累加、脉冲发放和膜电位复位三个阶段 19 | 20 | Args: 21 | reset.value = v_reset 22 | accumulate.v_init = v_init 23 | fire.v_th = v_th 24 | fire.surrogate_function: 默认为Rectangle 25 | window_size: Rectangle的矩形窗宽度, default = 1 26 | ''' 27 | def __init__(self, v_th, v_reset, v_init=None, window_size=1): 28 | super(IF, self).__init__() 29 | self.reset = HardUpdateAfterSpike(value=v_reset) 30 | self.accumulate = Accumulate( 31 | v_init=self.reset.value if v_init is None else v_init) 32 | Rectangle.window_size = window_size 33 | self.fire = FireWithConstantThreshold( 34 | surrogate_function=Rectangle, v_th=v_th) 35 | 36 | def forward(self, u_in: torch.Tensor, v: torch.Tensor = None): 37 | # update 38 | v_update = self.accumulate(u_in, v) 39 | # fire 40 | spike = self.fire(v_update) 41 | v = self.reset(v_update, spike) 42 | return spike, v 43 | 44 | 45 | class QIF(QModel): 46 | '''支持量化的IF神经元 47 | ''' 48 | def __init__(self, v_th, v_reset, v_init=None, window_size=1): 49 | QModel.__init__(self) 50 | self.reset = QHardUpdateAfterSpike(value=v_reset) 51 | self.accumulate = QAccumulate( 52 | v_init=self.reset.value if v_init is None else v_init) 53 | Rectangle.window_size = window_size 54 | self.fire = QFireWithConstantThreshold( 55 | surrogate_function=Rectangle, v_th=v_th) 56 | 57 | def forward(self, u_in: torch.Tensor, scale: torch.Tensor, v: torch.Tensor = None): 58 | # update 59 | v_update = self.accumulate.forward(u_in, scale, v) 60 | # fire 61 | spike = self.fire.forward(v_update, scale) 62 | v = self.reset.forward(v_update, spike, scale) 63 | return spike, v -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Hybrid Programming Framework 2 | 3 | ## 简介 4 | 5 | HNN编程框架基于PyTorch进行开发,提供了编写ANN、SNN、HNN模型的编程接口,同时可以支持通过此编程框架描述的ANN、SNN模型的自动化量化(HNN的自动化量化仍在开发中),可以支持后训练静态量化和量化感知训练。下面对SNN和HNN编程进行简要说明: 6 | - SNN编程由一系列基本SNN操作组成,通过这些基本操作可以组成灵活的、功能丰富的扩展LIF神经元模型,用户也可以基于这些基本操作实现自定义的神经元模型。 7 | - HNN编程中的HNN主要指[1]中以Hybrid Unit (HU)为转换单元来连接ANN和SNN网络的混合网络,编程框架中实现了可扩展的HU,用户可使用编程框架中提供的各种HU或自定义HU。 8 | 9 | 此框架的开发过程中考虑了与BiMap的融合,通过此编程框架描述的网络可以进一步被BiMap中的编译系统编译部署到支持的类脑计算芯片上。 10 | 11 | HNN编程框架的详细开发及使用文档请见工程文档。 12 | 13 | ## 基本使用 14 | 15 | 通过pip安装: 16 | ```bash 17 | pip install hnn 18 | ``` 19 | 20 | 注:目前因为ONNX版本兼容问题,Pytorch需要使用1.11.0版本 21 | 22 | `examples`文件夹下为通过此编程框架写出的一些ANN、SNN、HNN模型,以需要量化的SNN为例,SNN模型需要继承`src.snn`中的`QModel`类,并通过`QConv2d`, `QLinear`, `QLIF`等算子来搭建网络: 23 | ```python 24 | from src.snn import QModel, QLinear, QLIF 25 | 26 | 27 | class SNN(QModel): 28 | def __init__(self, in_channels, T, num_classes=10): 29 | super(SNN, self).__init__(time_window_size=T) 30 | self.linear = QLinear( 31 | in_features=in_channels, out_features=num_classes) 32 | self.lif = QLIF(v_th=1, v_leaky_alpha=0.9, 33 | v_leaky_beta=0, v_reset=0) 34 | ``` 35 | 36 | 37 | ## 参考引用 38 | 39 | 如果使用到本编程框架的HNN部分,请引用[1]: 40 | 41 | @article{Zhao2022, 42 | doi = {10.1038/s41467-022-30964-7}, 43 | url = {https://doi.org/10.1038/s41467-022-30964-7}, 44 | year = {2022}, 45 | month = jun, 46 | publisher = {Springer Science and Business Media {LLC}}, 47 | volume = {13}, 48 | number = {1}, 49 | author = {Rong Zhao and Zheyu Yang and Hao Zheng and Yujie Wu and Faqiang Liu and Zhenzhi Wu and Lukai Li and Feng Chen and Seng Song and Jun Zhu and Wenli Zhang and Haoyu Huang and Mingkun Xu and Kaifeng Sheng and Qianbo Yin and Jing Pei and Guoqi Li and Youhui Zhang and Mingguo Zhao and Luping Shi}, 50 | title = {A framework for the general design and computation of hybrid neural networks}, 51 | journal = {Nature Communications} 52 | } 53 | 54 | 本工程的SNN和HNN编程部分参考或复用了部分[SpikingJelly](https://github.com/fangwei123456/spikingjelly)的代码: 55 | 56 | @misc{SpikingJelly, 57 | title = {SpikingJelly}, 58 | author = {Fang, Wei and Chen, Yanqi and Ding, Jianhao and Chen, Ding and Yu, Zhaofei and Zhou, Huihui and Timothée Masquelier and Tian, Yonghong and other contributors}, 59 | year = {2020}, 60 | howpublished = {\url{https://github.com/fangwei123456/spikingjelly}}, 61 | note = {Accessed: YYYY-MM-DD}, 62 | } 63 | -------------------------------------------------------------------------------- /hnn/snn/q_integrate.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) OpenBII 2 | # Team: CBICR 3 | # SPDX-License-Identifier: Apache-2.0 4 | # See: https://spdx.org/licenses/ 5 | 6 | import logging 7 | from abc import ABC, abstractmethod 8 | from hnn.snn.q_module import QModule 9 | 10 | 11 | class QIntegrate(QModule, ABC): 12 | '''类似于torch.nn.Module, QModel和其他支持量化的算子都继承于QModule类 13 | 14 | Attributes: 15 | quantization_mode: 表示QModule处于量化模式 16 | aware_mode: 表示QModule处于量化感知模式 17 | q_params_ready: 表示QModule中的量化参数已经统计完毕 18 | pretrained: 表示QModule已经加载过预训练模型 19 | ''' 20 | def __init__(self, is_encoder: bool): 21 | super().__init__() 22 | self.q_params_ready = False 23 | self.is_encoder = is_encoder 24 | self.collecting = False 25 | self.weight_scale = None 26 | self.bias_scale = None 27 | if self.is_encoder: 28 | self.input_scale = None 29 | 30 | @abstractmethod 31 | def collect_q_params(self): 32 | '''统计量化参数 33 | 34 | 继承QIntegrate的子类中的此方法需要先调用父类的此方法将模型置于正确的状态 35 | 如果算子作为encoder使用, 会将算子置于统计量化参数的状态 36 | ''' 37 | self.quantization_mode = False 38 | self.aware_mode = False 39 | if self.is_encoder: 40 | self.collecting = True 41 | if not(self.pretrained): 42 | logging.warning( 43 | 'Collecting quantization parameters usually requires a pretrained model') 44 | 45 | @abstractmethod 46 | def calculate_q_params(self): 47 | '''抽象方法, 用于计算量化参数 48 | 49 | 继承QIntegrate的子类中的此方法需要先调用父类的此方法将模型置于正确的状态 50 | ''' 51 | self.quantization_mode = False 52 | self.aware_mode = False 53 | self.q_params_ready = True 54 | self.collecting = False 55 | 56 | @abstractmethod 57 | def quantize(self): 58 | '''抽象方法, 用于对模型进行量化 59 | 60 | 继承QIntegrate的子类中的此方法需要先调用父类的此方法将模型置于正确的状态 61 | 62 | Raises: 63 | AssertionError: 如果模型已经处于量化状态则调用此方法会报错 64 | ''' 65 | QModule.quantize(self) 66 | assert self.q_params_ready, 'Quantization cannot be executed unless quantization parameters have been collected' 67 | 68 | @abstractmethod 69 | def aware(self): 70 | '''抽象方法, 用于对模型进行量化感知训练 71 | 72 | 继承QModule的子类中的此方法需要先调用父类的此方法将模型置于正确的状态 73 | 74 | Raises: 75 | AssertionError: 如果模型没有计算得到量化参数则调用此方法会报错 76 | ''' 77 | QModule.aware(self) 78 | assert self.q_params_ready, 'QAT cannot be executed unless quantization parameters have been collected' 79 | -------------------------------------------------------------------------------- /hnn/onnx_export_pass.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) OpenBII 2 | # Team: CBICR 3 | # SPDX-License-Identifier: Apache-2.0 4 | # See: https://spdx.org/licenses/ 5 | 6 | import os 7 | 8 | import onnx 9 | import onnxsim 10 | import torch 11 | from onnx.shape_inference import infer_shapes 12 | 13 | from hnn.network_type import NetworkType 14 | 15 | 16 | def onnx_export( 17 | model: torch.nn.Module, input, output_path, 18 | model_path=None, 19 | reserve_control_flow=False, 20 | network_type: NetworkType = NetworkType.ANN 21 | ): 22 | if network_type == NetworkType.ANN: 23 | if model_path is not None: 24 | if hasattr(model, 'load_quantized_model'): 25 | model.load_quantized_model( 26 | checkpoint_path=model_path, 27 | device=torch.device('cpu') 28 | ) 29 | else: 30 | state_dict = torch.load(model_path) 31 | model.load_state_dict(state_dict) 32 | if reserve_control_flow: 33 | model = torch.jit.script(model) 34 | os.makedirs(os.path.dirname(output_path), exist_ok=True) 35 | torch.onnx.export(model=model, args=input, f=output_path, 36 | keep_initializers_as_inputs=True, 37 | do_constant_folding=True) 38 | onnx_model = onnx.load(output_path) 39 | onnx_model, _ = onnxsim.simplify(onnx_model) 40 | onnx.save(onnx_model, output_path) 41 | elif network_type == NetworkType.SNN: 42 | if model_path is not None: 43 | if hasattr(model, 'load_quantized_model'): 44 | model.load_quantized_model( 45 | checkpoint_path=model_path, 46 | device=torch.device('cpu') 47 | ) 48 | else: 49 | state_dict = torch.load(model_path) 50 | model.load_state_dict(state_dict) 51 | if reserve_control_flow: 52 | model = torch.jit.script(model) 53 | os.makedirs(os.path.dirname(output_path), exist_ok=True) 54 | torch.onnx.export(model=model, args=input, f=output_path, 55 | keep_initializers_as_inputs=True, 56 | do_constant_folding=False, 57 | custom_opsets={'snn': 1}) 58 | # SNN中由于存在自定义算子无法调用onnx-simplifier, 需要在onnxruntime中实现自定义算子 59 | onnx_model = onnx.load(output_path) 60 | onnx.save(infer_shapes(onnx_model), output_path) 61 | else: 62 | raise NotImplementedError(network_type.name + 'has not been supported') 63 | 64 | return onnx.load(output_path) 65 | -------------------------------------------------------------------------------- /hnn/snn/threshold_accumulate.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) OpenBII 2 | # Team: CBICR 3 | # SPDX-License-Identifier: Apache-2.0 4 | # See: https://spdx.org/licenses/ 5 | 6 | import torch 7 | from hnn.grad import FakeQuantizeINT28 8 | from hnn.snn.q_module import QModule 9 | 10 | 11 | class ThresholdAccumulate(torch.nn.Module): 12 | '''膜电位阈值累加 13 | 14 | Args: 15 | v_th0: 固定阈值 16 | ''' 17 | def __init__(self, v_th0) -> None: 18 | super(ThresholdAccumulate, self).__init__() 19 | self.v_th0 = v_th0 20 | 21 | def forward(self, v_th_adpt) -> torch.Tensor: 22 | with torch.no_grad(): 23 | v_th_adpt = torch.as_tensor(v_th_adpt) 24 | v_th = self.v_th0 + v_th_adpt 25 | return v_th 26 | 27 | 28 | class QThresholdAccumulate(QModule, ThresholdAccumulate): 29 | '''支持量化的膜电位阈值累加 30 | 31 | 其他说明类似于hnn/snn/accumulate.py 32 | ''' 33 | def __init__(self, v_th0): 34 | QModule.__init__(self) 35 | ThresholdAccumulate.__init__(self, v_th0) 36 | self.scale = None 37 | self.first_time = True 38 | self.pretrained = False 39 | self.freeze = False 40 | 41 | def forward(self, v_th_adpt, scale: torch.Tensor) -> torch.Tensor: 42 | self.scale = scale 43 | if self.quantization_mode: 44 | self._quantize(v_th_adpt) 45 | if self.aware_mode: 46 | assert not( 47 | self.quantization_mode), 'Quantization mode and QAT mode are mutual exclusive' 48 | v_th_adpt = FakeQuantizeINT28.apply(v_th_adpt, scale) 49 | # forward 50 | v_th = ThresholdAccumulate.forward(self, v_th_adpt) 51 | if self.quantization_mode: 52 | assert not( 53 | self.aware_mode), 'Quantization mode and QAT mode are mutual exclusive' 54 | v_th = v_th.clamp(-134217728, 134217727) # INT28 55 | return v_th 56 | 57 | def _quantize(self, v) -> torch.Tensor: 58 | if self.first_time: 59 | self.first_time = False 60 | if not self.pretrained and not self.freeze: 61 | self.v_th0 = round(self.v_th0 * self.scale.item()) 62 | v = torch.as_tensor(v) 63 | v = v.mul(self.scale).round().clamp(-134217728, 134217727) 64 | return v 65 | 66 | def dequantize(self): 67 | QModule.dequantize(self) 68 | self.v_th0 = self.v_th0 / self.scale.item() 69 | 70 | def aware(self): 71 | if self.quantization_mode: 72 | self.dequantize() 73 | QModule.aware(self) 74 | self.v_th0 = round(self.v_th0 * self.scale.item()) / self.scale.item() -------------------------------------------------------------------------------- /hnn/snn/soft_update_after_spike.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) OpenBII 2 | # Team: CBICR 3 | # SPDX-License-Identifier: Apache-2.0 4 | # See: https://spdx.org/licenses/ 5 | 6 | import torch 7 | from hnn.snn.q_module import QModule 8 | 9 | 10 | class SoftUpdateAfterSpike(torch.nn.Module): 11 | '''发放脉冲后对输入进行soft reset 12 | 13 | soft reset包含两种模式, constant模式在输入上加上一个常量, 否则对输入加一个张量 14 | 15 | Args: 16 | value: constant模式下的常量, default = None 17 | ''' 18 | def __init__(self, value=None) -> None: 19 | super(SoftUpdateAfterSpike, self).__init__() 20 | self.value = value 21 | 22 | def forward(self, x: torch.Tensor, spike: torch.Tensor, update: torch.Tensor = None): 23 | if self.value is None: 24 | assert update is not None 25 | out = x + spike * update 26 | else: 27 | out = x + spike * self.value 28 | return out 29 | 30 | 31 | class QSoftUpdateAfterSpike(QModule, SoftUpdateAfterSpike): 32 | '''支持量化的发放脉冲后对输入进行soft reset操作 33 | 34 | 其他说明类似于hnn/snn/accumulate.py 35 | ''' 36 | def __init__(self, value=None) -> None: 37 | QModule.__init__(self) 38 | SoftUpdateAfterSpike.__init__(self, value) 39 | self.weight_scale = None 40 | self.first_time = True 41 | self.pretrained = False 42 | self.freeze = False 43 | 44 | def forward(self, x: torch.Tensor, spike: torch.Tensor, weight_scale: torch.Tensor, update: torch.Tensor = None): 45 | self.weight_scale = weight_scale 46 | if self.quantization_mode: 47 | self._quantize() 48 | x = SoftUpdateAfterSpike.forward(self, x, spike, update) 49 | if self.quantization_mode: 50 | assert not( 51 | self.aware_mode), 'Quantization mode and QAT mode are mutual exclusive' 52 | x = x.clamp(-134217728, 134217727) # INT28 53 | return x 54 | 55 | def _quantize(self): 56 | if self.first_time: 57 | self.first_time = False 58 | if not self.pretrained and not self.freeze: 59 | if self.value is not None: 60 | self.value = round(self.value * self.weight_scale.item()) 61 | 62 | def dequantize(self): 63 | QModule.dequantize(self) 64 | if self.value is not None: 65 | self.value = self.value / self.weight_scale.item() 66 | 67 | def aware(self): 68 | if self.quantization_mode: 69 | self.dequantize() 70 | QModule.aware(self) 71 | if self.value is not None: 72 | self.value = round( 73 | self.value * self.weight_scale.item()) / self.weight_scale.item() -------------------------------------------------------------------------------- /hnn/snn/reset_after_spike.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) OpenBII 2 | # Team: CBICR 3 | # SPDX-License-Identifier: Apache-2.0 4 | # See: https://spdx.org/licenses/ 5 | 6 | import torch 7 | from hnn.snn.q_model import QModel 8 | from hnn.snn.reset_mode import ResetMode 9 | from hnn.snn.hard_update_after_spike import HardUpdateAfterSpike, QHardUpdateAfterSpike 10 | from hnn.snn.soft_update_after_spike import SoftUpdateAfterSpike, QSoftUpdateAfterSpike 11 | 12 | 13 | class ResetAfterSpike(torch.nn.Module): 14 | '''脉冲发放后膜电位复位 15 | 16 | 根据reset_mode不同选择不同的发放模式 17 | 18 | Args: 19 | reset_mode: 发放模式 20 | reset.v_reset: 当发放模式为HARD时的复位值 21 | reset.dv: 当发放模式为SOFT_CONSTANT时膜电位减去dv 22 | ''' 23 | def __init__(self, reset_mode: ResetMode, v_reset=None, dv=None) -> None: 24 | super(ResetAfterSpike, self).__init__() 25 | self.reset_mode = reset_mode 26 | if self.reset_mode == ResetMode.HARD: 27 | assert v_reset is not None 28 | self.reset = HardUpdateAfterSpike(value=v_reset) 29 | elif self.reset_mode == ResetMode.SOFT_CONSTANT: 30 | assert dv is not None 31 | self.reset = SoftUpdateAfterSpike(value=-dv) 32 | else: 33 | assert self.reset_mode == ResetMode.SOFT, "Invalid reset mode" 34 | self.reset = SoftUpdateAfterSpike() 35 | 36 | def forward(self, v: torch.Tensor, spike: torch.Tensor, update: torch.Tensor = None): 37 | if self.reset_mode == ResetMode.SOFT: 38 | out = self.reset(v, spike, -update) 39 | else: 40 | out = self.reset(v, spike) 41 | return out 42 | 43 | 44 | class QResetAfterSpike(QModel): 45 | '''支持量化的脉冲发放后膜电位复位操作 46 | ''' 47 | def __init__(self, reset_mode: ResetMode, v_reset=None, dv=None) -> None: 48 | QModel.__init__(self) 49 | self.reset_mode = reset_mode 50 | if self.reset_mode == ResetMode.HARD: 51 | assert v_reset is not None 52 | self.reset = QHardUpdateAfterSpike(value=v_reset) 53 | elif self.reset_mode == ResetMode.SOFT_CONSTANT: 54 | assert dv is not None 55 | self.reset = QSoftUpdateAfterSpike(value=-dv) 56 | else: 57 | assert self.reset_mode == ResetMode.SOFT, "Invalid reset mode" 58 | self.reset = SoftUpdateAfterSpike() 59 | 60 | def forward(self, v: torch.Tensor, spike: torch.Tensor, scale: torch.Tensor, update: torch.Tensor = None): 61 | if self.reset_mode == ResetMode.SOFT: 62 | out = self.reset.forward(v, spike, -update) # 这里要求update必须已经被量化过 63 | else: 64 | out = self.reset.forward(v, spike, weight_scale=scale) 65 | return out -------------------------------------------------------------------------------- /docs/sphinx/source/SNN编程与量化框架.rst: -------------------------------------------------------------------------------- 1 | ======================================================================== 2 | SNN编程与量化框架 3 | ======================================================================== 4 | 5 | 本文档作者:曲环宇 6 | 7 | SNN框架包括基本的SNN编程框架和SNN量化框架,其中量化框架在不置于量化模式时的功能等同于编程框架。 8 | 9 | SNN编程 10 | ###################### 11 | 12 | 下图为SNN编程框架的组成,最底层为编程框架提供的一系列基本操作,这些基本操作可以组成更复杂的二级操作,其中Refractory虽然包含其他基本操作,但本身具有特别的功能,所以也可以视为基本操作。二级操作中包括如复位和阈值等基本的动力学,同时也已经可以组成出IF这样基本的神经元。二级操作可以进一步组成更高级的操作,例如LIF等神经元级操作,整个SNN编程框架的最大集是Extended LIF神经元。 13 | 14 | .. image:: _static/snn_framework.png 15 | :width: 100% 16 | :align: center 17 | 18 | 每个基本操作的功能和Extended LIF整体的介绍待补充。 19 | 20 | SNN量化 21 | ######################################## 22 | 23 | SNN量化原理 24 | ******************* 25 | 26 | 脉冲化输入层量化 27 | ------------------ 28 | 29 | .. image:: _static/spike_input_snn.png 30 | :width: 100% 31 | :align: center 32 | 33 | 左图展示了基本的SNN操作在浮点数形式下的计算,当SNN层的输入为脉冲化输入时,并不需要对输入进行量化,只需要对权重进行量化。当权重以线性方式进行映射时,需要对输入的初始膜电位、阈值、复位值和膜电位泄漏时的偏置值应用相同的线性映射,保证量化后的整个神经元计算流程和浮点数的计算是近似等价的。 34 | 35 | 编码层量化 36 | ------------------ 37 | 38 | .. image:: _static/encoder_snn.png 39 | :width: 100% 40 | :align: center 41 | 42 | 在现在的SNN网络中,常用第一个SNN层当作可学习的编码层将浮点数的输入编码成脉冲化的输入然后用于后续的SNN网络进行处理。此时integrate步骤处理的不是脉冲化输入,则需要对输入和权重同时进行量化,然后再根据输入和权重对应的量化参数对输入的初始膜电位、阈值、复位值和膜电位泄漏时的偏置值进行量化。 43 | 44 | 由于输入的量化参数需要在训练集上进行统计才能得到,整体上的量化步骤如下: 45 | 46 | 1. 训练浮点数模型 47 | 2. 统计量化参数 48 | 3. 进行量化 49 | 4. 对量化模型进行测试 50 | 51 | SNN量化框架实现 52 | *********************** 53 | 54 | 待补充 55 | 56 | SNN量化框架基本使用流程 57 | ************************************ 58 | 59 | 60 | 1. 编写类似于下面的SNN模型: 61 | 62 | .. code:: python 63 | 64 | import QModel, QLinear, QLIF 65 | 66 | class SNN(QModel): 67 | def __init__(self): 68 | super(SNN, self).__init__() 69 | self.linear = QLinear(28 * 28, 10, bias=False) 70 | self.lif = QLIF(v_th_0=1, v_leaky_alpha=0.5, v_leaky_beta=0, v_reset=0) 71 | 72 | def forward(self, x: torch.Tensor, v: torch.Tensor=None): 73 | x = x.view(x.size(0), -1) 74 | x, q_param = self.linear(x) 75 | out, v = self.lif(x, q_param, v) 76 | return out, v 77 | 78 | 2. 在当前状态下进行浮点数模型的训练 79 | 80 | 3. 将模型中的Conv-BN融合成一个卷积计算 81 | 82 | 4. 如果SNN中包含浮点数输入的编码层,则首先调用\ ``QModel``\ 中的\ ``collect_q_params()``\ 方法将模型置于待统计量化参数的状态,否则跳至7 83 | 84 | 5. 将SNN在给定数据集上进行推理,推理过程中会自动完成统计量化参数的工作 85 | 86 | 6. 调用\ ``QModel``\ 中的\ ``calculate_q_params()``\ 方法计算出用于量化的量化参数 87 | 88 | 7. 调用\ ``quantize()``\ 方法将\ ``QModel``\ 置于量化模式,\ ``quantize()``\ 方法会调用\ ``QModel``\ 的所有实例化\ ``QModule``\ 的属性的\ ``quantize()``\ 方法 89 | 90 | 8. 对\ ``QModel``\ 进行正常的前向推理,得到静态量化的测试精度;前向推理时不需要对输入进行量化 91 | 92 | 9. 调用\ ``save_quantized_model(checkpoint_path)``\ 方法保存量化模型(模型和量化参数会被保存到同一个文件) -------------------------------------------------------------------------------- /examples/hnn/s2ahnn_lenet.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) OpenBII 2 | # Team: CBICR 3 | # SPDX-License-Identifier: Apache-2.0 4 | # See: https://spdx.org/licenses/ 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | from hnn.ann.q_linear import QLinear 10 | from hnn.hu.s2a_global_rate_coding import S2AGlobalRateCoding 11 | from hnn.snn.lif import QLIF 12 | from hnn.snn.q_conv2d import QConv2d 13 | from hnn.snn.q_model import QModel 14 | 15 | 16 | class S2AHNNLeNet(QModel): 17 | def __init__(self, T): 18 | super(S2AHNNLeNet, self).__init__(time_window_size=T) 19 | self.conv1 = QConv2d(in_channels=1, out_channels=6, 20 | kernel_size=5, stride=1, padding=2, bias=False) 21 | self.lif1 = QLIF(v_th=1, v_leaky_alpha=0.9, 22 | v_leaky_beta=0, v_reset=0) 23 | self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2) 24 | # Nx6x14x14 25 | self.conv2 = QConv2d(in_channels=6, out_channels=16, 26 | kernel_size=5, stride=1, padding=0, bias=False) 27 | self.lif2 = QLIF(v_th=1, v_leaky_alpha=0.9, 28 | v_leaky_beta=0, v_reset=0) 29 | # Nx16x10x10 30 | self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2) 31 | # Nx16x5x5 32 | self.s2a = S2AGlobalRateCoding( 33 | window_size=T, non_linear=torch.nn.ReLU()) 34 | self.linear1 = QLinear(in_features=400, out_features=120) 35 | self.linear2 = QLinear(in_features=120, out_features=84) 36 | self.linear3 = QLinear(in_features=84, out_features=10) 37 | self.relu = nn.ReLU() 38 | 39 | def forward(self, inputs: torch.Tensor): 40 | spike = torch.zeros((self.T, 10, 16, 5, 5)) 41 | v1 = None 42 | v2 = None 43 | for i in range(self.T): 44 | x, q = self.conv1(inputs[i]) 45 | out, v1 = self.lif1(x, q, v1) 46 | out = self.maxpool1(out) 47 | 48 | x, q = self.conv2(out) 49 | out, v2 = self.lif2(x, q, v2) 50 | out = self.maxpool2(out) 51 | spike[i] = out 52 | 53 | spike = spike.permute(1, 2, 3, 4, 0) 54 | x = self.s2a(spike) 55 | x = x.view(x.size(0), -1) 56 | 57 | x = self.linear1(x) 58 | x = self.relu(x) 59 | x = self.linear2(x) 60 | x = self.relu(x) 61 | x = self.linear3(x) 62 | 63 | return x 64 | 65 | 66 | if __name__ == '__main__': 67 | x = torch.randn((10, 10, 1, 28, 28)) # INPUT.SHAPE:[T, N, C, H, W] 68 | model = S2AHNNLeNet(10) 69 | y = model(x) 70 | print(y.shape) 71 | 72 | torch.onnx.export(model, x, 'temp/s2a_hnn_lenet.onnx', 73 | custom_opsets={'snn': 1}, opset_version=11) 74 | -------------------------------------------------------------------------------- /hnn/snn/lif_with_tensor_threshold_and_reset_mode_and_refractory.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) OpenBII 2 | # Team: CBICR 3 | # SPDX-License-Identifier: Apache-2.0 4 | # See: https://spdx.org/licenses/ 5 | 6 | import torch 7 | from hnn.snn.q_model import QModel 8 | from hnn.snn.if_with_tensor_threshold_and_reset_mode_and_refractory import IFWithTensorThresholdAndResetModeAndRefractory, QIFWithTensorThresholdAndResetModeAndRefractory 9 | from hnn.snn.reset_mode import ResetMode 10 | from hnn.snn.leaky import Leaky, QLeaky 11 | 12 | 13 | class LIFWithTensorThresholdAndResetModeAndRefractory(torch.nn.Module): 14 | '''Leaky-Integrate-and-Fire神经元, 不同神经元可以有不同的阈值, 支持可配置的膜电位复位模式和不应期 15 | 16 | Integrate阶段通过其他算子完成 17 | 包括支持不应期的膜电位累加、支持不同神经元有不同阈值的脉冲发放、支持多种模式的膜电位复位和膜电位泄漏四个阶段 18 | 19 | Args: 20 | if_node.reset.reset_mode = reset_mode 21 | if_node.reset.v_reset = v_reset 22 | if_node.reset.dv = dv 23 | if_node.accumulate.v_init = v_init 24 | if_node.fire.surrogate_function: 默认为Rectangle 25 | window_size: Rectangle的矩形窗宽度, default = 1 26 | v_leaky.alpha = v_leaky_alpha 27 | v_leaky.beta = v_leaky_beta 28 | v_leaky.adpt_en = v_leaky_adpt_en 29 | ''' 30 | def __init__(self, v_leaky_alpha, v_leaky_beta, reset_mode: ResetMode, v_reset=None, dv=None, v_leaky_adpt_en=False, v_init=None, window_size=1): 31 | super(LIFWithTensorThresholdAndResetModeAndRefractory, self).__init__() 32 | self.if_node = IFWithTensorThresholdAndResetModeAndRefractory( 33 | reset_mode=reset_mode, v_reset=v_reset, dv=dv, v_init=v_init, window_size=window_size) 34 | self.v_leaky = Leaky(alpha=v_leaky_alpha, 35 | beta=v_leaky_beta, adpt_en=v_leaky_adpt_en) 36 | 37 | def forward(self, u_in: torch.Tensor, v_th, v=None, ref_cnt=None): 38 | spike, v = self.if_node(u_in, v_th, v, ref_cnt) 39 | v = self.v_leaky(v) 40 | return spike, v 41 | 42 | 43 | class QLIFWithTensorThresholdAndResetModeAndRefractory(QModel): 44 | def __init__(self, v_leaky_alpha, v_leaky_beta, reset_mode: ResetMode, v_reset=None, dv=None, v_leaky_adpt_en=False, v_init=None, window_size=1): 45 | QModel.__init__(self) 46 | self.if_node = QIFWithTensorThresholdAndResetModeAndRefractory( 47 | reset_mode=reset_mode, v_reset=v_reset, dv=dv, v_init=v_init, window_size=window_size) 48 | self.v_leaky = QLeaky(alpha=v_leaky_alpha, 49 | beta=v_leaky_beta, adpt_en=v_leaky_adpt_en) 50 | 51 | def forward(self, u_in: torch.Tensor, v_th, scale, v=None, ref_cnt=None): 52 | spike, v = self.if_node.forward(u_in, v_th, scale, v, ref_cnt) 53 | v = self.v_leaky.forward(v, scale) 54 | return spike, v -------------------------------------------------------------------------------- /unit_tests/test_fpmodel.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) OpenBII 2 | # Team: CBICR 3 | # SPDX-License-Identifier: Apache-2.0 4 | # See: https://spdx.org/licenses/ 5 | 6 | import unittest 7 | import torch 8 | import torch.nn as nn 9 | 10 | from hnn.ann.q_conv2d import QConv2d 11 | from hnn.ann.q_linear import QLinear 12 | from hnn.ann.q_model import QModel 13 | 14 | 15 | class LeNet(nn.Module): 16 | def __init__(self): 17 | super(LeNet, self).__init__() 18 | self.conv1 = nn.Conv2d(1, 6, 5, padding=2) 19 | self.maxpool1 = nn.MaxPool2d(2, 2) 20 | self.conv2 = nn.Conv2d(6, 16, 5) 21 | self.maxpool2 = nn.MaxPool2d(2, 2) 22 | self.linear1 = nn.Linear(400, 120) 23 | self.linear2 = nn.Linear(120, 84) 24 | self.linear3 = nn.Linear(84, 10) 25 | self.relu = nn.ReLU(inplace=True) 26 | 27 | def forward(self, x): 28 | x = self.conv1(x) 29 | x = self.relu(x) 30 | x = self.maxpool1(x) 31 | x = self.conv2(x) 32 | x = self.relu(x) 33 | x = self.maxpool2(x) 34 | x = x.view(x.size(0), -1) 35 | x = self.linear1(x) 36 | x = self.relu(x) 37 | x = self.linear2(x) 38 | x = self.relu(x) 39 | x = self.linear3(x) 40 | return x 41 | 42 | 43 | class QLeNet(QModel): 44 | def __init__(self): 45 | super(QLeNet, self).__init__() 46 | self.conv1 = QConv2d(1, 6, 5, padding=2) 47 | self.maxpool1 = nn.MaxPool2d(2, 2) 48 | self.conv2 = QConv2d(6, 16, 5) 49 | self.maxpool2 = nn.MaxPool2d(2, 2) 50 | self.linear1 = QLinear(400, 120) 51 | self.linear2 = QLinear(120, 84) 52 | self.linear3 = QLinear(84, 10) 53 | self.relu = nn.ReLU(inplace=True) 54 | 55 | def forward(self, x): 56 | x = self.conv1(x) 57 | x = self.relu(x) 58 | x = self.maxpool1(x) 59 | x = self.conv2(x) 60 | x = self.relu(x) 61 | x = self.maxpool2(x) 62 | x = x.view(x.size(0), -1) 63 | x = self.linear1(x) 64 | x = self.relu(x) 65 | x = self.linear2(x) 66 | x = self.relu(x) 67 | x = self.linear3(x) 68 | return x 69 | 70 | 71 | class TestFPModel(unittest.TestCase): 72 | def test_fpmodel(self): 73 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 74 | model = LeNet() 75 | qmodel = QLeNet() 76 | qmodel.load_state_dict(model.state_dict()) 77 | model.to(device) 78 | qmodel.to(device) 79 | x = torch.randn((2, 1, 28, 28)) 80 | x = x.to(device) 81 | y = model(x) 82 | qy = qmodel(x) 83 | self.assertTrue(y.equal(qy)) 84 | 85 | if __name__ == '__main__': 86 | t1 = TestFPModel() 87 | t1.test_fpmodel() 88 | -------------------------------------------------------------------------------- /hnn/snn/if_with_tensor_threshold_and_reset_mode_and_refractory.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) OpenBII 2 | # Team: CBICR 3 | # SPDX-License-Identifier: Apache-2.0 4 | # See: https://spdx.org/licenses/ 5 | 6 | import torch 7 | from hnn.snn.q_model import QModel 8 | from hnn.snn.reset_mode import ResetMode 9 | from hnn.snn.reset_after_spike import ResetAfterSpike, QResetAfterSpike 10 | from hnn.snn.surrogate.rectangle import Rectangle 11 | from hnn.snn.accumulate_with_refractory import AccumulateWithRefractory, QAccumulateWithRefractory 12 | from hnn.snn.fire import Fire 13 | 14 | 15 | class IFWithTensorThresholdAndResetModeAndRefractory(torch.nn.Module): 16 | '''Integrate-and-Fire神经元, 不同神经元可以有不同的阈值, 支持可配置的膜电位复位模式和不应期 17 | 18 | Integrate阶段通过其他算子完成 19 | 包括支持不应期的膜电位累加、支持不同神经元有不同阈值的脉冲发放和支持多种模式的膜电位复位三个阶段 20 | 21 | Args: 22 | reset.reset_mode = reset_mode 23 | reset.v_reset = v_reset 24 | reset.dv = dv 25 | accumulate.v_init = v_init 26 | fire.surrogate_function: 默认为Rectangle 27 | window_size: Rectangle的矩形窗宽度, default = 1 28 | ''' 29 | def __init__(self, reset_mode: ResetMode, v_reset=None, dv=None, v_init=None, window_size=1): 30 | super(IFWithTensorThresholdAndResetModeAndRefractory, self).__init__() 31 | self.reset = ResetAfterSpike( 32 | reset_mode=reset_mode, v_reset=v_reset, dv=dv) 33 | self.accumulate = AccumulateWithRefractory( 34 | v_init=v_reset if v_init is None else v_init) 35 | Rectangle.window_size = window_size 36 | self.fire = Fire(surrogate_function=Rectangle) 37 | 38 | def forward(self, u_in: torch.Tensor, v_th: torch.Tensor, v: torch.Tensor = None, ref_cnt: torch.Tensor = None): 39 | # update 40 | v_update = self.accumulate(u_in, v, ref_cnt) 41 | # fire 42 | spike = self.fire(v_update, v_th) 43 | v = self.reset(v_update, spike) 44 | return spike, v 45 | 46 | 47 | class QIFWithTensorThresholdAndResetModeAndRefractory(QModel): 48 | def __init__(self, reset_mode: ResetMode, v_reset=None, dv=None, v_init=None, window_size=1): 49 | QModel.__init__(self) 50 | self.reset = QResetAfterSpike( 51 | reset_mode=reset_mode, v_reset=v_reset, dv=dv) 52 | self.accumulate = QAccumulateWithRefractory( 53 | v_init=v_reset if v_init is None else v_init) 54 | Rectangle.window_size = window_size 55 | self.fire = Fire(surrogate_function=Rectangle) 56 | 57 | def forward(self, u_in: torch.Tensor, v_th: torch.Tensor, scale: torch.Tensor, v: torch.Tensor = None, ref_cnt: torch.Tensor = None): 58 | # update 59 | v_update = self.accumulate.forward(u_in, scale, v, ref_cnt) 60 | # fire 61 | spike = self.fire.forward(v_update, v_th) 62 | v = self.reset.forward(v_update, spike, scale, v_th) 63 | return spike, v -------------------------------------------------------------------------------- /unit_tests/test_restrict.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) OpenBII 2 | # Team: CBICR 3 | # SPDX-License-Identifier: Apache-2.0 4 | # See: https://spdx.org/licenses/ 5 | 6 | import unittest 7 | import torch 8 | import torch.nn as nn 9 | 10 | from hnn.ann.q_conv2d import QConv2d 11 | from hnn.ann.q_linear import QLinear 12 | from hnn.ann.q_model import QModel 13 | 14 | 15 | class LeNet(nn.Module): 16 | def __init__(self): 17 | super(LeNet, self).__init__() 18 | self.conv1 = nn.Conv2d(1, 6, 5, padding=2) 19 | self.maxpool1 = nn.MaxPool2d(2, 2) 20 | self.conv2 = nn.Conv2d(6, 16, 5) 21 | self.maxpool2 = nn.MaxPool2d(2, 2) 22 | self.linear1 = nn.Linear(400, 120) 23 | self.linear2 = nn.Linear(120, 84) 24 | self.linear3 = nn.Linear(84, 10) 25 | self.relu = nn.ReLU(inplace=True) 26 | 27 | def forward(self, x): 28 | x = self.conv1(x) 29 | x = x.clamp(-1, 1) 30 | x = self.relu(x) 31 | x = self.maxpool1(x) 32 | x = self.conv2(x) 33 | x = x.clamp(-1, 1) 34 | x = self.relu(x) 35 | x = self.maxpool2(x) 36 | x = x.view(x.size(0), -1) 37 | x = self.linear1(x) 38 | x = x.clamp(-1, 1) 39 | x = self.relu(x) 40 | x = self.linear2(x) 41 | x = x.clamp(-1, 1) 42 | x = self.relu(x) 43 | x = self.linear3(x) 44 | x = x.clamp(-1, 1) 45 | return x 46 | 47 | 48 | class QLeNet(QModel): 49 | def __init__(self): 50 | super(QLeNet, self).__init__() 51 | self.conv1 = QConv2d(1, 6, 5, padding=2) 52 | self.maxpool1 = nn.MaxPool2d(2, 2) 53 | self.conv2 = QConv2d(6, 16, 5) 54 | self.maxpool2 = nn.MaxPool2d(2, 2) 55 | self.linear1 = QLinear(400, 120) 56 | self.linear2 = QLinear(120, 84) 57 | self.linear3 = QLinear(84, 10) 58 | self.relu = nn.ReLU(inplace=True) 59 | 60 | def forward(self, x): 61 | x = self.conv1(x) 62 | x = self.relu(x) 63 | x = self.maxpool1(x) 64 | x = self.conv2(x) 65 | x = self.relu(x) 66 | x = self.maxpool2(x) 67 | x = x.view(x.size(0), -1) 68 | x = self.linear1(x) 69 | x = self.relu(x) 70 | x = self.linear2(x) 71 | x = self.relu(x) 72 | x = self.linear3(x) 73 | return x 74 | 75 | 76 | class TestRestrict(unittest.TestCase): 77 | def test_restrict(self): 78 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 79 | model = LeNet() 80 | qmodel = QLeNet() 81 | qmodel.load_state_dict(model.state_dict()) 82 | model.to(device) 83 | qmodel.to(device) 84 | qmodel.restrict() 85 | x = torch.randn((2, 1, 28, 28)) 86 | x = x.to(device) 87 | qy = qmodel(x) 88 | y = model(x.clamp(-1, 1)) 89 | self.assertTrue(y.equal(qy)) 90 | 91 | if __name__ == '__main__': 92 | t1 = TestRestrict() 93 | t1.test_restrict() -------------------------------------------------------------------------------- /hnn/snn/accumulate_with_refractory.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) OpenBII 2 | # Team: CBICR 3 | # SPDX-License-Identifier: Apache-2.0 4 | # See: https://spdx.org/licenses/ 5 | 6 | import torch 7 | from hnn.snn.q_module import QModule 8 | from hnn.grad import FakeQuantizeINT28 9 | 10 | 11 | class AccumulateWithRefractory(torch.nn.Module): 12 | '''考虑不应期的膜电位累加 13 | 14 | 不应期计数不为0时不进行膜电位累加 15 | 16 | Args: 17 | v_init: 如果输入膜电位为None, 则输入膜电位默认为固定初始值 18 | ''' 19 | def __init__(self, v_init) -> None: 20 | super(AccumulateWithRefractory, self).__init__() 21 | self.v_init = v_init 22 | 23 | def forward(self, u_in: torch.Tensor, v=None, ref_cnt=None) -> torch.Tensor: 24 | if v is None: 25 | v = torch.full_like(u_in, self.v_init) 26 | if ref_cnt is None: 27 | ref_cnt = torch.zeros_like(u_in) 28 | ref_mask = (1 - ref_cnt).clamp(min=0) 29 | return u_in * ref_mask + v 30 | 31 | 32 | class QAccumulateWithRefractory(QModule, AccumulateWithRefractory): 33 | '''支持量化的考虑不应期的膜电位累加操作 34 | 35 | 其他说明类似于hnn/snn/q_accumulate.py 36 | ''' 37 | def __init__(self, v_init) -> None: 38 | QModule.__init__(self) 39 | AccumulateWithRefractory.__init__(self, v_init) 40 | self.weight_scale = None 41 | self.first_time = True 42 | self.pretrained = False 43 | self.freeze = False 44 | 45 | def forward(self, u_in, weight_scale: torch.Tensor, v=None, ref_cnt=None): 46 | self.weight_scale = weight_scale 47 | if self.quantization_mode: 48 | v = self._quantize(v) 49 | if self.aware_mode: 50 | assert not( 51 | self.quantization_mode), 'Quantization mode and QAT mode are mutual exclusive' 52 | if v is not None: 53 | v = FakeQuantizeINT28.apply(v, weight_scale) 54 | v = AccumulateWithRefractory.forward(self, u_in, v, ref_cnt) 55 | if self.quantization_mode: 56 | assert not( 57 | self.aware_mode), 'Quantization mode and QAT mode are mutual exclusive' 58 | v = v.clamp(-134217728, 134217727) # INT28 59 | return v 60 | 61 | def _quantize(self, v: torch.Tensor) -> torch.Tensor: 62 | if self.first_time: 63 | self.first_time = False 64 | if not self.pretrained and not self.freeze: 65 | self.v_init = round(self.v_init * self.weight_scale.item()) 66 | if v is not None: 67 | v = v.mul(self.weight_scale).round( 68 | ).clamp(-134217728, 134217727) 69 | return v 70 | 71 | def dequantize(self): 72 | QModule.dequantize(self) 73 | self.v_init = self.v_init / self.weight_scale.item() 74 | 75 | def aware(self): 76 | if self.quantization_mode: 77 | self.dequantize() 78 | QModule.aware(self) 79 | self.v_init = round( 80 | self.v_init * self.weight_scale.item()) / self.weight_scale.item() -------------------------------------------------------------------------------- /examples/hnn/a2shnn_lenet.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) OpenBII 2 | # Team: CBICR 3 | # SPDX-License-Identifier: Apache-2.0 4 | # See: https://spdx.org/licenses/ 5 | 6 | import torch 7 | from hnn.hu.a2s_learnable_coding import A2SLearnableCoding 8 | from hnn.snn.lif import LIF 9 | from hnn.snn.output_rate_coding import OutputRateCoding 10 | from hnn.snn.model import Model 11 | from hnn.hu.model import A2SModel 12 | from hnn.snn.model import InputMode 13 | 14 | 15 | class ANN(torch.nn.Module): 16 | def __init__(self) -> None: 17 | super().__init__() 18 | self.conv1 = torch.nn.Conv2d(in_channels=1, out_channels=6, 19 | kernel_size=5, stride=1, padding=2) 20 | self.maxpool1 = torch.nn.MaxPool2d(kernel_size=2, stride=2) 21 | self.relu = torch.nn.ReLU() 22 | self.conv2 = torch.nn.Conv2d(in_channels=6, out_channels=16, 23 | kernel_size=5, stride=1, padding=0) 24 | self.maxpool2 = torch.nn.MaxPool2d(kernel_size=2, stride=2) 25 | 26 | def forward(self, x): 27 | x = self.conv1(x) 28 | x = self.maxpool1(x) 29 | x = self.relu(x) 30 | x = self.conv2(x) 31 | x = self.maxpool2(x) 32 | x = self.relu(x) 33 | return x 34 | 35 | 36 | class SNN(Model): 37 | def __init__(self, time_interval, mode) -> None: 38 | super().__init__(time_interval, mode) 39 | self.linear1 = torch.nn.Linear(in_features=400, out_features=120) 40 | self.lif1 = LIF(v_th=1, v_leaky_alpha=0.5, 41 | v_leaky_beta=0, v_reset=0, v_leaky_adpt_en=False) 42 | self.linear2 = torch.nn.Linear(in_features=120, out_features=84) 43 | self.lif2 = LIF(v_th=1, v_leaky_alpha=0.5, 44 | v_leaky_beta=0, v_reset=0, v_leaky_adpt_en=False) 45 | self.linear3 = torch.nn.Linear(in_features=84, out_features=10) 46 | self.lif3 = LIF(v_th=1, v_leaky_alpha=0.5, 47 | v_leaky_beta=0, v_reset=0, v_leaky_adpt_en=False) 48 | 49 | def forward(self, x, v1=None, v2=None, v3=None): 50 | x = self.linear1(x) 51 | x, v1 = self.lif1(x, v1) 52 | x = self.linear2(x) 53 | x, v2 = self.lif2(x, v2) 54 | x = self.linear3(x) 55 | x, v3 = self.lif3(x, v3) 56 | return x, v1, v2, v3 57 | 58 | 59 | class HNN(A2SModel): 60 | def __init__(self, T): 61 | super().__init__(T=T) 62 | self.ann = ANN() 63 | self.snn = SNN(time_interval=T, mode=InputMode.SEQUENTIAL) 64 | self.a2shu = A2SLearnableCoding(window_size=T, converter=torch.nn.Identity(), non_linear=torch.nn.ReLU()) 65 | self.encode = OutputRateCoding() 66 | 67 | def reshape(self, x: torch.Tensor): 68 | x = x.view(x.size(0), -1, x.size(-1)) 69 | return x.permute(2, 0, 1) 70 | 71 | 72 | if __name__ == '__main__': 73 | x = torch.randn((10, 1, 28, 28)) 74 | model = HNN(5) 75 | y = model(x) 76 | print(y.shape) 77 | -------------------------------------------------------------------------------- /examples/hnn/as2hnn_lenet.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) OpenBII 2 | # Team: CBICR 3 | # SPDX-License-Identifier: Apache-2.0 4 | # See: https://spdx.org/licenses/ 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | from hnn.ann.q_conv2d import QConv2d 10 | from hnn.hu.a2s_poisson_coding_sign_convert import A2SPoissonCodingSignConvert 11 | from hnn.snn.lif import QLIF 12 | from hnn.snn.output_rate_coding import OutputRateCoding 13 | from hnn.snn.q_linear import QLinear 14 | from hnn.snn.q_model import QModel 15 | 16 | 17 | class HNNLeNet(QModel): 18 | def __init__(self, T): 19 | super(HNNLeNet, self).__init__(time_window_size=T) 20 | self.conv1 = QConv2d(in_channels=1, out_channels=6, 21 | kernel_size=5, stride=1, padding=2, bias=False) 22 | self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2) 23 | # Nx6x16x16 24 | self.conv2 = QConv2d(in_channels=6, out_channels=16, 25 | kernel_size=5, stride=1, padding=0, bias=False) 26 | # Nx16x12x12 27 | self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2) 28 | # Nx16x6x6 29 | self.a2shu = A2SPoissonCodingSignConvert( 30 | window_size=T, non_linear=torch.nn.ReLU()) 31 | self.linear1 = QLinear(in_features=400, out_features=120) 32 | self.lif1 = QLIF(v_th=1, v_leaky_alpha=0.9, 33 | v_leaky_beta=0, v_reset=0) 34 | self.linear2 = QLinear(in_features=120, out_features=84) 35 | self.lif2 = QLIF(v_th=1, v_leaky_alpha=0.9, 36 | v_leaky_beta=0, v_reset=0) 37 | self.linear3 = QLinear(in_features=84, out_features=10) 38 | self.lif3 = QLIF(v_th=1, v_leaky_alpha=0.9, 39 | v_leaky_beta=0, v_reset=0) 40 | self.coding = OutputRateCoding() 41 | 42 | def forward(self, x: torch.Tensor): 43 | x = self.conv1(x) 44 | x = self.maxpool1(x) 45 | x = self.conv2(x) 46 | x = self.maxpool2(x) 47 | # A2SHU 48 | x = self.a2shu(x) # [N, C, H, W] -> [N, C, H, W, T] 49 | spike = torch.zeros((self.T, 10, 10)) 50 | x = x.permute(4, 0, 1, 2, 3) # [T, N, C, H, W] 51 | input = x.view(x.size(0), x.size(1), -1) # [T, N, C * H * W] 52 | v1 = None 53 | v2 = None 54 | v3 = None 55 | for i in range(self.T): 56 | x, q = self.linear1(input[i]) 57 | out, v1 = self.lif1(x, q, v1) 58 | 59 | x, q = self.linear2(out) 60 | out, v2 = self.lif2(x, q, v2) 61 | 62 | x, q = self.linear3(out) 63 | out, v3 = self.lif3(x, q, v3) 64 | 65 | spike[i] = out 66 | return self.coding(spike) 67 | 68 | 69 | if __name__ == '__main__': 70 | x = torch.randn((10, 1, 28, 28)) 71 | model = HNNLeNet(5) 72 | y = model(x) 73 | print(y.shape) 74 | 75 | torch.onnx.export(model, x, 'temp/a2s_hnn_lenet.onnx', 76 | custom_opsets={'snn': 1}, opset_version=11) 77 | -------------------------------------------------------------------------------- /hnn/snn/leaky.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) OpenBII 2 | # Team: CBICR 3 | # SPDX-License-Identifier: Apache-2.0 4 | # See: https://spdx.org/licenses/ 5 | 6 | import torch 7 | from hnn.grad import DifferentiableFloor 8 | from hnn.snn.q_module import QModule 9 | 10 | 11 | class Leaky(torch.nn.Module): 12 | '''指数衰减操作 13 | 14 | y = a x + b (a <= 1) 15 | 16 | Args: 17 | alpha: 指数衰减系数 18 | beta: 指数衰减常数 19 | adpt_en: 是否进行指数衰减, 否相当于alpha = 1 20 | ''' 21 | def __init__(self, alpha, beta, adpt_en=True): 22 | super(Leaky, self).__init__() 23 | self.alpha = alpha 24 | self.beta = beta 25 | assert alpha <= 1 26 | self.adpt_en = adpt_en 27 | 28 | def forward(self, x: torch.Tensor): 29 | if self.adpt_en: 30 | out = self.alpha * x + self.beta 31 | else: 32 | out = x + self.beta 33 | return out 34 | 35 | 36 | class QLeaky(QModule, Leaky): 37 | '''支持量化的指数衰减操作 38 | ''' 39 | def __init__(self, alpha, beta, adpt_en=True): 40 | QModule.__init__(self) 41 | Leaky.__init__(self, alpha=alpha, beta=beta, adpt_en=adpt_en) 42 | self.weight_scale = None 43 | self.first_time = True 44 | self.pretrained = False 45 | self.freeze = False 46 | 47 | def forward(self, x: torch.Tensor, weight_scale: torch.Tensor): 48 | self.weight_scale = weight_scale 49 | if self.quantization_mode: 50 | self._quantize() 51 | # forward 52 | if self.adpt_en: 53 | if self.quantization_mode: 54 | x = torch.floor(self.alpha * x) + self.beta 55 | elif self.aware_mode: 56 | assert not( 57 | self.quantization_mode), 'Quantization mode and QAT mode are mutual exclusive' 58 | x = DifferentiableFloor( 59 | self.alpha * x * weight_scale) / weight_scale + self.beta 60 | else: 61 | x = self.alpha * x + self.beta 62 | else: 63 | x = x + self.beta 64 | if self.quantization_mode: 65 | assert not( 66 | self.aware_mode), 'Quantization mode and QAT mode are mutual exclusive' 67 | x = x.clamp(-134217728, 134217727) # INT28 68 | return x 69 | 70 | def _quantize(self): 71 | if self.first_time: 72 | self.first_time = False 73 | if not self.pretrained and not self.freeze: 74 | self.alpha = round(self.alpha * 256) / 256 75 | self.beta = round(self.beta * self.weight_scale.item()) 76 | 77 | def dequantize(self): 78 | QModule.dequantize(self) 79 | self.beta = self.beta / self.weight_scale.item() 80 | 81 | def aware(self): 82 | if self.quantization_mode: 83 | self.dequantize() 84 | QModule.aware(self) 85 | self.alpha = round(self.alpha * 256) / 256 86 | self.beta = round(self.beta * self.weight_scale.item() 87 | ) / self.weight_scale.item() -------------------------------------------------------------------------------- /hnn/ann/q_module.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) OpenBII 2 | # Team: CBICR 3 | # SPDX-License-Identifier: Apache-2.0 4 | # See: https://spdx.org/licenses/ 5 | 6 | import torch 7 | from abc import ABC, abstractmethod 8 | 9 | 10 | class QModule(ABC): 11 | '''类似于torch.nn.Module, QModel和其他支持量化的算子都继承于QModule类 12 | 13 | Attributes: 14 | activation_absmax: 静态变量, 用于训练过程中限制激活的范围, default = 1 15 | quantization_mode: 表示QModule处于量化模式 16 | aware_mode: 表示QModule处于量化感知模式 17 | q_params_ready: 表示QModule中的量化参数已经统计完毕 18 | restricted: 表示QModule处于激活被限制状态 19 | bit_shift_unit: 硬件上用于实现量化时需要的参数 20 | ''' 21 | activation_absmax = 1 22 | 23 | def __init__(self): 24 | self.quantization_mode: bool = False 25 | self.aware_mode: bool = False 26 | self.q_params_ready: bool = False 27 | self.restricted: bool = False 28 | self.bit_shift_unit: int = None 29 | 30 | @abstractmethod 31 | def collect_q_params(self): 32 | '''抽象方法, 用于计算量化参数 33 | 34 | 继承QModule的子类中的此方法需要先调用父类的此方法将模型置于正确的状态 35 | ''' 36 | self.quantization_mode = False 37 | self.aware_mode = False 38 | self.q_params_ready = True 39 | 40 | @abstractmethod 41 | def quantize(self): 42 | '''抽象方法, 用于对模型进行量化 43 | 44 | 继承QModule的子类中的此方法需要先调用父类的此方法将模型置于正确的状态 45 | 46 | Raises: 47 | AssertionError: 如果模型已经处于量化状态则调用此方法会报错 48 | AssertionError: 如果模型没有计算得到量化参数则调用此方法会报错 49 | ''' 50 | assert not(self.quantization_mode), 'Model has been quantized' 51 | self.quantization_mode = True 52 | self.aware_mode = False 53 | self.restricted = False 54 | assert self.q_params_ready, 'Quantization cannot be executed unless quantization parameters have been collected' 55 | 56 | @abstractmethod 57 | def aware(self): 58 | '''抽象方法, 用于对模型进行量化感知训练 59 | 60 | 继承QModule的子类中的此方法需要先调用父类的此方法将模型置于正确的状态 61 | 62 | Raises: 63 | AssertionError: 如果模型没有计算得到量化参数则调用此方法会报错 64 | ''' 65 | self.aware_mode = True 66 | assert self.q_params_ready, 'QAT cannot be executed unless quantization parameters have been collected' 67 | 68 | @abstractmethod 69 | def dequantize(self): 70 | '''抽象方法, 用于对模型进行反量化, 将量化模型转换成浮点数模型 71 | 72 | 继承QModule的子类中的此方法需要先调用父类的此方法将模型置于正确的状态 73 | ''' 74 | self.quantization_mode = False 75 | 76 | def restrict(self): 77 | '''抽象方法, 用于对模型的激活值范围进行限制 78 | 79 | 继承QModule的子类中的此方法需要先调用父类的此方法将模型置于正确的状态 80 | 81 | Raises: 82 | AssertionError: 如果模型已经处于量化状态则调用此方法会报错 83 | ''' 84 | self.restricted = True 85 | assert not(self.quantization_mode) 86 | 87 | @staticmethod 88 | def quantize_input(x: torch.Tensor): 89 | x = x.div(x.abs().max()).mul(128).floor().clamp(-128, 127) 90 | return x 91 | 92 | @staticmethod 93 | def restrict_input(x: torch.Tensor): 94 | x = x.div(x.abs().max()).mul(QModule.activation_absmax) 95 | return x -------------------------------------------------------------------------------- /unit_tests/test_aware.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) OpenBII 2 | # Team: CBICR 3 | # SPDX-License-Identifier: Apache-2.0 4 | # See: https://spdx.org/licenses/ 5 | 6 | import unittest 7 | import torch 8 | import torch.nn as nn 9 | 10 | from hnn.ann.q_conv2d import QConv2d 11 | from hnn.ann.q_linear import QLinear 12 | from hnn.ann.q_model import QModel 13 | 14 | 15 | class LeNet(nn.Module): 16 | def __init__(self): 17 | super(LeNet, self).__init__() 18 | self.conv1 = nn.Conv2d(1, 6, 5, padding=2) 19 | self.maxpool1 = nn.MaxPool2d(2, 2) 20 | self.conv2 = nn.Conv2d(6, 16, 5) 21 | self.maxpool2 = nn.MaxPool2d(2, 2) 22 | self.linear1 = nn.Linear(400, 120) 23 | self.linear2 = nn.Linear(120, 84) 24 | self.linear3 = nn.Linear(84, 10) 25 | self.relu = nn.ReLU(inplace=True) 26 | 27 | def forward(self, x): 28 | x = self.conv1(x) 29 | x = self.relu(x) 30 | x = self.maxpool1(x) 31 | x = self.conv2(x) 32 | x = self.relu(x) 33 | x = self.maxpool2(x) 34 | x = x.view(x.size(0), -1) 35 | x = self.linear1(x) 36 | x = self.relu(x) 37 | x = self.linear2(x) 38 | x = self.relu(x) 39 | x = self.linear3(x) 40 | return x 41 | 42 | 43 | class QLeNet(QModel): 44 | def __init__(self): 45 | super(QLeNet, self).__init__() 46 | self.conv1 = QConv2d(1, 6, 5, padding=2) 47 | self.maxpool1 = nn.MaxPool2d(2, 2) 48 | self.conv2 = QConv2d(6, 16, 5) 49 | self.maxpool2 = nn.MaxPool2d(2, 2) 50 | self.linear1 = QLinear(400, 120) 51 | self.linear2 = QLinear(120, 84) 52 | self.linear3 = QLinear(84, 10) 53 | self.relu = nn.ReLU(inplace=True) 54 | 55 | def forward(self, x): 56 | x = self.conv1(x) 57 | x = self.relu(x) 58 | x = self.maxpool1(x) 59 | x = self.conv2(x) 60 | x = self.relu(x) 61 | x = self.maxpool2(x) 62 | x = x.view(x.size(0), -1) 63 | x = self.linear1(x) 64 | x = self.relu(x) 65 | x = self.linear2(x) 66 | x = self.relu(x) 67 | x = self.linear3(x) 68 | return x 69 | 70 | 71 | class TestAware(unittest.TestCase): 72 | def test_aware(self): 73 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 74 | model = LeNet() 75 | qmodel = QLeNet() 76 | qmodel.load_state_dict(model.state_dict()) 77 | model.to(device) 78 | qmodel.to(device) 79 | qmodel.collect_q_params() 80 | qmodel.quantize() 81 | x = torch.randn((2, 1, 28, 28)) 82 | qx = (x / x.abs().max()).mul(128).round().clamp(-128, 127) 83 | qx = qx.to(device) 84 | qy = qmodel(qx) 85 | x = (x / x.abs().max()).mul(128).round().clamp(-128, 127).div(128) 86 | qmodel.aware() 87 | x = x.to(device) 88 | y = qmodel(x) 89 | self.assertTrue((qy / y).equal((qy / y).mean() * torch.ones_like(y))) 90 | 91 | if __name__ == '__main__': 92 | t1 = TestAware() 93 | t1.test_aware() 94 | -------------------------------------------------------------------------------- /unit_tests/test_quantize.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) OpenBII 2 | # Team: CBICR 3 | # SPDX-License-Identifier: Apache-2.0 4 | # See: https://spdx.org/licenses/ 5 | 6 | import unittest 7 | import logging 8 | import torch 9 | import torch.nn as nn 10 | 11 | from hnn.ann.q_conv2d import QConv2d 12 | from hnn.ann.q_linear import QLinear 13 | from hnn.ann.q_model import QModel 14 | from hnn.ann.q_module import QModule 15 | 16 | 17 | class LeNet(nn.Module): 18 | def __init__(self): 19 | super(LeNet, self).__init__() 20 | self.conv1 = nn.Conv2d(1, 6, 5, padding=2) 21 | self.maxpool1 = nn.MaxPool2d(2, 2) 22 | self.conv2 = nn.Conv2d(6, 16, 5) 23 | self.maxpool2 = nn.MaxPool2d(2, 2) 24 | self.linear1 = nn.Linear(400, 120) 25 | self.linear2 = nn.Linear(120, 84) 26 | self.linear3 = nn.Linear(84, 10) 27 | self.relu = nn.ReLU(inplace=True) 28 | 29 | def forward(self, x): 30 | x = self.conv1(x) 31 | x = self.relu(x) 32 | x = self.maxpool1(x) 33 | x = self.conv2(x) 34 | x = self.relu(x) 35 | x = self.maxpool2(x) 36 | x = x.view(x.size(0), -1) 37 | x = self.linear1(x) 38 | x = self.relu(x) 39 | x = self.linear2(x) 40 | x = self.relu(x) 41 | x = self.linear3(x) 42 | return x 43 | 44 | 45 | class QLeNet(QModel): 46 | def __init__(self): 47 | super(QLeNet, self).__init__() 48 | self.conv1 = QConv2d(1, 6, 5, padding=2) 49 | self.maxpool1 = nn.MaxPool2d(2, 2) 50 | self.conv2 = QConv2d(6, 16, 5) 51 | self.maxpool2 = nn.MaxPool2d(2, 2) 52 | self.linear1 = QLinear(400, 120) 53 | self.linear2 = QLinear(120, 84) 54 | self.linear3 = QLinear(84, 10) 55 | self.relu = nn.ReLU(inplace=True) 56 | 57 | def forward(self, x): 58 | x = self.conv1(x) 59 | x = self.relu(x) 60 | x = self.maxpool1(x) 61 | x = self.conv2(x) 62 | x = self.relu(x) 63 | x = self.maxpool2(x) 64 | x = x.view(x.size(0), -1) 65 | x = self.linear1(x) 66 | x = self.relu(x) 67 | x = self.linear2(x) 68 | x = self.relu(x) 69 | x = self.linear3(x) 70 | return x 71 | 72 | 73 | class TestQuantize(unittest.TestCase): 74 | def test_quantize(self): 75 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 76 | model = LeNet() 77 | qmodel = QLeNet() 78 | qmodel.load_state_dict(model.state_dict()) 79 | model.to(device) 80 | qmodel.to(device) 81 | qmodel.collect_q_params() 82 | for name, module in qmodel.named_modules(): 83 | if not(isinstance(module, QModel)) and isinstance(module, QModule): 84 | logging.debug(name + ': ' + str(int(module.bit_shift))) 85 | qmodel.quantize() 86 | x = torch.randn((2, 1, 28, 28)) 87 | x = x.to(device) 88 | qx = (x / x.abs().max()).mul(128).round().clamp(-128, 127) 89 | qy = qmodel(qx) 90 | logging.debug(qy) 91 | 92 | 93 | if __name__ == '__main__': 94 | logging.basicConfig(level=logging.DEBUG) 95 | t1 = TestQuantize() 96 | t1.test_quantize() -------------------------------------------------------------------------------- /hnn/ann/q_adaptive_avgpool2d.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) OpenBII 2 | # Team: CBICR 3 | # SPDX-License-Identifier: Apache-2.0 4 | # See: https://spdx.org/licenses/ 5 | 6 | import torch 7 | import math 8 | from hnn.ann.q_module import QModule 9 | from hnn.grad import FakeQuantizeFloor 10 | 11 | 12 | class QAdaptiveAvgPool2d(QModule, torch.nn.AdaptiveAvgPool2d): 13 | '''支持量化的平均池化算子 14 | 15 | 算子继承自torch.nn.AdaptiveAvgPool2d, 基本参数与torch.nn.AdaptiveAvgPool2d完全相同, 此处不再赘述 16 | 目前只考虑了整个模型中只出现一个平均池化 17 | 18 | Args: 19 | bit_shift: 完成定点数计算后需要的量化参数 20 | absmax: 对输出激活进行限制时的范围 21 | kernel_size: 池化窗的大小 22 | is_last_node: 是否是最后一个算子 23 | ''' 24 | 25 | def __init__(self, output_size, kernel_size, is_last_node=False): 26 | torch.nn.AdaptiveAvgPool2d.__init__(self, output_size) 27 | QModule.__init__(self) 28 | self.bit_shift = None 29 | self.absmax = None 30 | self.kernel_size = kernel_size 31 | self.is_last_node = is_last_node 32 | 33 | def collect_q_params(self): 34 | QModule.collect_q_params(self) 35 | self.bit_shift = self.bit_shift_unit * round(math.log(self.kernel_size, 2)) 36 | self.absmax = 2 ** self.bit_shift / self.kernel_size ** 2 37 | 38 | def forward(self, x: torch.Tensor): 39 | if self.restricted: 40 | x = x.clamp(-QModule.activation_absmax, QModule.activation_absmax) 41 | QModule.activation_absmax = self.absmax 42 | if self.aware_mode: 43 | assert not( 44 | self.quantization_mode), 'Quantization mode and QAT mode are mutual exclusive' 45 | x = FakeQuantizeFloor.apply(x, 128 / QModule.activation_absmax) 46 | QModule.activation_absmax = self.absmax 47 | out = torch.nn.AdaptiveAvgPool2d.forward(self, x) 48 | if self.quantization_mode and not(self.is_last_node): 49 | assert not( 50 | self.aware_mode), 'Quantization mode and QAT mode are mutual exclusive' 51 | out = out.mul(self.kernel_size ** 2).clamp(-2147483648, 2147483647).div(2 ** 52 | self.bit_shift).floor().clamp(-128, 127) 53 | if self.is_last_node: 54 | out = out.clamp(-2147483648, 2147483647) 55 | return out 56 | 57 | def quantize(self): 58 | QModule.quantize(self) 59 | 60 | def dequantize(self): 61 | QModule.dequantize(self) 62 | 63 | def aware(self): 64 | if self.quantization_mode: 65 | self.dequantize() 66 | QModule.aware(self) 67 | 68 | def restrict(self): 69 | '''平均池化量化参数的计算在restrict方法中完成 70 | 71 | y = (x1 + x2 + ... + x_kernel_size) / kernel_size^2 72 | 128y = (128x1 + 128x2 + ... + 128x_kernel_size) / kernel_size^2 73 | n = round((log_2 kernel_size^2) / 2) 74 | bit_shift = n * bit_shift_unit 75 | 128 * kernel_size^2 / 2^bit_shift y = (128x1 + 128x2 + ... + 128x_kernel_size) / 2^bit_shift 76 | ''' 77 | QModule.restrict(self) 78 | self.bit_shift = self.bit_shift_unit * round(math.log(self.kernel_size, 2)) 79 | self.absmax = 2 ** self.bit_shift / self.kernel_size ** 2 -------------------------------------------------------------------------------- /examples/snn/snn_lenet.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) OpenBII 2 | # Team: CBICR 3 | # SPDX-License-Identifier: Apache-2.0 4 | # See: https://spdx.org/licenses/ 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | from hnn.snn.lif import QLIF 10 | from hnn.snn.output_rate_coding import OutputRateCoding 11 | from hnn.snn.q_conv2d import QConv2d 12 | from hnn.snn.q_linear import QLinear 13 | from hnn.snn.q_model import QModel 14 | 15 | 16 | class SNNLeNet(QModel): 17 | def __init__(self, T): 18 | super(SNNLeNet, self).__init__(time_window_size=T) 19 | self.conv1 = QConv2d(in_channels=1, out_channels=6, 20 | kernel_size=5, stride=1, padding=2, bias=False) 21 | # Nx6x32x32 22 | self.lif1 = QLIF(v_th=1, v_leaky_alpha=0.9, 23 | v_leaky_beta=0, v_reset=0) 24 | self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2) 25 | # Nx6x16x16 26 | self.conv2 = QConv2d(in_channels=6, out_channels=16, 27 | kernel_size=5, stride=1, padding=0, bias=False) 28 | # Nx16x12x12 29 | self.lif2 = QLIF(v_th=1, v_leaky_alpha=0.9, 30 | v_leaky_beta=0, v_reset=0) 31 | self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2) 32 | # Nx16x6x6 33 | self.linear1 = QLinear(in_features=576, out_features=120) 34 | self.lif3 = QLIF(v_th=1, v_leaky_alpha=0.9, 35 | v_leaky_beta=0, v_reset=0) 36 | self.linear2 = QLinear(in_features=120, out_features=84) 37 | self.lif4 = QLIF(v_th=1, v_leaky_alpha=0.9, 38 | v_leaky_beta=0, v_reset=0) 39 | self.linear3 = QLinear(in_features=84, out_features=10) 40 | self.lif5 = QLIF(v_th=1, v_leaky_alpha=0.9, 41 | v_leaky_beta=0, v_reset=0) 42 | self.coding = OutputRateCoding() 43 | 44 | def forward(self, x: torch.Tensor): 45 | # x_seq = x.unsqueeze(0).repeat( 46 | # self.T, 1, 1, 1, 1) # [N, C, H, W] -> [T,N, C, H, W] 47 | # x_seq = self.snnlenet(x_seq) 48 | # fr = x_seq.mean(0) 49 | spike = torch.zeros((self.T, 10, 10)) 50 | xx = x 51 | v1 = None 52 | v2 = None 53 | v3 = None 54 | v4 = None 55 | v5 = None 56 | for i in range(self.T): 57 | x, q = self.conv1(xx) 58 | out, v1 = self.lif1(x, q, v1) 59 | out = self.maxpool1(out) 60 | 61 | x, q = self.conv2(out) 62 | out, v2 = self.lif2(x, q, v2) 63 | out = self.maxpool2(out) 64 | 65 | out = torch.flatten(out, 1, -1) 66 | 67 | x, q = self.linear1(out) 68 | out, v3 = self.lif3(x, q, v3) 69 | 70 | x, q = self.linear2(out) 71 | out, v4 = self.lif2(x, q, v4) 72 | 73 | x, q = self.linear3(out) 74 | out, v5 = self.lif3(x, q, v5) 75 | 76 | spike[i] = out 77 | return self.coding(spike) 78 | 79 | 80 | if __name__ == '__main__': 81 | x = torch.randn((10, 1, 32, 32)) 82 | model = SNNLeNet(10) 83 | y = model(x) 84 | 85 | torch.onnx.export(model, x, 'temp/SNNLeNet.onnx', 86 | custom_opsets={'snn': 1}, opset_version=11) 87 | -------------------------------------------------------------------------------- /unit_tests/test_collect_q_params.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) OpenBII 2 | # Team: CBICR 3 | # SPDX-License-Identifier: Apache-2.0 4 | # See: https://spdx.org/licenses/ 5 | 6 | import math 7 | import unittest 8 | 9 | import torch 10 | import torch.nn as nn 11 | 12 | from hnn.ann.q_conv2d import QConv2d 13 | from hnn.ann.q_linear import QLinear 14 | from hnn.ann.q_model import QModel 15 | from hnn.ann.q_module import QModule 16 | 17 | 18 | class LeNet(nn.Module): 19 | def __init__(self): 20 | super(LeNet, self).__init__() 21 | self.conv1 = nn.Conv2d(1, 6, 5, padding=2) 22 | self.maxpool1 = nn.MaxPool2d(2, 2) 23 | self.conv2 = nn.Conv2d(6, 16, 5) 24 | self.maxpool2 = nn.MaxPool2d(2, 2) 25 | self.linear1 = nn.Linear(400, 120) 26 | self.linear2 = nn.Linear(120, 84) 27 | self.linear3 = nn.Linear(84, 10) 28 | self.relu = nn.ReLU(inplace=True) 29 | 30 | def forward(self, x): 31 | x = self.conv1(x) 32 | x = self.relu(x) 33 | x = self.maxpool1(x) 34 | x = self.conv2(x) 35 | x = self.relu(x) 36 | x = self.maxpool2(x) 37 | x = x.view(x.size(0), -1) 38 | x = self.linear1(x) 39 | x = self.relu(x) 40 | x = self.linear2(x) 41 | x = self.relu(x) 42 | x = self.linear3(x) 43 | return x 44 | 45 | 46 | class QLeNet(QModel): 47 | def __init__(self): 48 | super(QLeNet, self).__init__() 49 | self.conv1 = QConv2d(1, 6, 5, padding=2) 50 | self.maxpool1 = nn.MaxPool2d(2, 2) 51 | self.conv2 = QConv2d(6, 16, 5) 52 | self.maxpool2 = nn.MaxPool2d(2, 2) 53 | self.linear1 = QLinear(400, 120) 54 | self.linear2 = QLinear(120, 84) 55 | self.linear3 = QLinear(84, 10) 56 | self.relu = nn.ReLU(inplace=True) 57 | 58 | def forward(self, x): 59 | x = self.conv1(x) 60 | x = self.relu(x) 61 | x = self.maxpool1(x) 62 | x = self.conv2(x) 63 | x = self.relu(x) 64 | x = self.maxpool2(x) 65 | x = x.view(x.size(0), -1) 66 | x = self.linear1(x) 67 | x = self.relu(x) 68 | x = self.linear2(x) 69 | x = self.relu(x) 70 | x = self.linear3(x) 71 | return x 72 | 73 | 74 | class TestCollectQParams(unittest.TestCase): 75 | def test_collect_q_params(self): 76 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 77 | model = LeNet() 78 | qmodel = QLeNet() 79 | qmodel.load_state_dict(model.state_dict()) 80 | model.to(device) 81 | qmodel.to(device) 82 | qmodel.collect_q_params() 83 | compare_dict = {} 84 | for name, module in qmodel.named_modules(): 85 | if not(isinstance(module, QModel)) and isinstance(module, QModule): 86 | compare_dict[name] = [int(module.bit_shift)] 87 | for name, module in model.named_modules(): 88 | if isinstance(module, (nn.Conv2d, nn.Linear)): 89 | compare_dict[name].append(math.log(128 / module.weight.abs().max(), 2)) 90 | self.assertTrue(abs(compare_dict[name][0] - compare_dict[name][1]) < 1.5) 91 | 92 | 93 | if __name__ == '__main__': 94 | t1 = TestCollectQParams() 95 | t1.test_collect_q_params() 96 | -------------------------------------------------------------------------------- /hnn/grad.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) OpenBII 2 | # Team: CBICR 3 | # SPDX-License-Identifier: Apache-2.0 4 | # See: https://spdx.org/licenses/ 5 | 6 | import torch 7 | 8 | 9 | class FakeQuantize(torch.autograd.Function): 10 | @staticmethod 11 | def forward(ctx, x, scale): 12 | ctx.save_for_backward( 13 | x, torch.as_tensor(-128 / scale), torch.as_tensor(127 / scale)) 14 | x = x.mul(scale).round().clamp(-128, 127).div(scale) # 量化反量化 15 | return x 16 | 17 | @staticmethod 18 | def backward(ctx, grad_output): 19 | x, x_min, x_max = ctx.saved_tensors 20 | zeros = torch.zeros_like(x) 21 | ones = torch.ones_like(x) 22 | mask0 = torch.where(x < x_min, zeros, ones) 23 | mask1 = torch.where(x > x_max, zeros, ones) 24 | mask = mask0 * mask1 25 | grad = grad_output * mask 26 | return grad, None 27 | 28 | 29 | class FakeQuantizeFloor(torch.autograd.Function): 30 | @staticmethod 31 | def forward(ctx, x, scale): 32 | ctx.save_for_backward( 33 | x, torch.as_tensor(-128 / scale), torch.as_tensor(127 / scale)) 34 | x = x.mul(scale).floor().clamp(-128, 127).div(scale) # 量化反量化 35 | return x 36 | 37 | @staticmethod 38 | def backward(ctx, grad_output): 39 | x, x_min, x_max = ctx.saved_tensors 40 | zeros = torch.zeros_like(x) 41 | ones = torch.ones_like(x) 42 | mask0 = torch.where(x < x_min, zeros, ones) 43 | mask1 = torch.where(x > x_max, zeros, ones) 44 | mask = mask0 * mask1 45 | grad = grad_output * mask 46 | return grad, None 47 | 48 | 49 | class DifferentiableFloor(torch.autograd.Function): 50 | @staticmethod 51 | def forward(ctx, x: torch.Tensor): 52 | return torch.floor(x) 53 | 54 | @staticmethod 55 | def backward(ctx, grad_output): 56 | return grad_output 57 | 58 | 59 | class FakeQuantizeINT32(torch.autograd.Function): 60 | @staticmethod 61 | def forward(ctx, x, scale): 62 | ctx.save_for_backward( 63 | x, torch.as_tensor(-2147483648 / scale), torch.as_tensor(2147483647 / scale)) 64 | x = x.mul(scale).round().clamp(-2147483648, 65 | 2147483647).div(scale) # 量化反量化 66 | return x 67 | 68 | @staticmethod 69 | def backward(ctx, grad_output): 70 | x, x_min, x_max = ctx.saved_tensors 71 | zeros = torch.zeros_like(x) 72 | ones = torch.ones_like(x) 73 | mask0 = torch.where(x < x_min, zeros, ones) 74 | mask1 = torch.where(x > x_max, zeros, ones) 75 | mask = mask0 * mask1 76 | grad = grad_output * mask 77 | return grad, None 78 | 79 | 80 | class FakeQuantizeINT28(torch.autograd.Function): 81 | @staticmethod 82 | def forward(ctx, x, scale): 83 | ctx.save_for_backward( 84 | x, torch.as_tensor(-134217728 / scale), torch.as_tensor(134217727 / scale)) 85 | x = x.mul(scale).round().clamp(-134217728, 86 | 134217727).div(scale) # 量化反量化 87 | return x 88 | 89 | @staticmethod 90 | def backward(ctx, grad_output): 91 | x, x_min, x_max = ctx.saved_tensors 92 | zeros = torch.zeros_like(x) 93 | ones = torch.ones_like(x) 94 | mask0 = torch.where(x < x_min, zeros, ones) 95 | mask1 = torch.where(x > x_max, zeros, ones) 96 | mask = mask0 * mask1 97 | grad = grad_output * mask 98 | return grad, None 99 | -------------------------------------------------------------------------------- /hnn/ann/q_linear.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) OpenBII 2 | # Team: CBICR 3 | # SPDX-License-Identifier: Apache-2.0 4 | # See: https://spdx.org/licenses/ 5 | 6 | import torch 7 | import math 8 | from hnn.ann.q_module import QModule 9 | from hnn.grad import FakeQuantizeFloor, FakeQuantize, FakeQuantizeINT32 10 | 11 | 12 | class QLinear(QModule, torch.nn.Linear): 13 | '''支持量化的Linear算子 14 | 15 | 算子继承自torch.nn.Linear, 基本参数与torch.nn.Linear完全相同, 此处不再赘述 16 | 17 | Args: 18 | weight_scale: 权重的放缩系数 19 | bit_shift: 完成定点数计算后需要的量化参数 20 | is_last_node: 是否是最后一个算子 21 | ''' 22 | def __init__(self, in_features, out_features, bias=True, device=None, dtype=None, is_last_node=False): 23 | torch.nn.Linear.__init__( 24 | self, in_features, out_features, bias, device, dtype) 25 | QModule.__init__(self) 26 | self.weight_scale = None 27 | self.bit_shift = None 28 | self.is_last_node = is_last_node 29 | 30 | def collect_q_params(self): 31 | '''全连接中计算量化参数 32 | 33 | weight_absmax * weight_scale = 128 34 | weight_scale = 2^(bit_shift_unit * n) 35 | bit_shift = bit_shift_unit * n = log_2 (128 / weight_absmax) 36 | n = round(log_2 (128 / weight_absmax) / bit_shift_unit) 37 | 这里取整方法可以有很多 38 | ''' 39 | QModule.collect_q_params(self) 40 | weight_absmax = self.weight.data.abs().max() 41 | temp = math.log(128 / weight_absmax, 2) / self.bit_shift_unit 42 | if temp - math.floor(temp) >= 0.75: # 经验公式 43 | n = math.ceil(temp) 44 | else: 45 | n = math.floor(temp) 46 | self.bit_shift = self.bit_shift_unit * n 47 | self.weight_scale = 2 ** self.bit_shift 48 | 49 | def forward(self, x: torch.Tensor): 50 | if self.restricted: 51 | x = x.clamp(-QModule.activation_absmax, QModule.activation_absmax) 52 | if self.aware_mode: 53 | assert not( 54 | self.quantization_mode), 'Quantization mode and QAT mode are mutual exclusive' 55 | x = FakeQuantizeFloor.apply(x, 128 / QModule.activation_absmax) 56 | out = torch.nn.Linear.forward(self, x) 57 | if self.quantization_mode and not(self.is_last_node): 58 | assert not( 59 | self.aware_mode), 'Quantization mode and QAT mode are mutual exclusive' 60 | out = out.clamp(-2147483648, 61 | 2147483647).div(self.weight_scale).floor().clamp(-128, 127) 62 | if self.is_last_node: 63 | out = out.clamp(-2147483648, 2147483647) 64 | return out 65 | 66 | def quantize(self): 67 | QModule.quantize(self) 68 | self.weight.data = self.weight.data.mul( 69 | self.weight_scale).round().clamp(-128, 127) # INT8 70 | self.bias.data = self.bias.data.mul( 71 | self.weight_scale * 128 / QModule.activation_absmax).round().clamp(-2147483648, 2147483647) # INT32 72 | 73 | def dequantize(self): 74 | QModule.dequantize(self) 75 | self.weight.data = self.weight.data.div(self.weight_scale) 76 | self.bias.data = self.bias.data.div( 77 | self.weight_scale * 128 / QModule.activation_absmax) 78 | 79 | def aware(self): 80 | if self.quantization_mode: 81 | self.dequantize() 82 | QModule.aware(self) 83 | self.weight.data = FakeQuantize.apply( 84 | self.weight.data, self.weight_scale) 85 | self.bias.data = FakeQuantizeINT32.apply( 86 | self.bias.data, self.weight_scale * 128 / QModule.activation_absmax) 87 | -------------------------------------------------------------------------------- /docs/sphinx/source/HNN介绍.rst: -------------------------------------------------------------------------------- 1 | ======================================================================== 2 | HNN介绍 3 | ======================================================================== 4 | 5 | 本文档作者:刘发强、曲环宇 6 | 7 | 本文档分为HNN框架和HNN转换接口两部分来对HNN进行基本介绍,对HNN更加深入的介绍详见 [1]_ 。整体上讲,HNN可以分成三部分:ANN网络、SNN网络和将解耦合的ANN和SNN网络联系在一起的转换接口 -- Hybrid Unit (HU),HU的设计又可以分成根据专家知识手动设计和自适应学习HU两大类。 8 | 9 | HNN框架 10 | ###################### 11 | 12 | HNN框架这部分主要介绍ANN和SNN不同的神经元以及两种网络中不同的信息表示。 13 | 14 | 基本神经元模型 15 | ************************ 16 | 17 | ANN神经元模型 18 | ------------------------ 19 | 20 | ANN神经元模型不具有时间上的动力学,神经元的输入输出一般都为实数,基本模型可以表示为: 21 | 22 | .. math:: 23 | 24 | y = f( \mathbf{w} \cdot \mathbf{x} + b) 25 | 26 | 其中\ :math:`f(x)`\ 为激活函数,一般为非线性函数,例如ReLU。 27 | 28 | SNN神经元模型 29 | ------------------------ 30 | 31 | SNN神经元模型具有时间上的动力学,一般通过微分方程来对SNN神经元进行建模,神经元的输入输出一般为脉冲序列,其基本模型可以表示为(参考SpikingJelly [2]_ ): 32 | 33 | .. math:: 34 | u(t) & = \mathbf{w} \cdot \mathbf{s(t - 1)} + b \\ 35 | v(t) & = f(v(t - 1), u(t)) \\ 36 | s(t) & = \Theta(v(t) - v_{threshold}) 37 | 38 | 其中\ :math:`u(t)`\ 表示当前时刻的输入,\ :math:`v(t)`\ 代表当前时刻的膜电位,为前一时刻的膜电位和当前时刻输入的函数,\ :math:`s(t)`\ 代表当前时刻脉冲神经元的输出,其中 39 | 40 | .. math:: 41 | \Theta(x) = 42 | \begin{cases} 43 | 1, & x \geq 0 \\ 44 | 0, & x < 0 45 | \end{cases} 46 | 47 | 表示如果当前时刻膜电位高于阈值则发放脉冲,反之则不发放脉冲,如果神经元发放脉冲则神经元的膜电位会被重置。 48 | 49 | 在HNN网络中,为了保证ANN神经元与SNN神经元的解耦,引入了HU来作为ANN神经元与SNN神经元之间的转换接口,HU将在文档的第二部分具体介绍。 50 | 51 | 信号表示 52 | *********************** 53 | 54 | 根据上一小节对ANN神经元和SNN神经元模型的介绍,我们可以总结得到:其中ANN神经元中是同步的、连续的信号,ANN中的信号为实数域表示,而SNN神经元中是异步的、离散的信号,SNN中的信号为脉冲表示。 55 | 56 | 我们将信息流分成两种:直接的传递和间接的调制,如下图所示: 57 | 58 | .. image:: _static/hnn1.png 59 | :width: 100% 60 | :align: center 61 | 62 | 其中传递代表神经元之间通过突触传递的信息来直接影响神经元的状态,调制代表某个或某些神经元通过调整神经元或突触参数来间接影响其他神经元的状态,例如改变神经元阈值和突触权重等。在ANN和SNN中,信息的传递和调制都是在同质的信号之间发生。 63 | 64 | 在HNN中,我们通过HU来对两种异质的信号进行转换,信息流可以分成混合传递和混合调制两种,其中每种信息流又可以进一步分成ANN到SNN信号转换和SNN到ANN信号转换两种,如下图所示: 65 | 66 | .. image:: _static/hnn2.png 67 | :width: 100% 68 | :align: center 69 | 70 | 图中的实线代表ANN,虚线代表SNN。 71 | 72 | 73 | HNN转换接口 -- Hybrid Unit 74 | ######################################## 75 | 76 | 由于HU将ANN和SNN相互解耦,所以HNN中的ANN和SNN部分和常见的ANN、SNN相同,这一部分文档介绍HNN中的转换接口,Hybrid Unit。 77 | 78 | 基本操作 79 | ********************** 80 | 81 | HU在整体上可以分成4个计算步骤:truncating (\ :math:`W(t)`\),filtering (\ :math:`H(t)`\), non-linearity (\ :math:`F`\)和discretization (\ :math:`Q`\)。假设输入原始数据为\ :math:`X`\ ,输入数据为\ :math:`Y`\ ,HU可以表示为: 82 | 83 | .. math:: 84 | Y & = HU[X] \\ 85 | & = Q \cdot F \cdot H \cdot W(X) 86 | 87 | 其中:(这部分感觉写的不是很清楚) 88 | 89 | - **truncating:** 由于HU的输入和输出是两个没有时间上的依赖关系的时间序列,所以需要一个参数化的窗函数来同步HU的输入和输出的时间尺度,窗函数\ :math:`W(t, k, T_s)`\ 对输入\ :math:`X`\ 进行截取,其中\ :math:`T_s`\ 为时间窗长度。此步骤的输出为\ :math:`X \cdot W(t, k, T_s)`\ 。 90 | - **filtering:** 通过kernel函数\ :math:`H(t))`\ 对 \ :math:`X \cdot W(t, k, T_s)`\ 进行时域上的卷积。 91 | - **non-linearity:** 代表对上一步卷积得到的结果进行非线性变换。 92 | - **discretization:** 代表离散化操作,例如将连续信号转换成SNN中的脉冲序列,为可选操作。 93 | 94 | 95 | 配置方式 96 | ********************** 97 | 98 | 与传统的信号转换相比,HU是可配置的,例如HU中的kernel函数\ :math:`H(t)`\ 和非线性函数\ :math:`F`\ 都是参数化的。 99 | 100 | HU有两种配置方式:手动设计和自动学习。当HU的输入和输出之间的关系是确定的、简单并且已知的情况下,可以采用手动设计的方式来配置HU。在更加复杂的情况下,可以通过自动学习来对HU进行配置,HU可以通过下图中的3种方式进行学习: 101 | 102 | .. image:: _static/hnn3.png 103 | :width: 80% 104 | :align: center 105 | 106 | 1. 与前端或后端网络和在一起进行训练。 107 | 2. 通过设定特殊的优化目标来单独训练。 108 | 3. 和整个网络一起训练。 109 | 110 | 111 | -------------------- 112 | 113 | .. [1] `A framework for the general design and computation of hybrid neural networks `__ 114 | .. [2] `SpikingJelly `__ -------------------------------------------------------------------------------- /examples/ann/lenet.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) OpenBII 2 | # Team: CBICR 3 | # SPDX-License-Identifier: Apache-2.0 4 | # See: https://spdx.org/licenses/ 5 | 6 | import pickle 7 | 8 | import numpy as np 9 | import torch.nn as nn 10 | 11 | from hnn.ann.q_conv2d import QConv2d 12 | from hnn.ann.q_linear import QLinear 13 | from hnn.ann.q_model import QModel 14 | 15 | 16 | class LeNet(nn.Module): 17 | def __init__(self): 18 | super(LeNet, self).__init__() 19 | nn.Conv2d() 20 | self.conv1 = nn.Conv2d(1, 6, 5, padding=2) 21 | self.maxpool1 = nn.MaxPool2d(2, 2) 22 | self.conv2 = nn.Conv2d(6, 16, 5) 23 | self.maxpool2 = nn.MaxPool2d(2, 2) 24 | self.linear1 = nn.Linear(400, 120) 25 | self.linear2 = nn.Linear(120, 84) 26 | self.linear3 = nn.Linear(84, 10) 27 | self.relu = nn.ReLU(inplace=True) 28 | 29 | def forward(self, x): 30 | x = self.conv1(x) 31 | x = self.relu(x) 32 | x = self.maxpool1(x) 33 | x = self.conv2(x) 34 | x = self.relu(x) 35 | x = self.maxpool2(x) 36 | x = x.view(x.size(0), -1) 37 | x = self.linear1(x) 38 | x = self.relu(x) 39 | x = self.linear2(x) 40 | x = self.relu(x) 41 | x = self.linear3(x) 42 | return x 43 | 44 | 45 | class QLeNet(QModel): 46 | def __init__(self): 47 | super(QLeNet, self).__init__() 48 | self.conv1 = QConv2d(1, 6, 5, padding=2) 49 | self.maxpool1 = nn.MaxPool2d(2, 2) 50 | self.conv2 = QConv2d(6, 16, 5) 51 | self.maxpool2 = nn.MaxPool2d(2, 2) 52 | self.linear1 = QLinear(400, 120) 53 | self.linear2 = QLinear(120, 84) 54 | self.linear3 = QLinear(84, 10) 55 | self.relu = nn.ReLU(inplace=True) 56 | 57 | self.model_name = 'QLeNet' 58 | self.input_shape = (1, 1, 28, 28) 59 | 60 | def _forward(self, x): 61 | x = self.conv1(x) 62 | x = self.relu(x) 63 | x = self.maxpool1(x) 64 | x = self.conv2(x) 65 | x = self.relu(x) 66 | x = self.maxpool2(x) 67 | x = x.view(x.size(0), -1) 68 | x = self.linear1(x) 69 | x = self.relu(x) 70 | x = self.linear2(x) 71 | x = self.relu(x) 72 | x = self.linear3(x) 73 | return x 74 | 75 | def forward(self, x, record_path=None): 76 | if record_path is not None: 77 | record_dict = {} 78 | record_dict.update({0: x.squeeze(0).permute( 79 | 1, 2, 0).detach().numpy().astype(np.int32)}) 80 | x = self.conv1(x) 81 | x = self.relu(x) 82 | x = self.maxpool1(x) 83 | record_dict.update({8: x.squeeze(0).permute( 84 | 1, 2, 0).detach().numpy().astype(np.int32)}) 85 | x = self.conv2(x) 86 | x = self.relu(x) 87 | x = self.maxpool2(x) 88 | x = x.view(x.size(0), -1) 89 | record_dict.update( 90 | {18: x.squeeze(0).detach().numpy().astype(np.int32)}) 91 | x = self.linear1(x) 92 | x = self.relu(x) 93 | record_dict.update( 94 | {24: x.squeeze(0).detach().numpy().astype(np.int32)}) 95 | x = self.linear2(x) 96 | x = self.relu(x) 97 | record_dict.update( 98 | {30: x.squeeze(0).detach().numpy().astype(np.int32)}) 99 | x = self.linear3(x) 100 | record_dict.update( 101 | {34: x.view(-1).detach().numpy().astype(np.int32)}) 102 | with open(record_path, 'wb') as f: 103 | pickle.dump(record_dict, f) 104 | return x 105 | else: 106 | return self._forward(x) 107 | 108 | 109 | if __name__ == '__main__': 110 | model = QLeNet() 111 | model.execute(is_random_input=True, fix_random_seed=True, 112 | result_path='temp/QLeNet/o_0_0_0.dat', export_onnx_path='temp/QLeNet/QLeNet.onnx') 113 | -------------------------------------------------------------------------------- /hnn/snn/q_linear.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) OpenBII 2 | # Team: CBICR 3 | # SPDX-License-Identifier: Apache-2.0 4 | # See: https://spdx.org/licenses/ 5 | 6 | import torch 7 | from hnn.snn.q_integrate import QIntegrate 8 | from hnn.grad import FakeQuantize, FakeQuantizeINT28 9 | 10 | 11 | class QLinear(QIntegrate, torch.nn.Linear): 12 | '''支持量化的Linear算子 13 | 14 | 算子继承自torch.nn.Linear, 基本参数与torch.nn.Linear完全相同, 此处不再赘述 15 | 16 | Args: 17 | weight_scale: 权重的放缩系数 18 | is_encoder: 是否作为SNN中的encoder使用 19 | input_scale: 作为SNN中的encoder使用时对输入的放缩系数 20 | ''' 21 | def __init__(self, in_features, out_features, bias=True, device=None, dtype=None, is_encoder=False): 22 | torch.nn.Linear.__init__( 23 | self, in_features, out_features, bias, device, dtype) 24 | QIntegrate.__init__(self, is_encoder=is_encoder) 25 | 26 | def collect_q_params(self): 27 | '''统计量化参数 28 | 29 | 权重的放缩系数直接计算得到 30 | 如果作为encoder使用, 会将算子置于统计量化参数的状态 31 | ''' 32 | QIntegrate.collect_q_params(self) 33 | if self.is_encoder: 34 | self.num_inputs = 0 35 | self.sum_absmax = 0 36 | 37 | def calculate_q_params(self): 38 | '''计算量化参数 39 | 40 | 计算输入的放缩系数 41 | ''' 42 | QIntegrate.calculate_q_params(self) 43 | if self.is_encoder: 44 | self.input_scale = 128 / (self.sum_absmax / self.num_inputs) 45 | weight_absmax = self.weight.data.abs().max() 46 | self.weight_scale = 128 / weight_absmax 47 | 48 | def forward(self, x: torch.Tensor): 49 | '''前向推理 50 | 51 | Args: 52 | x: 张量输入 53 | 54 | Returns: 55 | 第一个输出: 张量输出 56 | 第二个输出: 传递给脉冲神经元的量化参数 57 | ''' 58 | if self.collecting: 59 | self.num_inputs += 1 60 | self.sum_absmax += x.data.abs().max() 61 | if self.is_encoder and self.quantization_mode: 62 | x = x.mul(self.input_scale).round().clamp(-128, 127) 63 | out = torch.nn.Linear.forward(self, x) 64 | if self.quantization_mode: 65 | out = out.clamp(-134217728, 134217727) # INT28 66 | if self.is_encoder: 67 | return out, self.weight_scale * self.input_scale if (self.weight_scale is not None and self.input_scale is not None) else None 68 | else: 69 | return out, self.weight_scale 70 | 71 | def quantize(self): 72 | QIntegrate.quantize(self) 73 | self.weight.data = self.weight.data.mul( 74 | self.weight_scale).round().clamp(-128, 127) # INT8 75 | if self.bias is not None: 76 | if self.is_encoder: 77 | self.bias.data = self.bias.data.mul( 78 | self.weight_scale * self.input_scale).round().clamp(-134217728, 134217727) # INT28 79 | else: 80 | self.bias.data = self.bias.data.mul( 81 | self.weight_scale).round().clamp(-134217728, 134217727) # INT28 82 | 83 | def dequantize(self): 84 | QIntegrate.dequantize(self) 85 | self.weight.data = self.weight.data.div(self.weight_scale) 86 | if self.bias is not None: 87 | if self.is_encoder: 88 | self.bias.data = self.bias.data.div( 89 | self.weight_scale * self.input_scale) 90 | else: 91 | self.bias.data = self.bias.data.div(self.weight_scale) 92 | 93 | def aware(self): 94 | if self.quantization_mode: 95 | self.dequantize() 96 | QIntegrate.aware(self) 97 | self.weight.data = FakeQuantize.apply( 98 | self.weight.data, self.weight_scale) 99 | if self.bias is not None: 100 | if self.is_encoder: 101 | self.bias.data = FakeQuantizeINT28.apply( 102 | self.bias.data, self.weight_scale * self.input_scale) 103 | else: 104 | self.bias.data = FakeQuantizeINT28.apply( 105 | self.bias.data, self.weight_scale) -------------------------------------------------------------------------------- /examples/ann/vgg16.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) OpenBII 2 | # Team: CBICR 3 | # SPDX-License-Identifier: Apache-2.0 4 | # See: https://spdx.org/licenses/ 5 | 6 | import pickle 7 | 8 | import numpy as np 9 | import torch 10 | import torch.nn as nn 11 | 12 | from hnn.ann.q_conv2d import QConv2d 13 | from hnn.ann.q_linear import QLinear 14 | from hnn.ann.q_model import QModel 15 | 16 | 17 | class QVGG16(QModel): 18 | def __init__(self, num_classes=1000): 19 | super(QVGG16, self).__init__() 20 | self.relu = nn.ReLU(inplace=True) 21 | self.conv0 = QConv2d(3, 64, 3, stride=1, padding=1) 22 | self.conv1 = QConv2d(64, 64, 3, stride=1, padding=1) 23 | self.maxpool0 = nn.MaxPool2d(2, stride=2) 24 | self.conv2 = QConv2d(64, 128, kernel_size=3, stride=1, padding=1) 25 | self.conv3 = QConv2d(128, 128, kernel_size=3, stride=1, padding=1) 26 | self.maxpool1 = nn.MaxPool2d(2, stride=2) 27 | self.conv4 = QConv2d(128, 256, kernel_size=3, stride=1, padding=1) 28 | self.conv5 = QConv2d(256, 256, kernel_size=3, stride=1, padding=1) 29 | self.conv6 = QConv2d(256, 256, kernel_size=3, stride=1, padding=1) 30 | self.maxpool2 = nn.MaxPool2d(2, stride=2) 31 | self.conv7 = QConv2d(256, 512, kernel_size=3, stride=1, padding=1) 32 | self.conv8 = QConv2d(512, 512, kernel_size=3, stride=1, padding=1) 33 | self.conv9 = QConv2d(512, 512, kernel_size=3, stride=1, padding=1) 34 | self.maxpool3 = nn.MaxPool2d(2, stride=2) 35 | self.conv10 = QConv2d(512, 512, kernel_size=3, stride=1, padding=1) 36 | self.conv11 = QConv2d(512, 512, kernel_size=3, stride=1, padding=1) 37 | self.conv12 = QConv2d(512, 512, kernel_size=3, stride=1, padding=1) 38 | self.maxpool4 = nn.MaxPool2d(2, stride=2) 39 | 40 | self.fc0 = QLinear(512 * 7 * 7, 4096) 41 | self.fc1 = QLinear(4096, 4096) 42 | self.fc2 = QLinear(4096, num_classes) 43 | 44 | self.model_name = 'QVGG16' 45 | self.input_shape = (1, 3, 224, 224) 46 | 47 | def _forward(self, x) -> torch.Tensor: 48 | x = self.conv0(x) 49 | x = self.relu(x) 50 | x = self.conv1(x) 51 | x = self.relu(x) 52 | x = self.maxpool0(x) 53 | x = self.conv2(x) 54 | x = self.relu(x) 55 | x = self.conv3(x) 56 | x = self.relu(x) 57 | x = self.maxpool1(x) 58 | x = self.conv4(x) 59 | x = self.relu(x) 60 | x = self.conv5(x) 61 | x = self.relu(x) 62 | x = self.conv6(x) 63 | x = self.relu(x) 64 | x = self.maxpool2(x) 65 | x = self.conv7(x) 66 | x = self.relu(x) 67 | x = self.conv8(x) 68 | x = self.relu(x) 69 | x = self.conv9(x) 70 | x = self.relu(x) 71 | x = self.maxpool3(x) 72 | x = self.conv10(x) 73 | x = self.relu(x) 74 | x = self.conv11(x) 75 | x = self.relu(x) 76 | x = self.conv12(x) 77 | x = self.relu(x) 78 | x = self.maxpool4(x) 79 | 80 | x = torch.flatten(x, 1) 81 | x = self.fc0(x) 82 | x = self.relu(x) 83 | x = self.fc1(x) 84 | x = self.relu(x) 85 | x = self.fc2(x) 86 | return x 87 | 88 | def forward(self, x, record_path=None): 89 | if record_path is not None: 90 | record_dict = {} 91 | record_dict.update({0: x.squeeze(0).permute( 92 | 1, 2, 0).detach().numpy().astype(np.int32)}) 93 | x = self._forward(x) 94 | record_dict.update( 95 | {106: x.view(-1).detach().numpy().astype(np.int32)}) 96 | 97 | with open(record_path, 'wb') as f: 98 | pickle.dump(record_dict, f) 99 | return x 100 | else: 101 | return self._forward(x) 102 | 103 | 104 | if __name__ == '__main__': 105 | model = QVGG16() 106 | model.execute(is_random_input=True, fix_random_seed=True, 107 | result_path='temp/QVGG16/o_0_0_0.dat', export_onnx_path='temp/QVGG16/QVGG16.onnx') 108 | -------------------------------------------------------------------------------- /hnn/ann/q_conv2d.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) OpenBII 2 | # Team: CBICR 3 | # SPDX-License-Identifier: Apache-2.0 4 | # See: https://spdx.org/licenses/ 5 | 6 | import torch 7 | import math 8 | from hnn.ann.q_module import QModule 9 | from hnn.grad import FakeQuantizeFloor, FakeQuantize, FakeQuantizeINT32 10 | 11 | 12 | class QConv2d(QModule, torch.nn.Conv2d): 13 | '''支持量化的Conv2d算子 14 | 15 | 算子继承自torch.nn.Conv2d, 基本参数与torch.nn.Conv2d完全相同, 此处不再赘述 16 | 17 | Args: 18 | weight_scale: 权重的放缩系数 19 | bit_shift: 完成定点数卷积计算后需要的量化参数 20 | is_last_node: 是否是最后一个算子 21 | ''' 22 | def __init__( 23 | self, 24 | in_channels, 25 | out_channels, 26 | kernel_size, 27 | stride=1, 28 | padding=0, 29 | dilation=1, 30 | groups=1, 31 | bias=True, 32 | padding_mode='zeros', 33 | device=None, 34 | dtype=None, 35 | is_last_node=False): 36 | torch.nn.Conv2d.__init__( 37 | self, 38 | in_channels=in_channels, 39 | out_channels=out_channels, 40 | kernel_size=kernel_size, 41 | stride=stride, 42 | padding=padding, 43 | dilation=dilation, 44 | groups=groups, 45 | bias=bias, 46 | padding_mode=padding_mode, 47 | device=device, 48 | dtype=dtype 49 | ) 50 | QModule.__init__(self) 51 | self.weight_scale = None 52 | self.bit_shift = None 53 | self.is_last_node = is_last_node 54 | 55 | def collect_q_params(self): 56 | '''卷积中计算量化参数 57 | 58 | weight_absmax * weight_scale = 128 59 | weight_scale = 2^(bit_shift_unit * n) 60 | bit_shift = bit_shift_unit * n = log_2 (128 / weight_absmax) 61 | n = round(log_2 (128 / weight_absmax) / bit_shift_unit) 62 | 这里取整方法可以有很多 63 | ''' 64 | QModule.collect_q_params(self) 65 | weight_absmax = self.weight.data.abs().max() 66 | temp = math.log(128 / weight_absmax, 2) / self.bit_shift_unit 67 | if temp - math.floor(temp) >= 0.75: # 经验公式 68 | n = math.ceil(temp) 69 | else: 70 | n = math.floor(temp) 71 | self.bit_shift = self.bit_shift_unit * n 72 | self.weight_scale = 2 ** self.bit_shift 73 | 74 | def forward(self, x: torch.Tensor): 75 | if self.restricted: 76 | x = x.clamp(-QModule.activation_absmax, QModule.activation_absmax) 77 | if self.aware_mode: 78 | assert not( 79 | self.quantization_mode), 'Quantization mode and QAT mode are mutual exclusive' 80 | x = FakeQuantizeFloor.apply(x, 128 / QModule.activation_absmax) 81 | out = torch.nn.Conv2d.forward(self, x) 82 | if self.quantization_mode and not(self.is_last_node): 83 | assert not( 84 | self.aware_mode), 'Quantization mode and QAT mode are mutual exclusive' 85 | out = out.clamp(-2147483648, 86 | 2147483647).div(self.weight_scale).floor().clamp(-128, 127) 87 | if self.is_last_node: 88 | out = out.clamp(-2147483648, 2147483647) 89 | return out 90 | 91 | def quantize(self): 92 | QModule.quantize(self) 93 | self.weight.data = self.weight.data.mul( 94 | self.weight_scale).round().clamp(-128, 127) # INT8 95 | self.bias.data = self.bias.data.mul( 96 | self.weight_scale * 128 / QModule.activation_absmax).round().clamp(-2147483648, 2147483647) # INT32 97 | 98 | def dequantize(self): 99 | QModule.dequantize(self) 100 | self.weight.data = self.weight.data.div(self.weight_scale) 101 | self.bias.data = self.bias.data.div( 102 | self.weight_scale * 128 / QModule.activation_absmax) 103 | 104 | def aware(self): 105 | if self.quantization_mode: 106 | self.dequantize() 107 | QModule.aware(self) 108 | self.weight.data = FakeQuantize.apply( 109 | self.weight.data, self.weight_scale) 110 | self.bias.data = FakeQuantizeINT32.apply( 111 | self.bias.data, self.weight_scale * 128 / QModule.activation_absmax) 112 | -------------------------------------------------------------------------------- /examples/ann/googlenet.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) OpenBII 2 | # Team: CBICR 3 | # SPDX-License-Identifier: Apache-2.0 4 | # See: https://spdx.org/licenses/ 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | from hnn.ann.q_adaptive_avgpool2d import QAdaptiveAvgPool2d 10 | from hnn.ann.q_conv2d import QConv2d 11 | from hnn.ann.q_linear import QLinear 12 | from hnn.ann.q_model import QModel 13 | 14 | 15 | class BasicConv2d(nn.Module): 16 | def __init__(self, in_channels, out_channels, **kwargs): 17 | super(BasicConv2d, self).__init__() 18 | self.conv = QConv2d(in_channels, out_channels, **kwargs) 19 | self.relu = nn.ReLU(inplace=True) 20 | 21 | def forward(self, x): 22 | x = self.conv(x) 23 | x = self.relu(x) 24 | return x 25 | 26 | 27 | class Inception(nn.Module): 28 | def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch5x5red, ch5x5, pool_proj): 29 | super(Inception, self).__init__() 30 | self.branch1 = BasicConv2d(in_channels, ch1x1, kernel_size=1) 31 | self.branch2 = nn.Sequential( 32 | BasicConv2d(in_channels, ch3x3red, kernel_size=1), 33 | BasicConv2d(ch3x3red, ch3x3, kernel_size=3, padding=1)) 34 | self.branch3 = nn.Sequential( 35 | BasicConv2d(in_channels, ch5x5red, kernel_size=1), 36 | BasicConv2d(ch5x5red, ch5x5, kernel_size=5, padding=2)) 37 | self.branch4 = nn.Sequential( 38 | nn.MaxPool2d(kernel_size=3, stride=1, padding=1), 39 | BasicConv2d(in_channels, pool_proj, kernel_size=1)) 40 | 41 | def forward(self, x): 42 | branch1 = self.branch1(x) 43 | branch2 = self.branch2(x) 44 | branch3 = self.branch3(x) 45 | branch4 = self.branch4(x) 46 | outputs = [branch1, branch2, branch3, branch4] 47 | return torch.cat(outputs, 1) 48 | 49 | 50 | class QGoogLeNet(QModel): 51 | def __init__(self, num_classes=1000): 52 | super(QGoogLeNet, self).__init__() 53 | self.conv1 = BasicConv2d(3, 64, kernel_size=7, stride=2, padding=3) 54 | self.maxpool1 = nn.MaxPool2d(3, stride=2, ceil_mode=True) 55 | 56 | self.conv2 = BasicConv2d(64, 64, kernel_size=1) 57 | self.conv3 = BasicConv2d(64, 192, kernel_size=3, padding=1) 58 | self.maxpool2 = nn.MaxPool2d(3, stride=2, ceil_mode=True) 59 | 60 | self.inception3a = Inception(192, 64, 96, 128, 16, 32, 32) 61 | self.inception3b = Inception(256, 128, 128, 192, 32, 96, 64) 62 | self.maxpool3 = nn.MaxPool2d(3, stride=2, ceil_mode=True) 63 | 64 | self.inception4a = Inception(480, 192, 96, 208, 16, 48, 64) 65 | self.inception4b = Inception(512, 160, 112, 224, 24, 64, 64) 66 | self.inception4c = Inception(512, 128, 128, 256, 24, 64, 64) 67 | self.inception4d = Inception(512, 112, 144, 288, 32, 64, 64) 68 | self.inception4e = Inception(528, 256, 160, 320, 32, 128, 128) 69 | self.maxpool4 = nn.MaxPool2d(3, stride=2, ceil_mode=True) 70 | 71 | self.inception5a = Inception(832, 256, 160, 320, 32, 128, 128) 72 | self.inception5b = Inception(832, 384, 192, 384, 48, 128, 128) 73 | 74 | self.avgpool = QAdaptiveAvgPool2d((1, 1), 7) 75 | self.fc = QLinear(1024, num_classes) 76 | 77 | self.model_name = 'QGoogLeNet' 78 | self.input_shape = (1, 3, 224, 224) 79 | 80 | def forward(self, x): 81 | x = self.conv1(x) 82 | x = self.maxpool1(x) 83 | x = self.conv2(x) 84 | x = self.conv3(x) 85 | x = self.maxpool2(x) 86 | 87 | x = self.inception3a(x) 88 | x = self.inception3b(x) 89 | x = self.maxpool3(x) 90 | x = self.inception4a(x) 91 | x = self.inception4b(x) 92 | x = self.inception4c(x) 93 | x = self.inception4d(x) 94 | x = self.inception4e(x) 95 | x = self.maxpool4(x) 96 | 97 | x = self.inception5a(x) 98 | x = self.inception5b(x) 99 | x = self.avgpool(x) 100 | x = torch.flatten(x, 1) 101 | x = self.fc(x) 102 | return x 103 | 104 | 105 | if __name__ == '__main__': 106 | model = QGoogLeNet() 107 | model.execute(is_random_input=True, fix_random_seed=True, 108 | result_path='temp/QGoogLeNet/o_0_0_0.dat', export_onnx_path='temp/QGoogLeNet/QGoogLeNet.onnx') 109 | -------------------------------------------------------------------------------- /examples/ann/vgg19.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) OpenBII 2 | # Team: CBICR 3 | # SPDX-License-Identifier: Apache-2.0 4 | # See: https://spdx.org/licenses/ 5 | 6 | import pickle 7 | 8 | import numpy as np 9 | import torch 10 | import torch.nn as nn 11 | 12 | from hnn.ann.q_conv2d import QConv2d 13 | from hnn.ann.q_linear import QLinear 14 | from hnn.ann.q_model import QModel 15 | 16 | 17 | class QVGG19(QModel): 18 | def __init__(self, num_classes=1000): 19 | super(QVGG19, self).__init__() 20 | self.relu = nn.ReLU(inplace=True) 21 | self.conv0 = QConv2d(3, 64, 3, stride=1, padding=1) 22 | self.conv1 = QConv2d(64, 64, 3, stride=1, padding=1) 23 | self.maxpool0 = nn.MaxPool2d(2, stride=2) 24 | self.conv2 = QConv2d(64, 128, kernel_size=3, stride=1, padding=1) 25 | self.conv3 = QConv2d(128, 128, kernel_size=3, stride=1, padding=1) 26 | self.maxpool1 = nn.MaxPool2d(2, stride=2) 27 | self.conv4 = QConv2d(128, 256, kernel_size=3, stride=1, padding=1) 28 | self.conv5 = QConv2d(256, 256, kernel_size=3, stride=1, padding=1) 29 | self.conv6 = QConv2d(256, 256, kernel_size=3, stride=1, padding=1) 30 | self.conv7 = QConv2d(256, 256, kernel_size=3, stride=1, padding=1) 31 | self.maxpool2 = nn.MaxPool2d(2, stride=2) 32 | self.conv8 = QConv2d(256, 512, kernel_size=3, stride=1, padding=1) 33 | self.conv9 = QConv2d(512, 512, kernel_size=3, stride=1, padding=1) 34 | self.conv10 = QConv2d(512, 512, kernel_size=3, stride=1, padding=1) 35 | self.conv11 = QConv2d(512, 512, kernel_size=3, stride=1, padding=1) 36 | self.maxpool3 = nn.MaxPool2d(2, stride=2) 37 | self.conv12 = QConv2d(512, 512, kernel_size=3, stride=1, padding=1) 38 | self.conv13 = QConv2d(512, 512, kernel_size=3, stride=1, padding=1) 39 | self.conv14 = QConv2d(512, 512, kernel_size=3, stride=1, padding=1) 40 | self.conv15 = QConv2d(512, 512, kernel_size=3, stride=1, padding=1) 41 | self.maxpool4 = nn.MaxPool2d(2, stride=2) 42 | 43 | self.fc0 = QLinear(512 * 7 * 7, 4096) 44 | self.fc1 = QLinear(4096, 4096) 45 | self.fc2 = QLinear(4096, num_classes) 46 | 47 | self.model_name = 'QVGG19' 48 | self.input_shape = (1, 3, 224, 224) 49 | 50 | def _forward(self, x) -> torch.Tensor: 51 | x = self.conv0(x) 52 | x = self.relu(x) 53 | x = self.conv1(x) 54 | x = self.relu(x) 55 | x = self.maxpool0(x) 56 | x = self.conv2(x) 57 | x = self.relu(x) 58 | x = self.conv3(x) 59 | x = self.relu(x) 60 | x = self.maxpool1(x) 61 | x = self.conv4(x) 62 | x = self.relu(x) 63 | x = self.conv5(x) 64 | x = self.relu(x) 65 | x = self.conv6(x) 66 | x = self.relu(x) 67 | x = self.conv7(x) 68 | x = self.relu(x) 69 | x = self.maxpool2(x) 70 | x = self.conv8(x) 71 | x = self.relu(x) 72 | x = self.conv9(x) 73 | x = self.relu(x) 74 | x = self.conv10(x) 75 | x = self.relu(x) 76 | x = self.conv11(x) 77 | x = self.relu(x) 78 | x = self.maxpool3(x) 79 | x = self.conv12(x) 80 | x = self.relu(x) 81 | x = self.conv13(x) 82 | x = self.relu(x) 83 | x = self.conv14(x) 84 | x = self.relu(x) 85 | x = self.conv15(x) 86 | x = self.relu(x) 87 | x = self.maxpool4(x) 88 | 89 | x = torch.flatten(x, 1) 90 | x = self.fc0(x) 91 | x = self.relu(x) 92 | x = self.fc1(x) 93 | x = self.relu(x) 94 | x = self.fc2(x) 95 | return x 96 | 97 | def forward(self, x, record_path=None): 98 | if record_path is not None: 99 | record_dict = {} 100 | record_dict.update({0: x.squeeze(0).permute( 101 | 1, 2, 0).detach().numpy().astype(np.int32)}) 102 | x = self._forward(x) 103 | record_dict.update( 104 | {124: x.view(-1).detach().numpy().astype(np.int32)}) 105 | 106 | with open(record_path, 'wb') as f: 107 | pickle.dump(record_dict, f) 108 | return x 109 | else: 110 | return self._forward(x) 111 | 112 | 113 | if __name__ == '__main__': 114 | model = QVGG19() 115 | model.execute(is_random_input=True, fix_random_seed=True, 116 | result_path='temp/QVGG19/o_0_0_0.dat', export_onnx_path='temp/QVGG19/QVGG19.onnx') 117 | -------------------------------------------------------------------------------- /unit_tests/test_lif.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) OpenBII 2 | # Team: CBICR 3 | # SPDX-License-Identifier: Apache-2.0 4 | # See: https://spdx.org/licenses/ 5 | 6 | import torch 7 | import unittest 8 | from hnn.snn.q_module import QModule 9 | from hnn.snn.lif import LIF, QLIF 10 | from hnn.grad import FakeQuantizeINT28 11 | 12 | 13 | class RefQLIF(QModule, LIF): 14 | def __init__(self, v_th_0, v_leaky_alpha=1, v_leaky_beta=0, v_reset=0, v_leaky_adpt_en=False, v_init=None): 15 | QModule.__init__(self) 16 | LIF.__init__(self, v_th_0, v_leaky_alpha, v_leaky_beta, 17 | v_reset, v_leaky_adpt_en, v_init) 18 | self.weight_scale = None 19 | self.first_time = True 20 | self.pretrained = False 21 | 22 | def collect_q_params(self): 23 | QModule.collect_q_params(self) 24 | 25 | def forward(self, x, weight_scale, v=None): 26 | self.weight_scale = weight_scale 27 | if self.quantization_mode: 28 | v = self._quantize(v) 29 | if self.aware_mode: 30 | assert not( 31 | self.quantization_mode), 'Quantization mode and QAT mode are mutual exclusive' 32 | if v is not None: 33 | v = FakeQuantizeINT28.apply(v, weight_scale) 34 | spike, v = LIF.forward(self, x, v) 35 | if self.quantization_mode: 36 | assert not( 37 | self.aware_mode), 'Quantization mode and QAT mode are mutual exclusive' 38 | v = v.clamp(-134217728, 134217727) # INT28 39 | return spike, v 40 | 41 | def quantize(self): 42 | QModule.quantize(self) 43 | 44 | def _quantize(self, v: torch.Tensor) -> torch.Tensor: 45 | if self.first_time: 46 | self.first_time = False 47 | if not self.pretrained: 48 | self.if_node.fire.v_th = round( 49 | self.if_node.fire.v_th * self.weight_scale.item()) 50 | self.if_node.accumulate.v_init = round( 51 | self.if_node.accumulate.v_init * self.weight_scale.item()) 52 | self.if_node.reset.value = round( 53 | self.if_node.reset.value * self.weight_scale.item()) 54 | self.v_leaky.beta = round( 55 | self.v_leaky.beta * self.weight_scale.item()) 56 | if self.v_leaky.adpt_en: 57 | self.v_leaky.alpha = round(self.v_leaky.alpha * 256) / 256 58 | if v is not None: 59 | v = v.mul(self.weight_scale).round( 60 | ).clamp(-134217728, 134217727) 61 | return v 62 | 63 | def dequantize(self): 64 | QModule.dequantize(self) 65 | self.if_node.fire.v_th = self.if_node.fire.v_th / self.weight_scale.item() 66 | self.if_node.reset.value = self.if_node.reset.value / self.weight_scale.item() 67 | self.v_leaky.beta = self.v_leaky.beta / self.weight_scale.item() 68 | self.if_node.accumulate.v_init = self.if_node.accumulate.v_init / \ 69 | self.weight_scale.item() 70 | 71 | def aware(self): 72 | if self.quantization_mode: 73 | self.dequantize() 74 | QModule.aware(self) 75 | self.if_node.fire.v_th = round( 76 | self.if_node.fire.v_th * self.weight_scale.item()) / self.weight_scale.item() 77 | self.if_node.reset.value = round( 78 | self.if_node.reset.value * self.weight_scale.item()) / self.weight_scale.item() 79 | self.if_node.accumulate.v_init = round( 80 | self.if_node.accumulate.v_init * self.weight_scale.item()) / self.weight_scale.item() 81 | self.v_leaky.beta = round( 82 | self.v_leaky.beta * self.weight_scale.item()) / self.weight_scale.item() 83 | if self.v_leaky.adpt_en: 84 | self.v_leaky.alpha = round(self.v_leaky.alpha * 256) / 256 85 | 86 | 87 | 88 | class TestLIF(unittest.TestCase): 89 | def test_lif(self): 90 | x = torch.randn(1, 1, 28, 28) 91 | x = (x > 0).float() 92 | lif = RefQLIF(v_th_0=1, v_leaky_alpha=0.9, v_leaky_beta=0.5, 93 | v_reset=0.2, v_leaky_adpt_en=True, v_init=0.1) 94 | new_lif = QLIF(v_th=1, v_leaky_alpha=0.9, v_leaky_beta=0.5, 95 | v_reset=0.2, v_leaky_adpt_en=True, v_init=0.1) 96 | lif.quantize() 97 | new_lif.quantize() 98 | scale = torch.as_tensor(100) 99 | _, v_ref = lif.forward(x, scale) 100 | _, v = new_lif.forward(x, scale) 101 | self.assertTrue((v_ref - v).abs().mean() < 1) 102 | 103 | 104 | if __name__ == '__main__': 105 | t1 = TestLIF() 106 | t1.test_lif() -------------------------------------------------------------------------------- /hnn/snn/q_conv2d.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) OpenBII 2 | # Team: CBICR 3 | # SPDX-License-Identifier: Apache-2.0 4 | # See: https://spdx.org/licenses/ 5 | 6 | import torch 7 | from hnn.snn.q_integrate import QIntegrate 8 | from hnn.grad import FakeQuantize, FakeQuantizeINT28 9 | 10 | 11 | class QConv2d(QIntegrate, torch.nn.Conv2d): 12 | '''支持量化的Conv2d算子 13 | 14 | 算子继承自torch.nn.Conv2d, 基本参数与torch.nn.Conv2d完全相同, 此处不再赘述 15 | 16 | Args: 17 | weight_scale: 权重的放缩系数 18 | is_encoder: 是否作为SNN中的encoder使用 19 | input_scale: 作为SNN中的encoder使用时对输入的放缩系数 20 | ''' 21 | def __init__( 22 | self, 23 | in_channels, 24 | out_channels, 25 | kernel_size, 26 | stride=1, 27 | padding=0, 28 | dilation=1, 29 | groups=1, 30 | bias=True, 31 | padding_mode='zeros', 32 | device=None, 33 | dtype=None, 34 | is_encoder=False): 35 | torch.nn.Conv2d.__init__( 36 | self, 37 | in_channels=in_channels, 38 | out_channels=out_channels, 39 | kernel_size=kernel_size, 40 | stride=stride, 41 | padding=padding, 42 | dilation=dilation, 43 | groups=groups, 44 | bias=bias, 45 | padding_mode=padding_mode, 46 | device=device, 47 | dtype=dtype 48 | ) 49 | QIntegrate.__init__(self, is_encoder=is_encoder) 50 | 51 | def collect_q_params(self): 52 | '''统计量化参数 53 | 54 | 权重的放缩系数直接计算得到 55 | 如果作为encoder使用, 会将算子置于统计量化参数的状态 56 | ''' 57 | QIntegrate.collect_q_params(self) 58 | if self.is_encoder: 59 | self.num_inputs = 0 60 | self.sum_absmax = 0 61 | 62 | def calculate_q_params(self): 63 | '''计算量化参数 64 | 65 | 计算输入的放缩系数 66 | ''' 67 | QIntegrate.calculate_q_params(self) 68 | weight_absmax = self.weight.data.abs().max() 69 | self.weight_scale = 128 / weight_absmax 70 | if self.is_encoder: 71 | self.input_scale = 128 / (self.sum_absmax / self.num_inputs) 72 | self.bias_scale = self.weight_scale * self.input_scale 73 | else: 74 | self.bias_scale = self.weight_scale 75 | 76 | def forward(self, x: torch.Tensor): 77 | '''前向推理 78 | 79 | Args: 80 | x: 张量输入 81 | 82 | Returns: 83 | 第一个输出: 张量输出 84 | 第二个输出: 传递给脉冲神经元的量化参数 85 | ''' 86 | if self.collecting: 87 | self.num_inputs += 1 88 | self.sum_absmax += x.data.abs().max() 89 | if self.is_encoder and self.quantization_mode: 90 | x = x.mul(self.input_scale).round().clamp(-128, 127) 91 | if self.is_encoder and self.aware_mode: 92 | x = FakeQuantize.apply(x, self.input_scale) 93 | out = torch.nn.Conv2d.forward(self, x) 94 | if self.quantization_mode: 95 | out = out.clamp(-134217728, 134217727) # INT28 96 | if self.aware_mode: 97 | out = FakeQuantizeINT28.apply(out, self.bias_scale) 98 | return out, self.bias_scale 99 | 100 | def quantize(self): 101 | QIntegrate.quantize(self) 102 | self.weight.data = self.weight.data.mul( 103 | self.weight_scale).round().clamp(-128, 127) # INT8 104 | if self.bias is not None: 105 | if self.is_encoder: 106 | self.bias.data = self.bias.data.mul( 107 | self.bias_scale).round().clamp(-134217728, 134217727) # INT28 108 | else: 109 | self.bias.data = self.bias.data.mul( 110 | self.weight_scale).round().clamp(-134217728, 134217727) # INT28 111 | 112 | def dequantize(self): 113 | QIntegrate.dequantize(self) 114 | self.weight.data = self.weight.data.div(self.weight_scale) 115 | if self.bias is not None: 116 | if self.is_encoder: 117 | self.bias.data = self.bias.data.div( 118 | self.bias_scale) 119 | else: 120 | self.bias.data = self.bias.data.div(self.weight_scale) 121 | 122 | def aware(self): 123 | if self.quantization_mode: 124 | self.dequantize() 125 | QIntegrate.aware(self) 126 | self.weight.data = FakeQuantize.apply( 127 | self.weight.data, self.weight_scale) 128 | if self.bias is not None: 129 | if self.is_encoder: 130 | self.bias.data = FakeQuantizeINT28.apply( 131 | self.bias.data, self.bias_scale) 132 | else: 133 | self.bias.data = FakeQuantizeINT28.apply( 134 | self.bias.data, self.weight_scale) -------------------------------------------------------------------------------- /docs/sphinx/source/HNN编程与量化框架.rst: -------------------------------------------------------------------------------- 1 | ======================================================================== 2 | HNN编程与量化框架 3 | ======================================================================== 4 | 5 | 本文档作者:曲环宇 6 | 7 | HNN框架包括基本的HNN编程框架和HNN量化框架,其中量化框架在不置于量化模式时的功能等同于编程框架。 8 | 9 | HNN编程 10 | ###################### 11 | 12 | 和ANN、SNN相比,HNN编程框架主要提供了可扩展的Hybrid Unit的实现,通过Hybrid Unit可以实现ANN到SNN或SNN到ANN的转换,进而构建起HNN模型。 13 | 14 | Hybrid Unit's UML 15 | ************************ 16 | 17 | .. image:: _static/hu.png 18 | :width: 100% 19 | :align: center 20 | 21 | HU整体上可以分成两类:A2SHU和S2AHU,其中A2SHU由Sampler,Nonlinear和PrecisionConvert组成,S2AHU由WindowSet,WindowConv和Nonlinear组成。 22 | 23 | 在A2SHU中,Sampler负责将ANN输出的连续值采样到多个时间步,Nonlinear对得到的中间结果进行一次非线性变换,最后Precision Convert将非线性变换后的结果转换成多个时间步上的脉冲值;在S2AHU中,WindowSet用于对SNN输出的脉冲设置时间窗,然后由WindowConv在时间窗内进行时域上的卷积,得到的结果再经过Nonlinear的非线性变换后即为ANN的输入。 24 | 25 | 每个类的介绍待补充。 26 | 27 | 28 | HNN编程框架使用 29 | *********************** 30 | 31 | 通过HU将ANN转换到SNN 32 | 33 | .. code:: python 34 | 35 | import torch 36 | from snn import LIF 37 | from hybrid_unit import A2SPoissonCodingSignConvert 38 | 39 | 40 | class A2SHNN(torch.nn.Module): 41 | def __init__(self, T): 42 | super(A2SHNN, self).__init__(time_window_size=T) 43 | self.conv1 = torch.nn.Conv2d(in_channels=1, out_channels=6, 44 | kernel_size=5, stride=1, padding=2, bias=False) 45 | self.maxpool1 = torch.nn.MaxPool2d(kernel_size=2, stride=2) 46 | self.conv2 = torch.nn.Conv2d(in_channels=6, out_channels=16, 47 | kernel_size=5, stride=1, padding=0, bias=False) 48 | self.maxpool2 = torch.nn.MaxPool2d(kernel_size=2, stride=2) 49 | self.a2shu = A2SPoissonCodingSignConvert(window_size=T, non_linear=torch.nn.ReLU()) 50 | self.linear1 = torch.nn.Linear(in_features=400, out_features=10) 51 | self.lif1 = LIF(v_th=1, v_leaky_alpha=0.9, 52 | v_leaky_beta=0, v_reset=0) 53 | 54 | def forward(self, x: torch.Tensor): 55 | x = self.conv1(x) 56 | x = self.maxpool1(x) 57 | x = self.conv2(x) 58 | x = self.maxpool2(x) 59 | # A2SHU 60 | x = self.a2shu(x) # [N, C, H, W] -> [N, C, H, W, T] 61 | spike = torch.zeros((self.T, x.size(0), 10)) 62 | x = x.permute(4, 0, 1, 2, 3) # [T, N, C, H, W] 63 | input = x.view(x.size(0), x.size(1), -1) # [T, N, C * H * W] 64 | v1 = None 65 | v2 = None 66 | v3 = None 67 | for i in range(self.T): 68 | x = self.linear1(input[i]) 69 | out, v1 = self.lif1(x, v1) 70 | x = self.linear2(out) 71 | out, v2 = self.lif2(x, v2) 72 | x = self.linear3(out) 73 | out, v3 = self.lif3(x, v3) 74 | spike[i] = out 75 | return spike.mean(dim=0) 76 | 77 | 通过HU将SNN转换到ANN 78 | 79 | .. code:: python 80 | 81 | import torch 82 | from snn import LIF 83 | from hybrid_unit import S2AGlobalRateCoding 84 | 85 | 86 | class S2AHU(torch.nn.Module): 87 | def __init__(self, T): 88 | super(S2AHU, self).__init__(time_window_size=T) 89 | self.conv1 = torch.nn.Conv2d(in_channels=1, out_channels=6, 90 | kernel_size=5, stride=1, padding=2, bias=False) 91 | self.lif1 = LIF(v_th=1, v_leaky_alpha=0.9, 92 | v_leaky_beta=0, v_reset=0) 93 | self.maxpool1 = torch.nn.MaxPool2d(kernel_size=2, stride=2) 94 | self.conv2 = torch.nn.Conv2d(in_channels=6, out_channels=16, 95 | kernel_size=5, stride=1, padding=0, bias=False) 96 | self.lif2 = LIF(v_th=1, v_leaky_alpha=0.9, 97 | v_leaky_beta=0, v_reset=0) 98 | self.maxpool2 = torch.nn.MaxPool2d(kernel_size=2, stride=2) 99 | self.s2ahu = S2AGlobalRateCoding(window_size=T, non_linear=torch.nn.ReLU()) 100 | self.linear1 = torch.nn.Linear(in_features=400, out_features=10) 101 | 102 | def forward(self, x: torch.Tensor): 103 | # [N, C, H ,W] 104 | inputs = x.unsqueeze(0).repeat(self.T, 1, 1, 1, 1) 105 | spike = torch.zeros((x.size(0), 16, 5, 5, self.T)) 106 | v1 = None 107 | v2 = None 108 | for i in range(self.T): 109 | x = self.conv1(inputs[i]) 110 | out, v1 = self.lif1(x, v1) 111 | out = self.maxpool1(out) 112 | x = self.conv2(out) 113 | out, v2 = self.lif2(x, v2) 114 | out = self.maxpool2(out) 115 | spike[..., i] = out 116 | # S2AHU 117 | x = self.s2ahu(spike) 118 | x = self.linear1(x) 119 | return x 120 | 121 | HNN量化 122 | ######################################## 123 | 124 | 待补充 -------------------------------------------------------------------------------- /hnn/snn/extended_lif.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) OpenBII 2 | # Team: CBICR 3 | # SPDX-License-Identifier: Apache-2.0 4 | # See: https://spdx.org/licenses/ 5 | 6 | import torch 7 | from hnn.snn.q_model import QModel 8 | from hnn.snn.reset_mode import ResetMode 9 | from hnn.snn.refractory import Refractory 10 | from hnn.snn.threshold_dynamics import ThresholdDynamics, QThresholdDynamics 11 | from hnn.snn.saturate import Saturate, QSaturate 12 | from hnn.snn.threshold_accumulate_with_saturate import ThresholdAccumulateWithSaturate, QThresholdAccumulateWithSaturate 13 | from hnn.snn.lif_with_tensor_threshold_and_reset_mode_and_refractory import LIFWithTensorThresholdAndResetModeAndRefractory, QLIFWithTensorThresholdAndResetModeAndRefractory 14 | 15 | 16 | class ExtendedLIF(torch.nn.Module): 17 | '''扩展LIF神经元 18 | 19 | 包括以下阶段: 20 | 1. 膜电位阈值累加, 支持下限饱和 21 | 2. LIF神经元计算, 支持不同神经元可以有不同的阈值, 支持可配置的膜电位复位模式和不应期 22 | 3. 膜电位下限饱和 23 | 4. 膜电位阈值动力学, 包括阈值自适应分量的指数衰减和发放后导致的阈值增加 24 | 5. 不应期减计数, 发放脉冲后不应期复位 25 | 26 | Integrate阶段通过其他算子完成 27 | 28 | Args: 29 | threshold_accumulate.accumulate.vth0 = v_th0 30 | threshold_accumulate.saturate.v_l = v_l 31 | lif.if_node.reset.reset_mode = reset_mode 32 | lif.if_node.reset.v_reset = v_reset 33 | lif.if_node.reset.dv = dv 34 | lif.if_node.accumulate.v_init = v_init 35 | lif.if_node.fire.surrogate_function: 默认为Rectangle 36 | window_size: Rectangle的矩形窗宽度, default = 1 37 | lif.v_leaky.alpha = v_leaky_alpha 38 | lif.v_leaky.beta = v_leaky_beta 39 | lif.v_leaky.adpt_en = v_leaky_adpt_en 40 | saturate.v_l = v_l 41 | refractory.reset.value = ref_len 42 | threshold_dynamics.decay.alpha = v_th_alpha 43 | threshold_dynamics.decay.beta = v_th_beta 44 | threshold_dynamics.decay.adpt_en = v_th_adpt_en 45 | threshold_dynamics.update.value = v_th_incre 46 | ''' 47 | def __init__(self, v_th0, 48 | v_leaky_alpha=1, v_leaky_beta=0, 49 | v_leaky_adpt_en=False, 50 | v_reset=0, v_init=None, 51 | v_th_alpha=1, v_th_beta=0, v_th_adpt_en=True, 52 | v_th_incre=0, v_l=None, dv=0, 53 | ref_len=0, reset_mode=ResetMode.HARD, 54 | window_size=1): 55 | self.threshold_accumulate = ThresholdAccumulateWithSaturate( 56 | v_th0=v_th0, v_l=v_l) 57 | self.lif = LIFWithTensorThresholdAndResetModeAndRefractory( 58 | v_leaky_alpha=v_leaky_alpha, v_leaky_beta=v_leaky_beta, reset_mode=reset_mode, 59 | v_reset=v_reset, dv=dv, v_leaky_adpt_en=v_leaky_adpt_en, v_init=v_init, window_size=window_size) 60 | self.saturate = Saturate(v_l=v_l) 61 | self.refractory = Refractory(ref_len=ref_len) 62 | self.threshold_dynamics = ThresholdDynamics( 63 | v_th_alpha=v_th_alpha, v_th_beta=v_th_beta, v_th_incre=v_th_incre, v_th_adpt_en=v_th_adpt_en) 64 | 65 | def forward(self, u_in, v_th_adpt, v=None, ref_cnt=None): 66 | v_th = self.threshold_accumulate.forward(v_th_adpt) 67 | spike, v = self.lif.forward(u_in, v_th, v, ref_cnt) 68 | v = self.saturate.forward(v) 69 | v_th_adpt = self.threshold_dynamics.forward(v_th_adpt, spike) 70 | ref_cnt = self.refractory.forward(ref_cnt, spike) 71 | return spike, v, v_th_adpt, ref_cnt 72 | 73 | 74 | class QExtendedLIF(QModel): 75 | '''支持量化的扩展LIF神经元 76 | ''' 77 | def __init__(self, v_th0, 78 | v_leaky_alpha=1, v_leaky_beta=0, 79 | v_reset=0, v_init=None, 80 | v_th_alpha=1, v_th_beta=0, 81 | v_th_incre=0, v_l=None, dv=0, 82 | ref_len=0, reset_mode=ResetMode.HARD, 83 | window_size=1): 84 | QModel.__init__(self) 85 | self.threshold_accumulate = QThresholdAccumulateWithSaturate( 86 | v_th0=v_th0, v_l=v_l) 87 | self.lif = QLIFWithTensorThresholdAndResetModeAndRefractory( 88 | v_leaky_alpha=v_leaky_alpha, v_leaky_beta=v_leaky_beta, reset_mode=reset_mode, 89 | v_reset=v_reset, dv=dv, v_init=v_init, window_size=window_size) 90 | self.saturate = QSaturate(v_l=v_l) 91 | self.refractory = Refractory(ref_len=ref_len) 92 | self.threshold_dynamics = QThresholdDynamics( 93 | v_th_alpha=v_th_alpha, v_th_beta=v_th_beta, v_th_incre=v_th_incre) 94 | 95 | def forward(self, u_in, v_th_adpt, scale, v=None, ref_cnt=None): 96 | v_th = self.threshold_accumulate.forward(v_th_adpt, scale) 97 | spike, v = self.lif.forward(u_in, v_th, scale, v, ref_cnt) 98 | v = self.saturate.forward(v, scale) 99 | v_th_adpt = self.threshold_dynamics.forward(v_th_adpt, spike, scale) 100 | ref_cnt = self.refractory.forward(ref_cnt, spike) 101 | return spike, v, v_th_adpt, ref_cnt -------------------------------------------------------------------------------- /examples/ann/squeezenet.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) OpenBII 2 | # Team: CBICR 3 | # SPDX-License-Identifier: Apache-2.0 4 | # See: https://spdx.org/licenses/ 5 | 6 | import pickle 7 | 8 | import numpy as np 9 | import torch 10 | import torch.nn as nn 11 | 12 | from hnn.ann.q_adaptive_avgpool2d import QAdaptiveAvgPool2d 13 | from hnn.ann.q_conv2d import QConv2d 14 | from hnn.ann.q_model import QModel 15 | 16 | 17 | class Fire(nn.Module): 18 | def __init__( 19 | self, 20 | inplanes: int, 21 | squeeze_planes: int, 22 | expand1x1_planes: int, 23 | expand3x3_planes: int 24 | ) -> None: 25 | super(Fire, self).__init__() 26 | self.inplanes = inplanes 27 | self.squeeze = QConv2d(inplanes, squeeze_planes, kernel_size=1) 28 | self.squeeze_activation = nn.ReLU(inplace=True) 29 | self.expand1x1 = QConv2d(squeeze_planes, expand1x1_planes, 30 | kernel_size=1) 31 | self.expand1x1_activation = nn.ReLU(inplace=True) 32 | self.expand3x3 = QConv2d(squeeze_planes, expand3x3_planes, 33 | kernel_size=3, padding=1) 34 | self.expand3x3_activation = nn.ReLU(inplace=True) 35 | 36 | def forward(self, x: torch.Tensor) -> torch.Tensor: 37 | x = self.squeeze_activation(self.squeeze(x)) 38 | return torch.cat([ 39 | self.expand1x1_activation(self.expand1x1(x)), 40 | self.expand3x3_activation(self.expand3x3(x)) 41 | ], 1) 42 | 43 | 44 | class QSqueezeNet(QModel): 45 | def __init__( 46 | self, 47 | version: str = '1_0', 48 | num_classes: int = 1000 49 | ) -> None: 50 | super(QSqueezeNet, self).__init__() 51 | self.num_classes = num_classes 52 | if version == '1_0': 53 | self.features = nn.Sequential( 54 | QConv2d(3, 96, kernel_size=7, stride=2), 55 | nn.ReLU(inplace=True), 56 | nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), 57 | Fire(96, 16, 64, 64), 58 | Fire(128, 16, 64, 64), 59 | Fire(128, 32, 128, 128), 60 | nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), 61 | Fire(256, 32, 128, 128), 62 | Fire(256, 48, 192, 192), 63 | Fire(384, 48, 192, 192), 64 | Fire(384, 64, 256, 256), 65 | nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), 66 | Fire(512, 64, 256, 256), 67 | ) 68 | elif version == '1_1': 69 | self.features = nn.Sequential( 70 | QConv2d(3, 64, kernel_size=3, stride=2), 71 | nn.ReLU(inplace=True), 72 | nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), 73 | Fire(64, 16, 64, 64), 74 | Fire(128, 16, 64, 64), 75 | nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), 76 | Fire(128, 32, 128, 128), 77 | Fire(256, 32, 128, 128), 78 | nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), 79 | Fire(256, 48, 192, 192), 80 | Fire(384, 48, 192, 192), 81 | Fire(384, 64, 256, 256), 82 | Fire(512, 64, 256, 256), 83 | ) 84 | else: 85 | raise ValueError("Unsupported SqueezeNet version {version}:" 86 | "1_0 or 1_1 expected".format(version=version)) 87 | 88 | # Final convolution is initialized differently from the rest 89 | final_conv = QConv2d(512, self.num_classes, kernel_size=1) 90 | self.classifier = nn.Sequential( 91 | final_conv, 92 | nn.ReLU(inplace=True), 93 | QAdaptiveAvgPool2d((1, 1), 13) 94 | ) 95 | 96 | self.model_name = 'QSqueezeNet' 97 | self.input_shape = (1, 3, 224, 224) 98 | 99 | def _forward(self, x: torch.Tensor) -> torch.Tensor: 100 | x = self.features(x) 101 | x = self.classifier(x) 102 | return torch.flatten(x, 1) 103 | 104 | def forward(self, x: torch.Tensor, record_path=None) -> torch.Tensor: 105 | if record_path is not None: 106 | record_dict = {} 107 | record_dict.update({0: x.squeeze(0).permute( 108 | 1, 2, 0).detach().numpy().astype(np.int32)}) 109 | x = self.features(x) 110 | x = self.classifier(x) 111 | record_dict.update( 112 | {188: x.view(-1).unsqueeze(0).unsqueeze(0).detach().numpy().astype(np.int32)}) 113 | with open(record_path, 'wb') as f: 114 | pickle.dump(record_dict, f) 115 | return torch.flatten(x, 1) 116 | else: 117 | return self._forward(x) 118 | 119 | 120 | if __name__ == '__main__': 121 | model = QSqueezeNet() 122 | model.execute(is_random_input=True, fix_random_seed=True, 123 | result_path='temp/QSqueezeNet/o_0_0_0.dat', export_onnx_path='temp/QSqueezeNet/QSqueezeNet.onnx') 124 | -------------------------------------------------------------------------------- /examples/ann/small_squeezenet.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) OpenBII 2 | # Team: CBICR 3 | # SPDX-License-Identifier: Apache-2.0 4 | # See: https://spdx.org/licenses/ 5 | 6 | import pickle 7 | from typing import Dict 8 | 9 | import numpy as np 10 | import torch 11 | import torch.nn as nn 12 | 13 | from hnn.ann.q_adaptive_avgpool2d import QAdaptiveAvgPool2d 14 | from hnn.ann.q_conv2d import QConv2d 15 | from hnn.ann.q_model import QModel 16 | 17 | 18 | class Fire(nn.Module): 19 | def __init__( 20 | self, 21 | inplanes: int, 22 | squeeze_planes: int, 23 | expand1x1_planes: int, 24 | expand3x3_planes: int 25 | ) -> None: 26 | super(Fire, self).__init__() 27 | self.inplanes = inplanes 28 | self.squeeze = QConv2d(inplanes, squeeze_planes, kernel_size=1) 29 | self.squeeze_activation = nn.ReLU(inplace=True) 30 | self.expand1x1 = QConv2d(squeeze_planes, expand1x1_planes, 31 | kernel_size=1) 32 | self.expand1x1_activation = nn.ReLU(inplace=True) 33 | self.expand3x3 = QConv2d(squeeze_planes, expand3x3_planes, 34 | kernel_size=3, padding=1) 35 | self.expand3x3_activation = nn.ReLU(inplace=True) 36 | 37 | def _forward(self, x: torch.Tensor) -> torch.Tensor: 38 | x = self.squeeze_activation(self.squeeze(x)) 39 | return torch.cat([ 40 | self.expand1x1_activation(self.expand1x1(x)), 41 | self.expand3x3_activation(self.expand3x3(x)) 42 | ], 1) 43 | 44 | def forward(self, x: torch.Tensor, record_dict: Dict = None, block_ids=None) -> torch.Tensor: 45 | if record_dict is not None: 46 | x = self.squeeze_activation(self.squeeze(x)) 47 | record_dict.update({block_ids[0]: x.squeeze(0).permute( 48 | 1, 2, 0).detach().numpy().astype(np.int32)}) 49 | x0 = self.expand1x1_activation(self.expand1x1(x)) 50 | record_dict.update({block_ids[1]: x0.squeeze(0).permute( 51 | 1, 2, 0).detach().numpy().astype(np.int32)}) 52 | x1 = self.expand3x3_activation(self.expand3x3(x)) 53 | record_dict.update({block_ids[2]: x1.squeeze(0).permute( 54 | 1, 2, 0).detach().numpy().astype(np.int32)}) 55 | x = torch.cat([x0, x1], 1) 56 | record_dict.update({block_ids[3]: x.squeeze(0).permute( 57 | 1, 2, 0).detach().numpy().astype(np.int32)}) 58 | return x 59 | else: 60 | return self._forward(x) 61 | 62 | 63 | class QSmallSqueezeNet(QModel): 64 | def __init__( 65 | self, 66 | num_classes: int = 10 67 | ) -> None: 68 | super(QSmallSqueezeNet, self).__init__() 69 | self.num_classes = num_classes 70 | self.conv0 = QConv2d(3, 64, kernel_size=3, stride=2) 71 | self.maxpool0 = nn.Sequential( 72 | nn.ReLU(inplace=True), 73 | nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True) 74 | ) 75 | self.fire0 = Fire(64, 16, 64, 64) 76 | self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True) 77 | self.fire1 = Fire(128, 32, 256, 256) 78 | self.maxpool2 = nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True) 79 | 80 | # Final convolution is initialized differently from the rest 81 | final_conv = QConv2d(512, self.num_classes, kernel_size=1) 82 | self.classifier = nn.Sequential( 83 | final_conv, 84 | nn.ReLU(inplace=True), 85 | QAdaptiveAvgPool2d((1, 1), 13) 86 | ) 87 | 88 | self.model_name = 'QSmallSqueezeNet' 89 | self.input_shape = (1, 3, 224, 224) 90 | 91 | def _forward(self, x: torch.Tensor) -> torch.Tensor: 92 | x = self.conv0(x) 93 | x = self.maxpool0(x) 94 | x = self.fire0(x) 95 | x = self.maxpool1(x) 96 | x = self.fire1(x) 97 | x = self.maxpool2(x) 98 | x = self.classifier(x) 99 | return torch.flatten(x, 1) 100 | 101 | def forward(self, x: torch.Tensor, record_path=None) -> torch.Tensor: 102 | if record_path is not None: 103 | record_dict = {} 104 | record_dict.update({0: x.squeeze(0).permute( 105 | 1, 2, 0).detach().numpy().astype(np.int32)}) 106 | x = self.conv0(x) 107 | x = self.maxpool0(x) 108 | record_dict.update({8: x.squeeze(0).permute( 109 | 1, 2, 0).detach().numpy().astype(np.int32)}) 110 | x = self.fire0(x, record_dict, [13, 19, 25, 29]) 111 | x = self.maxpool1(x) 112 | record_dict.update({31: x.squeeze(0).permute( 113 | 1, 2, 0).detach().numpy().astype(np.int32)}) 114 | x = self.fire1(x) 115 | x = self.maxpool2(x) 116 | record_dict.update({54: x.squeeze(0).permute( 117 | 1, 2, 0).detach().numpy().astype(np.int32)}) 118 | x = self.classifier(x) 119 | record_dict.update( 120 | {62: x.view(-1).unsqueeze(0).unsqueeze(0).detach().numpy().astype(np.int32)}) 121 | with open(record_path, 'wb') as f: 122 | pickle.dump(record_dict, f) 123 | return torch.flatten(x, 1) 124 | else: 125 | return self._forward(x) 126 | 127 | 128 | if __name__ == '__main__': 129 | model = QSmallSqueezeNet() 130 | model.execute(is_random_input=True, fix_random_seed=True, 131 | result_path='temp/QSmallSqueezeNet/o_0_0_0.dat', export_onnx_path='temp/QSmallSqueezeNet/QSmallSqueezeNet.onnx') 132 | -------------------------------------------------------------------------------- /docs/sphinx/source/conf.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Configuration file for the Sphinx documentation builder. 4 | # 5 | # This file does only contain a selection of the most common options. For a 6 | # full list see the documentation: 7 | # http://www.sphinx-doc.org/en/master/config 8 | 9 | # -- Path setup -------------------------------------------------------------- 10 | 11 | # If extensions (or modules to document with autodoc) are in another directory, 12 | # add these directories to sys.path here. If the directory is relative to the 13 | # documentation root, use os.path.abspath to make it absolute, like shown here. 14 | # 15 | import os 16 | import sys 17 | sys.path.insert(0, os.path.abspath('.')) 18 | 19 | 20 | # -- Project information ----------------------------------------------------- 21 | 22 | project = 'HNN编程框架' 23 | copyright = 'OpenBII 2023' 24 | author = 'CBICR' 25 | 26 | # The short X.Y version 27 | version = '1.0' 28 | # The full version, including alpha/beta/rc tags 29 | release = '0.1' 30 | 31 | 32 | # -- General configuration --------------------------------------------------- 33 | 34 | # If your documentation needs a minimal Sphinx version, state it here. 35 | # 36 | # needs_sphinx = '1.0' 37 | 38 | # Add any Sphinx extension module names here, as strings. They can be 39 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 40 | # ones. 41 | extensions = ['sphinx.ext.autosectionlabel'] 42 | 43 | # Add any paths that contain templates here, relative to this directory. 44 | templates_path = ['_templates'] 45 | 46 | # The suffix(es) of source filenames. 47 | # You can specify multiple suffix as a list of string: 48 | # 49 | # source_suffix = ['.rst', '.md'] 50 | source_suffix = '.rst' 51 | 52 | # The master toctree document. 53 | master_doc = 'index' 54 | 55 | # The language for content autogenerated by Sphinx. Refer to documentation 56 | # for a list of supported languages. 57 | # 58 | # This is also used if you do content translation via gettext catalogs. 59 | # Usually you set "language" from the command line for these cases. 60 | language = 'zh' 61 | 62 | # List of patterns, relative to source directory, that match files and 63 | # directories to ignore when looking for source files. 64 | # This pattern also affects html_static_path and html_extra_path. 65 | exclude_patterns = [] 66 | 67 | # The name of the Pygments (syntax highlighting) style to use. 68 | pygments_style = None 69 | 70 | 71 | # -- Options for HTML output ------------------------------------------------- 72 | 73 | # The theme to use for HTML and HTML Help pages. See the documentation for 74 | # a list of builtin themes. 75 | # 76 | html_theme = 'sphinx_rtd_theme' 77 | #html_theme = '' 78 | #html_theme = 'yummy_sphinx_theme' 79 | # html_theme = 'renku' 80 | 81 | # Theme options are theme-specific and customize the look and feel of a theme 82 | # further. For a list of options available for each theme, see the 83 | # documentation. 84 | # 85 | # html_theme_options = {} 86 | 87 | # Add any paths that contain custom static files (such as style sheets) here, 88 | # relative to this directory. They are copied after the builtin static files, 89 | # so a file named "default.css" will overwrite the builtin "default.css". 90 | html_static_path = ['_static'] 91 | 92 | # Custom sidebar templates, must be a dictionary that maps document names 93 | # to template names. 94 | # 95 | # The default sidebars (for documents that don't match any pattern) are 96 | # defined by theme itself. Builtin themes are using these templates by 97 | # default: ``['localtoc.html', 'relations.html', 'sourcelink.html', 98 | # 'searchbox.html']``. 99 | # 100 | # html_sidebars = {} 101 | 102 | 103 | # -- Options for HTMLHelp output --------------------------------------------- 104 | 105 | # Output file base name for HTML help builder. 106 | htmlhelp_basename = 'HNN_framework' 107 | 108 | 109 | # -- Options for LaTeX output ------------------------------------------------ 110 | 111 | latex_elements = { 112 | # The paper size ('letterpaper' or 'a4paper'). 113 | # 114 | # 'papersize': 'letterpaper', 115 | 116 | # The font size ('10pt', '11pt' or '12pt'). 117 | # 118 | # 'pointsize': '10pt', 119 | 120 | # Additional stuff for the LaTeX preamble. 121 | # 122 | # 'preamble': '', 123 | 124 | # Latex figure (float) alignment 125 | # 126 | # 'figure_align': 'htbp', 127 | } 128 | 129 | # Grouping the document tree into LaTeX files. List of tuples 130 | # (source start file, target name, title, 131 | # author, documentclass [howto, manual, or own class]). 132 | latex_documents = [ 133 | ] 134 | 135 | 136 | # -- Options for manual page output ------------------------------------------ 137 | 138 | # One entry per manual page. List of tuples 139 | # (source start file, name, description, authors, manual section). 140 | man_pages = [ 141 | ] 142 | 143 | 144 | # -- Options for Texinfo output ---------------------------------------------- 145 | 146 | # Grouping the document tree into Texinfo files. List of tuples 147 | # (source start file, target name, title, author, 148 | # dir menu entry, description, category) 149 | texinfo_documents = [ 150 | ] 151 | 152 | 153 | # -- Options for Epub output ------------------------------------------------- 154 | 155 | # Bibliographic Dublin Core info. 156 | epub_title = project 157 | 158 | # The unique identifier of the text. This can be a ISBN number 159 | # or the project homepage. 160 | # 161 | # epub_identifier = '' 162 | 163 | # A unique identification for the text. 164 | # 165 | # epub_uid = '' 166 | 167 | # A list of files that should not be packed into the epub file. 168 | epub_exclude_files = ['search.html'] 169 | -------------------------------------------------------------------------------- /examples/snn/snn_vgg16.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) OpenBII 2 | # Team: CBICR 3 | # SPDX-License-Identifier: Apache-2.0 4 | # See: https://spdx.org/licenses/ 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | from hnn.snn.lif import QLIF 10 | from hnn.snn.output_rate_coding import OutputRateCoding 11 | from hnn.snn.q_conv2d import QConv2d 12 | from hnn.snn.q_linear import QLinear 13 | from hnn.snn.q_model import QModel 14 | 15 | 16 | class SNNVGG16(QModel): 17 | def __init__(self, T=10, num_classes=1000): 18 | super(SNNVGG16, self).__init__(time_window_size=T) 19 | self.classes = num_classes 20 | self.relu = nn.ReLU(inplace=True) 21 | self.conv0 = QConv2d(3, 64, 3, stride=1, padding=1) 22 | self.lif0 = QLIF(1, 0.9, 0, 0) 23 | self.conv1 = QConv2d(64, 64, 3, stride=1, padding=1) 24 | self.lif1 = QLIF(1, 0.9, 0, 0) 25 | self.maxpool0 = nn.MaxPool2d(2, stride=2) 26 | self.conv2 = QConv2d(64, 128, kernel_size=3, stride=1, padding=1) 27 | self.lif2 = QLIF(1, 0.9, 0, 0) 28 | self.conv3 = QConv2d(128, 128, kernel_size=3, stride=1, padding=1) 29 | self.lif3 = QLIF(1, 0.9, 0, 0) 30 | self.maxpool1 = nn.MaxPool2d(2, stride=2) 31 | self.conv4 = QConv2d(128, 256, kernel_size=3, stride=1, padding=1) 32 | self.lif4 = QLIF(1, 0.9, 0, 0) 33 | self.conv5 = QConv2d(256, 256, kernel_size=3, stride=1, padding=1) 34 | self.lif5 = QLIF(1, 0.9, 0, 0) 35 | self.conv6 = QConv2d(256, 256, kernel_size=3, stride=1, padding=1) 36 | self.lif6 = QLIF(1, 0.9, 0, 0) 37 | self.maxpool2 = nn.MaxPool2d(2, stride=2) 38 | self.conv7 = QConv2d(256, 512, kernel_size=3, stride=1, padding=1) 39 | self.lif7 = QLIF(1, 0.9, 0, 0) 40 | self.conv8 = QConv2d(512, 512, kernel_size=3, stride=1, padding=1) 41 | self.lif8 = QLIF(1, 0.9, 0, 0) 42 | self.conv9 = QConv2d(512, 512, kernel_size=3, stride=1, padding=1) 43 | self.lif9 = QLIF(1, 0.9, 0, 0) 44 | self.maxpool3 = nn.MaxPool2d(2, stride=2) 45 | self.conv10 = QConv2d(512, 512, kernel_size=3, stride=1, padding=1) 46 | self.lif10 = QLIF(1, 0.9, 0, 0) 47 | self.conv11 = QConv2d(512, 512, kernel_size=3, stride=1, padding=1) 48 | self.lif11 = QLIF(1, 0.9, 0, 0) 49 | self.conv12 = QConv2d(512, 512, kernel_size=3, stride=1, padding=1) 50 | self.lif12 = QLIF(1, 0.9, 0, 0) 51 | self.maxpool4 = nn.MaxPool2d(2, stride=2) 52 | 53 | self.fc0 = QLinear(512 * 7 * 7, 4096) 54 | self.fc0lif0 = QLIF(1, 0.9, 0, 0) 55 | self.fc1 = QLinear(4096, 4096) 56 | self.fc1lif1 = QLIF(1, 0.9, 0, 0) 57 | self.fc2 = QLinear(4096, num_classes) 58 | self.fc2lif2 = QLIF(1, 0.9, 0, 0) 59 | self.coding = OutputRateCoding() 60 | 61 | def forward(self, x: torch.Tensor): 62 | spike = torch.zeros((self.T, 2, self.classes)) 63 | v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v_fc0, v_fc1, v_fc2 = None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None 64 | x_copy = x 65 | for i in range(self.T): 66 | x, q = self.conv0(x_copy) 67 | out, v0 = self.lif0(x, q, v0) 68 | x = self.relu(out) 69 | 70 | x, q = self.conv1(x) 71 | out, v1 = self.lif1(x, q, v1) 72 | x = self.relu(out) 73 | x = self.maxpool0(x) 74 | 75 | x, q = self.conv2(x) 76 | out, v2 = self.lif2(x, q, v2) 77 | x = self.relu(out) 78 | 79 | x, q = self.conv3(x) 80 | out, v3 = self.lif3(x, q, v3) 81 | x = self.relu(out) 82 | x = self.maxpool1(x) 83 | 84 | x, q = self.conv4(x) 85 | out, v4 = self.lif4(x, q, v4) 86 | x = self.relu(out) 87 | 88 | x, q = self.conv5(x) 89 | out, v5 = self.lif5(x, q, v5) 90 | x = self.relu(out) 91 | 92 | x, q = self.conv6(x) 93 | out, v6 = self.lif6(x, q, v6) 94 | x = self.relu(out) 95 | x = self.maxpool2(x) 96 | 97 | x, q = self.conv7(x) 98 | out, v7 = self.lif7(x, q, v7) 99 | x = self.relu(out) 100 | 101 | x, q = self.conv8(x) 102 | out, v8 = self.lif8(x, q, v8) 103 | x = self.relu(out) 104 | 105 | x, q = self.conv9(x) 106 | out, v9 = self.lif9(x, q, v9) 107 | x = self.relu(out) 108 | x = self.maxpool3(x) 109 | 110 | x, q = self.conv10(x) 111 | out, v10 = self.lif10(x, q, v10) 112 | x = self.relu(out) 113 | 114 | x, q = self.conv11(x) 115 | out, v11 = self.lif11(x, q, v11) 116 | x = self.relu(out) 117 | 118 | x, q = self.conv12(x) 119 | out, v12 = self.lif12(x, q, v12) 120 | x = self.relu(out) 121 | x = self.maxpool4(x) 122 | 123 | x = torch.flatten(x, 1) 124 | x, q = self.fc0(x) 125 | out, v_fc0 = self.fc0lif0(x, q, v_fc0) 126 | x = self.relu(out) 127 | 128 | x, q = self.fc1(x) 129 | out, v_fc1 = self.fc1lif1(x, q, v_fc1) 130 | x = self.relu(out) 131 | 132 | x, q = self.fc2(x) 133 | out, v_fc2 = self.fc2lif2(x, q, v_fc2) 134 | spike[i] = out 135 | return self.coding(spike) 136 | 137 | 138 | if __name__ == '__main__': 139 | model = SNNVGG16(T=10) 140 | x = torch.randn([2, 3, 224, 224]) 141 | y = model(x) 142 | 143 | torch.onnx.export(model, x, 'temp/SNNVGG16.onnx', 144 | custom_opsets={'snn': 1}, opset_version=11) 145 | -------------------------------------------------------------------------------- /examples/snn/snn_vgg19.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) OpenBII 2 | # Team: CBICR 3 | # SPDX-License-Identifier: Apache-2.0 4 | # See: https://spdx.org/licenses/ 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | from hnn.snn.lif import QLIF 10 | from hnn.snn.output_rate_coding import OutputRateCoding 11 | from hnn.snn.q_conv2d import QConv2d 12 | from hnn.snn.q_linear import QLinear 13 | from hnn.snn.q_model import QModel 14 | 15 | 16 | class SNNVGG19(QModel): 17 | def __init__(self, T=10, num_classes=1000): 18 | super(SNNVGG19, self).__init__(time_window_size=T) 19 | self.classes = num_classes 20 | self.relu = nn.ReLU(inplace=True) 21 | self.conv0 = QConv2d(3, 64, 3, stride=1, padding=1) 22 | self.lif0 = QLIF(1, 0.9, 0, 0) 23 | self.conv1 = QConv2d(64, 64, 3, stride=1, padding=1) 24 | self.lif1 = QLIF(1, 0.9, 0, 0) 25 | self.maxpool0 = nn.MaxPool2d(2, stride=2) 26 | self.conv2 = QConv2d(64, 128, kernel_size=3, stride=1, padding=1) 27 | self.lif2 = QLIF(1, 0.9, 0, 0) 28 | self.conv3 = QConv2d(128, 128, kernel_size=3, stride=1, padding=1) 29 | self.lif3 = QLIF(1, 0.9, 0, 0) 30 | self.maxpool1 = nn.MaxPool2d(2, stride=2) 31 | self.conv4 = QConv2d(128, 256, kernel_size=3, stride=1, padding=1) 32 | self.lif4 = QLIF(1, 0.9, 0, 0) 33 | self.conv5 = QConv2d(256, 256, kernel_size=3, stride=1, padding=1) 34 | self.lif5 = QLIF(1, 0.9, 0, 0) 35 | self.conv6 = QConv2d(256, 256, kernel_size=3, stride=1, padding=1) 36 | self.lif6 = QLIF(1, 0.9, 0, 0) 37 | self.conv7 = QConv2d(256, 256, kernel_size=3, stride=1, padding=1) 38 | self.lif7 = QLIF(1, 0.9, 0, 0) 39 | self.maxpool2 = nn.MaxPool2d(2, stride=2) 40 | self.conv8 = QConv2d(256, 512, kernel_size=3, stride=1, padding=1) 41 | self.lif8 = QLIF(1, 0.9, 0, 0) 42 | self.conv9 = QConv2d(512, 512, kernel_size=3, stride=1, padding=1) 43 | self.lif9 = QLIF(1, 0.9, 0, 0) 44 | self.conv10 = QConv2d(512, 512, kernel_size=3, stride=1, padding=1) 45 | self.lif10 = QLIF(1, 0.9, 0, 0) 46 | self.conv11 = QConv2d(512, 512, kernel_size=3, stride=1, padding=1) 47 | self.lif11 = QLIF(1, 0.9, 0, 0) 48 | self.maxpool3 = nn.MaxPool2d(2, stride=2) 49 | self.conv12 = QConv2d(512, 512, kernel_size=3, stride=1, padding=1) 50 | self.lif12 = QLIF(1, 0.9, 0, 0) 51 | self.conv13 = QConv2d(512, 512, kernel_size=3, stride=1, padding=1) 52 | self.lif13 = QLIF(1, 0.9, 0, 0) 53 | self.conv14 = QConv2d(512, 512, kernel_size=3, stride=1, padding=1) 54 | self.lif14 = QLIF(1, 0.9, 0, 0) 55 | self.conv15 = QConv2d(512, 512, kernel_size=3, stride=1, padding=1) 56 | self.lif15 = QLIF(1, 0.9, 0, 0) 57 | self.maxpool4 = nn.MaxPool2d(2, stride=2) 58 | 59 | self.fc0 = QLinear(512 * 7 * 7, 4096) 60 | self.fc0lif0 = QLIF(1, 0.9, 0, 0) 61 | self.fc1 = QLinear(4096, 4096) 62 | self.fc1lif1 = QLIF(1, 0.9, 0, 0) 63 | self.fc2 = QLinear(4096, num_classes) 64 | self.fc2lif2 = QLIF(1, 0.9, 0, 0) 65 | self.coding = OutputRateCoding() 66 | 67 | def forward(self, x: torch.Tensor): 68 | spike = torch.zeros((self.T, 2, self.classes)) 69 | v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v_fc0, v_fc1, v_fc2 = None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None 70 | x_copy = x 71 | for i in range(self.T): 72 | x, q = self.conv0(x_copy) 73 | out, v0 = self.lif0(x, q, v0) 74 | x = self.relu(out) 75 | 76 | x, q = self.conv1(x) 77 | out, v1 = self.lif1(x, q, v1) 78 | x = self.relu(out) 79 | x = self.maxpool0(x) 80 | 81 | x, q = self.conv2(x) 82 | out, v2 = self.lif2(x, q, v2) 83 | x = self.relu(out) 84 | 85 | x, q = self.conv3(x) 86 | out, v3 = self.lif3(x, q, v3) 87 | x = self.relu(out) 88 | x = self.maxpool1(x) 89 | 90 | x, q = self.conv4(x) 91 | out, v4 = self.lif4(x, q, v4) 92 | x = self.relu(out) 93 | 94 | x, q = self.conv5(x) 95 | out, v5 = self.lif5(x, q, v5) 96 | x = self.relu(out) 97 | 98 | x, q = self.conv6(x) 99 | out, v6 = self.lif6(x, q, v6) 100 | x = self.relu(out) 101 | 102 | x, q = self.conv7(x) 103 | out, v7 = self.lif7(x, q, v7) 104 | x = self.relu(out) 105 | x = self.maxpool2(x) 106 | 107 | x, q = self.conv8(x) 108 | out, v8 = self.lif8(x, q, v8) 109 | x = self.relu(out) 110 | 111 | x, q = self.conv9(x) 112 | out, v9 = self.lif9(x, q, v9) 113 | x = self.relu(out) 114 | 115 | x, q = self.conv10(x) 116 | out, v10 = self.lif10(x, q, v10) 117 | x = self.relu(out) 118 | 119 | x, q = self.conv11(x) 120 | out, v11 = self.lif11(x, q, v11) 121 | x = self.relu(out) 122 | x = self.maxpool3(x) 123 | 124 | x, q = self.conv12(x) 125 | out, v12 = self.lif12(x, q, v12) 126 | x = self.relu(out) 127 | 128 | x, q = self.conv13(x) 129 | out, v13 = self.lif13(x, q, v13) 130 | x = self.relu(out) 131 | 132 | x, q = self.conv14(x) 133 | out, v14 = self.lif14(x, q, v14) 134 | x = self.relu(out) 135 | 136 | x, q = self.conv15(x) 137 | out, v15 = self.lif15(x, q, v15) 138 | x = self.relu(out) 139 | x = self.maxpool4(x) 140 | 141 | x = torch.flatten(x, 1) 142 | x, q = self.fc0(x) 143 | out, v_fc0 = self.fc0lif0(x, q, v_fc0) 144 | x = self.relu(out) 145 | 146 | x, q = self.fc1(x) 147 | out, v_fc1 = self.fc1lif1(x, q, v_fc1) 148 | x = self.relu(out) 149 | 150 | x, q = self.fc2(x) 151 | out, v_fc2 = self.fc2lif2(x, q, v_fc2) 152 | spike[i] = out 153 | return self.coding(spike) 154 | 155 | 156 | if __name__ == '__main__': 157 | model = SNNVGG19(T=10) 158 | x = torch.randn([2, 3, 224, 224]) 159 | y = model(x) 160 | 161 | torch.onnx.export(model, x, 'temp/SNNVGG19.onnx', 162 | custom_opsets={'snn': 1}, opset_version=11) 163 | -------------------------------------------------------------------------------- /hnn/fuse_bn.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) OpenBII 2 | # Team: CBICR 3 | # SPDX-License-Identifier: Apache-2.0 4 | # See: https://spdx.org/licenses/ 5 | 6 | from __future__ import (absolute_import, division, print_function, 7 | unicode_literals) 8 | 9 | import copy 10 | 11 | import torch 12 | 13 | from hnn.ann.q_conv2d import QConv2d 14 | from hnn.ann.q_linear import QLinear 15 | from hnn.snn.q_conv2d import QConv2d as SQConv2d 16 | from hnn.snn.q_linear import QLinear as SQLinear 17 | 18 | 19 | def fuse2d_conv_bn(conv: torch.nn.Conv2d, bn: torch.nn.BatchNorm2d): 20 | assert (conv.training == bn.training), \ 21 | "Conv and BN both must be in the same mode (train or eval)." 22 | new_conv = torch.nn.Conv2d(in_channels=conv.in_channels, out_channels=conv.out_channels, 23 | kernel_size=conv.kernel_size, stride=conv.stride, padding=conv.padding, 24 | dilation=conv.dilation, groups=conv.groups, bias=True, padding_mode=conv.padding_mode) 25 | 26 | if bn.affine: 27 | gamma = bn.weight.data / torch.sqrt(bn.running_var + bn.eps) 28 | new_conv.weight.data = conv.weight.data * gamma.view(-1, 1, 1, 1) 29 | if conv.bias is not None: 30 | new_conv.bias.data = gamma * conv.bias.data - \ 31 | gamma * bn.running_mean + bn.bias.data 32 | else: 33 | new_conv.bias.data = bn.bias.data - gamma * bn.running_mean 34 | else: 35 | "affine 为 False 的情况, gamma=1, beta=0" 36 | gamma = 1 / torch.sqrt(bn.running_var + bn.eps) 37 | new_conv.weight.data = conv.weight.data * gamma 38 | if conv.bias is not None: 39 | new_conv.bias.data = gamma * conv.bias.data - gamma * bn.running_mean 40 | else: 41 | new_conv.bias.data = - gamma * bn.running_mean 42 | return new_conv 43 | 44 | 45 | def fuse1d_linear_bn(linear: torch.nn.Linear, bn: torch.nn.BatchNorm1d): 46 | assert (linear.training == bn.training), \ 47 | "Linear and BN both must be in the same mode (train or eval)." 48 | new_linear = torch.nn.Linear( 49 | in_features=linear.in_features, out_features=linear.out_features, bias=True) 50 | 51 | if bn.affine: 52 | gamma = bn.weight.data / torch.sqrt(bn.running_var + bn.eps) 53 | new_linear.weight.data = linear.weight.data * gamma.view(-1, 1) 54 | if linear.bias is not None: 55 | new_linear.bias.data = gamma * linear.bias.data - \ 56 | gamma * bn.running_mean + bn.bias.data 57 | else: 58 | new_linear.bias.data = bn.bias.data - gamma * bn.running_mean 59 | else: 60 | "affine 为 False 的情况, gamma=1, beta=0" 61 | gamma = 1 / torch.sqrt(bn.running_var + bn.eps) 62 | new_linear.weight.data = linear.weight.data * gamma 63 | if linear.bias is not None: 64 | new_linear.bias.data = gamma * linear.bias.data - gamma * bn.running_mean 65 | else: 66 | new_linear.bias.data = - gamma * bn.running_mean 67 | return new_linear 68 | 69 | 70 | # Generalization of getattr 71 | def _get_module(model, submodule_key): 72 | tokens = submodule_key.split('.') 73 | cur_mod = model 74 | for s in tokens: 75 | cur_mod = getattr(cur_mod, s) 76 | return cur_mod 77 | 78 | 79 | # Generalization of setattr 80 | def _set_module(model, submodule_key, module): 81 | tokens = submodule_key.split('.') 82 | sub_tokens = tokens[:-1] 83 | cur_mod = model 84 | for s in sub_tokens: 85 | cur_mod = getattr(cur_mod, s) 86 | 87 | setattr(cur_mod, tokens[-1], module) 88 | 89 | 90 | def fuse_known_modules(mod_list): 91 | OP_LIST_TO_FUSER_METHOD = { 92 | # (torch.nn.Conv1d, torch.nn.BatchNorm1d): fuse_conv_bn, 93 | # (torch.nn.Conv1d, torch.nn.BatchNorm1d, torch.nn.ReLU): fuse_conv_bn_relu, 94 | (torch.nn.Conv2d, torch.nn.BatchNorm2d): fuse2d_conv_bn, 95 | (torch.nn.Linear, torch.nn.BatchNorm1d): fuse1d_linear_bn, 96 | (QConv2d, torch.nn.BatchNorm2d): fuse2d_conv_bn, 97 | (QLinear, torch.nn.BatchNorm1d): fuse1d_linear_bn, 98 | (SQConv2d, torch.nn.BatchNorm2d): fuse2d_conv_bn, 99 | (SQLinear, torch.nn.BatchNorm1d): fuse1d_linear_bn, 100 | # (torch.nn.Conv2d, torch.nn.BatchNorm2d, torch.nn.ReLU): fuse_conv_bn_relu, 101 | # (torch.nn.Conv3d, torch.nn.BatchNorm3d): fuse_conv_bn, 102 | # (torch.nn.Conv3d, torch.nn.BatchNorm3d, torch.nn.ReLU): fuse_conv_bn_relu, 103 | # (torch.nn.Conv1d, torch.nn.ReLU): torch.nn.intrinsic.ConvReLU1d, 104 | # (torch.nn.Conv2d, torch.nn.ReLU): torch.nn.intrinsic.ConvReLU2d, 105 | # (torch.nn.Conv3d, torch.nn.ReLU): torch.nn.intrinsic.ConvReLU3d, 106 | # (torch.nn.Linear, torch.nn.ReLU): torch.nn.intrinsic.LinearReLU, 107 | # (torch.nn.BatchNorm2d, torch.nn.ReLU): torch.nn.intrinsic.BNReLU2d, 108 | # (torch.nn.BatchNorm3d, torch.nn.ReLU): torch.nn.intrinsic.BNReLU3d, 109 | } 110 | 111 | types = tuple(type(m) for m in mod_list) 112 | fuser_method = OP_LIST_TO_FUSER_METHOD.get(types, None) 113 | if fuser_method is None: 114 | raise NotImplementedError("Cannot fuse modules: {}".format(types)) 115 | new_mod = [None] * len(mod_list) 116 | new_mod[0] = fuser_method(*mod_list) 117 | 118 | for i in range(1, len(mod_list)): 119 | new_mod[i] = torch.nn.Identity() 120 | new_mod[i].training = mod_list[0].training 121 | 122 | return new_mod 123 | 124 | 125 | def _fuse_modules(model, modules_to_fuse, fuser_func=fuse_known_modules): 126 | mod_list = [] 127 | for item in modules_to_fuse: 128 | mod_list.append(_get_module(model, item)) 129 | 130 | # Fuse list of modules 131 | new_mod_list = fuser_func(mod_list) 132 | 133 | # Replace original module list with fused module list 134 | for i, item in enumerate(modules_to_fuse): 135 | _set_module(model, item, new_mod_list[i]) 136 | 137 | 138 | def fuse_modules(model, modules_to_fuse, inplace=False, fuser_func=fuse_known_modules): 139 | if not inplace: 140 | model = copy.deepcopy(model) 141 | 142 | if all(isinstance(module_element, str) for module_element in modules_to_fuse): 143 | # Handle case of modules_to_fuse being a list 144 | _fuse_modules(model, modules_to_fuse, fuser_func) 145 | else: 146 | # Handle case of modules_to_fuse being a list of lists 147 | for module_list in modules_to_fuse: 148 | _fuse_modules(model, module_list, fuser_func) 149 | return model 150 | --------------------------------------------------------------------------------