├── Fig ├── __init__.py ├── Capture.PNG ├── Training_procedure.png └── intro_vgg11_sa_and_attack_performance_plot.png ├── LICENSE ├── run_snn_tradit_test.py ├── run_snn_hire_test.py ├── README.md ├── attack_model_spike_cnt.py ├── vgg_spiking_nodewise.py ├── snn_free_nonorm_for_test_purpose.py └── snn_free_nonorm_for_bbtest_purpose.py /Fig/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /Fig/Capture.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ksouvik52/hiresnn2021/HEAD/Fig/Capture.PNG -------------------------------------------------------------------------------- /Fig/Training_procedure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ksouvik52/hiresnn2021/HEAD/Fig/Training_procedure.png -------------------------------------------------------------------------------- /Fig/intro_vgg11_sa_and_attack_performance_plot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ksouvik52/hiresnn2021/HEAD/Fig/intro_vgg11_sa_and_attack_performance_plot.png -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Souvik Kundu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /run_snn_tradit_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | ######################################################################## 5 | #### ~~~~~~ White box testing ~~~~~~ 6 | ## Hereere, pretrained_snn is the model on which wb test will be done. 7 | ######################################################################## 8 | 9 | cmd1 = "python snn_free_nonorm_for_test_purpose.py --dataset CIFAR100 --batch_size 32 --architecture VGG11 \ 10 | --epochs 1 --timesteps 8 --leak 1.0 --devices 1 \ 11 | --pretrained_snn='traditional_models/vgg11_cifar100_tradit_model.pt'" 12 | os.system(cmd1) 13 | 14 | 15 | ######################################################################## 16 | #### ~~~~~~ Black box testing ~~~~~~ 17 | ## Here, pretrained_snn is the model on which bb test will be done. 18 | ## and pretrained_snn_bb is the model that generates the adversarial images. 19 | ## We take models of same variants trained with different seed. 20 | ######################################################################## 21 | cmd2 = "python snn_free_nonorm_for_bbtest_purpose.py --dataset CIFAR100 --batch_size 32 --architecture VGG11 \ 22 | --epochs 1 --timesteps 8 --leak 1.0 --devices 1 \ 23 | --pretrained_snn='traditional_models/vgg11_cifar100_tradit_model.pt'\ 24 | --pretrained_snn_bb='traditional_models/vgg11_cifar100_tradit_bb_test_model.pt'" 25 | os.system(cmd2) 26 | -------------------------------------------------------------------------------- /run_snn_hire_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | ######################################################################## 5 | #### ~~~~~~ White box testing ~~~~~~ 6 | ## Hereere, pretrained_snn is the model on which wb test will be done. 7 | ######################################################################## 8 | 9 | cmd1 = "python snn_free_nonorm_for_test_purpose.py --dataset CIFAR100 --batch_size 32 --architecture VGG11 \ 10 | --epochs 1 --timesteps 8 --leak 1.0 --devices 1 \ 11 | --pretrained_snn='HIRE_SNN_models/vgg11_cifar100_hiresnn_tstep8_model.pt'" 12 | os.system(cmd1) 13 | 14 | 15 | ######################################################################## 16 | #### ~~~~~~ Black box testing ~~~~~~ 17 | ## Here, pretrained_snn is the model on which bb test will be done. 18 | ## and pretrained_snn_bb is the model that generates the adversarial images. 19 | ## We take models of same variants trained with different seed. 20 | ######################################################################## 21 | cmd2 = "python snn_free_nonorm_for_bbtest_purpose.py --dataset CIFAR100 --batch_size 32 --architecture VGG11 \ 22 | --epochs 1 --timesteps 8 --leak 1.0 --devices 1 \ 23 | --pretrained_snn='HIRE_SNN_models/vgg11_cifar100_hiresnn_tstep8_model.pt'\ 24 | --pretrained_snn_bb='HIRE_SNN_models/vgg11_cifar100_hiresnn_tstep8_bb_test_model.pt'" 25 | os.system(cmd2) 26 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |


2 | 3 | [![License: MIT](https://img.shields.io/badge/License-MIT-green.svg)](https://opensource.org/licenses/MIT) 4 | 5 | 6 | Welcome to the official repo of the `ICCV 2021` paper **`HIRE-SNN: Harnessing the Inherent Robustness of Energy-Efficient Deep Spiking Neural Networks by Training with Crafted Input Noise`**. 7 | 8 | **This repo currently contains the test codes. Training code will be updated soon!** 9 | 10 | ### Authors: 11 | 1. **Souvik Kundu** (souvikku@usc.edu) 12 | 2. Massoud Pedram (pedram@usc.edu) 13 | 3. Peter A. Beerel (pabeerel@usc.edu) 14 | 15 | ### Abstract: 16 | Low-latency deep spiking neural networks (SNNs) havebecome a promising alternative to conventional artificial neural networks (ANNs) because of their potential for increased energy efficiency on event-driven neuromorphic hardware.Neural networks, including SNNs, however, are subject to various adversarial attacks and must be trained to remain resilient against such attacks for many applications. Nevertheless, due to prohibitively high training costs associated with SNNs, an analysis and optimization of deep SNNs under various adversarial attacks have beenlargely overlooked. In this paper, we first present a detailed analysis of the inherent robustness of low-latency SNNs against popular gradient-based attacks, namely fast gradient sign method (FGSM) and projected gradient descent (PGD). Motivated by this analysis, to harness themodel’s robustness against these attacks we present an SNN training algorithm that uses crafted input noise and incurs no additional training time. To evaluate the merits of our algorithm, we conducted extensive experiments with variants of VGG and ResNet on both CIFAR-10 and CIFAR-100 dataset. Compared to standard trained direct-input SNNs, our trained models yield improved classification accuracy of up to 13.7% and 10.1% on FGSM and PGD attack generated images, respectively, with negligible loss in clean image accuracy. Our models also outperform inherently-robust SNNs trained on rate-coded in-puts with improved or similar classification performanceon attack-generated images while having up to 25x and ∼4.6x lower latency and computation energy, respectively. 17 | 18 |


19 | 20 | ### Version on which the models were tested: 21 | 22 | * PyTorch version: `1.5.1`. 23 | * Python version: `3.8.3`. 24 | 25 | ### Model download: 26 | #### A. HIRE-SNN models: 27 | 1. [vgg11_cifar100_hiresnn_tstep8_bb_test_model](https://drive.google.com/file/d/1iDDaO3EZnEPk9JaAn2tCL5M80d8DWxv3/view?usp=sharing) 28 | 2. [vgg11_cifar100_hiresnn_tstep8_model](https://drive.google.com/file/d/1huAXyOzdwlVHS2xsbyLdwr_fgwbdLIGQ/view?usp=sharing) 29 | #### B. Traditional SNN models: 30 | 1. [vgg11_cifar100_tradit_bb_test_model](https://drive.google.com/file/d/1GNrDCu7uD8IBAea6sHue84YEwoGd8t1G/view?usp=sharing) 31 | 2. [vgg11_cifar100_tradit_model](https://drive.google.com/file/d/1sayiUKccwrn0X77hAm-eiZUzVxcO3KhH/view?usp=sharing) 32 | ### To test adversarial accuracy of a saved model, please follow these steps: 33 | Create two folders named *`HIRE_SNN_models`* and *`traditional_models`*. Download the models to their respective folder locations. 34 | #### 1. HIRE SNN testing: 35 | a) To test HIRE SNN: select folder/file location: HIRE_SNN_models/vgg11_cifar100_hiresnn_tstep8_model.pt (as pretrained_snn) 36 | and 37 | HIRE_SNN_models/vgg11_cifar100_hiresnn_tstep8_bb_test_model.pt as pretrained_snn_bb(this is for black box testing only) 38 | and edit in the run_snn_hire_test.py file. 1. b) run command: python run_snn_hire_test.py 39 | 40 | #### 2. Traditional SNN testing: 41 | a) To test traditional SNN: select folder/file location: traditional_models/vgg11_cifar100_tradit_model.pt as pretrained_snn 42 | and 43 | traditional_models/vgg11_cifar100_tradit_bb_test_model.pt as pretrained_snn_bb (this is for black box testing only) 44 | 2. b) run command: python run_snn_tradit_test.py 45 | 46 | ### Cite this work 47 | If you find this project useful to you, please cite our work: 48 | 49 | @InProceedings{Kundu_2021_ICCV, 50 | author = {Kundu, Souvik and Pedram, Massoud and Beerel, Peter A.}, 51 | title = {HIRE-SNN: Harnessing the Inherent Robustness of Energy-Efficient Deep Spiking Neural Networks by Training With Crafted Input Noise}, 52 | booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)}, 53 | month = {October}, 54 | year = {2021}, 55 | pages = {5209-5218}} 56 | 57 | ### Acknowledgment 58 | [Hybrid SNN repo](https://github.com/nitin-rathi/hybrid-snn-conversion) 59 | 60 | -------------------------------------------------------------------------------- /attack_model_spike_cnt.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | import torch.nn.functional as F 4 | import torch 5 | import copy 6 | import numpy as np 7 | 8 | class Attack(object): 9 | 10 | def __init__(self, dataloader, criterion=None, gpu_id=0, 11 | epsilon=0.031, attack_method='pgd'): 12 | 13 | if criterion is not None: 14 | self.criterion = criterion 15 | else: 16 | self.criterion = nn.CrossEntropyLoss() 17 | 18 | self.dataloader = dataloader 19 | self.epsilon = epsilon 20 | self.gpu_id = gpu_id #this is integer 21 | 22 | if attack_method == 'fgsm': 23 | self.attack_method = self.fgsm 24 | elif attack_method == 'pgd': 25 | self.attack_method = self.pgd 26 | 27 | def update_params(self, epsilon=None, dataloader=None, attack_method=None): 28 | if epsilon is not None: 29 | self.epsilon = epsilon 30 | if dataloader is not None: 31 | self.dataloader = dataloader 32 | 33 | if attack_method is not None: 34 | if attack_method == 'fgsm': 35 | self.attack_method = self.fgsm 36 | elif attack_method == 'pgd': 37 | self.attack_method = self.pgd 38 | 39 | ## For SNN pgd takes two more args: mean and std to manually perform normalization for 40 | ## each of the k iterated perturbed data generated intermediately. 41 | def fgsm(self, model, data, target, args, data_min=0, data_max=1): 42 | 43 | if args.dataset == 'CIFAR10': 44 | mean = torch.Tensor(np.array([0.4914, 0.4822, 0.4465])[:, np.newaxis, np.newaxis]) 45 | mean = mean.expand(3, 32, 32).cuda() 46 | std = torch.Tensor(np.array([0.2023, 0.1994, 0.2010])[:, np.newaxis, np.newaxis]) 47 | std = std.expand(3, 32, 32).cuda() 48 | if args.dataset == 'CIFAR100': 49 | mean = torch.Tensor(np.array([0.5071,0.4867,0.4408])[:, np.newaxis, np.newaxis]) 50 | mean = mean.expand(3, 32, 32).cuda() 51 | std = torch.Tensor(np.array([0.2675,0.2565,0.2761])[:, np.newaxis, np.newaxis]) 52 | std = std.expand(3, 32, 32).cuda() 53 | 54 | model.eval() 55 | # perturbed_data = copy.deepcopy(data) 56 | perturbed_data = data.clone() 57 | 58 | perturbed_data.requires_grad = True 59 | #As we take the raw un-normalized data, we convert to a normalized data 60 | # and then feed to model 61 | perturbed_data_norm = perturbed_data -mean 62 | perturbed_data_norm.div_(std) 63 | output,_ = model(perturbed_data_norm) 64 | #print('perturbed_data.requires_grad:', perturbed_data.requires_grad) 65 | loss = F.cross_entropy(output, target) 66 | if perturbed_data.grad is not None: 67 | perturbed_data.grad.data.zero_() 68 | 69 | loss.backward() 70 | 71 | # Collect the element-wise sign of the data gradient 72 | sign_data_grad = perturbed_data.grad.data.sign() 73 | perturbed_data.requires_grad = False 74 | 75 | with torch.no_grad(): 76 | # Create the perturbed image by adjusting each pixel of the input image 77 | perturbed_data += self.epsilon*sign_data_grad 78 | # Adding clipping to maintain [min,max] range, default 0,1 for image 79 | perturbed_data.clamp_(data_min, data_max) 80 | 81 | return perturbed_data 82 | 83 | ## For SNN pgd takes two more args: mean and std to manually perform normalization for 84 | ## each of the k iterated perturbed data generated intermediately. 85 | def pgd(self, model, data, target, k=7, a=0.01, random_start=True, 86 | d_min=0, d_max=1): #to reduce time for SNN kept k = 3, or else for ANN we use k=7 87 | 88 | mean = torch.Tensor(np.array([0.4914, 0.4822, 0.4465])[:, np.newaxis, np.newaxis]) 89 | mean = mean.expand(3, 32, 32).cuda() 90 | std = torch.Tensor(np.array([0.2023, 0.1994, 0.2010])[:, np.newaxis, np.newaxis]) 91 | std = std.expand(3, 32, 32).cuda() 92 | 93 | model.eval() 94 | # perturbed_data = copy.deepcopy(data) 95 | perturbed_data = data.clone() 96 | perturbed_data.requires_grad = True 97 | 98 | data_max = data + self.epsilon 99 | data_min = data - self.epsilon 100 | data_max.clamp_(d_min, d_max) 101 | data_min.clamp_(d_min, d_max) 102 | 103 | if random_start: 104 | with torch.no_grad(): 105 | perturbed_data.data = data + perturbed_data.uniform_(-1*self.epsilon, self.epsilon) 106 | perturbed_data.data.clamp_(d_min, d_max) 107 | 108 | for _ in range(k): 109 | ##for SNNs we don't have a mean, std layer separately, so we manually do mean 110 | ## subtraction here with every perturbed data generated 111 | 112 | in1 = perturbed_data - mean 113 | in1.div_(std) 114 | output,_ = model( in1 ) 115 | #print('output shape:{}, target shape:{}', output.shape, target.shape) 116 | loss = F.cross_entropy(output, target) 117 | 118 | if perturbed_data.grad is not None: 119 | perturbed_data.grad.data.zero_() 120 | 121 | loss.backward() 122 | data_grad = perturbed_data.grad.data 123 | 124 | with torch.no_grad(): 125 | perturbed_data.data += a * torch.sign(data_grad) 126 | perturbed_data.data = torch.max(torch.min(perturbed_data, data_max), 127 | data_min) 128 | perturbed_data.requires_grad = False 129 | 130 | return perturbed_data 131 | -------------------------------------------------------------------------------- /vgg_spiking_nodewise.py: -------------------------------------------------------------------------------- 1 | #--------------------------------------------------- 2 | # Imports 3 | #--------------------------------------------------- 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import numpy as np 8 | import pdb 9 | import math 10 | from collections import OrderedDict 11 | from matplotlib import pyplot as plt 12 | import copy 13 | 14 | cfg = { 15 | 'VGG4' : [64, 'A', 128, 'A'], 16 | 'VGG5' : [64, 'A', 128, 128, 'A'], 17 | 'VGG6' : [64, 'A', 128, 128, 'A'], 18 | 'VGG9': [64, 'A', 128, 256, 'A', 256, 512, 'A', 512, 'A', 512], 19 | 'VGG11': [64, 'A', 128, 256, 'A', 512, 512, 512, 'A', 512, 512], 20 | 'VGG13': [64, 64, 'A', 128, 128, 'A', 256, 256, 'A', 512, 512, 512, 'A', 512], 21 | 'VGG16': [64, 64, 'A', 128, 128, 'A', 256, 256, 256, 'A', 512, 512, 512, 'A', 512, 512, 512], 22 | 'VGG19': [64, 64, 'A', 128, 128, 'A', 256, 256, 256, 256, 'A', 512, 512, 512, 512, 'A', 512, 512, 512, 512] 23 | } 24 | 25 | class LinearSpike(torch.autograd.Function): 26 | """ 27 | Here we use the piecewise-linear surrogate gradient as was done 28 | in Bellec et al. (2018). 29 | """ 30 | gamma = 0.3 # Controls the dampening of the piecewise-linear surrogate gradient 31 | 32 | @staticmethod 33 | def forward(ctx, input, last_spike): 34 | 35 | ctx.save_for_backward(input) 36 | out = torch.zeros_like(input).cuda() 37 | out[input > 0] = 1.0 38 | return out 39 | 40 | @staticmethod 41 | def backward(ctx, grad_output): 42 | 43 | input, = ctx.saved_tensors 44 | grad_input = grad_output.clone() 45 | grad = LinearSpike.gamma*F.threshold(1.0-torch.abs(input), 0, 0) 46 | return grad*grad_input, None 47 | 48 | class VGG_SNN(nn.Module): 49 | 50 | def __init__(self, vgg_name, activation='Linear', labels=10, timesteps=100, leak=1.0, default_threshold = 1.0, dropout=0.2, kernel_size=3, dataset='CIFAR10'): 51 | super().__init__() 52 | 53 | self.vgg_name = vgg_name 54 | self.act_func = LinearSpike.apply 55 | self.labels = labels 56 | self.timesteps = timesteps 57 | self.dropout = dropout 58 | self.kernel_size = kernel_size 59 | self.dataset = dataset 60 | self.mem = {} 61 | self.mask = {} 62 | self.spike = {} 63 | 64 | self.features, self.classifier = self._make_layers(cfg[self.vgg_name]) 65 | 66 | self._initialize_weights2() 67 | 68 | threshold = {} 69 | lk = {} 70 | for l in range(len(self.features)): 71 | if isinstance(self.features[l], nn.Conv2d): 72 | threshold['t'+str(l)] = nn.Parameter(torch.tensor(default_threshold)) 73 | lk['l'+str(l)] = nn.Parameter(torch.tensor(leak)) 74 | 75 | 76 | prev = len(self.features) 77 | for l in range(len(self.classifier)-1): 78 | if isinstance(self.classifier[l], nn.Linear): 79 | threshold['t'+str(prev+l)] = nn.Parameter(torch.tensor(default_threshold)) 80 | lk['l'+str(prev+l)] = nn.Parameter(torch.tensor(leak)) 81 | 82 | self.threshold = nn.ParameterDict(threshold) 83 | self.leak = nn.ParameterDict(lk) 84 | 85 | def _initialize_weights2(self): 86 | for m in self.modules(): 87 | 88 | if isinstance(m, nn.Conv2d): 89 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 90 | m.weight.data.normal_(0, math.sqrt(2. / n)) 91 | if m.bias is not None: 92 | m.bias.data.zero_() 93 | elif isinstance(m, nn.BatchNorm2d): 94 | m.weight.data.fill_(1) 95 | m.bias.data.zero_() 96 | elif isinstance(m, nn.Linear): 97 | n = m.weight.size(1) 98 | m.weight.data.normal_(0, 0.01) 99 | if m.bias is not None: 100 | m.bias.data.zero_() 101 | 102 | def threshold_update(self, scaling_factor=1.0, thresholds=[]): 103 | 104 | # Initialize thresholds 105 | self.scaling_factor = scaling_factor 106 | 107 | for pos in range(len(self.features)): 108 | if isinstance(self.features[pos], nn.Conv2d): 109 | if thresholds: 110 | self.threshold.update({'t'+str(pos): nn.Parameter(torch.tensor(thresholds.pop(0)*self.scaling_factor))}) 111 | #print('\t Layer{} : {:.2f}'.format(pos, self.threshold[pos])) 112 | 113 | prev = len(self.features) 114 | 115 | for pos in range(len(self.classifier)-1): 116 | if isinstance(self.classifier[pos], nn.Linear): 117 | if thresholds: 118 | self.threshold.update({'t'+str(prev+pos): nn.Parameter(torch.tensor(thresholds.pop(0)*self.scaling_factor))}) 119 | #print('\t Layer{} : {:.2f}'.format(prev+pos, self.threshold[prev+pos])) 120 | 121 | 122 | def _make_layers(self, cfg): 123 | layers = [] 124 | if self.dataset =='MNIST': 125 | in_channels = 1 126 | else: 127 | in_channels = 3 128 | 129 | for x in (cfg): 130 | stride = 1 131 | 132 | if x == 'A': 133 | layers.pop() 134 | layers += [nn.AvgPool2d(kernel_size=2, stride=2)] 135 | 136 | else: 137 | layers += [nn.Conv2d(in_channels, x, kernel_size=self.kernel_size, padding=(self.kernel_size-1)//2, stride=stride, bias=False), 138 | nn.ReLU(inplace=True) 139 | ] 140 | layers += [nn.Dropout(self.dropout)] 141 | in_channels = x 142 | 143 | features = nn.Sequential(*layers) 144 | 145 | layers = [] 146 | if self.dataset == 'IMAGENET': 147 | layers += [nn.Linear(512*7*7, 4096, bias=False)] 148 | layers += [nn.ReLU(inplace=True)] 149 | layers += [nn.Dropout(self.dropout)] 150 | layers += [nn.Linear(4096, 4096, bias=False)] 151 | layers += [nn.ReLU(inplace=True)] 152 | layers += [nn.Dropout(self.dropout)] 153 | layers += [nn.Linear(4096, self.labels, bias=False)] 154 | 155 | elif self.vgg_name == 'VGG6' and self.dataset != 'MNIST': 156 | layers += [nn.Linear(512*4*4, 4096, bias=False)] 157 | layers += [nn.ReLU(inplace=True)] 158 | layers += [nn.Dropout(self.dropout)] 159 | layers += [nn.Linear(4096, 4096, bias=False)] 160 | layers += [nn.ReLU(inplace=True)] 161 | layers += [nn.Dropout(self.dropout)] 162 | layers += [nn.Linear(4096, self.labels, bias=False)] 163 | 164 | elif self.vgg_name == 'VGG4' and self.dataset== 'MNIST': 165 | layers += [nn.Linear(128*7*7, 1024, bias=False)] 166 | layers += [nn.ReLU(inplace=True)] 167 | layers += [nn.Dropout(self.dropout)] 168 | #layers += [nn.Linear(4096, 4096, bias=False)] 169 | #layers += [nn.ReLU(inplace=True)] 170 | #layers += [nn.Dropout(self.dropout)] 171 | layers += [nn.Linear(1024, self.labels, bias=False)] 172 | 173 | elif self.vgg_name == 'VGG16' and self.dataset != 'MNIST': 174 | layers += [nn.Linear(512*2*2, 4096, bias=False)] 175 | layers += [nn.ReLU(inplace=True)] 176 | layers += [nn.Dropout(self.dropout)] 177 | layers += [nn.Linear(4096, 4096, bias=False)] 178 | layers += [nn.ReLU(inplace=True)] 179 | layers += [nn.Dropout(self.dropout)] 180 | layers += [nn.Linear(4096, self.labels, bias=False)] 181 | 182 | elif (self.vgg_name == 'VGG5' or self.vgg_name == 'VGG11') and self.dataset != 'MNIST': 183 | layers += [nn.Linear(512*2*2*4, 4096, bias=False)] 184 | layers += [nn.ReLU(inplace=True)] 185 | layers += [nn.Dropout(self.dropout)] 186 | layers += [nn.Linear(4096, 4096, bias=False)] 187 | layers += [nn.ReLU(inplace=True)] 188 | layers += [nn.Dropout(self.dropout)] 189 | layers += [nn.Linear(4096, self.labels, bias=False)] 190 | 191 | elif self.vgg_name == 'VGG6' and self.dataset == 'MNIST': 192 | layers += [nn.Linear(128*7*7, 4096, bias=False)] 193 | layers += [nn.ReLU(inplace=True)] 194 | layers += [nn.Dropout(self.dropout)] 195 | layers += [nn.Linear(4096, 4096, bias=False)] 196 | layers += [nn.ReLU(inplace=True)] 197 | layers += [nn.Dropout(self.dropout)] 198 | layers += [nn.Linear(4096, self.labels, bias=False)] 199 | 200 | elif self.vgg_name != 'VGG6' and self.dataset == 'MNIST': 201 | layers += [nn.Linear(512*1*1, 4096, bias=False)] 202 | layers += [nn.ReLU(inplace=True)] 203 | layers += [nn.Dropout(self.dropout)] 204 | layers += [nn.Linear(4096, 4096, bias=False)] 205 | layers += [nn.ReLU(inplace=True)] 206 | layers += [nn.Dropout(self.dropout)] 207 | layers += [nn.Linear(4096, self.labels, bias=False)] 208 | 209 | 210 | classifer = nn.Sequential(*layers) 211 | return (features, classifer) 212 | 213 | def network_update(self, timesteps, leak): 214 | self.timesteps = timesteps 215 | 216 | def timestep_update(self, timesteps): 217 | self.timesteps = timesteps 218 | 219 | def neuron_init(self, x): 220 | self.batch_size = x.size(0) 221 | self.width = x.size(2) 222 | self.height = x.size(3) 223 | 224 | self.mem = {} 225 | self.spike = {} 226 | self.mask = {} 227 | self.spike_count= {} 228 | 229 | for l in range(len(self.features)): 230 | 231 | if isinstance(self.features[l], nn.Conv2d): 232 | self.mem[l] = torch.zeros(self.batch_size, self.features[l].out_channels, self.width, self.height) 233 | #this is to just keep track what happens to spike after conv 234 | 235 | 236 | 237 | elif isinstance(self.features[l], nn.ReLU): 238 | if isinstance(self.features[l-1], nn.Conv2d): 239 | self.spike[l] = torch.ones(self.mem[l-1].shape)*(-1000) 240 | self.spike_count[l] = torch.zeros(self.mem[l-1].size()) 241 | 242 | elif isinstance(self.features[l-1], nn.AvgPool2d): 243 | self.spike[l] = torch.ones(self.batch_size, self.features[l-2].out_channels, self.width, self.height)*(-1000) 244 | 245 | 246 | elif isinstance(self.features[l], nn.Dropout): 247 | self.mask[l] = self.features[l](torch.ones(self.mem[l-2].shape).cuda()) 248 | 249 | elif isinstance(self.features[l], nn.AvgPool2d): 250 | self.width = self.width//self.features[l].kernel_size 251 | self.height = self.height//self.features[l].kernel_size 252 | self.spike_count[l] = torch.zeros(self.batch_size, self.features[l-2].out_channels,self.width,self.height) 253 | 254 | 255 | prev = len(self.features) 256 | 257 | for l in range(len(self.classifier)): 258 | 259 | if isinstance(self.classifier[l], nn.Linear): 260 | self.mem[prev+l] = torch.zeros(self.batch_size, self.classifier[l].out_features) 261 | 262 | 263 | elif isinstance(self.classifier[l], nn.ReLU): 264 | self.spike[prev+l] = torch.ones(self.mem[prev+l-1].shape)*(-1000) 265 | self.spike_count[prev+l] = torch.zeros(self.mem[prev+l-1].size()) 266 | 267 | elif isinstance(self.classifier[l], nn.Dropout): 268 | self.mask[prev+l] = self.classifier[l](torch.ones(self.mem[prev+l-2].shape).cuda()) 269 | 270 | 271 | def percentile(self, t, q): 272 | 273 | k = 1 + round(.01 * float(q) * (t.numel() - 1)) 274 | result = t.view(-1).kthvalue(k).values.item() 275 | return result 276 | 277 | def forward(self, x, find_max_mem=False, spike_count=[], max_mem_layer=0): 278 | 279 | self.neuron_init(x) 280 | max_mem=0.0 281 | 282 | for t in range(self.timesteps): 283 | out_prev = x 284 | 285 | for l in range(len(self.features)): 286 | 287 | if isinstance(self.features[l], (nn.Conv2d)): 288 | 289 | if find_max_mem and l==max_mem_layer: 290 | cur = self.percentile(self.features[l](out_prev).view(-1), 99.7) 291 | if (cur>max_mem): 292 | max_mem = torch.tensor([cur]) 293 | break 294 | 295 | delta_mem = self.features[l](out_prev) 296 | self.mem[l] = getattr(self.leak, 'l'+str(l)) *self.mem[l] + delta_mem 297 | mem_thr = (self.mem[l]/getattr(self.threshold, 't'+str(l))) - 1.0 298 | rst = getattr(self.threshold, 't'+str(l)) * (mem_thr>0).float() 299 | self.mem[l] = self.mem[l]-rst 300 | 301 | elif isinstance(self.features[l], nn.ReLU): 302 | 303 | out = self.act_func(mem_thr, (t-1-self.spike[l])) 304 | self.spike[l] = self.spike[l].masked_fill(out.bool(),t-1) 305 | self.spike_count[l][out.bool()] = self.spike_count[l][out.bool()] + 1 306 | out_prev = out.clone() 307 | 308 | elif isinstance(self.features[l], nn.AvgPool2d): 309 | out_prev = self.features[l](out_prev) 310 | self.spike_count[l][out_prev.bool()] = self.spike_count[l][out_prev.bool()] + 1 311 | 312 | 313 | elif isinstance(self.features[l], nn.Dropout): 314 | out_prev = out_prev * self.mask[l] 315 | 316 | if find_max_mem and max_mem_layermax_mem: 329 | max_mem = torch.tensor([cur]) 330 | break 331 | 332 | delta_mem = self.classifier[l](out_prev) 333 | self.mem[prev+l] = getattr(self.leak, 'l'+str(prev+l)) * self.mem[prev+l] + delta_mem 334 | mem_thr = (self.mem[prev+l]/getattr(self.threshold, 't'+str(prev+l))) - 1.0 335 | rst = getattr(self.threshold,'t'+str(prev+l)) * (mem_thr>0).float() 336 | self.mem[prev+l] = self.mem[prev+l]-rst 337 | 338 | 339 | elif isinstance(self.classifier[l], nn.ReLU): 340 | out = self.act_func(mem_thr, (t-1-self.spike[prev+l])) 341 | self.spike[prev+l] = self.spike[prev+l].masked_fill(out.bool(),t-1) 342 | self.spike_count[prev+l][out.bool()] = self.spike_count[prev+l][out.bool()] + 1 343 | out_prev = out.clone() 344 | 345 | elif isinstance(self.classifier[l], nn.Dropout): 346 | out_prev = out_prev * self.mask[prev+l] 347 | 348 | # Compute the classification layer outputs 349 | if not find_max_mem: 350 | self.mem[prev+l+1] = self.mem[prev+l+1] + self.classifier[l+1](out_prev) 351 | if find_max_mem: 352 | return max_mem 353 | 354 | return self.mem[prev+l+1], self.spike_count 355 | 356 | 357 | 358 | -------------------------------------------------------------------------------- /snn_free_nonorm_for_test_purpose.py: -------------------------------------------------------------------------------- 1 | #--------------------------------------------------- 2 | # Imports 3 | #--------------------------------------------------- 4 | from __future__ import print_function 5 | import argparse 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torch.optim as optim 10 | from torchvision import datasets, transforms, models 11 | from torch.utils.data.dataloader import DataLoader 12 | from torch.autograd import Variable 13 | from matplotlib import pyplot as plt 14 | from matplotlib.gridspec import GridSpec 15 | import numpy as np 16 | import datetime 17 | import pdb 18 | from vgg_spiking_nodewise import * 19 | import sys 20 | import os 21 | import shutil 22 | import argparse 23 | from attack_model_spike_cnt import Attack 24 | 25 | 26 | class AverageMeter(object): 27 | """Computes and stores the average and current value""" 28 | def __init__(self, name, fmt=':f'): 29 | self.name = name 30 | self.fmt = fmt 31 | self.reset() 32 | 33 | def reset(self): 34 | self.val = 0 35 | self.avg = 0 36 | self.sum = 0 37 | self.count = 0 38 | 39 | def update(self, val, n=1): 40 | self.val = val 41 | self.sum += val * n 42 | self.count += n 43 | self.avg = self.sum / self.count 44 | 45 | def __str__(self): 46 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 47 | return fmtstr.format(**self.__dict__) 48 | 49 | def accuracy(output, target, topk=(1,)): 50 | """Computes the accuracy over the k top predictions for the specified values of k""" 51 | with torch.no_grad(): 52 | maxk = max(topk) 53 | batch_size = target.size(0) 54 | 55 | _, pred = output.topk(maxk, 1, True, True) 56 | pred = pred.t() 57 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 58 | 59 | res = [] 60 | for k in topk: 61 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 62 | res.append(correct_k.mul_(100.0 / batch_size)) 63 | return res 64 | 65 | def test(epoch, test_loader, best_Acc, args): 66 | 67 | losses = AverageMeter('Loss') 68 | top1 = AverageMeter('Acc@1') 69 | avg_spike_cnt = [] 70 | 71 | if args.dataset == 'CIFAR10': 72 | mean = torch.Tensor(np.array([0.4914, 0.4822, 0.4465])[:, np.newaxis, np.newaxis]) 73 | mean = mean.expand(3, 32, 32).cuda() 74 | std = torch.Tensor(np.array([0.2023, 0.1994, 0.2010])[:, np.newaxis, np.newaxis]) 75 | std = std.expand(3, 32, 32).cuda() 76 | if args.dataset == 'CIFAR100': 77 | mean = torch.Tensor(np.array([0.5071,0.4867,0.4408])[:, np.newaxis, np.newaxis]) 78 | mean = mean.expand(3, 32, 32).cuda() 79 | std = torch.Tensor(np.array([0.2675,0.2565,0.2761])[:, np.newaxis, np.newaxis]) 80 | std = std.expand(3, 32, 32).cuda() 81 | 82 | if args.test_only: 83 | temp1 = [] 84 | temp2 = [] 85 | for key, value in sorted(model.threshold.items(), key=lambda x: (int(x[0][1:]), (x[1]))): 86 | temp1 = temp1+[round(value.item(),2)] 87 | for key, value in sorted(model.leak.items(), key=lambda x: (int(x[0][1:]), (x[1]))): 88 | temp2 = temp2+[round(value.item(),2)] 89 | f.write('\n Thresholds: {}, leak: {}'.format(temp1, temp2)) 90 | 91 | with torch.no_grad(): 92 | model.eval() 93 | global max_accuracy 94 | 95 | for batch_idx, (data, target) in enumerate(test_loader): 96 | 97 | if torch.cuda.is_available() and args.gpu: 98 | data, target = data.cuda(), target.cuda() 99 | 100 | data = data - mean 101 | data.div_(std) 102 | output, spike_count = model(data) 103 | loss = F.cross_entropy(output,target) 104 | pred = output.max(1,keepdim=True)[1] 105 | correct = pred.eq(target.data.view_as(pred)).cpu().sum() 106 | 107 | losses.update(loss.item(),data.size(0)) 108 | top1.update(correct.item()/data.size(0), data.size(0)) 109 | 110 | if test_acc_every_batch: 111 | 112 | f.write('\n Images {}/{} Accuracy: {}/{}({:.4f})' 113 | .format( 114 | test_loader.batch_size*(batch_idx+1), 115 | len(test_loader.dataset), 116 | correct.item(), 117 | data.size(0), 118 | top1.avg*100 119 | ) 120 | ) 121 | 122 | temp1 = [] 123 | temp2 = [] 124 | for key, value in sorted(model.threshold.items(), key=lambda x: (int(x[0][1:]), (x[1]))): 125 | temp1 = temp1+[value.item()] 126 | for key, value in sorted(model.leak.items(), key=lambda x: (int(x[0][1:]), (x[1]))): 127 | temp2 = temp2+[value.item()] 128 | 129 | if epoch>5 and top1.avg<0.15: 130 | f.write('\n Quitting as the training is not progressing') 131 | exit(0) 132 | 133 | f.write(' test_loss: {:.4f}, test_acc: {:.4f}, time: {}' 134 | .format( 135 | losses.avg, 136 | top1.avg*100, 137 | datetime.timedelta(seconds=(datetime.datetime.now() - start_time).seconds) 138 | ) 139 | ) 140 | best_Acc.append(top1.avg*100) 141 | return top1.avg, best_Acc 142 | 143 | 144 | def validate_fgsm(val_loader, model, args, eps=0.031): 145 | if args.dataset == 'CIFAR10': 146 | mean = torch.Tensor(np.array([0.4914, 0.4822, 0.4465])[:, np.newaxis, np.newaxis]) 147 | mean = mean.expand(3, 32, 32).cuda() 148 | std = torch.Tensor(np.array([0.2023, 0.1994, 0.2010])[:, np.newaxis, np.newaxis]) 149 | std = std.expand(3, 32, 32).cuda() 150 | if args.dataset == 'CIFAR100': 151 | mean = torch.Tensor(np.array([0.5071,0.4867,0.4408])[:, np.newaxis, np.newaxis]) 152 | mean = mean.expand(3, 32, 32).cuda() 153 | std = torch.Tensor(np.array([0.2675,0.2565,0.2761])[:, np.newaxis, np.newaxis]) 154 | std = std.expand(3, 32, 32).cuda() 155 | losses = AverageMeter('Loss') 156 | prec1_fgsm = 0.0 157 | prec5_fgsm = 0.0 158 | n = 0 159 | model.eval() 160 | attacker = Attack(dataloader=val_loader, 161 | attack_method='fgsm', epsilon=0.031) 162 | for i, (data, target) in enumerate(val_loader): 163 | if torch.cuda.is_available() and args.gpu: 164 | data, target = data.cuda(), target.cuda() 165 | n += target.size(0) 166 | data.requires_grad = False 167 | perturbed_data = attacker.attack_method(model, data, target, args) 168 | perturbed_data.sub_(mean).div_(std) 169 | output_fgsm,_ = model(perturbed_data) 170 | loss_fgsm = F.cross_entropy(output_fgsm, target) 171 | _, pred_fgsm = output_fgsm.topk(5, 1, largest=True, sorted=True) 172 | target_fgsm = target.view(target.size(0),-1).expand_as(pred_fgsm) 173 | correct_fgsm = pred_fgsm.eq(target_fgsm).float() 174 | prec1_fgsm += correct_fgsm[:,:1].sum() 175 | prec5_fgsm += correct_fgsm[:,:5].sum() 176 | losses.update(loss_fgsm.item(), data.size(0)) 177 | 178 | top1_fgsm = 100.*(prec1_fgsm/float(n)) 179 | top5_fgsm = 100.*(prec5_fgsm/float(n)) 180 | print('\n Top1 FGSM:{}'.format(top1_fgsm)) 181 | return top1_fgsm 182 | 183 | 184 | def validate_pgd(val_loader, model, eps=0.031, K=7, a=0.01): 185 | if args.dataset == 'CIFAR10': 186 | mean = torch.Tensor(np.array([0.4914, 0.4822, 0.4465])[:, np.newaxis, np.newaxis]) 187 | mean = mean.expand(3, 32, 32).cuda() 188 | std = torch.Tensor(np.array([0.2023, 0.1994, 0.2010])[:, np.newaxis, np.newaxis]) 189 | std = std.expand(3, 32, 32).cuda() 190 | if args.dataset == 'CIFAR100': 191 | mean = torch.Tensor(np.array([0.5071,0.4867,0.4408])[:, np.newaxis, np.newaxis]) 192 | mean = mean.expand(3, 32, 32).cuda() 193 | std = torch.Tensor(np.array([0.2675,0.2565,0.2761])[:, np.newaxis, np.newaxis]) 194 | std = std.expand(3, 32, 32).cuda() 195 | losses = AverageMeter('Loss') 196 | prec1_pgd = 0.0 197 | prec5_pgd = 0.0 198 | n = 0 199 | model.eval() 200 | print('Value of K:{}, eps:{}'.format(K, eps)) 201 | for i, (data, target) in enumerate(val_loader): 202 | if torch.cuda.is_available() and args.gpu: 203 | data, target = data.cuda(), target.cuda() 204 | n += target.size(0) 205 | orig_input = data.clone() 206 | randn = torch.FloatTensor(data.size()).uniform_(-eps, eps).cuda() 207 | data += randn 208 | data.clamp_(0, 1.0) 209 | for _ in range(K): 210 | invar = Variable(data, requires_grad=True) 211 | in1 = invar - mean 212 | in1.div_(std) 213 | output,_ = model(in1) 214 | ascend_loss = F.cross_entropy(output, target) 215 | ascend_grad = torch.autograd.grad(ascend_loss, invar)[0] 216 | pert = torch.sign(ascend_grad)*a 217 | data += pert.data 218 | data = torch.max(orig_input-eps, data) 219 | data = torch.min(orig_input+eps, data) 220 | data.clamp_(0, 1.0) 221 | data.sub_(mean).div_(std) 222 | with torch.no_grad(): 223 | # compute output 224 | output,_ = model(data) 225 | loss_pgd = F.cross_entropy(output, target) 226 | 227 | # measure accuracy and record loss 228 | _, pred_pgd = output.topk(5, 1, largest=True, sorted=True) 229 | target_pgd = target.view(target.size(0),-1).expand_as(pred_pgd) 230 | correct_pgd = pred_pgd.eq(target_pgd).float() 231 | prec1_pgd += correct_pgd[:,:1].sum() 232 | prec5_pgd += correct_pgd[:,:5].sum() 233 | losses.update(loss_pgd.item(), data.size(0)) 234 | 235 | top1_pgd = 100.*(prec1_pgd/float(n)) 236 | top5_pgd = 100.*(prec5_pgd/float(n)) 237 | print('\n Top1 PGD:{}'.format(top1_pgd)) 238 | return top1_pgd 239 | 240 | 241 | if __name__ == '__main__': 242 | parser = argparse.ArgumentParser(description='SNN training') 243 | parser.add_argument('--gpu', default=True, type=bool, help='use gpu') 244 | parser.add_argument('--dataset', default='CIFAR10', type=str, help='dataset name', choices=['MNIST','CIFAR10','CIFAR100','IMAGENET', 'TINY_IMAGENET']) 245 | parser.add_argument('--batch_size', default=64, type=int, help='minibatch size') 246 | parser.add_argument('-a','--architecture', default='VGG16', type=str, help='network architecture', choices=['VGG4','VGG5','VGG6','VGG9','VGG11','VGG13','VGG16','VGG19','RESNET12','RESNET20','RESNET34']) 247 | parser.add_argument('--pretrained_snn', default='', type=str, help='pretrained SNN for inference') 248 | parser.add_argument('--test_only', action='store_true', help='perform only inference') 249 | parser.add_argument('--log', action='store_true', help='to print the output on terminal or to log file') 250 | parser.add_argument('--pgd_iter', default=7, type=int, help='number of pgd iterations') 251 | parser.add_argument('--pgd_step', default=0.01, type=float, help='pgd attack step size') 252 | parser.add_argument('--epochs', default=30, type=int, help='number of training epochs') 253 | parser.add_argument('--timesteps', default=20, type=int, help='simulation timesteps') 254 | parser.add_argument('--leak', default=1.0, type=float, help='membrane leak') 255 | parser.add_argument('--default_threshold', default=1.0, type=float, help='intial threshold to train SNN from scratch') 256 | parser.add_argument('--activation', default='Linear', type=str, help='SNN activation function', choices=['Linear']) 257 | parser.add_argument('--optimizer', default='SGD', type=str, help='optimizer for SNN backpropagation', choices=['SGD', 'Adam']) 258 | parser.add_argument('--kernel_size', default=3, type=int, help='filter size for the conv layers') 259 | parser.add_argument('--test_acc_every_batch', action='store_true', help='print acc of every batch during inference') 260 | parser.add_argument('--devices', default='0', type=str, help='list of gpu device(s)') 261 | parser.add_argument('--resume', default='', type=str, help='resume training from this state') 262 | parser.add_argument('--dont_save', action='store_true', help='don\'t save training model during testing') 263 | 264 | args = parser.parse_args() 265 | 266 | os.environ['CUDA_VISIBLE_DEVICES'] = args.devices 267 | 268 | #torch.backends.cudnn.deterministic = True 269 | #torch.backends.cudnn.benchmark = False 270 | 271 | dataset = args.dataset 272 | batch_size = args.batch_size 273 | architecture = args.architecture 274 | pretrained_snn = args.pretrained_snn 275 | epochs = args.epochs 276 | timesteps = args.timesteps 277 | leak = args.leak 278 | default_threshold = args.default_threshold 279 | activation = args.activation 280 | kernel_size = args.kernel_size 281 | test_acc_every_batch= args.test_acc_every_batch 282 | resume = args.resume 283 | start_epoch = 1 284 | max_accuracy = 0.0 285 | 286 | log_file = './logs/' 287 | try: 288 | os.mkdir(log_file) 289 | except OSError: 290 | pass 291 | identifier = 'snn_'+architecture.lower()+'_'+dataset.lower()+'_'+str(timesteps)+'_'+str(datetime.datetime.now()) 292 | log_file+=identifier+'.log' 293 | 294 | if args.log: 295 | f = open(log_file, 'w', buffering=1) 296 | else: 297 | f = sys.stdout 298 | 299 | if torch.cuda.is_available() and args.gpu: 300 | torch.set_default_tensor_type('torch.cuda.FloatTensor') 301 | 302 | if dataset == 'CIFAR10': 303 | normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) 304 | elif dataset == 'CIFAR100': 305 | normalize = transforms.Normalize((0.5071,0.4867,0.4408), (0.2675,0.2565,0.2761)) 306 | 307 | if dataset in ['CIFAR10', 'CIFAR100']: 308 | transform_train = transforms.Compose([ 309 | transforms.RandomCrop(32, padding=4), 310 | transforms.RandomHorizontalFlip(), 311 | transforms.ToTensor() 312 | ]) 313 | transform_test = transforms.Compose([transforms.ToTensor()]) 314 | 315 | if dataset == 'CIFAR10': 316 | trainset = datasets.CIFAR10(root = '../SNN_adversary/cifar_data', train = True, download = True, transform = transform_train) 317 | testset = datasets.CIFAR10(root='../SNN_adversary/cifar_data', train=False, download=True, transform = transform_test) 318 | labels = 10 319 | train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True) 320 | test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False) 321 | 322 | elif dataset == 'CIFAR100': 323 | trainset = datasets.CIFAR100(root = './cifar_data', train = True, download = True, transform = transform_train) 324 | testset = datasets.CIFAR100(root='./cifar_data', train=False, download=True, transform = transform_test) 325 | labels = 100 326 | train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True) 327 | test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False) 328 | 329 | if architecture[0:3].lower() == 'vgg': 330 | model = VGG_SNN(vgg_name = architecture, activation = activation, labels=labels, timesteps=timesteps, leak=leak, default_threshold=default_threshold, dropout=0.2, kernel_size=kernel_size, dataset=dataset) 331 | 332 | if pretrained_snn: 333 | 334 | state = torch.load(pretrained_snn, map_location='cpu') 335 | missing_keys, unexpected_keys = model.load_state_dict(state, strict=False) 336 | else: 337 | print("Please provide an snn file name") 338 | 339 | #model = nn.DataParallel(model) 340 | 341 | if torch.cuda.is_available() and args.gpu: 342 | model.cuda() 343 | 344 | if resume: 345 | f.write('\n Resuming from checkpoint {}'.format(resume)) 346 | state = torch.load(resume, map_location='cpu') 347 | missing_keys, unexpected_keys = model.load_state_dict(state, strict=False) 348 | f.write('\n Missing keys : {}, Unexpected Keys: {}'.format(missing_keys, unexpected_keys)) 349 | #f.write('\n Info: Accuracy of loaded ANN model: {}'.format(state['accuracy'])) 350 | 351 | #epoch = state['epoch'] 352 | start_epoch = epoch + 1 353 | #max_accuracy = state['accuracy'] 354 | #optimizer.load_state_dict(state['optimizer']) 355 | for param_group in optimizer.param_groups: 356 | learning_rate = param_group['lr'] 357 | 358 | f.write('\n Loaded from resume epoch: {}, accuracy: {:.4f} lr: {:.1e}'.format(epoch, max_accuracy, learning_rate)) 359 | 360 | test_Acc = [0] 361 | pgd_test_acc = [0] 362 | fgsmtest_acc = [0] 363 | for epoch in range(start_epoch, epochs+1): 364 | start_time = datetime.datetime.now() 365 | top1, test_Acc = test(epoch, test_loader, test_Acc, args) 366 | top1_fgsm = validate_fgsm(test_loader, model, args, eps=0.031) 367 | top1_pgd = validate_pgd(test_loader, model, eps=0.031, K=7, a=0.01) 368 | fgsmtest_acc.append(top1_fgsm) 369 | pgd_test_acc.append(top1_pgd) 370 | #print('Epoch:{}, TestAcc:{}, PGD acc:{}'.format(epoch, top1, top1_pgd)) 371 | print('Epoch:{}, TestAcc:{}, FGSM acc:{}, PGD acc:{}'.format(epoch, top1*100, top1_fgsm, top1_pgd)) 372 | 373 | #f.write('\n Highest accuracy: {:.4f}'.format(max_accuracy)) 374 | 375 | 376 | 377 | 378 | -------------------------------------------------------------------------------- /snn_free_nonorm_for_bbtest_purpose.py: -------------------------------------------------------------------------------- 1 | #--------------------------------------------------- 2 | # Imports 3 | #--------------------------------------------------- 4 | from __future__ import print_function 5 | import argparse 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torch.optim as optim 10 | from torchvision import datasets, transforms, models 11 | from torch.utils.data.dataloader import DataLoader 12 | from torch.autograd import Variable 13 | from matplotlib import pyplot as plt 14 | from matplotlib.gridspec import GridSpec 15 | import numpy as np 16 | import datetime 17 | import pdb 18 | from vgg_spiking_nodewise import * 19 | import sys 20 | import os 21 | import shutil 22 | import argparse 23 | from attack_model_spike_cnt import Attack 24 | 25 | 26 | class AverageMeter(object): 27 | """Computes and stores the average and current value""" 28 | def __init__(self, name, fmt=':f'): 29 | self.name = name 30 | self.fmt = fmt 31 | self.reset() 32 | 33 | def reset(self): 34 | self.val = 0 35 | self.avg = 0 36 | self.sum = 0 37 | self.count = 0 38 | 39 | def update(self, val, n=1): 40 | self.val = val 41 | self.sum += val * n 42 | self.count += n 43 | self.avg = self.sum / self.count 44 | 45 | def __str__(self): 46 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 47 | return fmtstr.format(**self.__dict__) 48 | 49 | def accuracy(output, target, topk=(1,)): 50 | """Computes the accuracy over the k top predictions for the specified values of k""" 51 | with torch.no_grad(): 52 | maxk = max(topk) 53 | batch_size = target.size(0) 54 | 55 | _, pred = output.topk(maxk, 1, True, True) 56 | pred = pred.t() 57 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 58 | 59 | res = [] 60 | for k in topk: 61 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 62 | res.append(correct_k.mul_(100.0 / batch_size)) 63 | return res 64 | 65 | def test(epoch, test_loader, best_Acc, args): 66 | 67 | losses = AverageMeter('Loss') 68 | top1 = AverageMeter('Acc@1') 69 | avg_spike_cnt = [] 70 | 71 | if args.dataset == 'CIFAR10': 72 | mean = torch.Tensor(np.array([0.4914, 0.4822, 0.4465])[:, np.newaxis, np.newaxis]) 73 | mean = mean.expand(3, 32, 32).cuda() 74 | std = torch.Tensor(np.array([0.2023, 0.1994, 0.2010])[:, np.newaxis, np.newaxis]) 75 | std = std.expand(3, 32, 32).cuda() 76 | if args.dataset == 'CIFAR100': 77 | mean = torch.Tensor(np.array([0.5071,0.4867,0.4408])[:, np.newaxis, np.newaxis]) 78 | mean = mean.expand(3, 32, 32).cuda() 79 | std = torch.Tensor(np.array([0.2675,0.2565,0.2761])[:, np.newaxis, np.newaxis]) 80 | std = std.expand(3, 32, 32).cuda() 81 | 82 | if args.test_only: 83 | temp1 = [] 84 | temp2 = [] 85 | for key, value in sorted(model.threshold.items(), key=lambda x: (int(x[0][1:]), (x[1]))): 86 | temp1 = temp1+[round(value.item(),2)] 87 | for key, value in sorted(model.leak.items(), key=lambda x: (int(x[0][1:]), (x[1]))): 88 | temp2 = temp2+[round(value.item(),2)] 89 | f.write('\n Thresholds: {}, leak: {}'.format(temp1, temp2)) 90 | 91 | with torch.no_grad(): 92 | model.eval() 93 | global max_accuracy 94 | 95 | for batch_idx, (data, target) in enumerate(test_loader): 96 | 97 | if torch.cuda.is_available() and args.gpu: 98 | data, target = data.cuda(), target.cuda() 99 | 100 | data = data - mean 101 | data.div_(std) 102 | output, spike_count = model(data) 103 | loss = F.cross_entropy(output,target) 104 | pred = output.max(1,keepdim=True)[1] 105 | correct = pred.eq(target.data.view_as(pred)).cpu().sum() 106 | 107 | losses.update(loss.item(),data.size(0)) 108 | top1.update(correct.item()/data.size(0), data.size(0)) 109 | 110 | if test_acc_every_batch: 111 | 112 | f.write('\n Images {}/{} Accuracy: {}/{}({:.4f})' 113 | .format( 114 | test_loader.batch_size*(batch_idx+1), 115 | len(test_loader.dataset), 116 | correct.item(), 117 | data.size(0), 118 | top1.avg*100 119 | ) 120 | ) 121 | 122 | temp1 = [] 123 | temp2 = [] 124 | for key, value in sorted(model.threshold.items(), key=lambda x: (int(x[0][1:]), (x[1]))): 125 | temp1 = temp1+[value.item()] 126 | for key, value in sorted(model.leak.items(), key=lambda x: (int(x[0][1:]), (x[1]))): 127 | temp2 = temp2+[value.item()] 128 | 129 | if epoch>5 and top1.avg<0.15: 130 | f.write('\n Quitting as the training is not progressing') 131 | exit(0) 132 | 133 | f.write(' test_loss: {:.4f}, test_acc: {:.4f}, time: {}' 134 | .format( 135 | losses.avg, 136 | top1.avg*100, 137 | datetime.timedelta(seconds=(datetime.datetime.now() - start_time).seconds) 138 | ) 139 | ) 140 | best_Acc.append(top1.avg*100) 141 | return top1.avg, best_Acc 142 | 143 | 144 | def validate_fgsm_bb(val_loader, model_bb, model, args, eps=0.031): 145 | if args.dataset == 'CIFAR10': 146 | mean = torch.Tensor(np.array([0.4914, 0.4822, 0.4465])[:, np.newaxis, np.newaxis]) 147 | mean = mean.expand(3, 32, 32).cuda() 148 | std = torch.Tensor(np.array([0.2023, 0.1994, 0.2010])[:, np.newaxis, np.newaxis]) 149 | std = std.expand(3, 32, 32).cuda() 150 | if args.dataset == 'CIFAR100': 151 | mean = torch.Tensor(np.array([0.5071,0.4867,0.4408])[:, np.newaxis, np.newaxis]) 152 | mean = mean.expand(3, 32, 32).cuda() 153 | std = torch.Tensor(np.array([0.2675,0.2565,0.2761])[:, np.newaxis, np.newaxis]) 154 | std = std.expand(3, 32, 32).cuda() 155 | losses = AverageMeter('Loss') 156 | prec1_fgsm = 0.0 157 | prec5_fgsm = 0.0 158 | prec1_pgd = 0.0 159 | prec5_pgd = 0.0 160 | n = 0 161 | model_bb.eval() 162 | model.eval() 163 | attacker = Attack(dataloader=val_loader, 164 | attack_method='fgsm', epsilon=0.031) 165 | for i, (data, target) in enumerate(val_loader): 166 | if torch.cuda.is_available() and args.gpu: 167 | data, target = data.cuda(), target.cuda() 168 | n += target.size(0) 169 | data.requires_grad = False 170 | #Here generate the perturbed data from model_bb 171 | perturbed_data = attacker.attack_method(model_bb, data, target, args) 172 | perturbed_data.sub_(mean).div_(std) 173 | # Use the model_bb generated perturbed data to evaluate on model 174 | output_fgsm,_ = model(perturbed_data) 175 | loss_fgsm = F.cross_entropy(output_fgsm, target) 176 | _, pred_fgsm = output_fgsm.topk(5, 1, largest=True, sorted=True) 177 | target_fgsm = target.view(target.size(0),-1).expand_as(pred_fgsm) 178 | correct_fgsm = pred_fgsm.eq(target_fgsm).float() 179 | prec1_fgsm += correct_fgsm[:,:1].sum() 180 | prec5_fgsm += correct_fgsm[:,:5].sum() 181 | losses.update(loss_fgsm.item(), data.size(0)) 182 | 183 | top1_fgsm = 100.*(prec1_fgsm/float(n)) 184 | top5_fgsm = 100.*(prec5_fgsm/float(n)) 185 | print('Top1 FGSM:{}'.format(top1_fgsm)) 186 | return top1_fgsm 187 | 188 | 189 | def validate_pgd_bb(val_loader, model_bb, model, eps=0.031, K=7, a=0.01): 190 | if args.dataset == 'CIFAR10': 191 | mean = torch.Tensor(np.array([0.4914, 0.4822, 0.4465])[:, np.newaxis, np.newaxis]) 192 | mean = mean.expand(3, 32, 32).cuda() 193 | std = torch.Tensor(np.array([0.2023, 0.1994, 0.2010])[:, np.newaxis, np.newaxis]) 194 | std = std.expand(3, 32, 32).cuda() 195 | if args.dataset == 'CIFAR100': 196 | mean = torch.Tensor(np.array([0.5071,0.4867,0.4408])[:, np.newaxis, np.newaxis]) 197 | mean = mean.expand(3, 32, 32).cuda() 198 | std = torch.Tensor(np.array([0.2675,0.2565,0.2761])[:, np.newaxis, np.newaxis]) 199 | std = std.expand(3, 32, 32).cuda() 200 | losses = AverageMeter('Loss') 201 | prec1_pgd = 0.0 202 | prec5_pgd = 0.0 203 | n = 0 204 | model_bb.eval() 205 | model.eval() 206 | print('Value of K:{}, eps:{}'.format(K, eps)) 207 | for i, (data, target) in enumerate(val_loader): 208 | if torch.cuda.is_available() and args.gpu: 209 | data, target = data.cuda(), target.cuda() 210 | n += target.size(0) 211 | orig_input = data.clone() 212 | randn = torch.FloatTensor(data.size()).uniform_(-eps, eps).cuda() 213 | data += randn 214 | data.clamp_(0, 1.0) 215 | for _ in range(K): 216 | invar = Variable(data, requires_grad=True) 217 | in1 = invar - mean 218 | in1.div_(std) 219 | #Here generate the output for grad computation from model_bb 220 | output,_ = model_bb(in1) 221 | ascend_loss = F.cross_entropy(output, target) 222 | ascend_grad = torch.autograd.grad(ascend_loss, invar)[0] 223 | pert = torch.sign(ascend_grad)*a 224 | data += pert.data 225 | data = torch.max(orig_input-eps, data) 226 | data = torch.min(orig_input+eps, data) 227 | data.clamp_(0, 1.0) 228 | data.sub_(mean).div_(std) 229 | with torch.no_grad(): 230 | # compute output on model 231 | output,_ = model(data) 232 | loss_pgd = F.cross_entropy(output, target) 233 | 234 | # measure accuracy and record loss 235 | _, pred_pgd = output.topk(5, 1, largest=True, sorted=True) 236 | target_pgd = target.view(target.size(0),-1).expand_as(pred_pgd) 237 | correct_pgd = pred_pgd.eq(target_pgd).float() 238 | prec1_pgd += correct_pgd[:,:1].sum() 239 | prec5_pgd += correct_pgd[:,:5].sum() 240 | losses.update(loss_pgd.item(), data.size(0)) 241 | 242 | top1_pgd = 100.*(prec1_pgd/float(n)) 243 | top5_pgd = 100.*(prec5_pgd/float(n)) 244 | print('\n Top1 PGD:{}'.format(top1_pgd)) 245 | return top1_pgd 246 | 247 | if __name__ == '__main__': 248 | parser = argparse.ArgumentParser(description='SNN training') 249 | parser.add_argument('--gpu', default=True, type=bool, help='use gpu') 250 | parser.add_argument('--dataset', default='CIFAR10', type=str, help='dataset name', choices=['MNIST','CIFAR10','CIFAR100','IMAGENET', 'TINY_IMAGENET']) 251 | parser.add_argument('--batch_size', default=64, type=int, help='minibatch size') 252 | parser.add_argument('-a','--architecture', default='VGG16', type=str, help='network architecture', choices=['VGG4','VGG5','VGG6','VGG9','VGG11','VGG13','VGG16','VGG19','RESNET12','RESNET20','RESNET34']) 253 | parser.add_argument('--pretrained_snn', default='', type=str, help='pretrained SNN for inference') 254 | parser.add_argument('--pretrained_snn_bb', default='', type=str, help='pretrained SNN for BB inference') 255 | parser.add_argument('--test_only', action='store_true', help='perform only inference') 256 | parser.add_argument('--log', action='store_true', help='to print the output on terminal or to log file') 257 | parser.add_argument('--pgd_iter', default=7, type=int, help='number of pgd iterations') 258 | parser.add_argument('--pgd_step', default=0.01, type=float, help='pgd attack step size') 259 | parser.add_argument('--epochs', default=30, type=int, help='number of training epochs') 260 | parser.add_argument('--timesteps', default=20, type=int, help='simulation timesteps') 261 | parser.add_argument('--leak', default=1.0, type=float, help='membrane leak') 262 | parser.add_argument('--default_threshold', default=1.0, type=float, help='intial threshold to train SNN from scratch') 263 | parser.add_argument('--activation', default='Linear', type=str, help='SNN activation function', choices=['Linear']) 264 | parser.add_argument('--optimizer', default='SGD', type=str, help='optimizer for SNN backpropagation', choices=['SGD', 'Adam']) 265 | parser.add_argument('--kernel_size', default=3, type=int, help='filter size for the conv layers') 266 | parser.add_argument('--test_acc_every_batch', action='store_true', help='print acc of every batch during inference') 267 | parser.add_argument('--devices', default='0', type=str, help='list of gpu device(s)') 268 | parser.add_argument('--resume', default='', type=str, help='resume training from this state') 269 | parser.add_argument('--dont_save', action='store_true', help='don\'t save training model during testing') 270 | 271 | args = parser.parse_args() 272 | 273 | os.environ['CUDA_VISIBLE_DEVICES'] = args.devices 274 | 275 | #torch.backends.cudnn.deterministic = True 276 | #torch.backends.cudnn.benchmark = False 277 | 278 | dataset = args.dataset 279 | batch_size = args.batch_size 280 | architecture = args.architecture 281 | pretrained_snn = args.pretrained_snn 282 | pretrained_snn_bb = args.pretrained_snn_bb 283 | epochs = args.epochs 284 | timesteps = args.timesteps 285 | leak = args.leak 286 | default_threshold = args.default_threshold 287 | activation = args.activation 288 | kernel_size = args.kernel_size 289 | test_acc_every_batch= args.test_acc_every_batch 290 | resume = args.resume 291 | start_epoch = 1 292 | max_accuracy = 0.0 293 | 294 | log_file = './logs/' 295 | try: 296 | os.mkdir(log_file) 297 | except OSError: 298 | pass 299 | identifier = 'snn_'+architecture.lower()+'_'+dataset.lower()+'_'+str(timesteps)+'_'+str(datetime.datetime.now()) 300 | log_file+=identifier+'.log' 301 | 302 | if args.log: 303 | f = open(log_file, 'w', buffering=1) 304 | else: 305 | f = sys.stdout 306 | 307 | if torch.cuda.is_available() and args.gpu: 308 | torch.set_default_tensor_type('torch.cuda.FloatTensor') 309 | 310 | if dataset == 'CIFAR10': 311 | normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) 312 | elif dataset == 'CIFAR100': 313 | normalize = transforms.Normalize((0.5071,0.4867,0.4408), (0.2675,0.2565,0.2761)) 314 | 315 | if dataset in ['CIFAR10', 'CIFAR100']: 316 | transform_train = transforms.Compose([ 317 | transforms.RandomCrop(32, padding=4), 318 | transforms.RandomHorizontalFlip(), 319 | transforms.ToTensor() 320 | ]) 321 | transform_test = transforms.Compose([transforms.ToTensor()]) 322 | 323 | if dataset == 'CIFAR10': 324 | trainset = datasets.CIFAR10(root = '../SNN_adversary/cifar_data', train = True, download = True, transform = transform_train) 325 | testset = datasets.CIFAR10(root='../SNN_adversary/cifar_data', train=False, download=True, transform = transform_test) 326 | labels = 10 327 | train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True) 328 | test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False) 329 | 330 | elif dataset == 'CIFAR100': 331 | trainset = datasets.CIFAR100(root = './cifar_data', train = True, download = True, transform = transform_train) 332 | testset = datasets.CIFAR100(root='./cifar_data', train=False, download=True, transform = transform_test) 333 | labels = 100 334 | train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True) 335 | test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False) 336 | 337 | if architecture[0:3].lower() == 'vgg': 338 | model = VGG_SNN(vgg_name = architecture, activation = activation, labels=labels, timesteps=timesteps, leak=leak, default_threshold=default_threshold, dropout=0.2, kernel_size=kernel_size, dataset=dataset) 339 | model_bb = VGG_SNN(vgg_name = architecture, activation = activation, labels=labels, timesteps=timesteps, leak=leak, default_threshold=default_threshold, dropout=0.2, kernel_size=kernel_size, dataset=dataset) 340 | 341 | if pretrained_snn: 342 | if pretrained_snn_bb: 343 | state_bb = torch.load(pretrained_snn_bb, map_location='cpu') 344 | missing_keys, unexpected_keys = model_bb.load_state_dict(state_bb, strict=False) 345 | f.write('\n For the bb model:Missing keys : {}, Unexpected Keys: {}'.format(missing_keys, unexpected_keys)) 346 | state = torch.load(pretrained_snn, map_location='cpu') 347 | missing_keys, unexpected_keys = model.load_state_dict(state, strict=False) 348 | else: 349 | print("Please provide an snn file name") 350 | 351 | #model = nn.DataParallel(model) 352 | 353 | if torch.cuda.is_available() and args.gpu: 354 | model.cuda() 355 | if pretrained_snn_bb: 356 | model_bb.cuda() 357 | 358 | if resume: 359 | f.write('\n Resuming from checkpoint {}'.format(resume)) 360 | state = torch.load(resume, map_location='cpu') 361 | missing_keys, unexpected_keys = model.load_state_dict(state, strict=False) 362 | f.write('\n Missing keys : {}, Unexpected Keys: {}'.format(missing_keys, unexpected_keys)) 363 | #f.write('\n Info: Accuracy of loaded ANN model: {}'.format(state['accuracy'])) 364 | 365 | #epoch = state['epoch'] 366 | start_epoch = epoch + 1 367 | #max_accuracy = state['accuracy'] 368 | #optimizer.load_state_dict(state['optimizer']) 369 | for param_group in optimizer.param_groups: 370 | learning_rate = param_group['lr'] 371 | 372 | f.write('\n Loaded from resume epoch: {}, accuracy: {:.4f} lr: {:.1e}'.format(epoch, max_accuracy, learning_rate)) 373 | 374 | test_Acc = [0] 375 | pgd_test_acc = [0] 376 | fgsmtest_acc = [0] 377 | for epoch in range(start_epoch, epochs+1): 378 | start_time = datetime.datetime.now() 379 | top1, test_Acc = test(epoch, test_loader, test_Acc, args) 380 | top1_fgsm = validate_fgsm_bb(test_loader, model_bb, model, args, eps=0.031) 381 | top1_pgd = validate_pgd_bb(test_loader, model_bb, model, eps=0.031, K=7, a=0.01) 382 | fgsmtest_acc.append(top1_fgsm) 383 | pgd_test_acc.append(top1_pgd) 384 | #print('Epoch:{}, TestAcc:{}, PGD acc:{}'.format(epoch, top1, top1_pgd)) 385 | print('Epoch:{}, TestAcc:{}, FGSM acc:{}, PGD acc:{}'.format(epoch, top1*100, top1_fgsm, top1_pgd)) 386 | 387 | #f.write('\n Highest accuracy: {:.4f}'.format(max_accuracy)) 388 | 389 | 390 | 391 | 392 | --------------------------------------------------------------------------------