├── README.md ├── args_config.py ├── cifar_dvs_dataset.py ├── network_utils.py ├── quant_net.py ├── quant_resnet.py ├── spike_related.py ├── train_snn.py └── training_utils.py /README.md: -------------------------------------------------------------------------------- 1 | # MINT_Quantization 2 | 3 | ## TODO: 4 | I will clean up the codes soon... 5 | 6 | ## Notice: 7 | I found the code to have some errors when using different PyTorch versions. I will solve the problem later. 8 | For now, please run the code using PyTorch with version 1.13.0. This version is tested to be working. Thanks. 9 | 10 | ## Citing 11 | If you find MINT is useful for your research, please use the following bibtex to cite us, 12 | 13 | ``` 14 | @inproceedings{yin2024mint, 15 | title={MINT: Multiplier-less INTeger Quantization for Energy Efficient Spiking Neural Networks}, 16 | author={Yin, Ruokai and Li, Yuhang and Moitra, Abhishek and Panda, Priyadarshini}, 17 | booktitle={2024 29th Asia and South Pacific Design Automation Conference (ASP-DAC)}, 18 | pages={830--835}, 19 | year={2024}, 20 | organization={IEEE} 21 | } 22 | ``` 23 | -------------------------------------------------------------------------------- /args_config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def get_args(): 5 | 6 | parser = argparse.ArgumentParser("UQSNN") 7 | 8 | parser.add_argument("--batch_size", default=256, type=int, help="Batch size") 9 | parser.add_argument('--lr', type=float, default=1e-3) 10 | # parser.add_argument('--gpu', type=str, default='0') 11 | # parser.add_argument('--dump-dir', type=str, default="logdir") 12 | # parser.add_argument("--encode", default="d", type=str, help="Encoding [p d]") 13 | parser.add_argument("--arch", default="res19", type=str, help="Arch [vgg9,vgg16,res19]") 14 | parser.add_argument('--dataset_dir', type=str, default='../dataset/', help='path to the dataset') 15 | parser.add_argument("--dataset", default="dvs", type=str, help="Dataset [cifar10,svhn,tiny,dvs]") 16 | parser.add_argument("--optim", default='adam', type=str, help="Optimizer [adam, sgd]") 17 | parser.add_argument('--leak_mem',default=0.5, type=float) 18 | parser.add_argument('--th',default=0.5, type=float) 19 | parser.add_argument('--rst',default="hard", type=str) 20 | parser.add_argument('--T', type=int, default=10) 21 | parser.add_argument('-uq', action='store_true') 22 | parser.add_argument('-bq', action='store_true') 23 | parser.add_argument('-wq', action='store_true') 24 | parser.add_argument('-share', action='store_true') 25 | parser.add_argument('-sft_rst', action='store_true') 26 | parser.add_argument('-conv_b', action='store_true') 27 | parser.add_argument('-bn_a', action='store_true') 28 | parser.add_argument('-xa', action='store_true') 29 | parser.add_argument('-ts', action='store_true') 30 | 31 | parser.add_argument('--epoch', type=int, default=200) 32 | # parser.add_argument("--seed", default=0, type=int, help="Random seed") 33 | parser.add_argument("--num_workers", default=4, type=int, help="number of workers") 34 | parser.add_argument("--train_display_freq", default=1, type=int, help="display_freq for train") 35 | parser.add_argument("--test_display_freq", default=1, type=int, help="display_freq for test") 36 | # parser.add_argument("--setting", type=str, help="display_freq for test") 37 | # parser.add_argument('--quant', default=4, type=int, help='quantization-bits') 38 | args = parser.parse_args() 39 | 40 | 41 | 42 | return args -------------------------------------------------------------------------------- /cifar_dvs_dataset.py: -------------------------------------------------------------------------------- 1 | from spikingjelly.datasets.cifar10_dvs import CIFAR10DVS 2 | from torchvision.transforms import Resize 3 | from torch.utils.data import Subset 4 | import numpy as np 5 | import math 6 | import tqdm 7 | import torch 8 | import torchvision.transforms as transforms 9 | 10 | class MyCIFAR10DVS(Subset): 11 | def __init__(self, root, train_ratio=0.9, data_type="frame", frames_number=10, split_by="number", random_split=False, size=(48, 48)): 12 | transform_dvs = transforms.Compose([ 13 | lambda x: torch.from_numpy(x), 14 | Resize(size=size)]) 15 | 16 | dataset_dvs = CIFAR10DVS(root=root, data_type=data_type, frames_number=frames_number, split_by=split_by, transform=transform_dvs) 17 | 18 | train_idx, test_idx = self.split_to_train_test_set(train_ratio=train_ratio, origin_dataset=dataset_dvs, num_classes=10, random_split=random_split) 19 | self.train_dvs = train_idx 20 | self.test_dvs = test_idx 21 | 22 | def split(self): 23 | 24 | return self.train_dvs, self.test_dvs 25 | 26 | 27 | # super().__init__(dataset_dvs, train_idx if train_ratio == 0.9 else test_idx) 28 | 29 | # def __getitem__(self, index): 30 | # data, label = self.dataset[index] 31 | # return data, label 32 | 33 | @staticmethod 34 | def split_to_train_test_set(train_ratio: float, origin_dataset: torch.utils.data.Dataset, num_classes: int, random_split: bool = False): 35 | 36 | label_idx = [] 37 | for i in range(num_classes): 38 | label_idx.append([]) 39 | 40 | for i, item in enumerate(tqdm.tqdm(origin_dataset)): 41 | y = item[1] 42 | if isinstance(y, np.ndarray) or isinstance(y, torch.Tensor): 43 | y = y.item() 44 | label_idx[y].append(i) 45 | train_idx = [] 46 | test_idx = [] 47 | if random_split: 48 | for i in range(num_classes): 49 | np.random.shuffle(label_idx[i]) 50 | 51 | for i in range(num_classes): 52 | pos = math.ceil(label_idx[i].__len__() * train_ratio) 53 | train_idx.extend(label_idx[i][0: pos]) 54 | test_idx.extend(label_idx[i][pos: label_idx[i].__len__()]) 55 | 56 | return torch.utils.data.Subset(origin_dataset, train_idx), torch.utils.data.Subset(origin_dataset, test_idx) 57 | # def split_to_train_test_set(train_ratio: float, origin_dataset: torch.utils.data.Dataset, num_classes: int, random_split: bool = False): 58 | 59 | # label_idx = [] 60 | # for i in range(num_classes): 61 | # label_idx.append([]) 62 | 63 | # for i, item in enumerate(tqdm.tqdm(origin_dataset)): 64 | # y = item[1] 65 | # if isinstance(y, np.ndarray) or isinstance(y, torch.Tensor): 66 | # y = y.item() 67 | # label_idx[y].append(i) 68 | # train_idx = [] 69 | # test_idx = [] 70 | 71 | # for i in range(num_classes): 72 | # np.random.shuffle(label_idx[i]) # ensuring random selection even if random_split is False 73 | 74 | # pos = math.ceil(len(label_idx[i]) * train_ratio) 75 | # train_idx.extend(label_idx[i][0: pos]) 76 | # test_idx.extend(label_idx[i][pos: ]) # making sure there's no overlap 77 | 78 | # return torch.utils.data.Subset(origin_dataset, train_idx), torch.utils.data.Subset(origin_dataset, test_idx) 79 | -------------------------------------------------------------------------------- /network_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import torch.nn.functional as F 5 | import torchvision 6 | import args_config 7 | from torchvision import datasets, transforms 8 | import gc 9 | from torch.autograd import Variable 10 | import torch.optim as optim 11 | import torch.backends.cudnn as cudnn 12 | from statistics import mean 13 | import math 14 | from training_utils import * 15 | 16 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 17 | cudnn.benchmark = True 18 | cudnn.deterministic = True 19 | 20 | args = args_config.get_args() 21 | 22 | 23 | 24 | class QFConvBN2dLIF(nn.Module): 25 | """ folding the conv2d and batchnorm2d in the inference""" 26 | 27 | def __init__(self, conv_module, bn_module, lif_module, num_bits_w=4, num_bits_bias=4, num_bits_u=4): 28 | super(QFConvBN2dLIF,self).__init__() 29 | 30 | self.conv_module = conv_module 31 | self.bn_module = bn_module 32 | self.lif_module = lif_module 33 | 34 | self.num_bits_w = num_bits_w 35 | self.num_bits_bias = num_bits_bias 36 | self.num_bits_u = num_bits_u 37 | 38 | # initial_w = conv_module.weight.data.abs().max() 39 | initial_beta = torch.Tensor(conv_module.weight.abs().mean() * 2 / math.sqrt((2**(self.num_bits_w-1)-1))) 40 | # print(initial_w) 41 | self.beta = nn.ParameterList([nn.Parameter(initial_beta) for i in range(1)]).cuda() 42 | # nn.ParameterList([nn.Parameter(initial_w) for i in range(1)]).cuda() 43 | # print(self.scaling[0]) 44 | # self.scaling = nn.Parameter(,requires_grad=True).cuda() 45 | # print(self.scaling) 46 | # self.scaling.requires_grad_(requires_grad=True) 47 | 48 | def fold_bn(self, mean, std): 49 | if self.bn_module.affine: 50 | gamma_ = self.bn_module.weight / std 51 | weight = self.conv_module.weight * gamma_.reshape(self.conv_module.out_channels, 1, 1, 1) 52 | if self.conv_module.bias is not None: 53 | bias = gamma_ * self.conv_module.bias - gamma_ * mean + self.bn_module.bias 54 | else: 55 | bias = self.bn_module.bias - gamma_ * mean 56 | else: 57 | gamma_ = 1 / std 58 | # print(std.shape) 59 | # print(self.conv_module.weight.shape) 60 | weight = self.conv_module.weight * gamma_.reshape(self.conv_module.out_channels, 1, 1, 1) 61 | if self.conv_module.bias is not None: 62 | bias = gamma_ * self.conv_module.bias - gamma_ * mean 63 | else: 64 | bias = -gamma_ * mean 65 | # bias = 0*mean 66 | 67 | return weight, bias 68 | 69 | 70 | def forward(self, x): 71 | if self.training: 72 | ### Get the bn stats first, doing conv first 73 | y = F.conv2d(x, self.conv_module.weight, self.conv_module.bias, 74 | stride=self.conv_module.stride, 75 | padding=self.conv_module.padding, 76 | dilation=self.conv_module.dilation, 77 | groups=self.conv_module.groups) 78 | 79 | y = y.permute(1, 0, 2, 3) # NCHW -> CNHW 80 | y = y.reshape(self.conv_module.out_channels, -1) # CNHW -> (C,NHW) 81 | mean = y.mean(1) 82 | var = y.var(1) 83 | 84 | self.bn_module.running_mean = \ 85 | self.bn_module.momentum * self.bn_module.running_mean + \ 86 | (1 - self.bn_module.momentum) * mean 87 | self.bn_module.running_var = \ 88 | self.bn_module.momentum * self.bn_module.running_var + \ 89 | (1 - self.bn_module.momentum) * var 90 | 91 | else: 92 | #### Using long term mean and var during inference 93 | mean = self.bn_module.running_mean 94 | var = self.bn_module.running_var 95 | 96 | std = torch.sqrt(var + self.bn_module.eps) 97 | weight, bias = self.fold_bn(mean, std) 98 | 99 | # print("w max:", weight.max()) 100 | # print("b max:", bias.max()) 101 | # print("w min:", weight.min()) 102 | # print("b min:", bias.min()) 103 | # if self.scaling is None: 104 | # 105 | # self.scaling = nn.ParameterList([nn.Parameter(torch.tensor([alpha])) for i in range(1)]).cuda() 106 | # else: 107 | # qweight = w_q(weight, self.num_bits_w, self.scaling[0]) 108 | if args.wq: 109 | if args.share: 110 | qweight,beta = w_q(weight, self.num_bits_w, self.beta[0]) 111 | else: 112 | qweight = b_q(weight, self.num_bits_w) 113 | else: 114 | qweight = weight 115 | 116 | if args.bq: 117 | if args.share: 118 | qbias,beta = w_q(bias, self.num_bits_bias, beta) 119 | else: 120 | qbias = b_q(bias, self.num_bits_bias) 121 | else: 122 | qbias = bias 123 | 124 | x = F.conv2d(x, qweight,qbias, 125 | stride=self.conv_module.stride, 126 | padding=self.conv_module.padding, 127 | dilation=self.conv_module.dilation, 128 | groups=self.conv_module.groups) 129 | 130 | if args.share: 131 | s = self.lif_module(x, args.share, beta) 132 | else: 133 | s = self.lif_module(x, args.share, 0) 134 | 135 | return s 136 | 137 | 138 | class QConv2dLIF(nn.Module): 139 | """ integerate the conv2d and LIF in the inference""" 140 | 141 | def __init__(self, conv_module, lif_module, num_bits_w=4, num_bits_u=4): 142 | super(QConv2dLIF,self).__init__() 143 | 144 | self.conv_module = conv_module 145 | self.lif_module = lif_module 146 | 147 | self.num_bits_w = num_bits_w 148 | self.num_bits_u = num_bits_u 149 | 150 | # initial_w = conv_module.weight.data.abs().max() 151 | initial_beta = torch.Tensor(conv_module.weight.abs().mean() * 2 / math.sqrt((2**(self.num_bits_w-1)-1))) 152 | # print(initial_w) 153 | self.beta = nn.ParameterList([nn.Parameter(initial_beta) for i in range(1)]).cuda() 154 | # nn.ParameterList([nn.Parameter(initial_w) for i in range(1)]).cuda() 155 | # print(self.scaling[0]) 156 | # self.scaling = nn.Parameter(,requires_grad=True).cuda() 157 | # print(self.scaling) 158 | # self.scaling.requires_grad_(requires_grad=True) 159 | 160 | 161 | 162 | 163 | def forward(self, x): 164 | # if self.training: 165 | if args.wq: 166 | if args.share: 167 | qweight,beta = w_q(self.conv_module.weight, self.num_bits_w, self.beta[0]) 168 | else: 169 | qweight = b_q(self.conv_module.weight, self.num_bits_w) 170 | else: 171 | qweight = self.conv_module.weight 172 | # qweight= w_q(self.weight, self.num_bits_weight, in_alpha) 173 | x = F.conv2d(x, qweight, self.conv_module.bias, stride=self.conv_module.stride, 174 | padding=self.conv_module.padding, 175 | dilation=self.conv_module.dilation, 176 | groups=self.conv_module.groups) 177 | 178 | if args.share: 179 | s = self.lif_module(x, args.share, beta, bias=0) 180 | else: 181 | s = self.lif_module(x, args.share, 0, bias=0) 182 | # else: 183 | # if args.wq: 184 | # if args.share: 185 | # qweight,beta = w_q_inference(self.conv_module.weight, self.num_bits_w, self.beta[0]) 186 | # else: 187 | # qweight = b_q_inference(self.conv_module.weight, self.num_bits_w) 188 | # else: 189 | # qweight = self.conv_module.weight 190 | # # print(torch.unique(qweight).shape) 191 | # # qweight= w_q(self.weight, self.num_bits_weight, in_alpha) 192 | # x = F.conv2d(x, qweight, self.conv_module.bias, stride=self.conv_module.stride, 193 | # padding=self.conv_module.padding, 194 | # dilation=self.conv_module.dilation, 195 | # groups=self.conv_module.groups) 196 | 197 | # if args.share: 198 | # s = self.lif_module(x, args.share, beta, bias=0) 199 | # else: 200 | # s = self.lif_module(x, args.share, 0, bias=0) 201 | 202 | return s 203 | 204 | 205 | class QConvBN2dLIF(nn.Module): 206 | """ integerate the conv2d, BN, and LIF in the inference""" 207 | 208 | def __init__(self, conv_module, bn_module, lif_module, num_bits_w=4,num_bits_b=4, num_bits_u=4): 209 | super(QConvBN2dLIF,self).__init__() 210 | 211 | self.conv_module = conv_module 212 | self.lif_module = lif_module 213 | self.bn_module = bn_module 214 | 215 | self.num_bits_w = num_bits_w 216 | self.num_bits_b = num_bits_b 217 | self.num_bits_u = num_bits_u 218 | 219 | # initial_w = conv_module.weight.data.abs().max() 220 | initial_beta = torch.Tensor(conv_module.weight.abs().mean() * 2 / math.sqrt((2**(self.num_bits_w-1)-1))) 221 | # print(initial_w) 222 | self.beta = nn.ParameterList([nn.Parameter(initial_beta) for i in range(1)]).cuda() 223 | # nn.ParameterList([nn.Parameter(initial_w) for i in range(1)]).cuda() 224 | # print(self.scaling[0]) 225 | # self.scaling = nn.Parameter(,requires_grad=True).cuda() 226 | # print(self.scaling) 227 | # self.scaling.requires_grad_(requires_grad=True) 228 | 229 | 230 | 231 | 232 | def forward(self, x): 233 | # if self.training: 234 | if args.wq: 235 | if args.share: 236 | qweight,beta = w_q(self.conv_module.weight, self.num_bits_w, self.beta[0]) 237 | else: 238 | qweight = b_q(self.conv_module.weight, self.num_bits_w) 239 | else: 240 | qweight = self.conv_module.weight 241 | # qweight= w_q(self.weight, self.num_bits_weight, in_alpha) 242 | x = F.conv2d(x, qweight, self.conv_module.bias, stride=self.conv_module.stride, 243 | padding=self.conv_module.padding, 244 | dilation=self.conv_module.dilation, 245 | groups=self.conv_module.groups) 246 | x = self.bn_module(x) 247 | # mean = self.bn_module.running_mean 248 | # var = self.bn_module.running_var 249 | # std = torch.sqrt(var + self.bn_module.eps) 250 | # gamma_ = (self.bn_module.weight / std) 251 | # bias = (self.bn_module.bias - gamma_ * mean) 252 | # gamma_ = gamma_.reshape(1,self.conv_module.out_channels, 1, 1) 253 | # bias = bias.reshape(1,self.conv_module.out_channels, 1, 1) 254 | # x = gamma_*x + bias 255 | if args.share: 256 | s = self.lif_module(x, args.share, beta, bias=0) 257 | else: 258 | s = self.lif_module(x, args.share, 0, bias=0) 259 | # else: 260 | # if args.wq: 261 | # if args.share: 262 | # qweight,beta = w_q_inference(self.conv_module.weight, self.num_bits_w, self.beta[0]) 263 | # else: 264 | # qweight,_ = b_q_inference(self.conv_module.weight, self.num_bits_w) 265 | # else: 266 | # qweight = self.conv_module.weight 267 | # # print(torch.unique(qweight).shape) 268 | # # qweight= w_q(self.weight, self.num_bits_weight, in_alpha) 269 | # x = F.conv2d(x, qweight, self.conv_module.bias, stride=self.conv_module.stride, 270 | # padding=self.conv_module.padding, 271 | # dilation=self.conv_module.dilation, 272 | # groups=self.conv_module.groups) 273 | 274 | # # print(x.shape) 275 | 276 | # mean = self.bn_module.running_mean 277 | # var = self.bn_module.running_var 278 | # std = torch.sqrt(var + self.bn_module.eps) 279 | # gamma_ = (self.bn_module.weight / std) 280 | # bias = (self.bn_module.bias - gamma_ * mean) 281 | # gamma_ = gamma_.reshape(1,self.conv_module.out_channels, 1, 1) 282 | # bias = bias.reshape(1,self.conv_module.out_channels, 1, 1) 283 | 284 | # # print(gamma_.shape) 285 | # if args.share: 286 | # x = gamma_*x 287 | # else: 288 | # x = gamma_*x+ bias 289 | 290 | # if args.share: 291 | # s = self.lif_module(x, args.share, beta, bias=bias/beta) 292 | # else: 293 | # s = self.lif_module(x, args.share, 0) 294 | 295 | return s 296 | 297 | 298 | class QConvBN2d(nn.Module): 299 | """ integerate the conv2d and BN in the inference""" 300 | 301 | def __init__(self, conv_module, bn_module, num_bits_w=4,num_bits_u=4,short_cut=False): 302 | super(QConvBN2d,self).__init__() 303 | 304 | self.conv_module = conv_module 305 | self.bn_module = bn_module 306 | 307 | self.num_bits_w = num_bits_w 308 | self.num_bits_u = num_bits_u 309 | self.short_cut = short_cut 310 | 311 | 312 | initial_beta = torch.Tensor(conv_module.weight.abs().mean() * 2 / math.sqrt((2**(self.num_bits_w-1)-1))) 313 | self.beta = nn.ParameterList([nn.Parameter(initial_beta) for i in range(1)]).cuda() 314 | 315 | def forward(self, x): 316 | if args.wq: 317 | if args.share: 318 | qweight,beta = w_q(self.conv_module.weight, self.num_bits_w, self.beta[0]) 319 | else: 320 | qweight = b_q(self.conv_module.weight, self.num_bits_w) 321 | else: 322 | qweight = self.conv_module.weight 323 | # qweight= w_q(self.weight, self.num_bits_weight, in_alpha) 324 | x = F.conv2d(x, qweight, self.conv_module.bias, stride=self.conv_module.stride, 325 | padding=self.conv_module.padding, 326 | dilation=self.conv_module.dilation, 327 | groups=self.conv_module.groups) 328 | x = self.bn_module(x) 329 | 330 | return x -------------------------------------------------------------------------------- /quant_net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import torch.nn.functional as F 5 | import torchvision 6 | import args_config 7 | from torchvision import datasets, transforms 8 | import gc 9 | from torch.autograd import Variable 10 | import torch.optim as optim 11 | import torch.backends.cudnn as cudnn 12 | from statistics import mean 13 | from training_utils import Firing, w_q, b_q 14 | from network_utils import * 15 | from spike_related import LIFSpike 16 | 17 | args = args_config.get_args() 18 | # firing = Firing.apply 19 | 20 | 21 | 22 | class Q_ShareScale_VGG9(nn.Module): 23 | def __init__(self,time_step,dataset): 24 | super(Q_ShareScale_VGG9, self).__init__() 25 | 26 | #### Set bitwidth for quantization 27 | self.num_bits_w = 2 28 | self.num_bits_u = 2 29 | 30 | #### Print out the parameters for quantization 31 | 32 | print("quant bw for w: " + str(self.num_bits_w)) 33 | print("quant bw for u: " + str(self.num_bits_u)) 34 | 35 | #### Other parameters for SNNs 36 | self.time_step = time_step 37 | 38 | input_dim = 3 39 | 40 | # print(args.th) 41 | 42 | 43 | self.conv1 = nn.Conv2d(input_dim, 64, kernel_size=3, padding=1, bias=False) 44 | self.direct_lif = LIFSpike(thresh=args.th, leak=args.leak_mem, gamma=1.0, soft_reset=args.sft_rst, quant_u=False) 45 | # self.ConvBnLif1 = QConvBN2dLIF(conv1,bn1,self.num_bits_w,self.num_bits_b,self.num_bits_u) 46 | 47 | conv2 = nn.Conv2d(64,64, kernel_size=3, padding=1, bias=False) 48 | lif2 = LIFSpike(thresh=args.th, leak=args.leak_mem, gamma=1.0, soft_reset=args.sft_rst, quant_u=args.uq, num_bits_u=self.num_bits_u) 49 | self.ConvLif2 = QConv2dLIF(conv2,lif2,self.num_bits_w,self.num_bits_u) 50 | 51 | self.pool1 = nn.MaxPool2d(kernel_size=2) 52 | 53 | conv3 = nn.Conv2d(64,128, kernel_size=3, padding=1, bias=False) 54 | lif3 = LIFSpike(thresh=args.th, leak=args.leak_mem, gamma=1.0, soft_reset=args.sft_rst, quant_u=args.uq, num_bits_u=self.num_bits_u) 55 | self.ConvLif3 = QConv2dLIF(conv3,lif3,self.num_bits_w,self.num_bits_u) 56 | 57 | conv4 = nn.Conv2d(128,128, kernel_size=3, padding=1, bias=False) 58 | lif4 = LIFSpike(thresh=args.th, leak=args.leak_mem, gamma=1.0, soft_reset=args.sft_rst, quant_u=args.uq, num_bits_u=self.num_bits_u) 59 | self.ConvLif4 = QConv2dLIF(conv4,lif4,self.num_bits_w,self.num_bits_u) 60 | 61 | self.pool2 = nn.MaxPool2d(kernel_size=2) 62 | 63 | conv5 = nn.Conv2d(128,256, kernel_size=3, padding=1, bias=False) 64 | lif5 = LIFSpike(thresh=args.th, leak=args.leak_mem, gamma=1.0, soft_reset=args.sft_rst, quant_u=args.uq, num_bits_u=self.num_bits_u) 65 | self.ConvLif5 = QConv2dLIF(conv5,lif5,self.num_bits_w,self.num_bits_u) 66 | 67 | conv6 = nn.Conv2d(256,256, kernel_size=3, padding=1, bias=False) 68 | lif6 = LIFSpike(thresh=args.th, leak=args.leak_mem, gamma=1.0, soft_reset=args.sft_rst, quant_u=args.uq, num_bits_u=self.num_bits_u) 69 | self.ConvLif6 = QConv2dLIF(conv6,lif6,self.num_bits_w,self.num_bits_u) 70 | 71 | conv7 = nn.Conv2d(256,256, kernel_size=3, padding=1, bias=False) 72 | lif7 = LIFSpike(thresh=args.th, leak=args.leak_mem, gamma=1.0, soft_reset=args.sft_rst, quant_u=args.uq, num_bits_u=self.num_bits_u) 73 | self.ConvLif7 = QConv2dLIF(conv7,lif7,self.num_bits_w,self.num_bits_u) 74 | 75 | self.pool3 = nn.AdaptiveAvgPool2d((1, 1)) 76 | 77 | if dataset == 'tiny': 78 | size = 1 79 | clas = 200 80 | else: 81 | size = 1 82 | clas = 10 83 | self.fc_out = nn.Linear(256*(size**2), clas, bias=True) 84 | 85 | self.weight_init() 86 | 87 | def reset_dynamics(self): 88 | for m in self.modules(): 89 | if isinstance(m,QConv2dLIF): 90 | m.lif_module.reset_mem() 91 | self.direct_lif.reset_mem() 92 | return 0 93 | 94 | def weight_init(self): 95 | for m in self.modules(): 96 | if isinstance(m,QConvBN2dLIF): 97 | nn.init.kaiming_uniform_(m.conv_module.weight) 98 | if isinstance(m,nn.Linear): 99 | nn.init.kaiming_uniform_(m.weight) 100 | 101 | 102 | def forward(self, inp): 103 | 104 | u_out = [] 105 | self.reset_dynamics() 106 | static_input = self.conv1(inp) 107 | 108 | for t in range(self.time_step): 109 | s = self.direct_lif.direct_forward(static_input,False,0) 110 | 111 | s = self.ConvLif2(s) 112 | # print(torch.sum(s)) 113 | s = self.pool1(s) 114 | 115 | s = self.ConvLif3(s) 116 | s = self.ConvLif4(s) 117 | 118 | s = self.pool2(s) 119 | 120 | s = self.ConvLif5(s) 121 | s = self.ConvLif6(s) 122 | s = self.ConvLif7(s) 123 | 124 | s = self.pool3(s) 125 | 126 | s = s.view(s.shape[0],-1) 127 | s = self.fc_out(s) 128 | 129 | u_out += [s] 130 | 131 | return u_out 132 | 133 | 134 | 135 | 136 | class Q_ShareScale_VGG16(nn.Module): 137 | def __init__(self,time_step,dataset): 138 | super(Q_ShareScale_VGG16, self).__init__() 139 | 140 | #### Set bitwidth for quantization 141 | self.num_bits_w = 4 142 | self.num_bits_b = 4 143 | self.num_bits_u = 4 144 | 145 | #### Print out the parameters for quantization 146 | 147 | print("quant bw for w: " + str(self.num_bits_w)) 148 | print("quant bw for b: " + str(self.num_bits_b)) 149 | print("quant bw for u: " + str(self.num_bits_u)) 150 | 151 | #### Other parameters for SNNs 152 | self.time_step = time_step 153 | 154 | if dataset == 'dvs': 155 | input_dim = 2 156 | else: 157 | input_dim = 3 158 | 159 | # print(args.th) 160 | 161 | self.conv1 = nn.Conv2d(input_dim, 64, kernel_size=3, padding=1, bias=False) 162 | self.bn1 = nn.BatchNorm2d(64,affine=True) 163 | self.direct_lif = LIFSpike(thresh=args.th, leak=args.leak_mem, gamma=1.0, soft_reset=args.sft_rst, quant_u=False) 164 | # self.ConvBnLif1 = QConvBN2dLIF(conv1,bn1,self.num_bits_w,self.num_bits_b,self.num_bits_u) 165 | 166 | conv1dvs = nn.Conv2d(input_dim, 64, kernel_size=3, padding=1, bias=False) 167 | bn1dvs = nn.BatchNorm2d(64,affine=True) 168 | lif1dvs = LIFSpike(thresh=args.th, leak=args.leak_mem, gamma=1.0, soft_reset=args.sft_rst, quant_u=args.uq, num_bits_u=self.num_bits_u) 169 | self.ConvBnLif1 = QConvBN2dLIF(conv1dvs,bn1dvs,lif1dvs,self.num_bits_w,self.num_bits_b,self.num_bits_u) 170 | 171 | conv2 = nn.Conv2d(64,64, kernel_size=3, padding=1, bias=args.conv_b) 172 | bn2 = nn.BatchNorm2d(64,affine=args.bn_a) 173 | lif2 = LIFSpike(thresh=args.th, leak=args.leak_mem, gamma=1.0, soft_reset=args.sft_rst, quant_u=args.uq, num_bits_u=self.num_bits_u) 174 | self.ConvBnLif2 = QConvBN2dLIF(conv2,bn2,lif2,self.num_bits_w,self.num_bits_b,self.num_bits_u) 175 | 176 | self.pool1 = nn.MaxPool2d(kernel_size=2) 177 | 178 | conv3 = nn.Conv2d(64,128, kernel_size=3, padding=1, bias=args.conv_b) 179 | bn3 = nn.BatchNorm2d(128,affine=args.bn_a) 180 | lif3 = LIFSpike(thresh=args.th, leak=args.leak_mem, gamma=1.0, soft_reset=args.sft_rst, quant_u=args.uq, num_bits_u=self.num_bits_u) 181 | self.ConvBnLif3 = QConvBN2dLIF(conv3,bn3,lif3,self.num_bits_w,self.num_bits_b,self.num_bits_u) 182 | 183 | conv4 = nn.Conv2d(128,128, kernel_size=3, padding=1, bias=args.conv_b) 184 | bn4 = nn.BatchNorm2d(128,affine=args.bn_a) 185 | lif4 = LIFSpike(thresh=args.th, leak=args.leak_mem, gamma=1.0, soft_reset=args.sft_rst, quant_u=args.uq, num_bits_u=self.num_bits_u) 186 | self.ConvBnLif4 = QConvBN2dLIF(conv4,bn4,lif4,self.num_bits_w,self.num_bits_b,self.num_bits_u) 187 | 188 | self.pool2 = nn.MaxPool2d(kernel_size=2) 189 | 190 | conv5 = nn.Conv2d(128,256, kernel_size=3, padding=1, bias=args.conv_b) 191 | bn5 = nn.BatchNorm2d(256,affine=args.bn_a) 192 | lif5 = LIFSpike(thresh=args.th, leak=args.leak_mem, gamma=1.0, soft_reset=args.sft_rst, quant_u=args.uq, num_bits_u=self.num_bits_u) 193 | self.ConvBnLif5 = QConvBN2dLIF(conv5,bn5,lif5,self.num_bits_w,self.num_bits_b,self.num_bits_u) 194 | 195 | conv6 = nn.Conv2d(256,256, kernel_size=3, padding=1, bias=args.conv_b) 196 | bn6 = nn.BatchNorm2d(256,affine=args.bn_a) 197 | lif6 = LIFSpike(thresh=args.th, leak=args.leak_mem, gamma=1.0, soft_reset=args.sft_rst, quant_u=args.uq, num_bits_u=self.num_bits_u) 198 | self.ConvBnLif6 = QConvBN2dLIF(conv6,bn6,lif6,self.num_bits_w,self.num_bits_b,self.num_bits_u) 199 | 200 | conv7 = nn.Conv2d(256,256, kernel_size=3, padding=1, bias=args.conv_b) 201 | bn7 = nn.BatchNorm2d(256,affine=args.bn_a) 202 | lif7 = LIFSpike(thresh=args.th, leak=args.leak_mem, gamma=1.0, soft_reset=args.sft_rst, quant_u=args.uq, num_bits_u=self.num_bits_u) 203 | self.ConvBnLif7 = QConvBN2dLIF(conv7,bn7,lif7,self.num_bits_w,self.num_bits_b,self.num_bits_u) 204 | 205 | self.pool3 = nn.MaxPool2d(kernel_size=2) 206 | 207 | conv8 = nn.Conv2d(256,512, kernel_size=3, padding=1, bias=args.conv_b) 208 | bn8 = nn.BatchNorm2d(512,affine=args.bn_a) 209 | lif8 = LIFSpike(thresh=args.th, leak=args.leak_mem, gamma=1.0, soft_reset=args.sft_rst, quant_u=args.uq, num_bits_u=self.num_bits_u) 210 | self.ConvBnLif8 = QConvBN2dLIF(conv8,bn8,lif8,self.num_bits_w,self.num_bits_b,self.num_bits_u) 211 | 212 | conv9 = nn.Conv2d(512,512, kernel_size=3, padding=1, bias=args.conv_b) 213 | bn9 = nn.BatchNorm2d(512,affine=args.bn_a) 214 | lif9 = LIFSpike(thresh=args.th, leak=args.leak_mem, gamma=1.0, soft_reset=args.sft_rst, quant_u=args.uq, num_bits_u=self.num_bits_u) 215 | self.ConvBnLif9 = QConvBN2dLIF(conv9,bn9,lif9,self.num_bits_w,self.num_bits_b,self.num_bits_u) 216 | 217 | conv10 = nn.Conv2d(512,512, kernel_size=3, padding=1, bias=args.conv_b) 218 | bn10 = nn.BatchNorm2d(512,affine=args.bn_a) 219 | lif10 = LIFSpike(thresh=args.th, leak=args.leak_mem, gamma=1.0, soft_reset=args.sft_rst, quant_u=args.uq, num_bits_u=self.num_bits_u) 220 | self.ConvBnLif10 = QConvBN2dLIF(conv10,bn10,lif10,self.num_bits_w,self.num_bits_b,self.num_bits_u) 221 | 222 | self.pool4 = nn.MaxPool2d(kernel_size=2) 223 | 224 | conv11 = nn.Conv2d(512,512, kernel_size=3, padding=1, bias=args.conv_b) 225 | bn11 = nn.BatchNorm2d(512,affine=args.bn_a) 226 | lif11 = LIFSpike(thresh=args.th, leak=args.leak_mem, gamma=1.0, soft_reset=args.sft_rst, quant_u=args.uq, num_bits_u=self.num_bits_u) 227 | self.ConvBnLif11 = QConvBN2dLIF(conv11,bn11,lif11,self.num_bits_w,self.num_bits_b,self.num_bits_u) 228 | 229 | conv12 = nn.Conv2d(512,512, kernel_size=3, padding=1, bias=args.conv_b) 230 | bn12 = nn.BatchNorm2d(512,affine=args.bn_a) 231 | lif12 = LIFSpike(thresh=args.th, leak=args.leak_mem, gamma=1.0, soft_reset=args.sft_rst, quant_u=args.uq, num_bits_u=self.num_bits_u) 232 | self.ConvBnLif12 = QConvBN2dLIF(conv12,bn12,lif12,self.num_bits_w,self.num_bits_b,self.num_bits_u) 233 | 234 | conv13 = nn.Conv2d(512,512, kernel_size=3, padding=1, bias=args.conv_b) 235 | bn13 = nn.BatchNorm2d(512,affine=args.bn_a) 236 | lif13 = LIFSpike(thresh=args.th, leak=args.leak_mem, gamma=1.0, soft_reset=args.sft_rst, quant_u=args.uq, num_bits_u=self.num_bits_u) 237 | self.ConvBnLif13 = QConvBN2dLIF(conv13,bn13,lif13,self.num_bits_w,self.num_bits_b,self.num_bits_u) 238 | 239 | self.pool5 = nn.AvgPool2d(kernel_size=2) 240 | 241 | if dataset == 'tiny': 242 | size = 2 243 | clas = 200 244 | else: 245 | size = 1 246 | clas = 10 247 | self.fc_out = nn.Linear(512*(size**2), clas, bias=True) 248 | 249 | self.dataset = dataset 250 | 251 | self.weight_init() 252 | 253 | def reset_dynamics(self): 254 | for m in self.modules(): 255 | if isinstance(m,QConvBN2dLIF): 256 | m.lif_module.reset_mem() 257 | self.direct_lif.reset_mem() 258 | return 0 259 | 260 | def weight_init(self): 261 | for m in self.modules(): 262 | if isinstance(m,QConvBN2dLIF): 263 | nn.init.kaiming_uniform_(m.conv_module.weight) 264 | if isinstance(m,nn.Linear): 265 | nn.init.kaiming_uniform_(m.weight) 266 | 267 | 268 | def forward(self, inp): 269 | 270 | u_out = [] 271 | self.reset_dynamics() 272 | if self.dataset != 'dvs': 273 | static_input = self.bn1(self.conv1(inp)) 274 | 275 | for t in range(self.time_step): 276 | if self.dataset == 'dvs': 277 | s = inp[:,t].to(torch.float32).cuda() 278 | s = self.ConvBnLif1(s) 279 | else: 280 | s = self.direct_lif.direct_forward(static_input,False,0) 281 | 282 | s = self.ConvBnLif2(s) 283 | # print(torch.sum(s)) 284 | s = self.pool1(s) 285 | 286 | s = self.ConvBnLif3(s) 287 | s = self.ConvBnLif4(s) 288 | 289 | s = self.pool2(s) 290 | 291 | s = self.ConvBnLif5(s) 292 | s = self.ConvBnLif6(s) 293 | s = self.ConvBnLif7(s) 294 | 295 | s = self.pool3(s) 296 | 297 | s = self.ConvBnLif8(s) 298 | s = self.ConvBnLif9(s) 299 | s = self.ConvBnLif10(s) 300 | 301 | s = self.pool4(s) 302 | 303 | s = self.ConvBnLif11(s) 304 | s = self.ConvBnLif12(s) 305 | s = self.ConvBnLif13(s) 306 | # print(torch.sum(s)) 307 | s = self.pool5(s) 308 | s = s.view(s.shape[0],-1) 309 | s = self.fc_out(s) 310 | 311 | u_out += [s] 312 | 313 | return u_out 314 | 315 | 316 | 317 | 318 | 319 | class Q_ShareScale_Fold_VGG16(nn.Module): 320 | def __init__(self,time_step,dataset): 321 | super(Q_ShareScale_Fold_VGG16, self).__init__() 322 | 323 | #### Set bitwidth for quantization 324 | self.num_bits_w = 8 325 | self.num_bits_b = 8 326 | self.num_bits_u = 8 327 | 328 | #### Print out the parameters for quantization 329 | 330 | print("quant bw for w: " + str(self.num_bits_w)) 331 | print("quant bw for b: " + str(self.num_bits_b)) 332 | print("quant bw for u: " + str(self.num_bits_u)) 333 | 334 | #### Other parameters for SNNs 335 | self.time_step = time_step 336 | 337 | input_dim = 3 338 | 339 | # print(args.th) 340 | 341 | self.conv1 = nn.Conv2d(input_dim, 64, kernel_size=3, padding=1, bias=False) 342 | self.bn1 = nn.BatchNorm2d(64,affine=True) 343 | self.direct_lif = LIFSpike(thresh=args.th, leak=args.leak_mem, gamma=1.0, soft_reset=args.sft_rst, quant_u=False) 344 | # self.ConvBnLif1 = QConvBN2dLIF(conv1,bn1,self.num_bits_w,self.num_bits_b,self.num_bits_u) 345 | 346 | conv2 = nn.Conv2d(64,64, kernel_size=3, padding=1, bias=args.conv_b) 347 | bn2 = nn.BatchNorm2d(64,affine=args.bn_a) 348 | lif2 = LIFSpike(thresh=args.th, leak=args.leak_mem, gamma=1.0, soft_reset=args.sft_rst, quant_u=args.uq, num_bits_u=self.num_bits_u) 349 | self.ConvBnLif2 = QConvBN2dLIF(conv2,bn2,lif2,self.num_bits_w,self.num_bits_b,self.num_bits_u) 350 | 351 | self.pool1 = nn.MaxPool2d(kernel_size=2) 352 | 353 | conv3 = nn.Conv2d(64,128, kernel_size=3, padding=1, bias=args.conv_b) 354 | bn3 = nn.BatchNorm2d(128,affine=args.bn_a) 355 | lif3 = LIFSpike(thresh=args.th, leak=args.leak_mem, gamma=1.0, soft_reset=args.sft_rst, quant_u=args.uq, num_bits_u=self.num_bits_u) 356 | self.ConvBnLif3 = QConvBN2dLIF(conv3,bn3,lif3,self.num_bits_w,self.num_bits_b,self.num_bits_u) 357 | 358 | conv4 = nn.Conv2d(128,128, kernel_size=3, padding=1, bias=args.conv_b) 359 | bn4 = nn.BatchNorm2d(128,affine=args.bn_a) 360 | lif4 = LIFSpike(thresh=args.th, leak=args.leak_mem, gamma=1.0, soft_reset=args.sft_rst, quant_u=args.uq, num_bits_u=self.num_bits_u) 361 | self.ConvBnLif4 = QConvBN2dLIF(conv4,bn4,lif4,self.num_bits_w,self.num_bits_b,self.num_bits_u) 362 | 363 | self.pool2 = nn.MaxPool2d(kernel_size=2) 364 | 365 | conv5 = nn.Conv2d(128,256, kernel_size=3, padding=1, bias=args.conv_b) 366 | bn5 = nn.BatchNorm2d(256,affine=args.bn_a) 367 | lif5 = LIFSpike(thresh=args.th, leak=args.leak_mem, gamma=1.0, soft_reset=args.sft_rst, quant_u=args.uq, num_bits_u=self.num_bits_u) 368 | self.ConvBnLif5 = QConvBN2dLIF(conv5,bn5,lif5,self.num_bits_w,self.num_bits_b,self.num_bits_u) 369 | 370 | conv6 = nn.Conv2d(256,256, kernel_size=3, padding=1, bias=args.conv_b) 371 | bn6 = nn.BatchNorm2d(256,affine=args.bn_a) 372 | lif6 = LIFSpike(thresh=args.th, leak=args.leak_mem, gamma=1.0, soft_reset=args.sft_rst, quant_u=args.uq, num_bits_u=self.num_bits_u) 373 | self.ConvBnLif6 = QConvBN2dLIF(conv6,bn6,lif6,self.num_bits_w,self.num_bits_b,self.num_bits_u) 374 | 375 | conv7 = nn.Conv2d(256,256, kernel_size=3, padding=1, bias=args.conv_b) 376 | bn7 = nn.BatchNorm2d(256,affine=args.bn_a) 377 | lif7 = LIFSpike(thresh=args.th, leak=args.leak_mem, gamma=1.0, soft_reset=args.sft_rst, quant_u=args.uq, num_bits_u=self.num_bits_u) 378 | self.ConvBnLif7 = QConvBN2dLIF(conv7,bn7,lif7,self.num_bits_w,self.num_bits_b,self.num_bits_u) 379 | 380 | self.pool3 = nn.MaxPool2d(kernel_size=2) 381 | 382 | conv8 = nn.Conv2d(256,512, kernel_size=3, padding=1, bias=args.conv_b) 383 | bn8 = nn.BatchNorm2d(512,affine=args.bn_a) 384 | lif8 = LIFSpike(thresh=args.th, leak=args.leak_mem, gamma=1.0, soft_reset=args.sft_rst, quant_u=args.uq, num_bits_u=self.num_bits_u) 385 | self.ConvBnLif8 = QConvBN2dLIF(conv8,bn8,lif8,self.num_bits_w,self.num_bits_b,self.num_bits_u) 386 | 387 | conv9 = nn.Conv2d(512,512, kernel_size=3, padding=1, bias=args.conv_b) 388 | bn9 = nn.BatchNorm2d(512,affine=args.bn_a) 389 | lif9 = LIFSpike(thresh=args.th, leak=args.leak_mem, gamma=1.0, soft_reset=args.sft_rst, quant_u=args.uq, num_bits_u=self.num_bits_u) 390 | self.ConvBnLif9 = QConvBN2dLIF(conv9,bn9,lif9,self.num_bits_w,self.num_bits_b,self.num_bits_u) 391 | 392 | conv10 = nn.Conv2d(512,512, kernel_size=3, padding=1, bias=args.conv_b) 393 | bn10 = nn.BatchNorm2d(512,affine=args.bn_a) 394 | lif10 = LIFSpike(thresh=args.th, leak=args.leak_mem, gamma=1.0, soft_reset=args.sft_rst, quant_u=args.uq, num_bits_u=self.num_bits_u) 395 | self.ConvBnLif10 = QConvBN2dLIF(conv10,bn10,lif10,self.num_bits_w,self.num_bits_b,self.num_bits_u) 396 | 397 | self.pool4 = nn.MaxPool2d(kernel_size=2) 398 | 399 | conv11 = nn.Conv2d(512,512, kernel_size=3, padding=1, bias=args.conv_b) 400 | bn11 = nn.BatchNorm2d(512,affine=args.bn_a) 401 | lif11 = LIFSpike(thresh=args.th, leak=args.leak_mem, gamma=1.0, soft_reset=args.sft_rst, quant_u=args.uq, num_bits_u=self.num_bits_u) 402 | self.ConvBnLif11 = QConvBN2dLIF(conv11,bn11,lif11,self.num_bits_w,self.num_bits_b,self.num_bits_u) 403 | 404 | conv12 = nn.Conv2d(512,512, kernel_size=3, padding=1, bias=args.conv_b) 405 | bn12 = nn.BatchNorm2d(512,affine=args.bn_a) 406 | lif12 = LIFSpike(thresh=args.th, leak=args.leak_mem, gamma=1.0, soft_reset=args.sft_rst, quant_u=args.uq, num_bits_u=self.num_bits_u) 407 | self.ConvBnLif12 = QConvBN2dLIF(conv12,bn12,lif12,self.num_bits_w,self.num_bits_b,self.num_bits_u) 408 | 409 | conv13 = nn.Conv2d(512,512, kernel_size=3, padding=1, bias=args.conv_b) 410 | bn13 = nn.BatchNorm2d(512,affine=args.bn_a) 411 | lif13 = LIFSpike(thresh=args.th, leak=args.leak_mem, gamma=1.0, soft_reset=args.sft_rst, quant_u=args.uq, num_bits_u=self.num_bits_u) 412 | self.ConvBnLif13 = QConvBN2dLIF(conv13,bn13,lif13,self.num_bits_w,self.num_bits_b,self.num_bits_u) 413 | 414 | self.pool5 = nn.AvgPool2d(kernel_size=2) 415 | 416 | if dataset == 'tiny': 417 | size = 2 418 | clas = 200 419 | else: 420 | size = 1 421 | clas = 10 422 | self.fc_out = nn.Linear(512*(size**2), clas, bias=True) 423 | 424 | self.weight_init() 425 | 426 | def reset_dynamics(self): 427 | for m in self.modules(): 428 | if isinstance(m,QConvBN2dLIF): 429 | m.lif_module.reset_mem() 430 | self.direct_lif.reset_mem() 431 | return 0 432 | 433 | def weight_init(self): 434 | for m in self.modules(): 435 | if isinstance(m,QConvBN2dLIF): 436 | nn.init.kaiming_uniform_(m.conv_module.weight) 437 | if isinstance(m,nn.Linear): 438 | nn.init.kaiming_uniform_(m.weight) 439 | 440 | 441 | def forward(self, inp): 442 | 443 | u_out = [] 444 | self.reset_dynamics() 445 | static_input = self.bn1(self.conv1(inp)) 446 | 447 | for t in range(self.time_step): 448 | s = self.direct_lif(static_input,False,0) 449 | 450 | s = self.ConvBnLif2(s) 451 | # print(torch.sum(s)) 452 | s = self.pool1(s) 453 | 454 | s = self.ConvBnLif3(s) 455 | s = self.ConvBnLif4(s) 456 | 457 | s = self.pool2(s) 458 | 459 | s = self.ConvBnLif5(s) 460 | s = self.ConvBnLif6(s) 461 | s = self.ConvBnLif7(s) 462 | 463 | s = self.pool3(s) 464 | 465 | s = self.ConvBnLif8(s) 466 | s = self.ConvBnLif9(s) 467 | s = self.ConvBnLif10(s) 468 | 469 | s = self.pool4(s) 470 | 471 | s = self.ConvBnLif11(s) 472 | s = self.ConvBnLif12(s) 473 | s = self.ConvBnLif13(s) 474 | # print(torch.sum(s)) 475 | s = self.pool5(s) 476 | s = s.view(s.shape[0],-1) 477 | s = self.fc_out(s) 478 | 479 | u_out += [s] 480 | 481 | return u_out -------------------------------------------------------------------------------- /quant_resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import torch.nn.functional as F 5 | import torchvision 6 | import args_config 7 | from torchvision import datasets, transforms 8 | import gc 9 | from torch.autograd import Variable 10 | import torch.optim as optim 11 | import torch.backends.cudnn as cudnn 12 | from statistics import mean 13 | from training_utils import Firing, w_q, b_q 14 | from network_utils import * 15 | from spike_related import LIFSpike 16 | 17 | args = args_config.get_args() 18 | # firing = Firing.apply 19 | 20 | 21 | import torch 22 | import torch.nn as nn 23 | import torch.nn.functional as F 24 | from spikingjelly.clock_driven import functional, layer, surrogate, neuron 25 | 26 | # tau_global = 1./(1. - 0.5) 27 | 28 | class BasicBlock(nn.Module): 29 | 30 | expansion = 1 31 | 32 | def __init__(self, in_planes, planes, n_w, n_u, n_b, stride=1): 33 | super(BasicBlock, self).__init__() 34 | 35 | 36 | self.num_bits_w = n_w 37 | self.num_bits_b = n_b 38 | self.num_bits_u = n_u 39 | 40 | conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 41 | bn1 = nn.BatchNorm2d(planes) 42 | lif1 = LIFSpike(thresh=args.th, leak=args.leak_mem, gamma=1.0, soft_reset=args.sft_rst, quant_u=args.uq, num_bits_u=self.num_bits_u) 43 | self.ConvBnLif1 = QConvBN2dLIF(conv1,bn1,lif1,self.num_bits_w,self.num_bits_b,self.num_bits_u) 44 | 45 | conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 46 | bn2 = nn.BatchNorm2d(planes) 47 | self.ConvBn2 = QConvBN2d(conv2,bn2,self.num_bits_w,self.num_bits_u) 48 | 49 | 50 | self.lif2 = LIFSpike(thresh=args.th, leak=args.leak_mem, gamma=1.0, soft_reset=args.sft_rst, quant_u=args.uq, num_bits_u=self.num_bits_u) 51 | 52 | self.shortcut = nn.Sequential() 53 | conv_sh = nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) 54 | bn_sh = nn.BatchNorm2d(self.expansion*planes) 55 | # self.ConvBn_sh = QConvBN2d(conv_sh,bn_sh,self.num_bits_w,self.num_bits_u) 56 | if stride != 1 or in_planes != self.expansion*planes: 57 | self.shortcut = nn.Sequential( 58 | QConvBN2d(conv_sh,bn_sh,self.num_bits_w,self.num_bits_u,short_cut=True) 59 | ) 60 | 61 | def forward(self, x): 62 | out = self.ConvBnLif1(x) 63 | out = self.ConvBn2(out) 64 | out += self.shortcut(x) 65 | out = self.lif2(out, args.share, self.ConvBn2.beta[0], bias=0) 66 | return out 67 | 68 | 69 | class ResNet(nn.Module): 70 | def __init__(self, block, num_blocks, num_classes=10, total_timestep =4): 71 | super(ResNet, self).__init__() 72 | self.in_planes = 64 73 | self.total_timestep = total_timestep 74 | if args.dataset == 'dvs': 75 | input_dim = 2 76 | else: 77 | input_dim = 3 78 | 79 | self.conv1 = nn.Conv2d(input_dim, 64, kernel_size=3, padding=1, bias=False) 80 | self.bn1 = nn.BatchNorm2d(64) 81 | 82 | self.num_bits_u = 16 83 | self.num_bits_w = 16 84 | self.num_bits_b = 8 85 | 86 | print("ResNet-basic-block weight bits: ", self.num_bits_w) 87 | print("ResNet-basic-block potential bits: ", self.num_bits_u) 88 | 89 | conv1dvs = nn.Conv2d(input_dim, 64, kernel_size=3, padding=1, bias=False) 90 | bn1dvs = nn.BatchNorm2d(64,affine=True) 91 | lif1dvs = LIFSpike(thresh=args.th, leak=args.leak_mem, gamma=1.0, soft_reset=args.sft_rst, quant_u=args.uq, num_bits_u=self.num_bits_u) 92 | self.ConvBnLif1 = QConvBN2dLIF(conv1dvs,bn1dvs,lif1dvs,self.num_bits_w,self.num_bits_b,self.num_bits_u) 93 | 94 | self.direct_lif = LIFSpike(thresh=args.th, leak=args.leak_mem, gamma=1.0, soft_reset=args.sft_rst, quant_u=False) 95 | # self.lif_input = neuron.ParametricLIFNode(v_threshold=1.0, v_reset=0.0, init_tau=2., 96 | # surrogate_function=surrogate.ATan(), 97 | # detach_reset=True) 98 | 99 | 100 | self.layer1 = self._make_layer(block, 128, num_blocks[0], self.num_bits_w, self.num_bits_u, self.num_bits_b, stride=2) 101 | self.layer2 = self._make_layer(block, 256, num_blocks[1], self.num_bits_w, self.num_bits_u, self.num_bits_b, stride=2) 102 | self.layer3 = self._make_layer(block, 512, num_blocks[2], self.num_bits_w, self.num_bits_u, self.num_bits_b, stride=2) 103 | 104 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 105 | 106 | self.fc1 = nn.Linear(512, 256) 107 | self.lif_fc = LIFSpike(thresh=args.th, leak=args.leak_mem, gamma=1.0, soft_reset=args.sft_rst, quant_u=False) 108 | # self.lif_fc = neuron.ParametricLIFNode(v_threshold=1.0, v_reset=0.0, init_tau=2., 109 | # surrogate_function=surrogate.ATan(), 110 | # detach_reset=True) 111 | self.fc2 = nn.Linear(256, num_classes) 112 | 113 | # for m in self.modules(): 114 | # if isinstance(m, Bottleneck): 115 | # nn.init.constant_(m.bn3.weight, 0) 116 | # elif isinstance(m, BasicBlock): 117 | # nn.init.constant_(m.bn2.weight, 0) 118 | # elif isinstance(m, nn.Conv2d): 119 | # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 120 | 121 | 122 | def _make_layer(self, block, planes, num_blocks, n_w, n_u, n_b, stride): 123 | strides = [stride] + [1]*(num_blocks-1) 124 | layers = [] 125 | for stride in strides: 126 | layers.append(block(self.in_planes, planes, n_w, n_u, n_b, stride)) 127 | self.in_planes = planes * block.expansion 128 | return nn.Sequential(*layers) 129 | 130 | def reset_dynamics(self): 131 | for m in self.modules(): 132 | if isinstance(m,QConv2dLIF): 133 | m.lif_module.reset_mem() 134 | elif isinstance(m,QConvBN2dLIF): 135 | m.lif_module.reset_mem() 136 | elif isinstance(m,LIFSpike): 137 | m.reset_mem() 138 | self.direct_lif.reset_mem() 139 | self.lif_fc.reset_mem() 140 | return 0 141 | 142 | def weight_init(self): 143 | for m in self.modules(): 144 | if isinstance(m,QConvBN2dLIF): 145 | nn.init.kaiming_uniform_(m.conv_module.weight) 146 | nn.init.kaiming_uniform_(m.bn_module.weight) 147 | elif isinstance(m,QConvBN2d): 148 | nn.init.kaiming_uniform_(m.conv_module.weight) 149 | nn.init.kaiming_uniform_(m.bn_module.weight) 150 | if isinstance(m,nn.Linear): 151 | nn.init.kaiming_uniform_(m.weight) 152 | 153 | def forward(self, x): 154 | 155 | # acc_voltage = 0 156 | u_out = [] 157 | self.reset_dynamics() 158 | if args.dataset != 'dvs': 159 | static_x = self.bn1(self.conv1(x)) 160 | # static_x = self.bn1(self.conv1(x)) 161 | 162 | for t in range(self.total_timestep): 163 | if args.dataset == 'dvs': 164 | out = x[:,t].to(torch.float32).cuda() 165 | out = self.ConvBnLif1(out) 166 | else: 167 | out = self.direct_lif.direct_forward(static_x,False,0) 168 | # out = self.direct_lif.direct_forward(static_x,False,0) 169 | out = self.layer1(out) 170 | out = self.layer2(out) 171 | out = self.layer3(out) 172 | out = self.avgpool(out) 173 | out = out.view(out.size(0), -1) 174 | out = self.lif_fc(self.fc1(out),False,0,bias=0) 175 | out = self.fc2(out) 176 | 177 | # acc_voltage = acc_voltage + out 178 | u_out += [out] 179 | 180 | # acc_voltage = acc_voltage / self.total_timestep 181 | 182 | return u_out 183 | 184 | 185 | def resnet18(): 186 | return ResNet(BasicBlock, [2,2,2,2]) 187 | 188 | def ResNet19(num_classes, total_timestep): 189 | return ResNet(BasicBlock, [3,3,2], num_classes, total_timestep) 190 | 191 | def ResNet34(): 192 | return ResNet(BasicBlock, [3,4,6,3]) 193 | 194 | def ResNet50(): 195 | return ResNet(Bottleneck, [3,4,6,3]) 196 | 197 | def ResNet101(): 198 | return ResNet(Bottleneck, [3,4,23,3]) 199 | 200 | def ResNet152(): 201 | return ResNet(Bottleneck, [3,8,36,3]) 202 | 203 | 204 | def test(): 205 | net = ResNet18() 206 | y = net(torch.randn(1,3,32,32)) 207 | print(y.size()) 208 | 209 | -------------------------------------------------------------------------------- /spike_related.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from training_utils import * 6 | import tracemalloc 7 | import gc 8 | 9 | class ZIF(torch.autograd.Function): 10 | @staticmethod 11 | def forward(ctx, input): 12 | out = (input > 0).float() 13 | # L = torch.tensor([gamma]) 14 | ctx.save_for_backward(input) 15 | 16 | return out 17 | 18 | @staticmethod 19 | def backward(ctx, grad_output): 20 | input, = ctx.saved_tensors 21 | grad_input = grad_output.clone() 22 | 23 | grad_input = grad_input * ((1.0 - torch.abs(input)).clamp(min=0)) 24 | 25 | # tmp = torch.ones_like(input) 26 | # tmp = torch.where(input.abs() < 0.5, 1., 0.) 27 | # grad_input = grad_input*torch.where(torch.abs(input-th)<1., 1, 0) 28 | # grad_input = grad_input*torch.where(input.abs() < 1., 1., 0.) 29 | return grad_input 30 | 31 | 32 | 33 | 34 | class LIFSpike(nn.Module): 35 | def __init__(self, thresh=0.5, leak=0.5, gamma=1.0, soft_reset=True, quant_u=False, num_bits_u=4): 36 | """ 37 | Implementing the LIF neurons. 38 | @param thresh: firing threshold; 39 | @param tau: membrane potential decay factor; 40 | @param gamma: hyper-parameter for controlling the sharpness in surrogate gradient; 41 | @param soft_reset: whether using soft-reset or hard-reset. 42 | """ 43 | super(LIFSpike, self).__init__() 44 | 45 | # self.act = ZIF.apply 46 | self.quant_u = quant_u 47 | self.num_bits_u = num_bits_u 48 | 49 | self.thresh = thresh 50 | self.leak = leak 51 | self.gamma = gamma 52 | self.soft_reset = soft_reset 53 | 54 | self.membrane_potential = 0 55 | # print(self.thresh) 56 | 57 | def reset_mem(self): 58 | self.membrane_potential = 0 59 | 60 | def forward(self, s, share, beta, bias): 61 | # act = ZIF.apply 62 | # if self.training: 63 | # beta* 64 | # x = gamma_*x + bias 65 | H = s + self.membrane_potential 66 | 67 | # s = act(H-self.thresh) 68 | grad = ((1.0 - torch.abs(H-self.thresh)).clamp(min=0)) 69 | s = (((H-self.thresh) > 0).float() - H*grad).detach() + H*grad.detach() 70 | # s = (H -(self.thresh/beta)>0).float() 71 | if self.soft_reset: 72 | U = (H - s*self.thresh)*self.leak 73 | else: 74 | U = H*self.leak*(1-s) 75 | 76 | if self.quant_u: 77 | if share: 78 | self.membrane_potential = u_q(U,self.num_bits_u,beta) 79 | else: 80 | self.membrane_potential= b_q(U,self.num_bits_u) 81 | else: 82 | self.membrane_potential = U 83 | # else: 84 | # # print(torch.unique(s).shape) 85 | 86 | # H = s + self.membrane_potential 87 | # # print(torch.unique(H).shape) 88 | # if share: 89 | # # s = torch.zeros_like(H).cuda() 90 | # # s[H >(self.thresh/beta-bias)] = 1.0 91 | # s = (H -(self.thresh/beta-bias)>0).float() 92 | # else: 93 | # s = ((H-self.thresh) > 0).float() 94 | # # print(torch.unique(H).shape) 95 | # if self.soft_reset: 96 | # U = (H - s*self.thresh)*self.leak 97 | # else: 98 | # U = H*self.leak*(1-s) 99 | 100 | # # if self.quant_u: 101 | # # if share: 102 | # # self.membrane_potential = (U).round().clamp(min=2**(-(self.num_bits_u-1)),max=2**(self.num_bits_u-1)-1) 103 | # # # if self.quant_u: 104 | # # # if share: 105 | # # # self.membrane_potential,_ = b_q_inference(U,self.num_bits_u) 106 | # # # else: 107 | # # # self.membrane_potential = b_q_inference(U,self.num_bits_u) 108 | # # else: 109 | # self.membrane_potential = U 110 | 111 | return s 112 | 113 | def direct_forward(self, s, share, beta): 114 | # act = ZIF.apply 115 | # if self.training: 116 | # beta* 117 | # x = gamma_*x + bias 118 | H = s + self.membrane_potential 119 | 120 | # s = act(H-self.thresh) 121 | grad = ((1.0 - torch.abs(H-self.thresh)).clamp(min=0)) 122 | s = (((H-self.thresh) > 0).float() - H*grad).detach() + H*grad.detach() 123 | if self.soft_reset: 124 | U = (H - s*self.thresh)*self.leak 125 | else: 126 | U = H*self.leak*(1-s) 127 | 128 | if self.quant_u: 129 | if share: 130 | self.membrane_potential = u_q(U,self.num_bits_u,beta) 131 | else: 132 | self.membrane_potential= b_q(U,self.num_bits_u) 133 | else: 134 | self.membrane_potential = U 135 | 136 | return s -------------------------------------------------------------------------------- /train_snn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import args_config 3 | import torchvision 4 | import torchvision.transforms as transforms 5 | import torch.nn as nn 6 | import os, time 7 | import torch.backends.cudnn as cudnn 8 | import dill 9 | import pickle 10 | 11 | from quant_net import * 12 | from quant_resnet import * 13 | from training_utils import * 14 | import tracemalloc 15 | import math 16 | import gc 17 | 18 | 19 | 20 | 21 | def main(): 22 | 23 | torch.manual_seed(23) 24 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 25 | cudnn.benchmark = True 26 | cudnn.deterministic = True 27 | 28 | args = args_config.get_args() 29 | print("********** SNN simulation parameters **********") 30 | print(args) 31 | 32 | if args.dataset == 'cifar10': 33 | transform_train = transforms.Compose([ 34 | transforms.RandomCrop(32, padding=4), 35 | transforms.RandomHorizontalFlip(), 36 | transforms.ToTensor(), 37 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) 38 | ]) 39 | 40 | transform_test = transforms.Compose([ 41 | transforms.ToTensor(), 42 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) 43 | ]) 44 | 45 | train_dataset = torchvision.datasets.CIFAR10( 46 | root=args.dataset_dir, 47 | train=True, 48 | transform=transform_train, 49 | download=True) 50 | 51 | test_dataset = torchvision.datasets.CIFAR10( 52 | root=args.dataset_dir, 53 | train=False, 54 | transform=transform_test, 55 | download=True) 56 | 57 | train_data_loader = torch.utils.data.DataLoader( 58 | dataset=train_dataset, 59 | batch_size=args.batch_size, 60 | shuffle=True, 61 | drop_last=True, 62 | num_workers=4, 63 | pin_memory=True) 64 | 65 | test_data_loader = torch.utils.data.DataLoader( 66 | dataset=test_dataset, 67 | batch_size=args.batch_size, 68 | shuffle=False, 69 | drop_last=False, 70 | num_workers=4, 71 | pin_memory=True) 72 | 73 | num_classes = 10 74 | 75 | elif args.dataset == 'svhn': 76 | 77 | transform_train = transforms.Compose([ 78 | transforms.RandomCrop(32, padding=4), 79 | transforms.RandomHorizontalFlip(), 80 | transforms.ToTensor(), 81 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 82 | ]) 83 | 84 | transform_test = transforms.Compose([ 85 | transforms.ToTensor(), 86 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 87 | ]) 88 | train_dataset = torchvision.datasets.SVHN( 89 | root=args.dataset_dir, 90 | split='train', 91 | transform=transform_train, 92 | download=True) 93 | test_dataset = torchvision.datasets.SVHN( 94 | root=args.dataset_dir, 95 | split='test', 96 | transform=transform_test, 97 | download=True) 98 | train_data_loader = torch.utils.data.DataLoader( 99 | dataset=train_dataset, 100 | batch_size=args.batch_size, 101 | shuffle=True, 102 | drop_last=True, 103 | num_workers=0, 104 | pin_memory=True) 105 | test_data_loader = torch.utils.data.DataLoader( 106 | dataset=test_dataset, 107 | batch_size=args.batch_size, 108 | shuffle=False, 109 | drop_last=False, 110 | num_workers=0, 111 | pin_memory=True) 112 | 113 | num_classes = 10 114 | 115 | elif args.dataset == 'tiny': 116 | traindir = os.path.join('/gpfs/gibbs/project/panda/shared/tiny-imagenet-200/train') 117 | valdir = os.path.join('/gpfs/gibbs/project/panda/shared/tiny-imagenet-200/val') 118 | # normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 119 | # std=[0.229, 0.224, 0.225]) 120 | train_transforms = transforms.Compose([ 121 | transforms.RandomRotation(20), 122 | transforms.RandomHorizontalFlip(0.5), 123 | transforms.ToTensor(), 124 | transforms.Normalize([0.4802, 0.4481, 0.3975], [0.2302, 0.2265, 0.2262]), 125 | ]) 126 | test_transforms = transforms.Compose([ 127 | transforms.ToTensor(), 128 | transforms.Normalize([0.4802, 0.4481, 0.3975], [0.2302, 0.2265, 0.2262]), 129 | ]) 130 | 131 | train_dataset = torchvision.datasets.ImageFolder(traindir, train_transforms) 132 | test_dataset = torchvision.datasets.ImageFolder(valdir, test_transforms) 133 | 134 | train_data_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4,pin_memory=True) 135 | test_data_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4,pin_memory=True) 136 | 137 | num_classes = 200 138 | elif args.dataset == 'dvs': 139 | train_dataset_dvs=torch.load("./train_dataset_dvs_8.pt",pickle_module=dill) 140 | test_dataset_dvs=torch.load("./test_dataset_dvs_8.pt",pickle_module=dill) 141 | 142 | train_data_loader = torch.utils.data.DataLoader(train_dataset_dvs, 143 | batch_size=args.batch_size, 144 | shuffle=True, 145 | num_workers=4, 146 | pin_memory=True) 147 | test_data_loader = torch.utils.data.DataLoader(test_dataset_dvs, 148 | batch_size=args.batch_size, 149 | shuffle=False, 150 | num_workers=4, 151 | pin_memory=True) 152 | num_classes = 10 153 | # print(type(train_dataset_dvs)) 154 | # print(len(train_dataset_dvs)) 155 | # print(train_dataset_dvs[0]) 156 | 157 | # print(type(test_dataset_dvs)) 158 | # print(len(test_dataset_dvs)) 159 | # print(test_dataset_dvs[0]) 160 | # exit() 161 | 162 | # # check the test and train sets 163 | # train_indices = set(train_dataset_dvs.indices) 164 | # test_indices = set(test_dataset_dvs.indices) 165 | 166 | # # The intersection should be an empty set, if they have no common elements 167 | # common_indices = train_indices.intersection(test_indices) 168 | # print(f"Common indices between train and test datasets: {common_indices}") 169 | # exit() 170 | criterion = nn.CrossEntropyLoss() 171 | if args.arch == 'vgg16': 172 | model = Q_ShareScale_VGG16(args.T,args.dataset).cuda() 173 | elif args.arch == 'vgg9': 174 | model = Q_ShareScale_VGG9(args.T,args.dataset).cuda() 175 | elif args.arch == 'res19': 176 | model = ResNet19(num_classes, args.T).cuda() 177 | # model = VGG19_Direct_TS_UQ(args.T, args.leak_mem, args.th, args.rst, args.uq, args.xq, args.wq, args.xa).cuda() 178 | # else: 179 | # model = VGG9_Direct_Uniform_UQ_List(args.T, args.leak_mem, args.th, args.rst).cuda() 180 | 181 | # print(model) 182 | 183 | if args.optim == 'sgd': 184 | optimizer = torch.optim.SGD(model.parameters(), args.lr, 0.9, weight_decay=5e-4) 185 | elif args.optim == 'adam': 186 | optimizer = torch.optim.Adam(model.parameters(), args.lr,weight_decay=1e-4) 187 | else: 188 | print ("Current does not support other optimizers other than sgd or adam.") 189 | exit() 190 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,T_max= args.epoch, eta_min= 0) 191 | # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,T_max=args.epoch) 192 | 193 | 194 | best_accuracy = 0 195 | # tracemalloc.start() 196 | for epoch_ in range(args.epoch): 197 | # snap1 = tracemalloc.take_snapshot() 198 | # time1 = time.time() 199 | loss = 0 200 | accuracy = 0 201 | 202 | loss = train(args, train_data_loader, model, criterion, optimizer, epoch_) 203 | 204 | accuracy= test(model, test_data_loader, criterion) 205 | 206 | scheduler.step() 207 | # time2 = time.time() 208 | # print("Training time for one epoch: ", time2-time1) 209 | if accuracy > best_accuracy: 210 | best_accuracy = accuracy 211 | # checkdir(f"{os.getcwd()}/model_dumps/{args.arch}/{args.dataset}/{args.rst}/w{model.num_bits_w}/u{model.num_bits_u}/share/T4") 212 | # torch.save(model, f"{os.getcwd()}/model_dumps/{args.arch}/{args.dataset}/{args.rst}/w{model.num_bits_w}/u{model.num_bits_u}/share/T4/final_dict.pth.tar") 213 | # checkdir(f"{os.getcwd()}/model_dumps/{args.arch}/{args.dataset}/{args.rst}/w4u4") 214 | # torch.save(model, f"{os.getcwd()}/model_dumps/{args.arch}/{args.dataset}/{args.rst}/w4u4/final_dict.pth.tar") 215 | checkdir(f"{os.getcwd()}/model_dumps/{args.arch}/{args.dataset}/{args.rst}/T10/baseline") 216 | torch.save(model, f"{os.getcwd()}/model_dumps/{args.arch}/{args.dataset}/{args.rst}/T10/baseline/final_dict.pth.tar") 217 | 218 | 219 | if (epoch_+1) % args.test_display_freq == 0: 220 | print(f'Train Epoch: {epoch_}/{args.epoch} Loss: {loss:.6f} Accuracy: {accuracy:.3f}% Best Accuracy: {best_accuracy:.3f}%') 221 | 222 | # gc.collect() 223 | # snap2 = tracemalloc.take_snapshot() 224 | # top_stats=snap1.compare_to(snap2, "lineno") 225 | # for stat in top_stats[:50]: 226 | # line = str(stat) 227 | # if("muless-int-snn" in line): 228 | # print(line) 229 | 230 | 231 | def train(args, train_data, model, criterion, optimizer, epoch): 232 | model.train() 233 | 234 | for batch_idx, (imgs, targets) in enumerate(train_data): 235 | train_loss = 0.0 236 | optimizer.zero_grad() 237 | imgs, targets = imgs.cuda(), targets.cuda() 238 | 239 | output = model(imgs) 240 | 241 | train_loss = sum([criterion(s, targets) for s in output]) / args.T 242 | 243 | train_loss.backward() 244 | if args.share: 245 | for m in model.modules(): 246 | if isinstance(m,QConvBN2dLIF): 247 | # print(m.scaling.grad) 248 | m.beta[0].grad.data = m.beta[0].grad/math.sqrt(torch.numel(m.conv_module.weight)*(2**(m.num_bits_w-1)-1)) 249 | elif isinstance(m,QConvBN2d): 250 | m.beta[0].grad.data = m.beta[0].grad/math.sqrt(torch.numel(m.conv_module.weight)*(2**(m.num_bits_w-1)-1)) 251 | # for a in model.alpha_list: 252 | # a.grad.data = a.grad/1000 253 | optimizer.step() 254 | 255 | return train_loss.item() 256 | 257 | if __name__ == '__main__': 258 | main() 259 | 260 | 261 | -------------------------------------------------------------------------------- /training_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | import os 6 | 7 | class Firing(torch.autograd.Function): 8 | 9 | @staticmethod 10 | def forward(ctx, inp): 11 | out = ((inp)>0).float() 12 | ctx.save_for_backward(inp) 13 | 14 | return out 15 | 16 | @staticmethod 17 | def backward(ctx, grad_output): 18 | input, = ctx.saved_tensors 19 | grad_input = grad_output.clone() 20 | # grad = grad_input * 0.3 * F.threshold(1.0 - torch.abs((input-ctx.th)/ctx.th), 0, 0) 21 | # grad = grad_input*0.3*torch.where(torch.abs(input-ctx.th)<1.25, 1, 0) 22 | grad = grad_input*torch.where(torch.abs(input-0.5)<1, 1, 0) 23 | return grad, None, None 24 | 25 | 26 | def lif_forward(H, th, beta): 27 | out = torch.zeros_like(H).cuda() 28 | out[H >= torch.ceil(th/beta)] = 1 29 | return out 30 | 31 | ### Sharing alpha 32 | def w_q(w, b, alpha): 33 | w = torch.tanh(w) 34 | # alpha = w.data.abs().max() 35 | # print(alpha) 36 | w = torch.clamp(w/alpha,min=-1,max=1) 37 | w = w*(2**(b-1)-1) 38 | w_hat = (w.round()-w).detach()+w 39 | # print(torch.unique(w_hat)) 40 | return w_hat*alpha/(2**(b-1)-1), alpha 41 | 42 | def u_q(u, b, alpha): 43 | u = torch.tanh(u) 44 | # alpha = w.data.abs().max() 45 | # print(alpha) 46 | u = torch.clamp(u/alpha,min=-1,max=1) 47 | u = u*(2**(b-1)-1) 48 | u_hat = (u.round()-u).detach()+u 49 | # print(torch.unique(w_hat)) 50 | return u_hat*alpha/(2**(b-1)-1) 51 | ### Not sharing alpha 52 | def b_q(w, b): 53 | w = torch.tanh(w) 54 | alpha = w.data.abs().max() 55 | # print(alpha) 56 | w = torch.clamp(w/alpha,min=-1,max=1) 57 | w = w*(2**(b-1)-1) 58 | w_hat = (w.round()-w).detach()+w 59 | # print(torch.unique(w_hat)) 60 | return w_hat*alpha/(2**(b-1)-1) 61 | 62 | def w_q_inference(w, b, alpha): 63 | w = torch.tanh(w) 64 | # alpha = w.data.abs().max() 65 | # print(alpha) 66 | w = torch.clamp(w/alpha,min=-1,max=1) 67 | w = w*(2**(b-1)-1) 68 | w_hat = w.round() 69 | return w_hat, alpha/(2**(b-1)-1) 70 | 71 | def b_q_inference(w, b): 72 | w = torch.tanh(w) 73 | alpha = w.data.abs().max() 74 | # print(alpha) 75 | w = torch.clamp(w/alpha,min=-1,max=1) 76 | w = w*(2**(b-1)-1) 77 | w_hat = w.round() 78 | return w_hat, alpha/(2**(b-1)-1) 79 | 80 | 81 | def checkdir(directory): 82 | if not os.path.exists(directory): 83 | os.makedirs(directory) 84 | 85 | 86 | def adjust_learning_rate(optimizer, cur_epoch, max_epoch): 87 | if ( 88 | cur_epoch == (max_epoch * 0.5) 89 | or cur_epoch == (max_epoch * 0.7) 90 | or cur_epoch == (max_epoch * 0.9) 91 | ): 92 | for param_group in optimizer.param_groups: 93 | param_group["lr"] /= 10 94 | 95 | def test(model, test_loader, criterion): 96 | 97 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 98 | 99 | model.eval() 100 | test_loss = 0 101 | correct = 0 102 | 103 | with torch.no_grad(): 104 | for data, target in test_loader: 105 | batch = data.shape[0] 106 | data, target = data.to(device), target.to(device) 107 | output = sum(model(data)) 108 | # print(type(output)) 109 | _,idx = output.data.max(1, keepdim=True) # get the index of the max log-probability 110 | correct += idx.eq(target.data.view_as(idx)).sum().item() 111 | 112 | accuracy = 100. * correct / len(test_loader.dataset) 113 | 114 | return accuracy 115 | 116 | def top_k_accuracy(outputs, targets, k=5): 117 | _, top_pred = outputs.topk(k, 1, True, True) 118 | top_pred = top_pred.t() 119 | correct = top_pred.eq(targets.view(1, -1).expand_as(top_pred)) 120 | top_k_acc = correct[:k].view(-1).float().sum(0, keepdim=True) / targets.size(0) 121 | return top_k_acc.item() 122 | 123 | def test_5(model, test_loader, criterion): 124 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 125 | 126 | model.eval() 127 | test_loss = 0 128 | correct = 0 129 | top5_correct = 0 130 | 131 | with torch.no_grad(): 132 | for data, target in test_loader: 133 | batch = data.shape[0] 134 | data, target = data.to(device), target.to(device) 135 | output = sum(model(data)) 136 | test_loss += criterion(output, target).item() # Accumulate the loss 137 | _, idx = output.data.topk(5, 1) # Get the indices of the top-5 predictions 138 | top5_correct += torch.sum(idx == target.view(-1, 1).expand_as(idx)).item() 139 | 140 | top5_accuracy = 100. * top5_correct / len(test_loader.dataset) 141 | 142 | return top5_accuracy 143 | 144 | 145 | def computing_firerate(module, inp, out): 146 | 147 | fired_spikes = torch.count_nonzero(out) 148 | module.spikerate += fired_spikes/8.0 149 | module.num_neuron += np.prod(out.shape[1:len(out.shape)])/8.0 150 | 151 | def test_spa(model, test_loader, criterion): 152 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 153 | model.eval() 154 | test_loss = 0 155 | correct = 0 156 | total = 0 157 | 158 | #TODO 159 | overall_nueron =0 160 | overall_spike = 0 161 | 162 | ### Defining Sparsity Handling Code 163 | neuron_type = 'QConv2dLIF' 164 | for name, module in model.named_modules(): 165 | if neuron_type in str(type(module)): 166 | module.lif_module.register_forward_hook(computing_firerate) 167 | module.lif_module.spikerate = 0 168 | module.lif_module.num_neuron = 0 169 | 170 | with torch.no_grad(): 171 | for data, target in test_loader: 172 | batch = data.shape[0] 173 | data, target = data.to(device), target.to(device) 174 | output = sum(model(data)) 175 | # print(type(output)) 176 | _,idx = output.data.max(1, keepdim=True) # get the index of the max log-probability 177 | correct += idx.eq(target.data.view_as(idx)).sum().item() 178 | 179 | accuracy = 100. * correct / len(test_loader.dataset) 180 | 181 | for name, module in model.named_modules(): 182 | if neuron_type in str(type(module)): 183 | overall_nueron += module.lif_module.num_neuron/len(test_loader) 184 | overall_spike += module.lif_module.spikerate/len(test_loader.dataset) 185 | # print(overall_nueron) 186 | # print(module.spikerate/len(test_loader.dataset)) 187 | # print(module.spikerate) 188 | print("Overall spike rate:", overall_spike/overall_nueron) 189 | 190 | 191 | return accuracy,overall_spike/overall_nueron 192 | 193 | 194 | 195 | 196 | def get_u_distribution(data,l,t,color): 197 | 198 | # print("dist tensor", hist) 199 | bins = 128 200 | # hist = torch.histc(data,bins).cpu() 201 | # i_max = ((torch.max(data).cpu()).item()) 202 | i_max = 10 203 | # i_min = ((torch.min(data).cpu()).item()) 204 | i_min = -10 205 | step = (i_max-i_min)/(bins) 206 | x = np.arange(i_min,i_max,step) 207 | plt.bar(x, data, align='center', color='#1E97B0') 208 | plt.xlabel('Bins') 209 | plt.ylabel('Frequency') 210 | plt.title('Frequency') 211 | plt.show() 212 | plt.savefig(f'./u_fig/hist_u_layer{l}_{t}.pdf', bbox_inches='tight', pad_inches=0.1) 213 | 214 | 215 | # def get_w_distribution(data,layer_i,e): 216 | 217 | # # print("dist tensor", hist) 218 | # bins = 10 219 | # hist = torch.histc(data,bins).cpu() 220 | # i_max = ((torch.max(data).cpu()).item()) 221 | # i_min = ((torch.min(data).cpu()).item()) 222 | # step = (i_max-i_min)/(bins) 223 | # if step != 0: 224 | # x = np.arange(i_min,i_max,step) 225 | # plt.bar(x, hist, align='center', color=['forestgreen']) 226 | # plt.xlabel('Bins') 227 | # plt.ylabel('Frequency') 228 | # plt.title('Frequency') 229 | # plt.savefig(f'./w_fig/hist_w_{layer_i}_{e}.pdf', bbox_inches='tight', pad_inches=0.1) 230 | # else: 231 | # print(f'W are all zeros at layer:{layer_i} at epoch {e}') 232 | 233 | 234 | class AverageMeter(object): 235 | """ 236 | Computes and stores the average and current value 237 | """ 238 | 239 | def __init__(self): 240 | self.reset() 241 | 242 | def reset(self): 243 | self.val = 0 244 | self.avg = 0 245 | self.sum = 0 246 | self.count = 0 247 | 248 | def update(self, val, n=1): 249 | self.val = val 250 | self.sum += val * n 251 | self.count += n 252 | self.avg = self.sum / self.count 253 | 254 | --------------------------------------------------------------------------------