├── utills.py ├── README.md ├── train.py └── model.py /utills.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def adjust_learning_rate(optimizer, cur_epoch, max_epoch): 5 | if cur_epoch == (max_epoch*0.5) or cur_epoch == (max_epoch*0.7) or cur_epoch==(max_epoch*0.9): 6 | for param_group in optimizer.param_groups: 7 | param_group['lr'] /= 10 8 | 9 | 10 | def accuracy(outp, target, topk=(1,)): 11 | """Computes the precision@k for the specified values of k""" 12 | with torch.no_grad(): 13 | maxk = max(topk) 14 | batch_size = target.size(0) 15 | 16 | _, pred = outp.topk(maxk, 1, True, True) 17 | pred = pred.t() 18 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 19 | 20 | res = [] 21 | for k in topk: 22 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 23 | res.append(correct_k.mul_(100.0 / batch_size)) 24 | return res 25 | 26 | 27 | class AverageMeter(object): 28 | """ 29 | Computes and stores the average and current value 30 | """ 31 | 32 | def __init__(self): 33 | self.reset() 34 | 35 | def reset(self): 36 | self.val = 0 37 | self.avg = 0 38 | self.sum = 0 39 | self.count = 0 40 | 41 | def update(self, val, n=1): 42 | self.val = val 43 | self.sum += val * n 44 | self.count += n 45 | self.avg = self.sum / self.count 46 | 47 | 48 | 49 | 50 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # BNTT-Batch-Normalization-Through-Time 2 | 3 | This repository contains the source code associated with [arXiv preprint arXiv:2010.01729][arXiv preprint arXiv:2010.01729] 4 | 5 | Accepted to Frontiers in Neuroscience (2021) 6 | 7 | [arXiv preprint arXiv:2010.01729]: https://arxiv.org/abs/2010.01729 8 | 9 | ## Introduction 10 | 11 | Spiking Neural Networks (SNNs) have recently emerged as an alternative to deep learning owing to sparse, asynchronous and binary event (or spike) driven processing, that can yield huge energy efficiency benefits on neuromorphic hardware. However, training high-accuracy and low-latency SNNs from scratch suffers from non-differentiable nature of a spiking neuron. To address this training issue in SNNs, we revisit batch normalization and propose a temporal Batch Normalization Through Time (BNTT) technique. Most prior SNN works till now have disregarded batch normalization deeming it ineffective for training temporal SNNs. Different from previous works, our proposed BNTT decouples the parameters in a BNTT layer along the time axis to capture the temporal dynamics of spikes. The temporally evolving learnable parameters in BNTT allow a neuron to control its spike rate through different time-steps, enabling low-latency and low-energy training from scratch. We conduct experiments on CIFAR-10, CIFAR-100, Tiny-ImageNet and event-driven DVS-CIFAR10 datasets. BNTT allows us to train deep SNN architectures from scratch, for the first time, on complex datasets with just few 25-30 time-steps. We also propose an early exit algorithm using the distribution of parameters in BNTT to reduce the latency at inference, that further improves the energy-efficiency. 12 | 13 | 14 | ## Prerequisites 15 | * Ubuntu 18.04 16 | * Python 3.6+ 17 | * PyTorch 1.5+ (recent version is recommended) 18 | * NVIDIA GPU (>= 12GB) 19 | 20 | ## Getting Started 21 | 22 | ### Installation 23 | * Configure virtual (anaconda) environment 24 | ``` 25 | conda create -n env_name python=3.7 26 | source activate env_name 27 | conda install pytorch torchvision cudatoolkit=10.0 -c pytorch 28 | ``` 29 | 30 | 31 | ## Training and testing 32 | 33 | * We provide VGG9/VGG11 architectures on CIFAR10/CIAR100 datasets 34 | * ```train.py```: code for training 35 | * ```model.py```: code for VGG9/VGG11 Spiking Neural Networks with BNTT 36 | * ```utill.py```: code for accuracy calculation / learning rate scheduler 37 | 38 | * Run the following command for VGG9 SNN on CIFAR10 39 | 40 | ``` 41 | python train.py --num_steps 25 --lr 0.3 --arch 'vgg9' --dataset 'cifar10' --batch_size 256 --leak_mem 0.95 --num_workers 4 --num_epochs 100 42 | ``` 43 | 44 | * Run the following command for VGG11 SNN on CIFAR100 45 | 46 | ``` 47 | python train.py --num_steps 30 --lr 0.3 --arch 'vgg11' --dataset 'cifar100' --batch_size 128 --leak_mem 0.99 --num_workers 4 --num_epochs 100 48 | ``` 49 | 50 | 51 | ## Citation 52 | 53 | Please consider citing our paper: 54 | ``` 55 | @article{kim2020revisiting, 56 | title={Revisiting Batch Normalization for Training Low-latency Deep Spiking Neural Networks from Scratch}, 57 | author={Kim, Youngeun and Panda, Priyadarshini}, 58 | journal={arXiv preprint arXiv:2010.01729}, 59 | year={2020} 60 | } 61 | ``` 62 | 63 | 64 | 65 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | ############################################# 2 | # @author: Youngeun Kim and Priya Panda # 3 | ############################################# 4 | #-------------------------------------------------- 5 | # Imports 6 | #-------------------------------------------------- 7 | import torch.optim as optim 8 | import torchvision 9 | from torch.utils.data.dataloader import DataLoader 10 | from torchvision import transforms 11 | from model import * 12 | 13 | from PIL import ImageFile 14 | ImageFile.LOAD_TRUNCATED_IMAGES = True 15 | import argparse 16 | import os.path 17 | import numpy as np 18 | import torch.backends.cudnn as cudnn 19 | from utills import * 20 | 21 | cudnn.benchmark = True 22 | cudnn.deterministic = True 23 | 24 | #-------------------------------------------------- 25 | # Parse input arguments 26 | #-------------------------------------------------- 27 | parser = argparse.ArgumentParser(description='SNN trained with BNTT', formatter_class=argparse.ArgumentDefaultsHelpFormatter) 28 | parser.add_argument('--seed', default=0, type=int, help='Random seed') 29 | parser.add_argument('--num_steps', default=25, type=int, help='Number of time-step') 30 | parser.add_argument('--batch_size', default=128, type=int, help='Batch size') 31 | parser.add_argument('--lr', default=0.1, type=float, help='Learning rate') 32 | parser.add_argument('--leak_mem', default=0.99, type=float, help='Leak_mem') 33 | parser.add_argument('--arch', default='vgg11', type=str, help='Dataset [vgg9, vgg11]') 34 | parser.add_argument('--dataset', default='cifar100', type=str, help='Dataset [cifar10, cifar100]') 35 | parser.add_argument('--num_epochs', default=120, type=int, help='Number of epochs') 36 | parser.add_argument('--num_workers', default=4, type=int, help='number of workers') 37 | parser.add_argument('--train_display_freq', default=10, type=int, help='display_freq for train') 38 | parser.add_argument('--test_display_freq', default=10, type=int, help='display_freq for test') 39 | 40 | 41 | global args 42 | args = parser.parse_args() 43 | 44 | 45 | #-------------------------------------------------- 46 | # Initialize tensorboard setting 47 | #-------------------------------------------------- 48 | log_dir = 'modelsave' 49 | if os.path.isdir(log_dir) is not True: 50 | os.mkdir(log_dir) 51 | 52 | 53 | user_foldername = (args.dataset)+(args.arch)+'_timestep'+str(args.num_steps) +'_lr'+str(args.lr) + '_epoch' + str(args.num_epochs) + '_leak' + str(args.leak_mem) 54 | 55 | 56 | 57 | #-------------------------------------------------- 58 | # Initialize seed 59 | #-------------------------------------------------- 60 | seed = args.seed 61 | np.random.seed(seed) 62 | torch.manual_seed(seed) 63 | torch.cuda.manual_seed_all(seed) 64 | 65 | #-------------------------------------------------- 66 | # SNN configuration parameters 67 | #-------------------------------------------------- 68 | # Leaky-Integrate-and-Fire (LIF) neuron parameters 69 | leak_mem = args.leak_mem 70 | 71 | # SNN learning and evaluation parameters 72 | batch_size = args.batch_size 73 | batch_size_test = args.batch_size*2 74 | num_epochs = args.num_epochs 75 | num_steps = args.num_steps 76 | lr = args.lr 77 | 78 | 79 | #-------------------------------------------------- 80 | # Load dataset 81 | #-------------------------------------------------- 82 | 83 | transform_train = transforms.Compose([ 84 | transforms.RandomCrop(32, padding=4), 85 | transforms.RandomHorizontalFlip(), 86 | transforms.ToTensor(), 87 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 88 | ]) 89 | 90 | transform_test = transforms.Compose([ 91 | transforms.ToTensor(), 92 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 93 | ]) 94 | 95 | 96 | if args.dataset == 'cifar10': 97 | num_cls = 10 98 | img_size = 32 99 | 100 | train_set = torchvision.datasets.CIFAR10(root='./data', train=True, 101 | download=True, transform=transform_train) 102 | test_set = torchvision.datasets.CIFAR10(root='./data', train=False, 103 | download=True, transform=transform_test) 104 | elif args.dataset == 'cifar100': 105 | num_cls = 100 106 | img_size = 32 107 | 108 | train_set = torchvision.datasets.CIFAR100(root='./data', train=True, 109 | download=True, transform=transform_train) 110 | test_set = torchvision.datasets.CIFAR100(root='./data', train=False, 111 | download=True, transform=transform_test) 112 | else: 113 | print("not implemented yet..") 114 | exit() 115 | 116 | 117 | 118 | trainloader = torch.utils.data.DataLoader(train_set, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True, drop_last=True) 119 | testloader = torch.utils.data.DataLoader(test_set, batch_size=args.batch_size*2, shuffle=False, num_workers=args.num_workers, pin_memory=True, drop_last=True) 120 | 121 | 122 | 123 | #-------------------------------------------------- 124 | # Instantiate the SNN model and optimizer 125 | #-------------------------------------------------- 126 | if args.arch == 'vgg9': 127 | model = SNN_VGG9_BNTT(num_steps = num_steps, leak_mem=leak_mem, img_size=img_size, num_cls=num_cls) 128 | elif args.arch == 'vgg11': 129 | model = SNN_VGG11_BNTT(num_steps = num_steps, leak_mem=leak_mem, img_size=img_size, num_cls=num_cls) 130 | else: 131 | print("not implemented yet..") 132 | exit() 133 | 134 | model = model.cuda() 135 | 136 | # Configure the loss function and optimizer 137 | criterion = nn.CrossEntropyLoss() 138 | optimizer = optim.SGD(model.parameters(), lr=args.lr,momentum=0.9,weight_decay=1e-4) 139 | scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=40, gamma=0.1) 140 | best_acc = 0 141 | 142 | # Print the SNN model, optimizer, and simulation parameters 143 | print('********** SNN simulation parameters **********') 144 | print('Simulation # time-step : {}'.format(num_steps)) 145 | print('Membrane decay rate : {0:.2f}\n'.format(leak_mem)) 146 | 147 | print('********** SNN learning parameters **********') 148 | print('Backprop optimizer : SGD') 149 | print('Batch size (training) : {}'.format(batch_size)) 150 | print('Batch size (testing) : {}'.format(batch_size_test)) 151 | print('Number of epochs : {}'.format(num_epochs)) 152 | print('Learning rate : {}'.format(lr)) 153 | 154 | #-------------------------------------------------- 155 | # Train the SNN using surrogate gradients 156 | #-------------------------------------------------- 157 | print('********** SNN training and evaluation **********') 158 | train_loss_list = [] 159 | test_acc_list = [] 160 | 161 | for epoch in range(num_epochs): 162 | train_loss = AverageMeter() 163 | model.train() 164 | for i, data in enumerate(trainloader): 165 | inputs, labels = data 166 | inputs = inputs.cuda() 167 | labels = labels.cuda() 168 | 169 | optimizer.zero_grad() 170 | output = model(inputs) 171 | 172 | loss = criterion(output, labels) 173 | 174 | prec1, prec5 = accuracy(output, labels, topk=(1, 5)) 175 | train_loss.update(loss.item(), labels.size(0)) 176 | 177 | loss.backward() 178 | optimizer.step() 179 | 180 | if (epoch+1) % args.train_display_freq ==0: 181 | print("Epoch: {}/{};".format(epoch+1, num_epochs), "########## Training loss: {}".format(train_loss.avg)) 182 | 183 | adjust_learning_rate(optimizer, epoch, num_epochs) 184 | 185 | 186 | 187 | if (epoch+1) % args.test_display_freq ==0: 188 | acc_top1, acc_top5 = [], [] 189 | model.eval() 190 | with torch.no_grad(): 191 | for j, data in enumerate(testloader, 0): 192 | 193 | images, labels = data 194 | images = images.cuda() 195 | labels = labels.cuda() 196 | 197 | out = model(images) 198 | prec1, prec5 = accuracy(out, labels, topk=(1, 5)) 199 | acc_top1.append(float(prec1)) 200 | acc_top5.append(float(prec5)) 201 | 202 | 203 | test_accuracy = np.mean(acc_top1) 204 | print ("test_accuracy : {}". format(test_accuracy)) 205 | 206 | 207 | # Model save 208 | if best_acc < test_accuracy: 209 | best_acc = test_accuracy 210 | 211 | model_dict = { 212 | 'global_step': epoch + 1, 213 | 'state_dict': model.state_dict(), 214 | 'accuracy': test_accuracy} 215 | 216 | torch.save(model_dict, log_dir+'/'+user_foldername+'_bestmodel.pth.tar') 217 | 218 | 219 | sys.exit(0) 220 | 221 | 222 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import sys 5 | 6 | 7 | 8 | 9 | class Surrogate_BP_Function(torch.autograd.Function): 10 | 11 | 12 | @staticmethod 13 | def forward(ctx, input): 14 | ctx.save_for_backward(input) 15 | out = torch.zeros_like(input).cuda() 16 | out[input > 0] = 1.0 17 | return out 18 | 19 | @staticmethod 20 | def backward(ctx, grad_output): 21 | input, = ctx.saved_tensors 22 | grad_input = grad_output.clone() 23 | grad = grad_input * 0.3 * F.threshold(1.0 - torch.abs(input), 0, 0) 24 | return grad 25 | 26 | 27 | def PoissonGen(inp, rescale_fac=2.0): 28 | rand_inp = torch.rand_like(inp).cuda() 29 | return torch.mul(torch.le(rand_inp * rescale_fac, torch.abs(inp)).float(), torch.sign(inp)) 30 | 31 | 32 | 33 | 34 | 35 | 36 | class SNN_VGG9_BNTT(nn.Module): 37 | def __init__(self, num_steps, leak_mem=0.95, img_size=32, num_cls=10): 38 | super(SNN_VGG9_BNTT, self).__init__() 39 | 40 | self.img_size = img_size 41 | self.num_cls = num_cls 42 | self.num_steps = num_steps 43 | self.spike_fn = Surrogate_BP_Function.apply 44 | self.leak_mem = leak_mem 45 | self.batch_num = self.num_steps 46 | 47 | print (">>>>>>>>>>>>>>>>>>> VGG 9 >>>>>>>>>>>>>>>>>>>>>>") 48 | print ("***** time step per batchnorm".format(self.batch_num)) 49 | print (">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>") 50 | 51 | affine_flag = True 52 | bias_flag = False 53 | 54 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=bias_flag) 55 | self.bntt1 = nn.ModuleList([nn.BatchNorm2d(64, eps=1e-4, momentum=0.1, affine=affine_flag) for i in range(self.batch_num)]) 56 | self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=bias_flag) 57 | self.bntt2 = nn.ModuleList([nn.BatchNorm2d(64, eps=1e-4, momentum=0.1, affine=affine_flag) for i in range(self.batch_num)]) 58 | self.pool1 = nn.AvgPool2d(kernel_size=2) 59 | 60 | self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=bias_flag) 61 | self.bntt3 = nn.ModuleList([nn.BatchNorm2d(128, eps=1e-4, momentum=0.1, affine=affine_flag) for i in range(self.batch_num)]) 62 | self.conv4 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=bias_flag) 63 | self.bntt4 = nn.ModuleList([nn.BatchNorm2d(128, eps=1e-4, momentum=0.1, affine=affine_flag) for i in range(self.batch_num)]) 64 | self.pool2 = nn.AvgPool2d(kernel_size=2) 65 | 66 | self.conv5 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1, bias=bias_flag) 67 | self.bntt5 = nn.ModuleList([nn.BatchNorm2d(256, eps=1e-4, momentum=0.1, affine=affine_flag) for i in range(self.batch_num)]) 68 | self.conv6 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=bias_flag) 69 | self.bntt6 = nn.ModuleList([nn.BatchNorm2d(256, eps=1e-4, momentum=0.1, affine=affine_flag) for i in range(self.batch_num)]) 70 | self.conv7 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=bias_flag) 71 | self.bntt7 = nn.ModuleList([nn.BatchNorm2d(256, eps=1e-4, momentum=0.1, affine=affine_flag) for i in range(self.batch_num)]) 72 | self.pool3 = nn.AvgPool2d(kernel_size=2) 73 | 74 | 75 | self.fc1 = nn.Linear((self.img_size//8)*(self.img_size//8)*256, 1024, bias=bias_flag) 76 | self.bntt_fc = nn.ModuleList([nn.BatchNorm1d(1024, eps=1e-4, momentum=0.1, affine=affine_flag) for i in range(self.batch_num)]) 77 | self.fc2 = nn.Linear(1024, self.num_cls, bias=bias_flag) 78 | 79 | self.conv_list = [self.conv1, self.conv2, self.conv3, self.conv4, self.conv5, self.conv6, self.conv7] 80 | self.bntt_list = [self.bntt1, self.bntt2, self.bntt3, self.bntt4, self.bntt5, self.bntt6, self.bntt7, self.bntt_fc] 81 | self.pool_list = [False, self.pool1, False, self.pool2, False, False, self.pool3] 82 | 83 | # Turn off bias of BNTT 84 | for bn_list in self.bntt_list: 85 | for bn_temp in bn_list: 86 | bn_temp.bias = None 87 | 88 | 89 | # Initialize the firing thresholds of all the layers 90 | for m in self.modules(): 91 | if (isinstance(m, nn.Conv2d)): 92 | m.threshold = 1.0 93 | torch.nn.init.xavier_uniform_(m.weight, gain=2) 94 | elif (isinstance(m, nn.Linear)): 95 | m.threshold = 1.0 96 | torch.nn.init.xavier_uniform_(m.weight, gain=2) 97 | 98 | 99 | 100 | 101 | def forward(self, inp): 102 | 103 | batch_size = inp.size(0) 104 | mem_conv1 = torch.zeros(batch_size, 64, self.img_size, self.img_size).cuda() 105 | mem_conv2 = torch.zeros(batch_size, 64, self.img_size, self.img_size).cuda() 106 | mem_conv3 = torch.zeros(batch_size, 128, self.img_size//2, self.img_size//2).cuda() 107 | mem_conv4 = torch.zeros(batch_size, 128, self.img_size//2, self.img_size//2).cuda() 108 | mem_conv5 = torch.zeros(batch_size, 256, self.img_size//4, self.img_size//4).cuda() 109 | mem_conv6 = torch.zeros(batch_size, 256, self.img_size//4, self.img_size//4).cuda() 110 | mem_conv7 = torch.zeros(batch_size, 256, self.img_size//4, self.img_size//4).cuda() 111 | mem_conv_list = [mem_conv1, mem_conv2, mem_conv3, mem_conv4, mem_conv5, mem_conv6, mem_conv7] 112 | 113 | mem_fc1 = torch.zeros(batch_size, 1024).cuda() 114 | mem_fc2 = torch.zeros(batch_size, self.num_cls).cuda() 115 | 116 | 117 | 118 | for t in range(self.num_steps): 119 | 120 | spike_inp = PoissonGen(inp) 121 | out_prev = spike_inp 122 | 123 | for i in range(len(self.conv_list)): 124 | mem_conv_list[i] = self.leak_mem * mem_conv_list[i] + self.bntt_list[i][t](self.conv_list[i](out_prev)) 125 | mem_thr = (mem_conv_list[i] / self.conv_list[i].threshold) - 1.0 126 | out = self.spike_fn(mem_thr) 127 | rst = torch.zeros_like(mem_conv_list[i]).cuda() 128 | rst[mem_thr > 0] = self.conv_list[i].threshold 129 | mem_conv_list[i] = mem_conv_list[i] - rst 130 | out_prev = out.clone() 131 | 132 | 133 | if self.pool_list[i] is not False: 134 | out = self.pool_list[i](out_prev) 135 | out_prev = out.clone() 136 | 137 | 138 | out_prev = out_prev.reshape(batch_size, -1) 139 | 140 | mem_fc1 = self.leak_mem * mem_fc1 + self.bntt_fc[t](self.fc1(out_prev)) 141 | mem_thr = (mem_fc1 / self.fc1.threshold) - 1.0 142 | out = self.spike_fn(mem_thr) 143 | rst = torch.zeros_like(mem_fc1).cuda() 144 | rst[mem_thr > 0] = self.fc1.threshold 145 | mem_fc1 = mem_fc1 - rst 146 | out_prev = out.clone() 147 | 148 | # accumulate voltage in the last layer 149 | mem_fc2 = mem_fc2 + self.fc2(out_prev) 150 | 151 | out_voltage = mem_fc2 / self.num_steps 152 | 153 | 154 | return out_voltage 155 | 156 | 157 | class SNN_VGG11_BNTT(nn.Module): 158 | def __init__(self, num_steps, leak_mem=0.95, img_size=32, num_cls=10): 159 | super(SNN_VGG11_BNTT, self).__init__() 160 | 161 | self.img_size = img_size 162 | self.num_cls = num_cls 163 | self.num_steps = num_steps 164 | self.spike_fn = Surrogate_BP_Function.apply 165 | self.leak_mem = leak_mem 166 | self.batch_num = self.num_steps 167 | 168 | print (">>>>>>>>>>>>>>>>> VGG11 >>>>>>>>>>>>>>>>>>>>>>>") 169 | print ("***** time step per batchnorm".format(self.batch_num)) 170 | print (">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>") 171 | 172 | affine_flag = True 173 | bias_flag = False 174 | 175 | 176 | 177 | 178 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=bias_flag) 179 | self.bntt1 = nn.ModuleList([nn.BatchNorm2d(64, eps=1e-4, momentum=0.1, affine=affine_flag) for i in range(self.batch_num)]) 180 | self.pool1 = nn.AvgPool2d(kernel_size=2) 181 | 182 | self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=bias_flag) 183 | self.bntt2 = nn.ModuleList([nn.BatchNorm2d(128, eps=1e-4, momentum=0.1, affine=affine_flag) for i in range(self.batch_num)]) 184 | self.pool2 = nn.AvgPool2d(kernel_size=2) 185 | 186 | self.conv3 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1, bias=bias_flag) 187 | self.bntt3 = nn.ModuleList([nn.BatchNorm2d(256, eps=1e-4, momentum=0.1, affine=affine_flag) for i in range(self.batch_num)]) 188 | self.conv4 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=bias_flag) 189 | self.bntt4 = nn.ModuleList([nn.BatchNorm2d(256, eps=1e-4, momentum=0.1, affine=affine_flag) for i in range(self.batch_num)]) 190 | self.pool3 = nn.AvgPool2d(kernel_size=2) 191 | 192 | self.conv5 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1, bias=bias_flag) 193 | self.bntt5 = nn.ModuleList([nn.BatchNorm2d(512, eps=1e-4, momentum=0.1, affine=affine_flag) for i in range(self.batch_num)]) 194 | self.conv6 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=bias_flag) 195 | self.bntt6 = nn.ModuleList([nn.BatchNorm2d(512, eps=1e-4, momentum=0.1, affine=affine_flag) for i in range(self.batch_num)]) 196 | self.pool4 = nn.AvgPool2d(kernel_size=2) 197 | 198 | self.conv7 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=bias_flag) 199 | self.bntt7 = nn.ModuleList([nn.BatchNorm2d(512, eps=1e-4, momentum=0.1, affine=affine_flag) for i in range(self.batch_num)]) 200 | self.conv8 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=bias_flag) 201 | self.bntt8 = nn.ModuleList([nn.BatchNorm2d(512, eps=1e-4, momentum=0.1, affine=affine_flag) for i in range(self.batch_num)]) 202 | self.pool5 = nn.AdaptiveAvgPool2d((1,1)) 203 | 204 | 205 | self.fc1 = nn.Linear(512, 4096, bias=bias_flag) 206 | self.bntt_fc = nn.ModuleList([nn.BatchNorm1d(4096, eps=1e-4, momentum=0.1, affine=affine_flag) for i in range(self.batch_num)]) 207 | self.fc2 = nn.Linear(4096, self.num_cls, bias=bias_flag) 208 | 209 | self.conv_list = [self.conv1, self.conv2, self.conv3, self.conv4, self.conv5, self.conv6, self.conv7, self.conv8] 210 | self.bntt_list = [self.bntt1, self.bntt2, self.bntt3, self.bntt4, self.bntt5, self.bntt6, self.bntt7, self.bntt8, self.bntt_fc] 211 | self.pool_list = [self.pool1, self.pool2, False, self.pool3, False, self.pool4, False, self.pool5] 212 | 213 | # Turn off bias of BNTT 214 | for bn_list in self.bntt_list: 215 | for bn_temp in bn_list: 216 | bn_temp.bias = None 217 | 218 | 219 | # Initialize the firing thresholds of all the layers 220 | for m in self.modules(): 221 | if (isinstance(m, nn.Conv2d)): 222 | m.threshold = 1.0 223 | torch.nn.init.xavier_uniform_(m.weight, gain=2) 224 | elif (isinstance(m, nn.Linear)): 225 | m.threshold = 1.0 226 | torch.nn.init.xavier_uniform_(m.weight, gain=2) 227 | 228 | 229 | 230 | 231 | def forward(self, inp): 232 | 233 | batch_size = inp.size(0) 234 | mem_conv1 = torch.zeros(batch_size, 64, self.img_size, self.img_size).cuda() 235 | mem_conv2 = torch.zeros(batch_size, 128, self.img_size // 2, self.img_size // 2).cuda() 236 | mem_conv3 = torch.zeros(batch_size, 256, self.img_size // 4, self.img_size // 4).cuda() 237 | mem_conv4 = torch.zeros(batch_size, 256, self.img_size // 4, self.img_size // 4).cuda() 238 | mem_conv5 = torch.zeros(batch_size, 512, self.img_size // 8, self.img_size // 8).cuda() 239 | mem_conv6 = torch.zeros(batch_size, 512, self.img_size // 8, self.img_size // 8).cuda() 240 | mem_conv7 = torch.zeros(batch_size, 512, self.img_size // 16, self.img_size // 16).cuda() 241 | mem_conv8 = torch.zeros(batch_size, 512, self.img_size // 16, self.img_size // 16).cuda() 242 | mem_conv_list = [mem_conv1, mem_conv2, mem_conv3, mem_conv4, mem_conv5, mem_conv6, mem_conv7, mem_conv8] 243 | 244 | mem_fc1 = torch.zeros(batch_size, 4096).cuda() 245 | mem_fc2 = torch.zeros(batch_size, self.num_cls).cuda() 246 | 247 | 248 | 249 | for t in range(self.num_steps): 250 | 251 | spike_inp = PoissonGen(inp) 252 | out_prev = spike_inp 253 | 254 | for i in range(len(self.conv_list)): 255 | mem_conv_list[i] = self.leak_mem * mem_conv_list[i] + self.bntt_list[i][t](self.conv_list[i](out_prev)) 256 | mem_thr = (mem_conv_list[i] / self.conv_list[i].threshold) - 1.0 257 | out = self.spike_fn(mem_thr) 258 | rst = torch.zeros_like(mem_conv_list[i]).cuda() 259 | rst[mem_thr > 0] = self.conv_list[i].threshold 260 | mem_conv_list[i] = mem_conv_list[i] - rst 261 | out_prev = out.clone() 262 | 263 | 264 | if self.pool_list[i] is not False: 265 | out = self.pool_list[i](out_prev) 266 | out_prev = out.clone() 267 | 268 | 269 | out_prev = out_prev.reshape(batch_size, -1) 270 | 271 | mem_fc1 = self.leak_mem * mem_fc1 + self.bntt_fc[t](self.fc1(out_prev)) 272 | mem_thr = (mem_fc1 / self.fc1.threshold) - 1.0 273 | out = self.spike_fn(mem_thr) 274 | rst = torch.zeros_like(mem_fc1).cuda() 275 | rst[mem_thr > 0] = self.fc1.threshold 276 | mem_fc1 = mem_fc1 - rst 277 | out_prev = out.clone() 278 | 279 | # accumulate voltage in the last layer 280 | mem_fc2 = mem_fc2 + self.fc2(out_prev) 281 | 282 | 283 | out_voltage = mem_fc2 / self.num_steps 284 | 285 | return out_voltage 286 | 287 | 288 | 289 | --------------------------------------------------------------------------------