├── README.md ├── model.py ├── models └── model.py ├── train.py ├── train_models.sh └── util.py /README.md: -------------------------------------------------------------------------------- 1 | # Rate-vs-Direct 2 | 3 | This repository contains the source code associated with "RATE CODING OR DIRECT CODING: WHICH ONE IS BETTER FOR ACCURATE, ROBUST, and ENERGY-EFFICIENT SPIKING NEURAL NETWORKS?", accepted to ICASSP2022. (https://arxiv.org/abs/2202.03133) 4 | 5 | 6 | ## Introduction 7 | 8 | Spiking Neural Networks (SNNs) have recently emerged as the low-power alternative to Artificial Neural Networks (ANNs), because of their asynchronous, sparse, and binary event-driven processing. Recent SNN works focus on an image classification task, therefore various coding techniques have been proposed to convert an image into temporal binary spikes. Among them, rate coding and direct coding are regarded as prospective candidates for building a practical SNN system as they show state-of-the-art performance on large-scale datasets. Despite their usage, there is little attention to comparing these two coding schemes in a fair manner. In this paper, we conduct a comprehensive analysis of the two coding techniques from three perspectives: accuracy, adversarial robustness, and energy-efficiency. 9 | First, we compare the performance of two coding techniques with three different architectures on various datasets. Then, we attack SNNs with two adversarial attack methods to reveal the adversarial robustness of each coding scheme. Finally, we evaluate the energy-efficiency of two coding schemes on a digital hardware platform. Our results show that direct coding can achieve better accuracy especially for a small number of timesteps. On the other hand, rate coding shows better robustness to adversarial attacks owing to the non-differentiable spike generation process. Rate coding also yields higher energy-efficiency than direct coding which requires multi-bit precision for the first layer. Our study explores the advantages and disadvantages of two codings, which is an important design consideration for building SNNs. 10 | 11 | 12 | 13 | ## Prerequisites 14 | * Ubuntu 18.04 15 | * Python 3.6+ 16 | * PyTorch 1.5+ (recent version is recommended) 17 | * Torchvision 0.8.0+ (recent version is recommended) 18 | * NVIDIA GPU (>= 12GB) 19 | 20 | 21 | ## Training and testing 22 | 23 | * ```train.py```: code for training 24 | * ```model.py```: code for MLP/VGG5/VGG9 Spiking Neural Networks with Rate/Direct coding 25 | * ```util.py```: code for accuracy calculation / learning rate scheduler 26 | 27 | 28 | * Argparse configuration 29 | 30 | ``` 31 | --dataset [mnist, cifar10, cifar100] 32 | --encode [p, d] 33 | --arch [mlp, vgg5, vgg9] 34 | --T(timestep) = [5, 10, 15, 20, 30] 35 | --leak_mem = [0.5] 36 | --epoch [100] 37 | --lr [1e-3] 38 | ``` 39 | 40 | * Run the following command for VGG5-SNN-Direct on CIFAR10 41 | 42 | ``` 43 | python train.py --dataset cifar10 --arch vgg5 --epoch 100 --encode d --leak_mem 0.5 --T 10 --lr 1e-3 --batch_size 128 44 | ``` 45 | 46 | * Run the following command for VGG9-SNN-Poisson on CIFAR100 47 | 48 | ``` 49 | python train.py --dataset cifar100 --arch vgg9 --encode p --leak_mem 0.5 --T 20 --lr 1e-3 --batch_size 128 50 | ``` 51 | 52 | 53 | 54 | 55 | 56 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | # File : model_cifar10.py 2 | # Descr: Define SNN models for the CIFAR10 dataset 3 | # Date : March 22, 2019 4 | 5 | # -------------------------------------------------- 6 | # Imports 7 | # -------------------------------------------------- 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import math 12 | import sys 13 | import numpy as np 14 | import numpy.linalg as LA 15 | from torch.autograd import Variable 16 | import pdb 17 | 18 | # -------------------------------------------------- 19 | # Spiking neuron with fast-sigmoid surrogate gradient 20 | # This class is replicated from: 21 | # https://github.com/fzenke/spytorch/blob/master/notebooks/SpyTorchTutorial2.ipynb 22 | # -------------------------------------------------- 23 | 24 | # Note: Only VGG9 is supported currently 25 | # To Do: Make a generic design for all VGG models 26 | 27 | 28 | class Surrogate_BP_Function(torch.autograd.Function): 29 | @staticmethod 30 | def forward(ctx, input): 31 | ctx.save_for_backward(input) 32 | out = torch.zeros_like(input).cuda() 33 | out[input > 0] = 1.0 34 | return out 35 | 36 | @staticmethod 37 | def backward(ctx, grad_output): 38 | (input,) = ctx.saved_tensors 39 | grad_input = grad_output.clone() 40 | grad = grad_input * 0.3 * F.threshold(1.0 - torch.abs(input), 0, 0) 41 | return grad 42 | 43 | 44 | def PoissonGen(inp, rescale_fac=2.0): 45 | rand_inp = torch.rand_like(inp).cuda() 46 | return torch.mul(torch.le(rand_inp * rescale_fac, torch.abs(inp)).float(), torch.sign(inp)) 47 | 48 | class MLP_Direct(nn.Module): 49 | def __init__(self, num_steps, leak_mem=0.95, img_size=28, num_cls=10, input_dim = 1): 50 | super(MLP_Direct, self).__init__() 51 | self.img_size = img_size 52 | self.num_cls = num_cls 53 | self.num_steps = num_steps 54 | self.spike_fn = Surrogate_BP_Function.apply 55 | self.leak_mem = leak_mem 56 | self.batch_num = self.num_steps 57 | self.arch = "SNN" 58 | 59 | print(">>>>>>>>>>>>>>>>>>> MLP_Direct Coding >>>>>>>>>>>>>>>>>>>>>>") 60 | 61 | #cifar10 & cifar 100 62 | if input_dim == 3: 63 | dim = 6 64 | #mnist 65 | elif input_dim == 1 : 66 | dim = 5 67 | 68 | self.fc1 = nn.Linear(img_size * img_size * input_dim, 800, bias=False) 69 | self.fc2 = nn.Linear(800, 10, bias=False) 70 | 71 | # Initialize the firing thresholds of all the layers 72 | for m in self.modules(): 73 | if isinstance(m, nn.Conv2d): 74 | m.threshold = 1.0 75 | torch.nn.init.xavier_uniform_(m.weight, gain=2) 76 | elif isinstance(m, nn.Linear): 77 | m.threshold = 1.0 78 | torch.nn.init.xavier_uniform_(m.weight, gain=2) 79 | 80 | def forward(self, inp): 81 | 82 | batch_size = inp.size(0) 83 | 84 | mem_fc1 = torch.zeros(batch_size, 800).cuda() 85 | mem_fc2 = torch.zeros(batch_size, 10).cuda() 86 | inp = inp.view(batch_size, -1) 87 | static_input = self.fc1(inp) 88 | 89 | for t in range(self.num_steps): 90 | # Charging and Firing 91 | mem_fc1 = self.leak_mem * mem_fc1 + (1-self.leak_mem)*static_input 92 | mem_thr = (mem_fc1 / self.fc1.threshold) - 1.0 93 | out = self.spike_fn(mem_thr) 94 | 95 | # Soft Reset 96 | rst = torch.zeros_like(mem_fc1).cuda() 97 | rst[mem_thr > 0] = self.fc1.threshold 98 | mem_fc1 = mem_fc1 - rst 99 | out_prev = out.clone() 100 | 101 | # accumulate voltage in the last layer 102 | mem_fc2 = mem_fc2 + self.fc2(out_prev) 103 | 104 | out_voltage = mem_fc2 / self.num_steps 105 | 106 | return out_voltage 107 | 108 | class MLP_Poisson(nn.Module): 109 | def __init__(self, num_steps, leak_mem=0.95, img_size=28, num_cls=10, input_dim = 1): 110 | super(MLP_Poisson, self).__init__() 111 | self.img_size = img_size 112 | self.num_cls = num_cls 113 | self.num_steps = num_steps 114 | self.spike_fn = Surrogate_BP_Function.apply 115 | self.leak_mem = leak_mem 116 | self.batch_num = self.num_steps 117 | self.arch = "SNN" 118 | 119 | print(">>>>>>>>>>>>>>>>>>> MLP_Poisson Coding >>>>>>>>>>>>>>>>>>>>>>") 120 | 121 | #cifar10 & cifar 100 122 | if input_dim == 3: 123 | dim = 6 124 | #mnist 125 | elif input_dim == 1 : 126 | dim = 5 127 | 128 | self.fc1 = nn.Linear(img_size * img_size * input_dim, 800, bias=False) 129 | self.fc2 = nn.Linear(800, 10, bias=False) 130 | 131 | # Initialize the firing thresholds of all the layers 132 | for m in self.modules(): 133 | if isinstance(m, nn.Conv2d): 134 | m.threshold = 1.0 135 | torch.nn.init.xavier_uniform_(m.weight, gain=2) 136 | elif isinstance(m, nn.Linear): 137 | m.threshold = 1.0 138 | torch.nn.init.xavier_uniform_(m.weight, gain=2) 139 | 140 | def forward(self, inp): 141 | 142 | batch_size = inp.size(0) 143 | inp = inp.view(batch_size, -1) 144 | 145 | mem_fc1 = torch.zeros(batch_size, 800).cuda() 146 | mem_fc2 = torch.zeros(batch_size, 10).cuda() 147 | 148 | for t in range(self.num_steps): 149 | spike_inp = PoissonGen(inp) 150 | out_prev = spike_inp 151 | 152 | # Charging and Firing 153 | mem_fc1 = self.leak_mem * mem_fc1 + (1-self.leak_mem)*self.fc1(out_prev) 154 | mem_thr = (mem_fc1 / self.fc1.threshold) - 1.0 155 | out = self.spike_fn(mem_thr) 156 | 157 | # Soft Reset 158 | rst = torch.zeros_like(mem_fc1).cuda() 159 | rst[mem_thr > 0] = self.fc1.threshold 160 | mem_fc1 = mem_fc1 - rst 161 | out_prev = out.clone() 162 | 163 | # accumulate voltage in the last layer 164 | mem_fc2 = mem_fc2 + self.fc2(out_prev) 165 | 166 | out_voltage = mem_fc2 / self.num_steps 167 | 168 | return out_voltage 169 | 170 | 171 | 172 | # Models 173 | class VGG5_Direct(nn.Module): 174 | def __init__(self, num_steps, leak_mem=0.95, img_size=32, num_cls=10, input_dim=3): 175 | super(VGG5_Direct, self).__init__() 176 | self.img_size = img_size 177 | self.num_cls = num_cls 178 | self.num_steps = num_steps 179 | self.spike_fn = Surrogate_BP_Function.apply 180 | self.leak_mem = leak_mem 181 | self.batch_num = self.num_steps 182 | self.arch = "SNN" 183 | print(">>>>>>>>>>>>>>>>>>> LeNet_Direct Coding >>>>>>>>>>>>>>>>>>>>>>") 184 | # cifar10 & cifar 100 185 | if input_dim == 3: 186 | dim = 8 187 | # mnist 188 | elif input_dim == 1: 189 | dim = 5 190 | 191 | self.conv1 = nn.Conv2d(input_dim, 64, kernel_size=3, padding=1, bias=False) 192 | self.pool1 = nn.AvgPool2d(kernel_size=2) 193 | self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1, bias=False) 194 | self.conv3 = nn.Conv2d(128, 128, kernel_size=3, padding=1, bias=False) 195 | self.pool2 = nn.AvgPool2d(kernel_size=2) 196 | 197 | self.fc1 = nn.Linear(128 * dim * dim, 1024, bias=False) 198 | self.fc2 = nn.Linear(1024, num_cls, bias=False) 199 | 200 | # Initialize the firing thresholds of all the layers 201 | for m in self.modules(): 202 | if isinstance(m, nn.Conv2d): 203 | m.threshold = 1.0 204 | torch.nn.init.xavier_uniform_(m.weight, gain=3) 205 | elif isinstance(m, nn.Linear): 206 | m.threshold = 1.0 207 | torch.nn.init.xavier_uniform_(m.weight, gain=3) 208 | 209 | self.conv_list = [self.conv1, self.conv2, self.conv3] 210 | 211 | self.pool_list = [self.pool1, self.pool2] 212 | 213 | self.fc_list = [self.fc1, self.fc2] 214 | 215 | def forward(self, inp): 216 | 217 | batch_size = inp.size(0) 218 | 219 | mem_conv1 = torch.zeros(batch_size, 64, self.img_size, self.img_size).cuda() 220 | mem_conv2 = torch.zeros(batch_size, 128, (self.img_size) // 2, (self.img_size) // 2).cuda() 221 | mem_conv3 = torch.zeros(batch_size, 128, (self.img_size) // 2, (self.img_size) // 2).cuda() 222 | 223 | mem_conv_list = [mem_conv1, mem_conv2, mem_conv3] 224 | 225 | mem_fc1 = torch.zeros(batch_size, 1024).cuda() 226 | mem_fc2 = torch.zeros(batch_size, self.num_cls).cuda() 227 | 228 | mem_fc_list = [mem_fc1, mem_fc2] 229 | 230 | # Direct coding - static input from conv1 231 | 232 | static_input = self.conv1(inp) 233 | 234 | for t in range(self.num_steps): 235 | # Charging and firing (lif for conv1) 236 | mem_conv_list[0] = self.leak_mem * mem_conv_list[0] + (1 - self.leak_mem) * static_input 237 | mem_thr = (mem_conv_list[0] / self.conv_list[0].threshold) - 1.0 238 | out = self.spike_fn(mem_thr) 239 | 240 | # Soft reset 241 | rst = torch.zeros_like(mem_conv_list[0]).cuda() 242 | rst[mem_thr > 0] = self.conv_list[0].threshold 243 | mem_conv_list[0] = mem_conv_list[0] - rst 244 | out_prev = out.clone() 245 | 246 | # Pooling 247 | out = self.pool_list[0](out_prev) 248 | out_prev = out.clone() 249 | 250 | mem_conv_list[1] = self.leak_mem * mem_conv_list[1] + (1 - self.leak_mem) * self.conv2(out_prev) 251 | mem_thr = (mem_conv_list[1] / self.conv_list[1].threshold) - 1.0 252 | out = self.spike_fn(mem_thr) 253 | rst = torch.zeros_like(mem_conv_list[1]).cuda() 254 | rst[mem_thr > 0] = self.conv_list[1].threshold 255 | mem_conv_list[1] = mem_conv_list[1] - rst 256 | out_prev = out.clone() 257 | 258 | # print ("aa", out_prev.sum()) 259 | 260 | mem_conv_list[2] = self.leak_mem * mem_conv_list[2] + (1 - self.leak_mem) * self.conv3(out_prev) 261 | mem_thr = (mem_conv_list[2] / self.conv_list[2].threshold) - 1.0 262 | out = self.spike_fn(mem_thr) 263 | rst = torch.zeros_like(mem_conv_list[2]).cuda() 264 | rst[mem_thr > 0] = self.conv_list[2].threshold 265 | mem_conv_list[2] = mem_conv_list[2] - rst 266 | out_prev = out.clone() 267 | 268 | # print ("bb",out_prev.sum()) 269 | 270 | # Pooling 271 | out = self.pool_list[1](out_prev) 272 | out_prev = out.clone() 273 | 274 | out_prev = out_prev.reshape(batch_size, -1) 275 | 276 | for i in range(len(self.fc_list) - 1): 277 | mem_fc_list[i] = self.leak_mem * mem_fc_list[i] + (1 - self.leak_mem) * self.fc_list[i](out_prev) 278 | mem_thr = (mem_fc_list[i] / self.fc_list[i].threshold) - 1.0 279 | out = self.spike_fn(mem_thr) 280 | 281 | rst = torch.zeros_like(mem_fc_list[i]).cuda() 282 | rst[mem_thr > 0] = self.fc_list[i].threshold 283 | mem_fc_list[i] = mem_fc_list[i] - rst 284 | out_prev = out.clone() 285 | 286 | # accumulate voltage in the last layer 287 | mem_fc2 = mem_fc2 + self.fc2(out_prev) 288 | 289 | out_voltage = mem_fc2 / self.num_steps 290 | 291 | return out_voltage 292 | 293 | 294 | class VGG5_Poisson(nn.Module): 295 | def __init__(self, num_steps, leak_mem=0.95, img_size=32, num_cls=10, input_dim=3): 296 | super(VGG5_Poisson, self).__init__() 297 | self.img_size = img_size 298 | self.num_cls = num_cls 299 | self.num_steps = num_steps 300 | self.spike_fn = Surrogate_BP_Function.apply 301 | self.leak_mem = leak_mem 302 | self.batch_num = self.num_steps 303 | self.arch = "SNN" 304 | print(">>>>>>>>>>>>>>>>>>> LeNet_Poisson Coding >>>>>>>>>>>>>>>>>>>>>>") 305 | 306 | # cifar10 & cifar 100 307 | if input_dim == 3: 308 | dim = 8 309 | # mnist 310 | elif input_dim == 1: 311 | dim = 5 312 | 313 | self.conv1 = nn.Conv2d(input_dim, 64, kernel_size=3, padding=1, bias=False) 314 | self.pool1 = nn.AvgPool2d(kernel_size=2) 315 | self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1, bias=False) 316 | self.conv3 = nn.Conv2d(128, 128, kernel_size=3, padding=1, bias=False) 317 | self.pool2 = nn.AvgPool2d(kernel_size=2) 318 | 319 | self.fc1 = nn.Linear(128 * dim * dim, 1024, bias=False) 320 | self.fc2 = nn.Linear(1024, num_cls, bias=False) 321 | 322 | # Initialize the firing thresholds of all the layers 323 | for m in self.modules(): 324 | if isinstance(m, nn.Conv2d): 325 | m.threshold = 1.0 326 | torch.nn.init.xavier_uniform_(m.weight, gain=4) 327 | elif isinstance(m, nn.Linear): 328 | m.threshold = 1.0 329 | torch.nn.init.xavier_uniform_(m.weight, gain=4) 330 | 331 | self.conv_list = [self.conv1, self.conv2, self.conv3] 332 | 333 | self.pool_list = [self.pool1, self.pool2] 334 | 335 | self.fc_list = [self.fc1, self.fc2] 336 | 337 | def forward(self, inp): 338 | 339 | batch_size = inp.size(0) 340 | 341 | mem_conv1 = torch.zeros(batch_size, 64, self.img_size, self.img_size).cuda() 342 | mem_conv2 = torch.zeros(batch_size, 128, (self.img_size) // 2, (self.img_size) // 2).cuda() 343 | mem_conv3 = torch.zeros(batch_size, 128, (self.img_size) // 2, (self.img_size) // 2).cuda() 344 | 345 | mem_conv_list = [mem_conv1, mem_conv2, mem_conv3] 346 | 347 | mem_fc1 = torch.zeros(batch_size, 1024).cuda() 348 | mem_fc2 = torch.zeros(batch_size, self.num_cls).cuda() 349 | 350 | mem_fc_list = [mem_fc1, mem_fc2] 351 | 352 | for t in range(self.num_steps): 353 | spike_inp = PoissonGen(inp) 354 | out_prev = spike_inp 355 | 356 | mem_conv_list[0] = self.leak_mem * mem_conv_list[0] + (1 - self.leak_mem) * self.conv1(out_prev) 357 | mem_thr = (mem_conv_list[0] / self.conv_list[0].threshold) - 1.0 358 | out = self.spike_fn(mem_thr) 359 | 360 | # Soft reset 361 | rst = torch.zeros_like(mem_conv_list[0]).cuda() 362 | rst[mem_thr > 0] = self.conv_list[0].threshold 363 | mem_conv_list[0] = mem_conv_list[0] - rst 364 | out_prev = out.clone() 365 | 366 | # Pooling 367 | out = self.pool_list[0](out_prev) 368 | out_prev = out.clone() 369 | 370 | mem_conv_list[1] = self.leak_mem * mem_conv_list[1] + (1 - self.leak_mem) * self.conv2(out_prev) 371 | mem_thr = (mem_conv_list[1] / self.conv_list[1].threshold) - 1.0 372 | out = self.spike_fn(mem_thr) 373 | rst = torch.zeros_like(mem_conv_list[1]).cuda() 374 | rst[mem_thr > 0] = self.conv_list[1].threshold 375 | mem_conv_list[1] = mem_conv_list[1] - rst 376 | out_prev = out.clone() 377 | 378 | mem_conv_list[2] = self.leak_mem * mem_conv_list[2] + (1 - self.leak_mem) * self.conv3(out_prev) 379 | mem_thr = (mem_conv_list[2] / self.conv_list[2].threshold) - 1.0 380 | out = self.spike_fn(mem_thr) 381 | rst = torch.zeros_like(mem_conv_list[2]).cuda() 382 | rst[mem_thr > 0] = self.conv_list[2].threshold 383 | mem_conv_list[2] = mem_conv_list[2] - rst 384 | out_prev = out.clone() 385 | 386 | # Pooling 387 | out = self.pool_list[1](out_prev) 388 | out_prev = out.clone() 389 | 390 | out_prev = out_prev.reshape(batch_size, -1) 391 | 392 | for i in range(len(self.fc_list) - 1): 393 | # Charging and Firing 394 | mem_fc_list[i] = self.leak_mem * mem_fc_list[i] + (1 - self.leak_mem) * self.fc_list[i](out_prev) 395 | mem_thr = (mem_fc_list[i] / self.fc_list[i].threshold) - 1.0 396 | out = self.spike_fn(mem_thr) 397 | 398 | # Soft Reset 399 | rst = torch.zeros_like(mem_fc_list[i]).cuda() 400 | rst[mem_thr > 0] = self.fc_list[i].threshold 401 | mem_fc_list[i] = mem_fc_list[i] - rst 402 | out_prev = out.clone() 403 | 404 | # accumulate voltage in the last layer 405 | mem_fc2 = mem_fc2 + self.fc2(out_prev) 406 | 407 | out_voltage = mem_fc2 / self.num_steps 408 | 409 | return out_voltage 410 | 411 | class VGG9_Direct(nn.Module): 412 | def __init__(self, num_steps, leak_mem=0.95, img_size=32, num_cls=10, input_dim = 3): 413 | super(VGG9_Direct_tau, self).__init__() 414 | 415 | self.img_size = img_size 416 | self.num_cls = num_cls 417 | self.num_steps = num_steps 418 | self.spike_fn = Surrogate_BP_Function.apply 419 | self.leak_mem = leak_mem 420 | self.batch_num = self.num_steps 421 | self.arch = "SNN" 422 | 423 | print(">>>>>>>>>>>>>>>>>>> VGG 9_Direct Coding >>>>>>>>>>>>>>>>>>>>>>") 424 | 425 | affine_flag = True 426 | bias_flag = False 427 | 428 | self.conv1 = nn.Conv2d(input_dim, 64, kernel_size=3, stride=1, padding=1, bias=bias_flag) 429 | self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=bias_flag) 430 | self.pool1 = nn.AvgPool2d(kernel_size=2) 431 | 432 | self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=bias_flag) 433 | self.conv4 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=bias_flag) 434 | self.pool2 = nn.AvgPool2d(kernel_size=2) 435 | 436 | self.conv5 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1, bias=bias_flag) 437 | 438 | self.conv6 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=bias_flag) 439 | 440 | self.conv7 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=bias_flag) 441 | self.pool3 = nn.AvgPool2d(kernel_size=2) 442 | 443 | # Test 444 | # self.drop1 = nn.Dropout(p=0.5) 445 | 446 | self.fc1 = nn.Linear((self.img_size // 8) * (self.img_size // 8) * 256, 1024, bias=bias_flag) 447 | self.fc2 = nn.Linear(1024, self.num_cls, bias=bias_flag) 448 | 449 | self.conv_list = [ 450 | self.conv1, 451 | self.conv2, 452 | self.conv3, 453 | self.conv4, 454 | self.conv5, 455 | self.conv6, 456 | self.conv7, 457 | ] 458 | 459 | self.pool_list = [ 460 | False, 461 | self.pool1, 462 | False, 463 | self.pool2, 464 | False, 465 | False, 466 | self.pool3, 467 | ] 468 | 469 | # Initialize the firing thresholds of all the layers 470 | for m in self.modules(): 471 | if isinstance(m, nn.Conv2d): 472 | m.threshold = 1.0 473 | torch.nn.init.xavier_uniform_(m.weight, gain=5) 474 | elif isinstance(m, nn.Linear): 475 | m.threshold = 1.0 476 | torch.nn.init.xavier_uniform_(m.weight, gain=5) 477 | 478 | 479 | def forward(self, inp): 480 | 481 | batch_size = inp.size(0) 482 | mem_conv1 = torch.zeros(batch_size, 64, self.img_size, self.img_size).cuda() 483 | mem_conv2 = torch.zeros(batch_size, 64, self.img_size, self.img_size).cuda() 484 | mem_conv3 = torch.zeros( 485 | batch_size, 128, self.img_size // 2, self.img_size // 2 486 | ).cuda() 487 | mem_conv4 = torch.zeros( 488 | batch_size, 128, self.img_size // 2, self.img_size // 2 489 | ).cuda() 490 | mem_conv5 = torch.zeros( 491 | batch_size, 256, self.img_size // 4, self.img_size // 4 492 | ).cuda() 493 | mem_conv6 = torch.zeros( 494 | batch_size, 256, self.img_size // 4, self.img_size // 4 495 | ).cuda() 496 | mem_conv7 = torch.zeros( 497 | batch_size, 256, self.img_size // 4, self.img_size // 4 498 | ).cuda() 499 | mem_conv_list = [ 500 | mem_conv1, 501 | mem_conv2, 502 | mem_conv3, 503 | mem_conv4, 504 | mem_conv5, 505 | mem_conv6, 506 | mem_conv7, 507 | ] 508 | 509 | mem_fc1 = torch.zeros(batch_size, 1024).cuda() 510 | mem_fc2 = torch.zeros(batch_size, self.num_cls).cuda() 511 | 512 | 513 | 514 | #Direct coding - static input from conv1 515 | 516 | static_input = self.conv1(inp) 517 | 518 | 519 | 520 | for t in range(self.num_steps): 521 | # Charging and firing (lif for conv1) 522 | mem_conv_list[0] = (1-self.leak_mem) * mem_conv_list[0] + self.leak_mem * static_input 523 | mem_thr = (mem_conv_list[0] / self.conv_list[0].threshold) - 1.0 524 | out = self.spike_fn(mem_thr) 525 | 526 | 527 | # Soft reset 528 | rst = torch.zeros_like(mem_conv_list[0]).cuda() 529 | rst[mem_thr > 0] = self.conv_list[0].threshold 530 | mem_conv_list[0] = mem_conv_list[0] - rst 531 | out_prev = out.clone() 532 | 533 | for i in range(1, len(self.conv_list)): 534 | 535 | mem_conv_list[i] = (1-self.leak_mem) * mem_conv_list[i] + self.leak_mem * self.conv_list[i](out_prev) 536 | mem_thr = (mem_conv_list[i] / self.conv_list[i].threshold) - 1.0 537 | out = self.spike_fn(mem_thr) 538 | rst = torch.zeros_like(mem_conv_list[i]).cuda() 539 | rst[mem_thr > 0] = self.conv_list[i].threshold 540 | mem_conv_list[i] = mem_conv_list[i] - rst 541 | out_prev = out.clone() 542 | if self.pool_list[i] is not False: 543 | out = self.pool_list[i](out_prev) 544 | out_prev = out.clone() 545 | 546 | 547 | 548 | out_prev = out_prev.reshape(batch_size, -1) 549 | 550 | # Test 551 | # out = self.drop1(out_prev) 552 | # out_prev = out.clone() 553 | 554 | mem_fc1 = (1-self.leak_mem) * mem_fc1 + self.leak_mem * self.fc1(out_prev) 555 | 556 | mem_thr = (mem_fc1 / self.fc1.threshold) - 1.0 557 | out = self.spike_fn(mem_thr) 558 | rst = torch.zeros_like(mem_fc1).cuda() 559 | rst[mem_thr > 0] = self.fc1.threshold 560 | mem_fc1 = mem_fc1 - rst 561 | out_prev = out.clone() 562 | 563 | # accumulate voltage in the last layer 564 | mem_fc2 = mem_fc2 + self.fc2(out_prev) 565 | 566 | out_voltage = mem_fc2 / self.num_steps 567 | 568 | return out_voltage 569 | 570 | 571 | 572 | 573 | class VGG9_Poisson(nn.Module): 574 | def __init__(self, num_steps, leak_mem=0.95, img_size=32, num_cls=10, input_dim = 3): 575 | super(VGG9_Poisson_tau, self).__init__() 576 | self.img_size = img_size 577 | self.num_cls = num_cls 578 | self.num_steps = num_steps 579 | self.spike_fn = Surrogate_BP_Function.apply 580 | self.leak_mem = leak_mem 581 | self.batch_num = self.num_steps 582 | self.arch = "SNN" 583 | print(">>>>>>>>>>>>>>>>>>> VGG 9_Poisson Coding >>>>>>>>>>>>>>>>>>>>>>") 584 | 585 | affine_flag = True 586 | bias_flag = False 587 | 588 | self.conv1 = nn.Conv2d( 589 | input_dim, 64, kernel_size=3, stride=1, padding=1, bias=bias_flag 590 | ) 591 | 592 | self.conv2 = nn.Conv2d( 593 | 64, 64, kernel_size=3, stride=1, padding=1, bias=bias_flag 594 | ) 595 | self.pool1 = nn.AvgPool2d(kernel_size=2) 596 | 597 | self.conv3 = nn.Conv2d( 598 | 64, 128, kernel_size=3, stride=1, padding=1, bias=bias_flag 599 | ) 600 | 601 | self.conv4 = nn.Conv2d( 602 | 128, 128, kernel_size=3, stride=1, padding=1, bias=bias_flag 603 | ) 604 | 605 | self.pool2 = nn.AvgPool2d(kernel_size=2) 606 | 607 | self.conv5 = nn.Conv2d( 608 | 128, 256, kernel_size=3, stride=1, padding=1, bias=bias_flag 609 | ) 610 | 611 | self.conv6 = nn.Conv2d( 612 | 256, 256, kernel_size=3, stride=1, padding=1, bias=bias_flag 613 | ) 614 | 615 | self.conv7 = nn.Conv2d( 616 | 256, 256, kernel_size=3, stride=1, padding=1, bias=bias_flag 617 | ) 618 | self.pool3 = nn.AvgPool2d(kernel_size=2) 619 | 620 | self.fc1 = nn.Linear( 621 | (self.img_size // 8) * (self.img_size // 8) * 256, 1024, bias=bias_flag 622 | ) 623 | 624 | self.fc2 = nn.Linear(1024, self.num_cls, bias=bias_flag) 625 | 626 | self.conv_list = [ 627 | self.conv1, 628 | self.conv2, 629 | self.conv3, 630 | self.conv4, 631 | self.conv5, 632 | self.conv6, 633 | self.conv7, 634 | ] 635 | 636 | self.pool_list = [ 637 | False, 638 | self.pool1, 639 | False, 640 | self.pool2, 641 | False, 642 | False, 643 | self.pool3, 644 | ] 645 | 646 | # Initialize the firing thresholds of all the layers 647 | for m in self.modules(): 648 | if isinstance(m, nn.Conv2d): 649 | m.threshold = 1.0 650 | torch.nn.init.xavier_uniform_(m.weight, gain=5) 651 | elif isinstance(m, nn.Linear): 652 | m.threshold = 1.0 653 | torch.nn.init.xavier_uniform_(m.weight, gain=5) 654 | 655 | 656 | def forward(self, inp): 657 | 658 | batch_size = inp.size(0) 659 | mem_conv1 = torch.zeros(batch_size, 64, self.img_size, self.img_size).cuda() 660 | mem_conv2 = torch.zeros(batch_size, 64, self.img_size, self.img_size).cuda() 661 | mem_conv3 = torch.zeros( 662 | batch_size, 128, self.img_size // 2, self.img_size // 2 663 | ).cuda() 664 | mem_conv4 = torch.zeros( 665 | batch_size, 128, self.img_size // 2, self.img_size // 2 666 | ).cuda() 667 | mem_conv5 = torch.zeros( 668 | batch_size, 256, self.img_size // 4, self.img_size // 4 669 | ).cuda() 670 | mem_conv6 = torch.zeros( 671 | batch_size, 256, self.img_size // 4, self.img_size // 4 672 | ).cuda() 673 | mem_conv7 = torch.zeros( 674 | batch_size, 256, self.img_size // 4, self.img_size // 4 675 | ).cuda() 676 | mem_conv_list = [ 677 | mem_conv1, 678 | mem_conv2, 679 | mem_conv3, 680 | mem_conv4, 681 | mem_conv5, 682 | mem_conv6, 683 | mem_conv7, 684 | ] 685 | 686 | mem_fc1 = torch.zeros(batch_size, 1024).cuda() 687 | mem_fc2 = torch.zeros(batch_size, self.num_cls).cuda() 688 | 689 | for t in range(self.num_steps): 690 | 691 | spike_inp = PoissonGen(inp) 692 | out_prev = spike_inp 693 | 694 | for i in range(len(self.conv_list)): 695 | # charging and firing 696 | mem_conv_list[i] = (1-self.leak_mem) * mem_conv_list[i] + self.leak_mem*self.conv_list[i](out_prev) 697 | mem_thr = (mem_conv_list[i] / self.conv_list[i].threshold) - 1.0 698 | out = self.spike_fn(mem_thr) 699 | 700 | # Soft reset 701 | rst = torch.zeros_like(mem_conv_list[i]).cuda() 702 | rst[mem_thr > 0] = self.conv_list[i].threshold 703 | mem_conv_list[i] = mem_conv_list[i] - rst 704 | out_prev = out.clone() 705 | 706 | if self.pool_list[i] is not False: 707 | out = self.pool_list[i](out_prev) 708 | out_prev = out.clone() 709 | 710 | 711 | 712 | out_prev = out_prev.reshape(batch_size, -1) 713 | mem_fc1 = (1-self.leak_mem) * mem_fc1 + self.leak_mem* self.fc1(out_prev) 714 | mem_thr = (mem_fc1 / self.fc1.threshold) - 1.0 715 | out = self.spike_fn(mem_thr) 716 | rst = torch.zeros_like(mem_fc1).cuda() 717 | rst[mem_thr > 0] = self.fc1.threshold 718 | mem_fc1 = mem_fc1 - rst 719 | out_prev = out.clone() 720 | 721 | # accumulate voltage in the last layer 722 | mem_fc2 = mem_fc2 + self.fc2(out_prev) 723 | 724 | out_voltage = mem_fc2 / self.num_steps 725 | 726 | return out_voltage 727 | 728 | 729 | 730 | class AverageMeter(object): 731 | """ 732 | Computes and stores the average and current value 733 | """ 734 | 735 | def __init__(self): 736 | self.reset() 737 | 738 | def reset(self): 739 | self.val = 0 740 | self.avg = 0 741 | self.sum = 0 742 | self.count = 0 743 | 744 | def update(self, val, n=1): 745 | self.val = val 746 | self.sum += val * n 747 | self.count += n 748 | self.avg = self.sum / self.count 749 | 750 | 751 | 752 | -------------------------------------------------------------------------------- /models/model.py: -------------------------------------------------------------------------------- 1 | # File : model_cifar10.py 2 | # Descr: Define SNN models for the CIFAR10 dataset 3 | # Date : March 22, 2019 4 | 5 | # -------------------------------------------------- 6 | # Imports 7 | # -------------------------------------------------- 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import math 12 | import sys 13 | import numpy as np 14 | import numpy.linalg as LA 15 | from torch.autograd import Variable 16 | import pdb 17 | 18 | # -------------------------------------------------- 19 | # Spiking neuron with fast-sigmoid surrogate gradient 20 | # This class is replicated from: 21 | # https://github.com/fzenke/spytorch/blob/master/notebooks/SpyTorchTutorial2.ipynb 22 | # -------------------------------------------------- 23 | 24 | # Note: Only VGG9 is supported currently 25 | # To Do: Make a generic design for all VGG models 26 | 27 | 28 | class Surrogate_BP_Function(torch.autograd.Function): 29 | @staticmethod 30 | def forward(ctx, input): 31 | ctx.save_for_backward(input) 32 | out = torch.zeros_like(input).cuda() 33 | out[input > 0] = 1.0 34 | return out 35 | 36 | @staticmethod 37 | def backward(ctx, grad_output): 38 | (input,) = ctx.saved_tensors 39 | grad_input = grad_output.clone() 40 | grad = grad_input * 0.3 * F.threshold(1.0 - torch.abs(input), 0, 0) 41 | return grad 42 | 43 | 44 | def PoissonGen(inp, rescale_fac=2.0): 45 | rand_inp = torch.rand_like(inp).cuda() 46 | return torch.mul(torch.le(rand_inp * rescale_fac, torch.abs(inp)).float(), torch.sign(inp)) 47 | 48 | class MLP_Direct(nn.Module): 49 | def __init__(self, num_steps, leak_mem=0.95, img_size=28, num_cls=10, input_dim = 1): 50 | super(MLP_Direct, self).__init__() 51 | self.img_size = img_size 52 | self.num_cls = num_cls 53 | self.num_steps = num_steps 54 | self.spike_fn = Surrogate_BP_Function.apply 55 | self.leak_mem = leak_mem 56 | self.batch_num = self.num_steps 57 | self.arch = "SNN" 58 | 59 | print(">>>>>>>>>>>>>>>>>>> MLP_Direct Coding >>>>>>>>>>>>>>>>>>>>>>") 60 | 61 | #cifar10 & cifar 100 62 | if input_dim == 3: 63 | dim = 6 64 | #mnist 65 | elif input_dim == 1 : 66 | dim = 5 67 | 68 | self.fc1 = nn.Linear(img_size * img_size * input_dim, 800, bias=False) 69 | self.fc2 = nn.Linear(800, 10, bias=False) 70 | 71 | # Initialize the firing thresholds of all the layers 72 | for m in self.modules(): 73 | if isinstance(m, nn.Conv2d): 74 | m.threshold = 1.0 75 | torch.nn.init.xavier_uniform_(m.weight, gain=2) 76 | elif isinstance(m, nn.Linear): 77 | m.threshold = 1.0 78 | torch.nn.init.xavier_uniform_(m.weight, gain=2) 79 | 80 | def forward(self, inp): 81 | 82 | batch_size = inp.size(0) 83 | 84 | mem_fc1 = torch.zeros(batch_size, 800).cuda() 85 | mem_fc2 = torch.zeros(batch_size, 10).cuda() 86 | inp = inp.view(batch_size, -1) 87 | static_input = self.fc1(inp) 88 | 89 | for t in range(self.num_steps): 90 | # Charging and Firing 91 | mem_fc1 = self.leak_mem * mem_fc1 + (1-self.leak_mem)*static_input 92 | mem_thr = (mem_fc1 / self.fc1.threshold) - 1.0 93 | out = self.spike_fn(mem_thr) 94 | 95 | # Soft Reset 96 | rst = torch.zeros_like(mem_fc1).cuda() 97 | rst[mem_thr > 0] = self.fc1.threshold 98 | mem_fc1 = mem_fc1 - rst 99 | out_prev = out.clone() 100 | 101 | # accumulate voltage in the last layer 102 | mem_fc2 = mem_fc2 + self.fc2(out_prev) 103 | 104 | out_voltage = mem_fc2 / self.num_steps 105 | 106 | return out_voltage 107 | 108 | class MLP_Poisson(nn.Module): 109 | def __init__(self, num_steps, leak_mem=0.95, img_size=28, num_cls=10, input_dim = 1): 110 | super(MLP_Poisson, self).__init__() 111 | self.img_size = img_size 112 | self.num_cls = num_cls 113 | self.num_steps = num_steps 114 | self.spike_fn = Surrogate_BP_Function.apply 115 | self.leak_mem = leak_mem 116 | self.batch_num = self.num_steps 117 | self.arch = "SNN" 118 | 119 | print(">>>>>>>>>>>>>>>>>>> MLP_Poisson Coding >>>>>>>>>>>>>>>>>>>>>>") 120 | 121 | #cifar10 & cifar 100 122 | if input_dim == 3: 123 | dim = 6 124 | #mnist 125 | elif input_dim == 1 : 126 | dim = 5 127 | 128 | self.fc1 = nn.Linear(img_size * img_size * input_dim, 800, bias=False) 129 | self.fc2 = nn.Linear(800, 10, bias=False) 130 | 131 | # Initialize the firing thresholds of all the layers 132 | for m in self.modules(): 133 | if isinstance(m, nn.Conv2d): 134 | m.threshold = 1.0 135 | torch.nn.init.xavier_uniform_(m.weight, gain=2) 136 | elif isinstance(m, nn.Linear): 137 | m.threshold = 1.0 138 | torch.nn.init.xavier_uniform_(m.weight, gain=2) 139 | 140 | def forward(self, inp): 141 | 142 | batch_size = inp.size(0) 143 | inp = inp.view(batch_size, -1) 144 | 145 | mem_fc1 = torch.zeros(batch_size, 800).cuda() 146 | mem_fc2 = torch.zeros(batch_size, 10).cuda() 147 | 148 | for t in range(self.num_steps): 149 | spike_inp = PoissonGen(inp) 150 | out_prev = spike_inp 151 | 152 | # Charging and Firing 153 | mem_fc1 = self.leak_mem * mem_fc1 + (1-self.leak_mem)*self.fc1(out_prev) 154 | mem_thr = (mem_fc1 / self.fc1.threshold) - 1.0 155 | out = self.spike_fn(mem_thr) 156 | 157 | # Soft Reset 158 | rst = torch.zeros_like(mem_fc1).cuda() 159 | rst[mem_thr > 0] = self.fc1.threshold 160 | mem_fc1 = mem_fc1 - rst 161 | out_prev = out.clone() 162 | 163 | # accumulate voltage in the last layer 164 | mem_fc2 = mem_fc2 + self.fc2(out_prev) 165 | 166 | out_voltage = mem_fc2 / self.num_steps 167 | 168 | return out_voltage 169 | 170 | 171 | 172 | # Models 173 | class VGG5_Direct(nn.Module): 174 | def __init__(self, num_steps, leak_mem=0.95, img_size=32, num_cls=10, input_dim=3): 175 | super(VGG5_Direct, self).__init__() 176 | self.img_size = img_size 177 | self.num_cls = num_cls 178 | self.num_steps = num_steps 179 | self.spike_fn = Surrogate_BP_Function.apply 180 | self.leak_mem = leak_mem 181 | self.batch_num = self.num_steps 182 | self.arch = "SNN" 183 | print(">>>>>>>>>>>>>>>>>>> LeNet_Direct Coding >>>>>>>>>>>>>>>>>>>>>>") 184 | # cifar10 & cifar 100 185 | if input_dim == 3: 186 | dim = 8 187 | # mnist 188 | elif input_dim == 1: 189 | dim = 5 190 | 191 | self.conv1 = nn.Conv2d(input_dim, 64, kernel_size=3, padding=1, bias=False) 192 | self.pool1 = nn.AvgPool2d(kernel_size=2) 193 | self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1, bias=False) 194 | self.conv3 = nn.Conv2d(128, 128, kernel_size=3, padding=1, bias=False) 195 | self.pool2 = nn.AvgPool2d(kernel_size=2) 196 | 197 | self.fc1 = nn.Linear(128 * dim * dim, 1024, bias=False) 198 | self.fc2 = nn.Linear(1024, num_cls, bias=False) 199 | 200 | # Initialize the firing thresholds of all the layers 201 | for m in self.modules(): 202 | if isinstance(m, nn.Conv2d): 203 | m.threshold = 1.0 204 | torch.nn.init.xavier_uniform_(m.weight, gain=3) 205 | elif isinstance(m, nn.Linear): 206 | m.threshold = 1.0 207 | torch.nn.init.xavier_uniform_(m.weight, gain=3) 208 | 209 | self.conv_list = [self.conv1, self.conv2, self.conv3] 210 | 211 | self.pool_list = [self.pool1, self.pool2] 212 | 213 | self.fc_list = [self.fc1, self.fc2] 214 | 215 | def forward(self, inp): 216 | 217 | batch_size = inp.size(0) 218 | 219 | mem_conv1 = torch.zeros(batch_size, 64, self.img_size, self.img_size).cuda() 220 | mem_conv2 = torch.zeros(batch_size, 128, (self.img_size) // 2, (self.img_size) // 2).cuda() 221 | mem_conv3 = torch.zeros(batch_size, 128, (self.img_size) // 2, (self.img_size) // 2).cuda() 222 | 223 | mem_conv_list = [mem_conv1, mem_conv2, mem_conv3] 224 | 225 | mem_fc1 = torch.zeros(batch_size, 1024).cuda() 226 | mem_fc2 = torch.zeros(batch_size, self.num_cls).cuda() 227 | 228 | mem_fc_list = [mem_fc1, mem_fc2] 229 | 230 | # Direct coding - static input from conv1 231 | 232 | static_input = self.conv1(inp) 233 | 234 | for t in range(self.num_steps): 235 | # Charging and firing (lif for conv1) 236 | mem_conv_list[0] = self.leak_mem * mem_conv_list[0] + (1 - self.leak_mem) * static_input 237 | mem_thr = (mem_conv_list[0] / self.conv_list[0].threshold) - 1.0 238 | out = self.spike_fn(mem_thr) 239 | 240 | # Soft reset 241 | rst = torch.zeros_like(mem_conv_list[0]).cuda() 242 | rst[mem_thr > 0] = self.conv_list[0].threshold 243 | mem_conv_list[0] = mem_conv_list[0] - rst 244 | out_prev = out.clone() 245 | 246 | # Pooling 247 | out = self.pool_list[0](out_prev) 248 | out_prev = out.clone() 249 | 250 | mem_conv_list[1] = self.leak_mem * mem_conv_list[1] + (1 - self.leak_mem) * self.conv2(out_prev) 251 | mem_thr = (mem_conv_list[1] / self.conv_list[1].threshold) - 1.0 252 | out = self.spike_fn(mem_thr) 253 | rst = torch.zeros_like(mem_conv_list[1]).cuda() 254 | rst[mem_thr > 0] = self.conv_list[1].threshold 255 | mem_conv_list[1] = mem_conv_list[1] - rst 256 | out_prev = out.clone() 257 | 258 | # print ("aa", out_prev.sum()) 259 | 260 | mem_conv_list[2] = self.leak_mem * mem_conv_list[2] + (1 - self.leak_mem) * self.conv3(out_prev) 261 | mem_thr = (mem_conv_list[2] / self.conv_list[2].threshold) - 1.0 262 | out = self.spike_fn(mem_thr) 263 | rst = torch.zeros_like(mem_conv_list[2]).cuda() 264 | rst[mem_thr > 0] = self.conv_list[2].threshold 265 | mem_conv_list[2] = mem_conv_list[2] - rst 266 | out_prev = out.clone() 267 | 268 | # print ("bb",out_prev.sum()) 269 | 270 | # Pooling 271 | out = self.pool_list[1](out_prev) 272 | out_prev = out.clone() 273 | 274 | out_prev = out_prev.reshape(batch_size, -1) 275 | 276 | for i in range(len(self.fc_list) - 1): 277 | mem_fc_list[i] = self.leak_mem * mem_fc_list[i] + (1 - self.leak_mem) * self.fc_list[i](out_prev) 278 | mem_thr = (mem_fc_list[i] / self.fc_list[i].threshold) - 1.0 279 | out = self.spike_fn(mem_thr) 280 | 281 | rst = torch.zeros_like(mem_fc_list[i]).cuda() 282 | rst[mem_thr > 0] = self.fc_list[i].threshold 283 | mem_fc_list[i] = mem_fc_list[i] - rst 284 | out_prev = out.clone() 285 | 286 | # accumulate voltage in the last layer 287 | mem_fc2 = mem_fc2 + self.fc2(out_prev) 288 | 289 | out_voltage = mem_fc2 / self.num_steps 290 | 291 | return out_voltage 292 | 293 | 294 | class VGG5_Poisson(nn.Module): 295 | def __init__(self, num_steps, leak_mem=0.95, img_size=32, num_cls=10, input_dim=3): 296 | super(VGG5_Poisson, self).__init__() 297 | self.img_size = img_size 298 | self.num_cls = num_cls 299 | self.num_steps = num_steps 300 | self.spike_fn = Surrogate_BP_Function.apply 301 | self.leak_mem = leak_mem 302 | self.batch_num = self.num_steps 303 | self.arch = "SNN" 304 | print(">>>>>>>>>>>>>>>>>>> LeNet_Poisson Coding >>>>>>>>>>>>>>>>>>>>>>") 305 | 306 | # cifar10 & cifar 100 307 | if input_dim == 3: 308 | dim = 8 309 | # mnist 310 | elif input_dim == 1: 311 | dim = 5 312 | 313 | self.conv1 = nn.Conv2d(input_dim, 64, kernel_size=3, padding=1, bias=False) 314 | self.pool1 = nn.AvgPool2d(kernel_size=2) 315 | self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1, bias=False) 316 | self.conv3 = nn.Conv2d(128, 128, kernel_size=3, padding=1, bias=False) 317 | self.pool2 = nn.AvgPool2d(kernel_size=2) 318 | 319 | self.fc1 = nn.Linear(128 * dim * dim, 1024, bias=False) 320 | self.fc2 = nn.Linear(1024, num_cls, bias=False) 321 | 322 | # Initialize the firing thresholds of all the layers 323 | for m in self.modules(): 324 | if isinstance(m, nn.Conv2d): 325 | m.threshold = 1.0 326 | torch.nn.init.xavier_uniform_(m.weight, gain=4) 327 | elif isinstance(m, nn.Linear): 328 | m.threshold = 1.0 329 | torch.nn.init.xavier_uniform_(m.weight, gain=4) 330 | 331 | self.conv_list = [self.conv1, self.conv2, self.conv3] 332 | 333 | self.pool_list = [self.pool1, self.pool2] 334 | 335 | self.fc_list = [self.fc1, self.fc2] 336 | 337 | def forward(self, inp): 338 | 339 | batch_size = inp.size(0) 340 | 341 | mem_conv1 = torch.zeros(batch_size, 64, self.img_size, self.img_size).cuda() 342 | mem_conv2 = torch.zeros(batch_size, 128, (self.img_size) // 2, (self.img_size) // 2).cuda() 343 | mem_conv3 = torch.zeros(batch_size, 128, (self.img_size) // 2, (self.img_size) // 2).cuda() 344 | 345 | mem_conv_list = [mem_conv1, mem_conv2, mem_conv3] 346 | 347 | mem_fc1 = torch.zeros(batch_size, 1024).cuda() 348 | mem_fc2 = torch.zeros(batch_size, self.num_cls).cuda() 349 | 350 | mem_fc_list = [mem_fc1, mem_fc2] 351 | 352 | for t in range(self.num_steps): 353 | spike_inp = PoissonGen(inp) 354 | out_prev = spike_inp 355 | 356 | mem_conv_list[0] = self.leak_mem * mem_conv_list[0] + (1 - self.leak_mem) * self.conv1(out_prev) 357 | mem_thr = (mem_conv_list[0] / self.conv_list[0].threshold) - 1.0 358 | out = self.spike_fn(mem_thr) 359 | 360 | # Soft reset 361 | rst = torch.zeros_like(mem_conv_list[0]).cuda() 362 | rst[mem_thr > 0] = self.conv_list[0].threshold 363 | mem_conv_list[0] = mem_conv_list[0] - rst 364 | out_prev = out.clone() 365 | 366 | # Pooling 367 | out = self.pool_list[0](out_prev) 368 | out_prev = out.clone() 369 | 370 | mem_conv_list[1] = self.leak_mem * mem_conv_list[1] + (1 - self.leak_mem) * self.conv2(out_prev) 371 | mem_thr = (mem_conv_list[1] / self.conv_list[1].threshold) - 1.0 372 | out = self.spike_fn(mem_thr) 373 | rst = torch.zeros_like(mem_conv_list[1]).cuda() 374 | rst[mem_thr > 0] = self.conv_list[1].threshold 375 | mem_conv_list[1] = mem_conv_list[1] - rst 376 | out_prev = out.clone() 377 | 378 | mem_conv_list[2] = self.leak_mem * mem_conv_list[2] + (1 - self.leak_mem) * self.conv3(out_prev) 379 | mem_thr = (mem_conv_list[2] / self.conv_list[2].threshold) - 1.0 380 | out = self.spike_fn(mem_thr) 381 | rst = torch.zeros_like(mem_conv_list[2]).cuda() 382 | rst[mem_thr > 0] = self.conv_list[2].threshold 383 | mem_conv_list[2] = mem_conv_list[2] - rst 384 | out_prev = out.clone() 385 | 386 | # Pooling 387 | out = self.pool_list[1](out_prev) 388 | out_prev = out.clone() 389 | 390 | out_prev = out_prev.reshape(batch_size, -1) 391 | 392 | for i in range(len(self.fc_list) - 1): 393 | # Charging and Firing 394 | mem_fc_list[i] = self.leak_mem * mem_fc_list[i] + (1 - self.leak_mem) * self.fc_list[i](out_prev) 395 | mem_thr = (mem_fc_list[i] / self.fc_list[i].threshold) - 1.0 396 | out = self.spike_fn(mem_thr) 397 | 398 | # Soft Reset 399 | rst = torch.zeros_like(mem_fc_list[i]).cuda() 400 | rst[mem_thr > 0] = self.fc_list[i].threshold 401 | mem_fc_list[i] = mem_fc_list[i] - rst 402 | out_prev = out.clone() 403 | 404 | # accumulate voltage in the last layer 405 | mem_fc2 = mem_fc2 + self.fc2(out_prev) 406 | 407 | out_voltage = mem_fc2 / self.num_steps 408 | 409 | return out_voltage 410 | 411 | class VGG9_Direct(nn.Module): 412 | def __init__(self, num_steps, leak_mem=0.95, img_size=32, num_cls=10, input_dim = 3): 413 | super(VGG9_Direct_tau, self).__init__() 414 | 415 | self.img_size = img_size 416 | self.num_cls = num_cls 417 | self.num_steps = num_steps 418 | self.spike_fn = Surrogate_BP_Function.apply 419 | self.leak_mem = leak_mem 420 | self.batch_num = self.num_steps 421 | self.arch = "SNN" 422 | 423 | print(">>>>>>>>>>>>>>>>>>> VGG 9_Direct Coding >>>>>>>>>>>>>>>>>>>>>>") 424 | 425 | affine_flag = True 426 | bias_flag = False 427 | 428 | self.conv1 = nn.Conv2d(input_dim, 64, kernel_size=3, stride=1, padding=1, bias=bias_flag) 429 | self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=bias_flag) 430 | self.pool1 = nn.AvgPool2d(kernel_size=2) 431 | 432 | self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=bias_flag) 433 | self.conv4 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=bias_flag) 434 | self.pool2 = nn.AvgPool2d(kernel_size=2) 435 | 436 | self.conv5 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1, bias=bias_flag) 437 | 438 | self.conv6 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=bias_flag) 439 | 440 | self.conv7 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=bias_flag) 441 | self.pool3 = nn.AvgPool2d(kernel_size=2) 442 | 443 | # Test 444 | # self.drop1 = nn.Dropout(p=0.5) 445 | 446 | self.fc1 = nn.Linear((self.img_size // 8) * (self.img_size // 8) * 256, 1024, bias=bias_flag) 447 | self.fc2 = nn.Linear(1024, self.num_cls, bias=bias_flag) 448 | 449 | self.conv_list = [ 450 | self.conv1, 451 | self.conv2, 452 | self.conv3, 453 | self.conv4, 454 | self.conv5, 455 | self.conv6, 456 | self.conv7, 457 | ] 458 | 459 | self.pool_list = [ 460 | False, 461 | self.pool1, 462 | False, 463 | self.pool2, 464 | False, 465 | False, 466 | self.pool3, 467 | ] 468 | 469 | # Initialize the firing thresholds of all the layers 470 | for m in self.modules(): 471 | if isinstance(m, nn.Conv2d): 472 | m.threshold = 1.0 473 | torch.nn.init.xavier_uniform_(m.weight, gain=5) 474 | elif isinstance(m, nn.Linear): 475 | m.threshold = 1.0 476 | torch.nn.init.xavier_uniform_(m.weight, gain=5) 477 | 478 | 479 | def forward(self, inp): 480 | 481 | batch_size = inp.size(0) 482 | mem_conv1 = torch.zeros(batch_size, 64, self.img_size, self.img_size).cuda() 483 | mem_conv2 = torch.zeros(batch_size, 64, self.img_size, self.img_size).cuda() 484 | mem_conv3 = torch.zeros( 485 | batch_size, 128, self.img_size // 2, self.img_size // 2 486 | ).cuda() 487 | mem_conv4 = torch.zeros( 488 | batch_size, 128, self.img_size // 2, self.img_size // 2 489 | ).cuda() 490 | mem_conv5 = torch.zeros( 491 | batch_size, 256, self.img_size // 4, self.img_size // 4 492 | ).cuda() 493 | mem_conv6 = torch.zeros( 494 | batch_size, 256, self.img_size // 4, self.img_size // 4 495 | ).cuda() 496 | mem_conv7 = torch.zeros( 497 | batch_size, 256, self.img_size // 4, self.img_size // 4 498 | ).cuda() 499 | mem_conv_list = [ 500 | mem_conv1, 501 | mem_conv2, 502 | mem_conv3, 503 | mem_conv4, 504 | mem_conv5, 505 | mem_conv6, 506 | mem_conv7, 507 | ] 508 | 509 | mem_fc1 = torch.zeros(batch_size, 1024).cuda() 510 | mem_fc2 = torch.zeros(batch_size, self.num_cls).cuda() 511 | 512 | 513 | 514 | #Direct coding - static input from conv1 515 | 516 | static_input = self.conv1(inp) 517 | 518 | 519 | 520 | for t in range(self.num_steps): 521 | # Charging and firing (lif for conv1) 522 | mem_conv_list[0] = (1-self.leak_mem) * mem_conv_list[0] + self.leak_mem * static_input 523 | mem_thr = (mem_conv_list[0] / self.conv_list[0].threshold) - 1.0 524 | out = self.spike_fn(mem_thr) 525 | 526 | 527 | # Soft reset 528 | rst = torch.zeros_like(mem_conv_list[0]).cuda() 529 | rst[mem_thr > 0] = self.conv_list[0].threshold 530 | mem_conv_list[0] = mem_conv_list[0] - rst 531 | out_prev = out.clone() 532 | 533 | for i in range(1, len(self.conv_list)): 534 | 535 | mem_conv_list[i] = (1-self.leak_mem) * mem_conv_list[i] + self.leak_mem * self.conv_list[i](out_prev) 536 | mem_thr = (mem_conv_list[i] / self.conv_list[i].threshold) - 1.0 537 | out = self.spike_fn(mem_thr) 538 | rst = torch.zeros_like(mem_conv_list[i]).cuda() 539 | rst[mem_thr > 0] = self.conv_list[i].threshold 540 | mem_conv_list[i] = mem_conv_list[i] - rst 541 | out_prev = out.clone() 542 | if self.pool_list[i] is not False: 543 | out = self.pool_list[i](out_prev) 544 | out_prev = out.clone() 545 | 546 | 547 | 548 | out_prev = out_prev.reshape(batch_size, -1) 549 | 550 | # Test 551 | # out = self.drop1(out_prev) 552 | # out_prev = out.clone() 553 | 554 | mem_fc1 = (1-self.leak_mem) * mem_fc1 + self.leak_mem * self.fc1(out_prev) 555 | 556 | mem_thr = (mem_fc1 / self.fc1.threshold) - 1.0 557 | out = self.spike_fn(mem_thr) 558 | rst = torch.zeros_like(mem_fc1).cuda() 559 | rst[mem_thr > 0] = self.fc1.threshold 560 | mem_fc1 = mem_fc1 - rst 561 | out_prev = out.clone() 562 | 563 | # accumulate voltage in the last layer 564 | mem_fc2 = mem_fc2 + self.fc2(out_prev) 565 | 566 | out_voltage = mem_fc2 / self.num_steps 567 | 568 | return out_voltage 569 | 570 | 571 | 572 | 573 | class VGG9_Poisson(nn.Module): 574 | def __init__(self, num_steps, leak_mem=0.95, img_size=32, num_cls=10, input_dim = 3): 575 | super(VGG9_Poisson_tau, self).__init__() 576 | self.img_size = img_size 577 | self.num_cls = num_cls 578 | self.num_steps = num_steps 579 | self.spike_fn = Surrogate_BP_Function.apply 580 | self.leak_mem = leak_mem 581 | self.batch_num = self.num_steps 582 | self.arch = "SNN" 583 | print(">>>>>>>>>>>>>>>>>>> VGG 9_Poisson Coding >>>>>>>>>>>>>>>>>>>>>>") 584 | 585 | affine_flag = True 586 | bias_flag = False 587 | 588 | self.conv1 = nn.Conv2d( 589 | input_dim, 64, kernel_size=3, stride=1, padding=1, bias=bias_flag 590 | ) 591 | 592 | self.conv2 = nn.Conv2d( 593 | 64, 64, kernel_size=3, stride=1, padding=1, bias=bias_flag 594 | ) 595 | self.pool1 = nn.AvgPool2d(kernel_size=2) 596 | 597 | self.conv3 = nn.Conv2d( 598 | 64, 128, kernel_size=3, stride=1, padding=1, bias=bias_flag 599 | ) 600 | 601 | self.conv4 = nn.Conv2d( 602 | 128, 128, kernel_size=3, stride=1, padding=1, bias=bias_flag 603 | ) 604 | 605 | self.pool2 = nn.AvgPool2d(kernel_size=2) 606 | 607 | self.conv5 = nn.Conv2d( 608 | 128, 256, kernel_size=3, stride=1, padding=1, bias=bias_flag 609 | ) 610 | 611 | self.conv6 = nn.Conv2d( 612 | 256, 256, kernel_size=3, stride=1, padding=1, bias=bias_flag 613 | ) 614 | 615 | self.conv7 = nn.Conv2d( 616 | 256, 256, kernel_size=3, stride=1, padding=1, bias=bias_flag 617 | ) 618 | self.pool3 = nn.AvgPool2d(kernel_size=2) 619 | 620 | self.fc1 = nn.Linear( 621 | (self.img_size // 8) * (self.img_size // 8) * 256, 1024, bias=bias_flag 622 | ) 623 | 624 | self.fc2 = nn.Linear(1024, self.num_cls, bias=bias_flag) 625 | 626 | self.conv_list = [ 627 | self.conv1, 628 | self.conv2, 629 | self.conv3, 630 | self.conv4, 631 | self.conv5, 632 | self.conv6, 633 | self.conv7, 634 | ] 635 | 636 | self.pool_list = [ 637 | False, 638 | self.pool1, 639 | False, 640 | self.pool2, 641 | False, 642 | False, 643 | self.pool3, 644 | ] 645 | 646 | # Initialize the firing thresholds of all the layers 647 | for m in self.modules(): 648 | if isinstance(m, nn.Conv2d): 649 | m.threshold = 1.0 650 | torch.nn.init.xavier_uniform_(m.weight, gain=5) 651 | elif isinstance(m, nn.Linear): 652 | m.threshold = 1.0 653 | torch.nn.init.xavier_uniform_(m.weight, gain=5) 654 | 655 | 656 | def forward(self, inp): 657 | 658 | batch_size = inp.size(0) 659 | mem_conv1 = torch.zeros(batch_size, 64, self.img_size, self.img_size).cuda() 660 | mem_conv2 = torch.zeros(batch_size, 64, self.img_size, self.img_size).cuda() 661 | mem_conv3 = torch.zeros( 662 | batch_size, 128, self.img_size // 2, self.img_size // 2 663 | ).cuda() 664 | mem_conv4 = torch.zeros( 665 | batch_size, 128, self.img_size // 2, self.img_size // 2 666 | ).cuda() 667 | mem_conv5 = torch.zeros( 668 | batch_size, 256, self.img_size // 4, self.img_size // 4 669 | ).cuda() 670 | mem_conv6 = torch.zeros( 671 | batch_size, 256, self.img_size // 4, self.img_size // 4 672 | ).cuda() 673 | mem_conv7 = torch.zeros( 674 | batch_size, 256, self.img_size // 4, self.img_size // 4 675 | ).cuda() 676 | mem_conv_list = [ 677 | mem_conv1, 678 | mem_conv2, 679 | mem_conv3, 680 | mem_conv4, 681 | mem_conv5, 682 | mem_conv6, 683 | mem_conv7, 684 | ] 685 | 686 | mem_fc1 = torch.zeros(batch_size, 1024).cuda() 687 | mem_fc2 = torch.zeros(batch_size, self.num_cls).cuda() 688 | 689 | for t in range(self.num_steps): 690 | 691 | spike_inp = PoissonGen(inp) 692 | out_prev = spike_inp 693 | 694 | for i in range(len(self.conv_list)): 695 | # charging and firing 696 | mem_conv_list[i] = (1-self.leak_mem) * mem_conv_list[i] + self.leak_mem*self.conv_list[i](out_prev) 697 | mem_thr = (mem_conv_list[i] / self.conv_list[i].threshold) - 1.0 698 | out = self.spike_fn(mem_thr) 699 | 700 | # Soft reset 701 | rst = torch.zeros_like(mem_conv_list[i]).cuda() 702 | rst[mem_thr > 0] = self.conv_list[i].threshold 703 | mem_conv_list[i] = mem_conv_list[i] - rst 704 | out_prev = out.clone() 705 | 706 | if self.pool_list[i] is not False: 707 | out = self.pool_list[i](out_prev) 708 | out_prev = out.clone() 709 | 710 | 711 | 712 | out_prev = out_prev.reshape(batch_size, -1) 713 | mem_fc1 = (1-self.leak_mem) * mem_fc1 + self.leak_mem* self.fc1(out_prev) 714 | mem_thr = (mem_fc1 / self.fc1.threshold) - 1.0 715 | out = self.spike_fn(mem_thr) 716 | rst = torch.zeros_like(mem_fc1).cuda() 717 | rst[mem_thr > 0] = self.fc1.threshold 718 | mem_fc1 = mem_fc1 - rst 719 | out_prev = out.clone() 720 | 721 | # accumulate voltage in the last layer 722 | mem_fc2 = mem_fc2 + self.fc2(out_prev) 723 | 724 | out_voltage = mem_fc2 / self.num_steps 725 | 726 | return out_voltage 727 | 728 | 729 | 730 | class AverageMeter(object): 731 | """ 732 | Computes and stores the average and current value 733 | """ 734 | 735 | def __init__(self): 736 | self.reset() 737 | 738 | def reset(self): 739 | self.val = 0 740 | self.avg = 0 741 | self.sum = 0 742 | self.count = 0 743 | 744 | def update(self, val, n=1): 745 | self.val = val 746 | self.sum += val * n 747 | self.count += n 748 | self.avg = self.sum / self.count 749 | 750 | 751 | 752 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | import torch.nn.functional as F 6 | # from torch.utils.tensorboard import SummaryWriter 7 | 8 | 9 | import models.model as model 10 | from util import adjust_learning_rate, accuracy, AverageMeter 11 | import torchvision 12 | from torchvision import transforms 13 | 14 | 15 | import numpy as np 16 | import os 17 | import sys 18 | import time 19 | import argparse 20 | 21 | 22 | ############## Reproducibility ############## 23 | # seed = 2021 24 | # np.random.seed(seed) 25 | # torch.manual_seed(seed) 26 | # torch.cuda.manual_seed_all(seed) 27 | # torch.backends.cudnn.deterministic = True 28 | # torch.backends.cudnn.benchmark = False 29 | ############################################# 30 | 31 | parser = argparse.ArgumentParser() 32 | 33 | parser.add_argument("--batch_size", default=128, type=int, help="Batch size") 34 | parser.add_argument('--lr', type=float, default=1e-3) 35 | parser.add_argument('--gpu', type=str, default='0') 36 | parser.add_argument('--dump-dir', type=str, default="logdir") 37 | parser.add_argument("--encode", default="d", type=str, help="Encoding [p d]") 38 | parser.add_argument("--arch", default="vgg9", type=str, help="Arch [mlp, lenet, vgg9, cifar10net]") 39 | parser.add_argument("--dataset", default="cifar10", type=str, help="Dataset [mnist, cifar10, cifar100]") 40 | parser.add_argument("--optim", default='adam', type=str, help="Optimizer [adam, sgd]") 41 | parser.add_argument('--leak_mem',default=0.5, type=float) 42 | parser.add_argument('--T', type=int, default=8) 43 | parser.add_argument('--epoch', type=int, default=100) 44 | parser.add_argument("--seed", default=0, type=int, help="Random seed") 45 | parser.add_argument("--num_workers", default=4, type=int, help="number of workers") 46 | parser.add_argument("--train_display_freq", default=10, type=int, help="display_freq for train") 47 | parser.add_argument("--test_display_freq", default=10, type=int, help="display_freq for test") 48 | parser.add_argument("--setting", type=str, help="display_freq for test") 49 | 50 | 51 | 52 | 53 | args = parser.parse_args() 54 | 55 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 56 | 57 | batch_size = args.batch_size 58 | lr = args.lr 59 | leak_mem = args.leak_mem 60 | 61 | dataset_dir = '/gpfs/project/panda/shared' 62 | dump_dir = args.dump_dir 63 | 64 | arch_prefix = args.dataset +"_" + args.arch + "_" + args.encode 65 | file_prefix = "T" + str(args.T) + "_lr" + str(args.lr) + "_epoch" + str(args.epoch) + "_leak" + str(args.leak_mem) 66 | 67 | print('{}'.format(args.setting)) 68 | 69 | print("arch : {} ".format(arch_prefix)) 70 | print("hyperparam : {} ".format(file_prefix)) 71 | 72 | log_dir = os.path.join(dump_dir, 'logs', arch_prefix, file_prefix) 73 | model_dir = os.path.join(dump_dir, 'models', arch_prefix, file_prefix) 74 | 75 | file_prefix = file_prefix + '.pkg' 76 | 77 | if not os.path.exists(log_dir): 78 | os.makedirs(log_dir) 79 | 80 | if not os.path.exists(model_dir): 81 | os.makedirs(model_dir) 82 | 83 | 84 | T = args.T 85 | N = args.epoch 86 | 87 | file_prefix = 'lr-' + np.format_float_scientific(lr, exp_digits=1, trim='-') + f'-b-{batch_size}-T-{T}' 88 | 89 | # Data augmentation 90 | img_size = { 91 | 'mnist' : 28, 92 | 'cifar10': 32, 93 | 'cifar100': 32, 94 | } 95 | 96 | num_cls = { 97 | 'mnist' : 10, 98 | 'cifar10': 10, 99 | 'cifar100': 100, 100 | } 101 | 102 | mean = { 103 | 'mnist' : 0.1307, 104 | 'cifar10': (0.4914, 0.4822, 0.4465), 105 | 'cifar100': (0.5071, 0.4867, 0.4408), 106 | } 107 | 108 | std = { 109 | 'mnist' : 0.3081, 110 | 'cifar10': (0.2023, 0.1994, 0.2010), 111 | 'cifar100': (0.2675, 0.2565, 0.2761), 112 | } 113 | 114 | if args.dataset == 'mnist': 115 | input_dim = 1 116 | else: 117 | input_dim = 3 118 | 119 | 120 | img_size = img_size[args.dataset] 121 | num_cls = num_cls[args.dataset] 122 | 123 | if args.dataset == 'mnist': 124 | # Data augmentation 125 | transform_train = transforms.Compose([ 126 | transforms.RandomAffine(degrees=30, translate=(0.15, 0.15), scale=(0.85, 1.11)), 127 | transforms.ToTensor(), 128 | transforms.Normalize(mean[args.dataset], std[args.dataset]), 129 | ]) 130 | transform_test = transforms.Compose([ 131 | transforms.ToTensor(), 132 | transforms.Normalize(0.1307, 0.3081), 133 | ]) 134 | 135 | train_dataset = torchvision.datasets.MNIST( 136 | root=dataset_dir, 137 | train=True, 138 | transform=transform_train, 139 | download=True) 140 | 141 | test_dataset = torchvision.datasets.MNIST( 142 | root=dataset_dir, 143 | train=False, 144 | transform=transform_test, 145 | download=True) 146 | 147 | train_data_loader = torch.utils.data.DataLoader( 148 | dataset=train_dataset, 149 | batch_size=batch_size, 150 | shuffle=True, 151 | drop_last=True, 152 | num_workers=8, 153 | pin_memory=True) 154 | 155 | test_data_loader = torch.utils.data.DataLoader( 156 | dataset=test_dataset, 157 | batch_size=batch_size, 158 | shuffle=False, 159 | drop_last=False, 160 | num_workers=8, 161 | pin_memory=True) 162 | 163 | elif args.dataset == 'cifar10': 164 | transform_train = transforms.Compose([ 165 | transforms.RandomCrop(32, padding=4), 166 | transforms.RandomHorizontalFlip(), 167 | transforms.ToTensor(), 168 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 169 | ]) 170 | 171 | transform_test = transforms.Compose([ 172 | transforms.ToTensor(), 173 | transforms.Normalize(mean[args.dataset], std[args.dataset]) 174 | ]) 175 | 176 | train_dataset = torchvision.datasets.CIFAR10( 177 | root=dataset_dir, 178 | train=True, 179 | transform=transform_train, 180 | download=True) 181 | 182 | test_dataset = torchvision.datasets.CIFAR10( 183 | root=dataset_dir, 184 | train=False, 185 | transform=transform_test, 186 | download=True) 187 | 188 | train_data_loader = torch.utils.data.DataLoader( 189 | dataset=train_dataset, 190 | batch_size=batch_size, 191 | shuffle=True, 192 | drop_last=True, 193 | num_workers=4, 194 | pin_memory=True) 195 | 196 | test_data_loader = torch.utils.data.DataLoader( 197 | dataset=test_dataset, 198 | batch_size=batch_size, 199 | shuffle=False, 200 | drop_last=False, 201 | num_workers=4, 202 | pin_memory=True) 203 | 204 | 205 | elif args.dataset == 'cifar100': 206 | 207 | transform_train = transforms.Compose([ 208 | transforms.RandomCrop(32, padding=4), 209 | transforms.RandomHorizontalFlip(), 210 | transforms.ToTensor(), 211 | transforms.Normalize(mean[args.dataset], std[args.dataset]), 212 | ]) 213 | 214 | transform_test = transforms.Compose([ 215 | transforms.ToTensor(), 216 | transforms.Normalize(mean[args.dataset], std[args.dataset]) 217 | ]) 218 | 219 | train_dataset = torchvision.datasets.CIFAR100( 220 | root=dataset_dir, 221 | train=True, 222 | transform=transform_train, 223 | download=True) 224 | 225 | test_dataset = torchvision.datasets.CIFAR100( 226 | root=dataset_dir, 227 | train=False, 228 | transform=transform_test, 229 | download=True) 230 | 231 | train_data_loader = torch.utils.data.DataLoader( 232 | dataset=train_dataset, 233 | batch_size=batch_size, 234 | shuffle=True, 235 | drop_last=True, 236 | num_workers=4, 237 | pin_memory=True) 238 | 239 | test_data_loader = torch.utils.data.DataLoader( 240 | dataset=test_dataset, 241 | batch_size=batch_size, 242 | shuffle=False, 243 | drop_last=False, 244 | num_workers=4, 245 | pin_memory=True) 246 | 247 | if args.encode == 'd': 248 | if args.arch == 'mlp': 249 | net = model.MLP_Direct(num_steps=T, leak_mem= leak_mem, img_size = img_size, input_dim = input_dim).cuda() 250 | print(f'Create new model') 251 | elif args.arch == 'vgg5': 252 | net = model.VGG5_Direct(num_steps=T, leak_mem= leak_mem, img_size = img_size, input_dim = input_dim, num_cls = num_cls).cuda() 253 | print(f'Create new model') 254 | elif args.arch == 'vgg9': 255 | net = model.VGG9_Direct(num_steps=T, leak_mem= leak_mem, img_size = img_size, input_dim = input_dim, num_cls = num_cls).cuda() 256 | print(f'Create new model') 257 | else: 258 | print(f'Not implemented Err - Architecture') 259 | exit() 260 | 261 | elif args.encode == 'p': 262 | if args.arch == 'mlp': 263 | net = model.MLP_Poisson(num_steps=T, leak_mem= leak_mem, input_dim = input_dim).cuda() 264 | print(f'Create new model') 265 | elif args.arch == 'vgg5': 266 | net = model.VGG5_Poisson(num_steps=T, leak_mem= leak_mem, img_size = img_size, input_dim = input_dim, num_cls = num_cls).cuda() 267 | print(f'Create new model') 268 | elif args.arch == 'vgg9': 269 | net = model.VGG9_Poisson(num_steps=T, leak_mem= leak_mem, input_dim = input_dim, img_size=img_size, num_cls = num_cls).cuda() 270 | print(f'Create new model') 271 | else: 272 | print(f'Not implemented Err - Architecture') 273 | exit() 274 | 275 | else: 276 | print(f'Not implemented Err - Encoding') 277 | exit() 278 | 279 | 280 | # print(net) 281 | 282 | max_test_accuracy = 0 283 | 284 | # Training Loop 285 | net= net.cuda() 286 | 287 | # Configure the loss function and optimizer 288 | criterion = nn.CrossEntropyLoss() 289 | if args.optim == 'sgd': 290 | optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum = 0.9, weight_decay=1e-4) 291 | else: 292 | optimizer = optim.Adam(net.parameters(), lr=args.lr) 293 | # scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=40, gamma=0.1) 294 | best_acc = 0 295 | 296 | # Print the SNN model, optimizer, and simulation parameters 297 | print("********** SNN simulation parameters **********") 298 | print("Simulation # time-step : {}".format(T)) 299 | print("Membrane decay rate : {0:.2f}\n".format(args.leak_mem)) 300 | print("********** SNN learning parameters **********") 301 | print("Backprop optimizer : SGD") 302 | print("Batch size (training) : {}".format(batch_size)) 303 | print("Batch size (testing) : {}".format(batch_size*2)) 304 | print("Number of epochs : {}".format(args.epoch)) 305 | print("Learning rate : {}".format(lr)) 306 | 307 | # -------------------------------------------------- 308 | # Train the SNN using surrogate gradients 309 | # -------------------------------------------------- 310 | print("********** SNN training and evaluation **********") 311 | train_loss_list = [] 312 | test_acc_list = [] 313 | start_epoch = 0 314 | 315 | 316 | for epoch in range(args.epoch): 317 | time_start = time.time() 318 | 319 | train_loss = AverageMeter() 320 | net.train() 321 | for i, data in enumerate(train_data_loader): 322 | inputs, labels = data 323 | inputs = inputs.cuda() 324 | labels = labels.cuda() 325 | 326 | optimizer.zero_grad() 327 | output = net(inputs) 328 | 329 | loss = criterion(output, labels) 330 | prec1, prec5 = accuracy(output, labels, topk=(1, 5)) 331 | train_loss.update(loss.item(), labels.size(0)) 332 | loss.backward() 333 | optimizer.step() 334 | 335 | if (epoch + 1) % args.train_display_freq == 0: 336 | print( 337 | "Epoch: {}/{};".format(epoch + 1, args.epoch), 338 | "########## Training loss: {}".format(train_loss.avg), 339 | ) 340 | 341 | adjust_learning_rate(optimizer, epoch, args.epoch) 342 | 343 | if (epoch + 1) % args.test_display_freq == 0: 344 | acc_top1, acc_top5 = [], [] 345 | net.eval() 346 | with torch.no_grad(): 347 | for j, data in enumerate(test_data_loader): 348 | images, labels = data 349 | images = images.cuda() 350 | labels = labels.cuda() 351 | 352 | out = net(images) 353 | prec1, prec5 = accuracy(out, labels, topk=(1, 5)) 354 | acc_top1.append(float(prec1)) 355 | acc_top5.append(float(prec5)) 356 | 357 | test_accuracy = np.mean(acc_top1) 358 | 359 | # Model save 360 | if best_acc < test_accuracy: 361 | best_acc = test_accuracy 362 | 363 | net_dict = { 364 | "global_step": epoch + 1, 365 | "state_dict": net.state_dict(), 366 | "optim" : optimizer.state_dict(), 367 | "accuracy": test_accuracy, 368 | } 369 | 370 | torch.save( 371 | net_dict, model_dir + "/" + "_bestmodel.pth.tar" 372 | ) 373 | print("best_accuracy : {}".format(best_acc)) 374 | 375 | time_end = time.time() 376 | print("best accracy in {} is : {}".format(arch_prefix + file_prefix, best_acc)) 377 | # print(f'Elapse: {time_end - time_start:.2f}s') 378 | 379 | 380 | sys.exit(0) 381 | 382 | -------------------------------------------------------------------------------- /train_models.sh: -------------------------------------------------------------------------------- 1 | # dataset / encode / arch / leak_mem / T (timestep) / 2 | # dataset [mnist, cifar10, cifar100] 3 | # encode [p, d] 4 | # arch [mlp, lenet, vgg9] 5 | # T = [8, 10, 15, 20] for d / [15 ,20, 30, 50] for p 6 | # leak_mem = [0.25, 0.5, 0.75, 0.9, 0.95, 0.99] 7 | # N = [120, 200] 8 | # lr [0.01, 0.001] 9 | 10 | 11 | python train_tau.py --dataset cifar100 --arch vgg9 --encode d --leak_mem 0.5 --T 30 --lr 1e-3 --batch_size 128 12 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | class Surrogate_BP_Function(torch.autograd.Function): 5 | @staticmethod 6 | def forward(ctx, input): 7 | ctx.save_for_backward(input) 8 | out = torch.zeros_like(input).cuda() 9 | out[input > 0] = 1.0 10 | return out 11 | 12 | @staticmethod 13 | def backward(ctx, grad_output): 14 | (input,) = ctx.saved_tensors 15 | grad_input = grad_output.clone() 16 | grad = grad_input * 0.3 * F.threshold(1.0 - torch.abs(input), 0, 0) 17 | return grad 18 | 19 | 20 | 21 | import torch 22 | 23 | 24 | 25 | 26 | def adjust_learning_rate(optimizer, cur_epoch, max_epoch): 27 | if ( 28 | cur_epoch == (max_epoch * 0.5) 29 | or cur_epoch == (max_epoch * 0.7) 30 | or cur_epoch == (max_epoch * 0.9) 31 | ): 32 | for param_group in optimizer.param_groups: 33 | param_group["lr"] /= 10 34 | 35 | def accuracy(outp, target, topk=(1,)): 36 | """Computes the precision@k for the specified values of k""" 37 | with torch.no_grad(): 38 | maxk = max(topk) 39 | batch_size = target.size(0) 40 | 41 | _, pred = outp.topk(maxk, 1, True, True) 42 | pred = pred.t() 43 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 44 | 45 | res = [] 46 | for k in topk: 47 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 48 | res.append(correct_k.mul_(100.0 / batch_size)) 49 | return res 50 | 51 | 52 | class AverageMeter(object): 53 | """ 54 | Computes and stores the average and current value 55 | """ 56 | 57 | def __init__(self): 58 | self.reset() 59 | 60 | def reset(self): 61 | self.val = 0 62 | self.avg = 0 63 | self.sum = 0 64 | self.count = 0 65 | 66 | def update(self, val, n=1): 67 | self.val = val 68 | self.sum += val * n 69 | self.count += n 70 | self.avg = self.sum / self.count 71 | --------------------------------------------------------------------------------