├── Fig
├── __init__.py
├── Capture.PNG
├── Training_procedure.png
└── intro_vgg11_sa_and_attack_performance_plot.png
├── LICENSE
├── run_snn_tradit_test.py
├── run_snn_hire_test.py
├── README.md
├── attack_model_spike_cnt.py
├── vgg_spiking_nodewise.py
├── snn_free_nonorm_for_test_purpose.py
└── snn_free_nonorm_for_bbtest_purpose.py
/Fig/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/Fig/Capture.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ksouvik52/hiresnn2021/HEAD/Fig/Capture.PNG
--------------------------------------------------------------------------------
/Fig/Training_procedure.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ksouvik52/hiresnn2021/HEAD/Fig/Training_procedure.png
--------------------------------------------------------------------------------
/Fig/intro_vgg11_sa_and_attack_performance_plot.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ksouvik52/hiresnn2021/HEAD/Fig/intro_vgg11_sa_and_attack_performance_plot.png
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Souvik Kundu
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/run_snn_tradit_test.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | ########################################################################
5 | #### ~~~~~~ White box testing ~~~~~~
6 | ## Hereere, pretrained_snn is the model on which wb test will be done.
7 | ########################################################################
8 |
9 | cmd1 = "python snn_free_nonorm_for_test_purpose.py --dataset CIFAR100 --batch_size 32 --architecture VGG11 \
10 | --epochs 1 --timesteps 8 --leak 1.0 --devices 1 \
11 | --pretrained_snn='traditional_models/vgg11_cifar100_tradit_model.pt'"
12 | os.system(cmd1)
13 |
14 |
15 | ########################################################################
16 | #### ~~~~~~ Black box testing ~~~~~~
17 | ## Here, pretrained_snn is the model on which bb test will be done.
18 | ## and pretrained_snn_bb is the model that generates the adversarial images.
19 | ## We take models of same variants trained with different seed.
20 | ########################################################################
21 | cmd2 = "python snn_free_nonorm_for_bbtest_purpose.py --dataset CIFAR100 --batch_size 32 --architecture VGG11 \
22 | --epochs 1 --timesteps 8 --leak 1.0 --devices 1 \
23 | --pretrained_snn='traditional_models/vgg11_cifar100_tradit_model.pt'\
24 | --pretrained_snn_bb='traditional_models/vgg11_cifar100_tradit_bb_test_model.pt'"
25 | os.system(cmd2)
26 |
--------------------------------------------------------------------------------
/run_snn_hire_test.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | ########################################################################
5 | #### ~~~~~~ White box testing ~~~~~~
6 | ## Hereere, pretrained_snn is the model on which wb test will be done.
7 | ########################################################################
8 |
9 | cmd1 = "python snn_free_nonorm_for_test_purpose.py --dataset CIFAR100 --batch_size 32 --architecture VGG11 \
10 | --epochs 1 --timesteps 8 --leak 1.0 --devices 1 \
11 | --pretrained_snn='HIRE_SNN_models/vgg11_cifar100_hiresnn_tstep8_model.pt'"
12 | os.system(cmd1)
13 |
14 |
15 | ########################################################################
16 | #### ~~~~~~ Black box testing ~~~~~~
17 | ## Here, pretrained_snn is the model on which bb test will be done.
18 | ## and pretrained_snn_bb is the model that generates the adversarial images.
19 | ## We take models of same variants trained with different seed.
20 | ########################################################################
21 | cmd2 = "python snn_free_nonorm_for_bbtest_purpose.py --dataset CIFAR100 --batch_size 32 --architecture VGG11 \
22 | --epochs 1 --timesteps 8 --leak 1.0 --devices 1 \
23 | --pretrained_snn='HIRE_SNN_models/vgg11_cifar100_hiresnn_tstep8_model.pt'\
24 | --pretrained_snn_bb='HIRE_SNN_models/vgg11_cifar100_hiresnn_tstep8_bb_test_model.pt'"
25 | os.system(cmd2)
26 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |

2 |
3 | [](https://opensource.org/licenses/MIT)
4 |
5 |
6 | Welcome to the official repo of the `ICCV 2021` paper **`HIRE-SNN: Harnessing the Inherent Robustness of Energy-Efficient Deep Spiking Neural Networks by Training with Crafted Input Noise`**.
7 |
8 | **This repo currently contains the test codes. Training code will be updated soon!**
9 |
10 | ### Authors:
11 | 1. **Souvik Kundu** (souvikku@usc.edu)
12 | 2. Massoud Pedram (pedram@usc.edu)
13 | 3. Peter A. Beerel (pabeerel@usc.edu)
14 |
15 | ### Abstract:
16 | Low-latency deep spiking neural networks (SNNs) havebecome a promising alternative to conventional artificial neural networks (ANNs) because of their potential for increased energy efficiency on event-driven neuromorphic hardware.Neural networks, including SNNs, however, are subject to various adversarial attacks and must be trained to remain resilient against such attacks for many applications. Nevertheless, due to prohibitively high training costs associated with SNNs, an analysis and optimization of deep SNNs under various adversarial attacks have beenlargely overlooked. In this paper, we first present a detailed analysis of the inherent robustness of low-latency SNNs against popular gradient-based attacks, namely fast gradient sign method (FGSM) and projected gradient descent (PGD). Motivated by this analysis, to harness themodel’s robustness against these attacks we present an SNN training algorithm that uses crafted input noise and incurs no additional training time. To evaluate the merits of our algorithm, we conducted extensive experiments with variants of VGG and ResNet on both CIFAR-10 and CIFAR-100 dataset. Compared to standard trained direct-input SNNs, our trained models yield improved classification accuracy of up to 13.7% and 10.1% on FGSM and PGD attack generated images, respectively, with negligible loss in clean image accuracy. Our models also outperform inherently-robust SNNs trained on rate-coded in-puts with improved or similar classification performanceon attack-generated images while having up to 25x and ∼4.6x lower latency and computation energy, respectively.
17 |
18 | 

19 |
20 | ### Version on which the models were tested:
21 |
22 | * PyTorch version: `1.5.1`.
23 | * Python version: `3.8.3`.
24 |
25 | ### Model download:
26 | #### A. HIRE-SNN models:
27 | 1. [vgg11_cifar100_hiresnn_tstep8_bb_test_model](https://drive.google.com/file/d/1iDDaO3EZnEPk9JaAn2tCL5M80d8DWxv3/view?usp=sharing)
28 | 2. [vgg11_cifar100_hiresnn_tstep8_model](https://drive.google.com/file/d/1huAXyOzdwlVHS2xsbyLdwr_fgwbdLIGQ/view?usp=sharing)
29 | #### B. Traditional SNN models:
30 | 1. [vgg11_cifar100_tradit_bb_test_model](https://drive.google.com/file/d/1GNrDCu7uD8IBAea6sHue84YEwoGd8t1G/view?usp=sharing)
31 | 2. [vgg11_cifar100_tradit_model](https://drive.google.com/file/d/1sayiUKccwrn0X77hAm-eiZUzVxcO3KhH/view?usp=sharing)
32 | ### To test adversarial accuracy of a saved model, please follow these steps:
33 | Create two folders named *`HIRE_SNN_models`* and *`traditional_models`*. Download the models to their respective folder locations.
34 | #### 1. HIRE SNN testing:
35 | a) To test HIRE SNN: select folder/file location: HIRE_SNN_models/vgg11_cifar100_hiresnn_tstep8_model.pt (as pretrained_snn)
36 | and
37 | HIRE_SNN_models/vgg11_cifar100_hiresnn_tstep8_bb_test_model.pt as pretrained_snn_bb(this is for black box testing only)
38 | and edit in the run_snn_hire_test.py file. 1. b) run command: python run_snn_hire_test.py
39 |
40 | #### 2. Traditional SNN testing:
41 | a) To test traditional SNN: select folder/file location: traditional_models/vgg11_cifar100_tradit_model.pt as pretrained_snn
42 | and
43 | traditional_models/vgg11_cifar100_tradit_bb_test_model.pt as pretrained_snn_bb (this is for black box testing only)
44 | 2. b) run command: python run_snn_tradit_test.py
45 |
46 | ### Cite this work
47 | If you find this project useful to you, please cite our work:
48 |
49 | @InProceedings{Kundu_2021_ICCV,
50 | author = {Kundu, Souvik and Pedram, Massoud and Beerel, Peter A.},
51 | title = {HIRE-SNN: Harnessing the Inherent Robustness of Energy-Efficient Deep Spiking Neural Networks by Training With Crafted Input Noise},
52 | booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
53 | month = {October},
54 | year = {2021},
55 | pages = {5209-5218}}
56 |
57 | ### Acknowledgment
58 | [Hybrid SNN repo](https://github.com/nitin-rathi/hybrid-snn-conversion)
59 |
60 |
--------------------------------------------------------------------------------
/attack_model_spike_cnt.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import math
3 | import torch.nn.functional as F
4 | import torch
5 | import copy
6 | import numpy as np
7 |
8 | class Attack(object):
9 |
10 | def __init__(self, dataloader, criterion=None, gpu_id=0,
11 | epsilon=0.031, attack_method='pgd'):
12 |
13 | if criterion is not None:
14 | self.criterion = criterion
15 | else:
16 | self.criterion = nn.CrossEntropyLoss()
17 |
18 | self.dataloader = dataloader
19 | self.epsilon = epsilon
20 | self.gpu_id = gpu_id #this is integer
21 |
22 | if attack_method == 'fgsm':
23 | self.attack_method = self.fgsm
24 | elif attack_method == 'pgd':
25 | self.attack_method = self.pgd
26 |
27 | def update_params(self, epsilon=None, dataloader=None, attack_method=None):
28 | if epsilon is not None:
29 | self.epsilon = epsilon
30 | if dataloader is not None:
31 | self.dataloader = dataloader
32 |
33 | if attack_method is not None:
34 | if attack_method == 'fgsm':
35 | self.attack_method = self.fgsm
36 | elif attack_method == 'pgd':
37 | self.attack_method = self.pgd
38 |
39 | ## For SNN pgd takes two more args: mean and std to manually perform normalization for
40 | ## each of the k iterated perturbed data generated intermediately.
41 | def fgsm(self, model, data, target, args, data_min=0, data_max=1):
42 |
43 | if args.dataset == 'CIFAR10':
44 | mean = torch.Tensor(np.array([0.4914, 0.4822, 0.4465])[:, np.newaxis, np.newaxis])
45 | mean = mean.expand(3, 32, 32).cuda()
46 | std = torch.Tensor(np.array([0.2023, 0.1994, 0.2010])[:, np.newaxis, np.newaxis])
47 | std = std.expand(3, 32, 32).cuda()
48 | if args.dataset == 'CIFAR100':
49 | mean = torch.Tensor(np.array([0.5071,0.4867,0.4408])[:, np.newaxis, np.newaxis])
50 | mean = mean.expand(3, 32, 32).cuda()
51 | std = torch.Tensor(np.array([0.2675,0.2565,0.2761])[:, np.newaxis, np.newaxis])
52 | std = std.expand(3, 32, 32).cuda()
53 |
54 | model.eval()
55 | # perturbed_data = copy.deepcopy(data)
56 | perturbed_data = data.clone()
57 |
58 | perturbed_data.requires_grad = True
59 | #As we take the raw un-normalized data, we convert to a normalized data
60 | # and then feed to model
61 | perturbed_data_norm = perturbed_data -mean
62 | perturbed_data_norm.div_(std)
63 | output,_ = model(perturbed_data_norm)
64 | #print('perturbed_data.requires_grad:', perturbed_data.requires_grad)
65 | loss = F.cross_entropy(output, target)
66 | if perturbed_data.grad is not None:
67 | perturbed_data.grad.data.zero_()
68 |
69 | loss.backward()
70 |
71 | # Collect the element-wise sign of the data gradient
72 | sign_data_grad = perturbed_data.grad.data.sign()
73 | perturbed_data.requires_grad = False
74 |
75 | with torch.no_grad():
76 | # Create the perturbed image by adjusting each pixel of the input image
77 | perturbed_data += self.epsilon*sign_data_grad
78 | # Adding clipping to maintain [min,max] range, default 0,1 for image
79 | perturbed_data.clamp_(data_min, data_max)
80 |
81 | return perturbed_data
82 |
83 | ## For SNN pgd takes two more args: mean and std to manually perform normalization for
84 | ## each of the k iterated perturbed data generated intermediately.
85 | def pgd(self, model, data, target, k=7, a=0.01, random_start=True,
86 | d_min=0, d_max=1): #to reduce time for SNN kept k = 3, or else for ANN we use k=7
87 |
88 | mean = torch.Tensor(np.array([0.4914, 0.4822, 0.4465])[:, np.newaxis, np.newaxis])
89 | mean = mean.expand(3, 32, 32).cuda()
90 | std = torch.Tensor(np.array([0.2023, 0.1994, 0.2010])[:, np.newaxis, np.newaxis])
91 | std = std.expand(3, 32, 32).cuda()
92 |
93 | model.eval()
94 | # perturbed_data = copy.deepcopy(data)
95 | perturbed_data = data.clone()
96 | perturbed_data.requires_grad = True
97 |
98 | data_max = data + self.epsilon
99 | data_min = data - self.epsilon
100 | data_max.clamp_(d_min, d_max)
101 | data_min.clamp_(d_min, d_max)
102 |
103 | if random_start:
104 | with torch.no_grad():
105 | perturbed_data.data = data + perturbed_data.uniform_(-1*self.epsilon, self.epsilon)
106 | perturbed_data.data.clamp_(d_min, d_max)
107 |
108 | for _ in range(k):
109 | ##for SNNs we don't have a mean, std layer separately, so we manually do mean
110 | ## subtraction here with every perturbed data generated
111 |
112 | in1 = perturbed_data - mean
113 | in1.div_(std)
114 | output,_ = model( in1 )
115 | #print('output shape:{}, target shape:{}', output.shape, target.shape)
116 | loss = F.cross_entropy(output, target)
117 |
118 | if perturbed_data.grad is not None:
119 | perturbed_data.grad.data.zero_()
120 |
121 | loss.backward()
122 | data_grad = perturbed_data.grad.data
123 |
124 | with torch.no_grad():
125 | perturbed_data.data += a * torch.sign(data_grad)
126 | perturbed_data.data = torch.max(torch.min(perturbed_data, data_max),
127 | data_min)
128 | perturbed_data.requires_grad = False
129 |
130 | return perturbed_data
131 |
--------------------------------------------------------------------------------
/vgg_spiking_nodewise.py:
--------------------------------------------------------------------------------
1 | #---------------------------------------------------
2 | # Imports
3 | #---------------------------------------------------
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | import numpy as np
8 | import pdb
9 | import math
10 | from collections import OrderedDict
11 | from matplotlib import pyplot as plt
12 | import copy
13 |
14 | cfg = {
15 | 'VGG4' : [64, 'A', 128, 'A'],
16 | 'VGG5' : [64, 'A', 128, 128, 'A'],
17 | 'VGG6' : [64, 'A', 128, 128, 'A'],
18 | 'VGG9': [64, 'A', 128, 256, 'A', 256, 512, 'A', 512, 'A', 512],
19 | 'VGG11': [64, 'A', 128, 256, 'A', 512, 512, 512, 'A', 512, 512],
20 | 'VGG13': [64, 64, 'A', 128, 128, 'A', 256, 256, 'A', 512, 512, 512, 'A', 512],
21 | 'VGG16': [64, 64, 'A', 128, 128, 'A', 256, 256, 256, 'A', 512, 512, 512, 'A', 512, 512, 512],
22 | 'VGG19': [64, 64, 'A', 128, 128, 'A', 256, 256, 256, 256, 'A', 512, 512, 512, 512, 'A', 512, 512, 512, 512]
23 | }
24 |
25 | class LinearSpike(torch.autograd.Function):
26 | """
27 | Here we use the piecewise-linear surrogate gradient as was done
28 | in Bellec et al. (2018).
29 | """
30 | gamma = 0.3 # Controls the dampening of the piecewise-linear surrogate gradient
31 |
32 | @staticmethod
33 | def forward(ctx, input, last_spike):
34 |
35 | ctx.save_for_backward(input)
36 | out = torch.zeros_like(input).cuda()
37 | out[input > 0] = 1.0
38 | return out
39 |
40 | @staticmethod
41 | def backward(ctx, grad_output):
42 |
43 | input, = ctx.saved_tensors
44 | grad_input = grad_output.clone()
45 | grad = LinearSpike.gamma*F.threshold(1.0-torch.abs(input), 0, 0)
46 | return grad*grad_input, None
47 |
48 | class VGG_SNN(nn.Module):
49 |
50 | def __init__(self, vgg_name, activation='Linear', labels=10, timesteps=100, leak=1.0, default_threshold = 1.0, dropout=0.2, kernel_size=3, dataset='CIFAR10'):
51 | super().__init__()
52 |
53 | self.vgg_name = vgg_name
54 | self.act_func = LinearSpike.apply
55 | self.labels = labels
56 | self.timesteps = timesteps
57 | self.dropout = dropout
58 | self.kernel_size = kernel_size
59 | self.dataset = dataset
60 | self.mem = {}
61 | self.mask = {}
62 | self.spike = {}
63 |
64 | self.features, self.classifier = self._make_layers(cfg[self.vgg_name])
65 |
66 | self._initialize_weights2()
67 |
68 | threshold = {}
69 | lk = {}
70 | for l in range(len(self.features)):
71 | if isinstance(self.features[l], nn.Conv2d):
72 | threshold['t'+str(l)] = nn.Parameter(torch.tensor(default_threshold))
73 | lk['l'+str(l)] = nn.Parameter(torch.tensor(leak))
74 |
75 |
76 | prev = len(self.features)
77 | for l in range(len(self.classifier)-1):
78 | if isinstance(self.classifier[l], nn.Linear):
79 | threshold['t'+str(prev+l)] = nn.Parameter(torch.tensor(default_threshold))
80 | lk['l'+str(prev+l)] = nn.Parameter(torch.tensor(leak))
81 |
82 | self.threshold = nn.ParameterDict(threshold)
83 | self.leak = nn.ParameterDict(lk)
84 |
85 | def _initialize_weights2(self):
86 | for m in self.modules():
87 |
88 | if isinstance(m, nn.Conv2d):
89 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
90 | m.weight.data.normal_(0, math.sqrt(2. / n))
91 | if m.bias is not None:
92 | m.bias.data.zero_()
93 | elif isinstance(m, nn.BatchNorm2d):
94 | m.weight.data.fill_(1)
95 | m.bias.data.zero_()
96 | elif isinstance(m, nn.Linear):
97 | n = m.weight.size(1)
98 | m.weight.data.normal_(0, 0.01)
99 | if m.bias is not None:
100 | m.bias.data.zero_()
101 |
102 | def threshold_update(self, scaling_factor=1.0, thresholds=[]):
103 |
104 | # Initialize thresholds
105 | self.scaling_factor = scaling_factor
106 |
107 | for pos in range(len(self.features)):
108 | if isinstance(self.features[pos], nn.Conv2d):
109 | if thresholds:
110 | self.threshold.update({'t'+str(pos): nn.Parameter(torch.tensor(thresholds.pop(0)*self.scaling_factor))})
111 | #print('\t Layer{} : {:.2f}'.format(pos, self.threshold[pos]))
112 |
113 | prev = len(self.features)
114 |
115 | for pos in range(len(self.classifier)-1):
116 | if isinstance(self.classifier[pos], nn.Linear):
117 | if thresholds:
118 | self.threshold.update({'t'+str(prev+pos): nn.Parameter(torch.tensor(thresholds.pop(0)*self.scaling_factor))})
119 | #print('\t Layer{} : {:.2f}'.format(prev+pos, self.threshold[prev+pos]))
120 |
121 |
122 | def _make_layers(self, cfg):
123 | layers = []
124 | if self.dataset =='MNIST':
125 | in_channels = 1
126 | else:
127 | in_channels = 3
128 |
129 | for x in (cfg):
130 | stride = 1
131 |
132 | if x == 'A':
133 | layers.pop()
134 | layers += [nn.AvgPool2d(kernel_size=2, stride=2)]
135 |
136 | else:
137 | layers += [nn.Conv2d(in_channels, x, kernel_size=self.kernel_size, padding=(self.kernel_size-1)//2, stride=stride, bias=False),
138 | nn.ReLU(inplace=True)
139 | ]
140 | layers += [nn.Dropout(self.dropout)]
141 | in_channels = x
142 |
143 | features = nn.Sequential(*layers)
144 |
145 | layers = []
146 | if self.dataset == 'IMAGENET':
147 | layers += [nn.Linear(512*7*7, 4096, bias=False)]
148 | layers += [nn.ReLU(inplace=True)]
149 | layers += [nn.Dropout(self.dropout)]
150 | layers += [nn.Linear(4096, 4096, bias=False)]
151 | layers += [nn.ReLU(inplace=True)]
152 | layers += [nn.Dropout(self.dropout)]
153 | layers += [nn.Linear(4096, self.labels, bias=False)]
154 |
155 | elif self.vgg_name == 'VGG6' and self.dataset != 'MNIST':
156 | layers += [nn.Linear(512*4*4, 4096, bias=False)]
157 | layers += [nn.ReLU(inplace=True)]
158 | layers += [nn.Dropout(self.dropout)]
159 | layers += [nn.Linear(4096, 4096, bias=False)]
160 | layers += [nn.ReLU(inplace=True)]
161 | layers += [nn.Dropout(self.dropout)]
162 | layers += [nn.Linear(4096, self.labels, bias=False)]
163 |
164 | elif self.vgg_name == 'VGG4' and self.dataset== 'MNIST':
165 | layers += [nn.Linear(128*7*7, 1024, bias=False)]
166 | layers += [nn.ReLU(inplace=True)]
167 | layers += [nn.Dropout(self.dropout)]
168 | #layers += [nn.Linear(4096, 4096, bias=False)]
169 | #layers += [nn.ReLU(inplace=True)]
170 | #layers += [nn.Dropout(self.dropout)]
171 | layers += [nn.Linear(1024, self.labels, bias=False)]
172 |
173 | elif self.vgg_name == 'VGG16' and self.dataset != 'MNIST':
174 | layers += [nn.Linear(512*2*2, 4096, bias=False)]
175 | layers += [nn.ReLU(inplace=True)]
176 | layers += [nn.Dropout(self.dropout)]
177 | layers += [nn.Linear(4096, 4096, bias=False)]
178 | layers += [nn.ReLU(inplace=True)]
179 | layers += [nn.Dropout(self.dropout)]
180 | layers += [nn.Linear(4096, self.labels, bias=False)]
181 |
182 | elif (self.vgg_name == 'VGG5' or self.vgg_name == 'VGG11') and self.dataset != 'MNIST':
183 | layers += [nn.Linear(512*2*2*4, 4096, bias=False)]
184 | layers += [nn.ReLU(inplace=True)]
185 | layers += [nn.Dropout(self.dropout)]
186 | layers += [nn.Linear(4096, 4096, bias=False)]
187 | layers += [nn.ReLU(inplace=True)]
188 | layers += [nn.Dropout(self.dropout)]
189 | layers += [nn.Linear(4096, self.labels, bias=False)]
190 |
191 | elif self.vgg_name == 'VGG6' and self.dataset == 'MNIST':
192 | layers += [nn.Linear(128*7*7, 4096, bias=False)]
193 | layers += [nn.ReLU(inplace=True)]
194 | layers += [nn.Dropout(self.dropout)]
195 | layers += [nn.Linear(4096, 4096, bias=False)]
196 | layers += [nn.ReLU(inplace=True)]
197 | layers += [nn.Dropout(self.dropout)]
198 | layers += [nn.Linear(4096, self.labels, bias=False)]
199 |
200 | elif self.vgg_name != 'VGG6' and self.dataset == 'MNIST':
201 | layers += [nn.Linear(512*1*1, 4096, bias=False)]
202 | layers += [nn.ReLU(inplace=True)]
203 | layers += [nn.Dropout(self.dropout)]
204 | layers += [nn.Linear(4096, 4096, bias=False)]
205 | layers += [nn.ReLU(inplace=True)]
206 | layers += [nn.Dropout(self.dropout)]
207 | layers += [nn.Linear(4096, self.labels, bias=False)]
208 |
209 |
210 | classifer = nn.Sequential(*layers)
211 | return (features, classifer)
212 |
213 | def network_update(self, timesteps, leak):
214 | self.timesteps = timesteps
215 |
216 | def timestep_update(self, timesteps):
217 | self.timesteps = timesteps
218 |
219 | def neuron_init(self, x):
220 | self.batch_size = x.size(0)
221 | self.width = x.size(2)
222 | self.height = x.size(3)
223 |
224 | self.mem = {}
225 | self.spike = {}
226 | self.mask = {}
227 | self.spike_count= {}
228 |
229 | for l in range(len(self.features)):
230 |
231 | if isinstance(self.features[l], nn.Conv2d):
232 | self.mem[l] = torch.zeros(self.batch_size, self.features[l].out_channels, self.width, self.height)
233 | #this is to just keep track what happens to spike after conv
234 |
235 |
236 |
237 | elif isinstance(self.features[l], nn.ReLU):
238 | if isinstance(self.features[l-1], nn.Conv2d):
239 | self.spike[l] = torch.ones(self.mem[l-1].shape)*(-1000)
240 | self.spike_count[l] = torch.zeros(self.mem[l-1].size())
241 |
242 | elif isinstance(self.features[l-1], nn.AvgPool2d):
243 | self.spike[l] = torch.ones(self.batch_size, self.features[l-2].out_channels, self.width, self.height)*(-1000)
244 |
245 |
246 | elif isinstance(self.features[l], nn.Dropout):
247 | self.mask[l] = self.features[l](torch.ones(self.mem[l-2].shape).cuda())
248 |
249 | elif isinstance(self.features[l], nn.AvgPool2d):
250 | self.width = self.width//self.features[l].kernel_size
251 | self.height = self.height//self.features[l].kernel_size
252 | self.spike_count[l] = torch.zeros(self.batch_size, self.features[l-2].out_channels,self.width,self.height)
253 |
254 |
255 | prev = len(self.features)
256 |
257 | for l in range(len(self.classifier)):
258 |
259 | if isinstance(self.classifier[l], nn.Linear):
260 | self.mem[prev+l] = torch.zeros(self.batch_size, self.classifier[l].out_features)
261 |
262 |
263 | elif isinstance(self.classifier[l], nn.ReLU):
264 | self.spike[prev+l] = torch.ones(self.mem[prev+l-1].shape)*(-1000)
265 | self.spike_count[prev+l] = torch.zeros(self.mem[prev+l-1].size())
266 |
267 | elif isinstance(self.classifier[l], nn.Dropout):
268 | self.mask[prev+l] = self.classifier[l](torch.ones(self.mem[prev+l-2].shape).cuda())
269 |
270 |
271 | def percentile(self, t, q):
272 |
273 | k = 1 + round(.01 * float(q) * (t.numel() - 1))
274 | result = t.view(-1).kthvalue(k).values.item()
275 | return result
276 |
277 | def forward(self, x, find_max_mem=False, spike_count=[], max_mem_layer=0):
278 |
279 | self.neuron_init(x)
280 | max_mem=0.0
281 |
282 | for t in range(self.timesteps):
283 | out_prev = x
284 |
285 | for l in range(len(self.features)):
286 |
287 | if isinstance(self.features[l], (nn.Conv2d)):
288 |
289 | if find_max_mem and l==max_mem_layer:
290 | cur = self.percentile(self.features[l](out_prev).view(-1), 99.7)
291 | if (cur>max_mem):
292 | max_mem = torch.tensor([cur])
293 | break
294 |
295 | delta_mem = self.features[l](out_prev)
296 | self.mem[l] = getattr(self.leak, 'l'+str(l)) *self.mem[l] + delta_mem
297 | mem_thr = (self.mem[l]/getattr(self.threshold, 't'+str(l))) - 1.0
298 | rst = getattr(self.threshold, 't'+str(l)) * (mem_thr>0).float()
299 | self.mem[l] = self.mem[l]-rst
300 |
301 | elif isinstance(self.features[l], nn.ReLU):
302 |
303 | out = self.act_func(mem_thr, (t-1-self.spike[l]))
304 | self.spike[l] = self.spike[l].masked_fill(out.bool(),t-1)
305 | self.spike_count[l][out.bool()] = self.spike_count[l][out.bool()] + 1
306 | out_prev = out.clone()
307 |
308 | elif isinstance(self.features[l], nn.AvgPool2d):
309 | out_prev = self.features[l](out_prev)
310 | self.spike_count[l][out_prev.bool()] = self.spike_count[l][out_prev.bool()] + 1
311 |
312 |
313 | elif isinstance(self.features[l], nn.Dropout):
314 | out_prev = out_prev * self.mask[l]
315 |
316 | if find_max_mem and max_mem_layermax_mem:
329 | max_mem = torch.tensor([cur])
330 | break
331 |
332 | delta_mem = self.classifier[l](out_prev)
333 | self.mem[prev+l] = getattr(self.leak, 'l'+str(prev+l)) * self.mem[prev+l] + delta_mem
334 | mem_thr = (self.mem[prev+l]/getattr(self.threshold, 't'+str(prev+l))) - 1.0
335 | rst = getattr(self.threshold,'t'+str(prev+l)) * (mem_thr>0).float()
336 | self.mem[prev+l] = self.mem[prev+l]-rst
337 |
338 |
339 | elif isinstance(self.classifier[l], nn.ReLU):
340 | out = self.act_func(mem_thr, (t-1-self.spike[prev+l]))
341 | self.spike[prev+l] = self.spike[prev+l].masked_fill(out.bool(),t-1)
342 | self.spike_count[prev+l][out.bool()] = self.spike_count[prev+l][out.bool()] + 1
343 | out_prev = out.clone()
344 |
345 | elif isinstance(self.classifier[l], nn.Dropout):
346 | out_prev = out_prev * self.mask[prev+l]
347 |
348 | # Compute the classification layer outputs
349 | if not find_max_mem:
350 | self.mem[prev+l+1] = self.mem[prev+l+1] + self.classifier[l+1](out_prev)
351 | if find_max_mem:
352 | return max_mem
353 |
354 | return self.mem[prev+l+1], self.spike_count
355 |
356 |
357 |
358 |
--------------------------------------------------------------------------------
/snn_free_nonorm_for_test_purpose.py:
--------------------------------------------------------------------------------
1 | #---------------------------------------------------
2 | # Imports
3 | #---------------------------------------------------
4 | from __future__ import print_function
5 | import argparse
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 | import torch.optim as optim
10 | from torchvision import datasets, transforms, models
11 | from torch.utils.data.dataloader import DataLoader
12 | from torch.autograd import Variable
13 | from matplotlib import pyplot as plt
14 | from matplotlib.gridspec import GridSpec
15 | import numpy as np
16 | import datetime
17 | import pdb
18 | from vgg_spiking_nodewise import *
19 | import sys
20 | import os
21 | import shutil
22 | import argparse
23 | from attack_model_spike_cnt import Attack
24 |
25 |
26 | class AverageMeter(object):
27 | """Computes and stores the average and current value"""
28 | def __init__(self, name, fmt=':f'):
29 | self.name = name
30 | self.fmt = fmt
31 | self.reset()
32 |
33 | def reset(self):
34 | self.val = 0
35 | self.avg = 0
36 | self.sum = 0
37 | self.count = 0
38 |
39 | def update(self, val, n=1):
40 | self.val = val
41 | self.sum += val * n
42 | self.count += n
43 | self.avg = self.sum / self.count
44 |
45 | def __str__(self):
46 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
47 | return fmtstr.format(**self.__dict__)
48 |
49 | def accuracy(output, target, topk=(1,)):
50 | """Computes the accuracy over the k top predictions for the specified values of k"""
51 | with torch.no_grad():
52 | maxk = max(topk)
53 | batch_size = target.size(0)
54 |
55 | _, pred = output.topk(maxk, 1, True, True)
56 | pred = pred.t()
57 | correct = pred.eq(target.view(1, -1).expand_as(pred))
58 |
59 | res = []
60 | for k in topk:
61 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
62 | res.append(correct_k.mul_(100.0 / batch_size))
63 | return res
64 |
65 | def test(epoch, test_loader, best_Acc, args):
66 |
67 | losses = AverageMeter('Loss')
68 | top1 = AverageMeter('Acc@1')
69 | avg_spike_cnt = []
70 |
71 | if args.dataset == 'CIFAR10':
72 | mean = torch.Tensor(np.array([0.4914, 0.4822, 0.4465])[:, np.newaxis, np.newaxis])
73 | mean = mean.expand(3, 32, 32).cuda()
74 | std = torch.Tensor(np.array([0.2023, 0.1994, 0.2010])[:, np.newaxis, np.newaxis])
75 | std = std.expand(3, 32, 32).cuda()
76 | if args.dataset == 'CIFAR100':
77 | mean = torch.Tensor(np.array([0.5071,0.4867,0.4408])[:, np.newaxis, np.newaxis])
78 | mean = mean.expand(3, 32, 32).cuda()
79 | std = torch.Tensor(np.array([0.2675,0.2565,0.2761])[:, np.newaxis, np.newaxis])
80 | std = std.expand(3, 32, 32).cuda()
81 |
82 | if args.test_only:
83 | temp1 = []
84 | temp2 = []
85 | for key, value in sorted(model.threshold.items(), key=lambda x: (int(x[0][1:]), (x[1]))):
86 | temp1 = temp1+[round(value.item(),2)]
87 | for key, value in sorted(model.leak.items(), key=lambda x: (int(x[0][1:]), (x[1]))):
88 | temp2 = temp2+[round(value.item(),2)]
89 | f.write('\n Thresholds: {}, leak: {}'.format(temp1, temp2))
90 |
91 | with torch.no_grad():
92 | model.eval()
93 | global max_accuracy
94 |
95 | for batch_idx, (data, target) in enumerate(test_loader):
96 |
97 | if torch.cuda.is_available() and args.gpu:
98 | data, target = data.cuda(), target.cuda()
99 |
100 | data = data - mean
101 | data.div_(std)
102 | output, spike_count = model(data)
103 | loss = F.cross_entropy(output,target)
104 | pred = output.max(1,keepdim=True)[1]
105 | correct = pred.eq(target.data.view_as(pred)).cpu().sum()
106 |
107 | losses.update(loss.item(),data.size(0))
108 | top1.update(correct.item()/data.size(0), data.size(0))
109 |
110 | if test_acc_every_batch:
111 |
112 | f.write('\n Images {}/{} Accuracy: {}/{}({:.4f})'
113 | .format(
114 | test_loader.batch_size*(batch_idx+1),
115 | len(test_loader.dataset),
116 | correct.item(),
117 | data.size(0),
118 | top1.avg*100
119 | )
120 | )
121 |
122 | temp1 = []
123 | temp2 = []
124 | for key, value in sorted(model.threshold.items(), key=lambda x: (int(x[0][1:]), (x[1]))):
125 | temp1 = temp1+[value.item()]
126 | for key, value in sorted(model.leak.items(), key=lambda x: (int(x[0][1:]), (x[1]))):
127 | temp2 = temp2+[value.item()]
128 |
129 | if epoch>5 and top1.avg<0.15:
130 | f.write('\n Quitting as the training is not progressing')
131 | exit(0)
132 |
133 | f.write(' test_loss: {:.4f}, test_acc: {:.4f}, time: {}'
134 | .format(
135 | losses.avg,
136 | top1.avg*100,
137 | datetime.timedelta(seconds=(datetime.datetime.now() - start_time).seconds)
138 | )
139 | )
140 | best_Acc.append(top1.avg*100)
141 | return top1.avg, best_Acc
142 |
143 |
144 | def validate_fgsm(val_loader, model, args, eps=0.031):
145 | if args.dataset == 'CIFAR10':
146 | mean = torch.Tensor(np.array([0.4914, 0.4822, 0.4465])[:, np.newaxis, np.newaxis])
147 | mean = mean.expand(3, 32, 32).cuda()
148 | std = torch.Tensor(np.array([0.2023, 0.1994, 0.2010])[:, np.newaxis, np.newaxis])
149 | std = std.expand(3, 32, 32).cuda()
150 | if args.dataset == 'CIFAR100':
151 | mean = torch.Tensor(np.array([0.5071,0.4867,0.4408])[:, np.newaxis, np.newaxis])
152 | mean = mean.expand(3, 32, 32).cuda()
153 | std = torch.Tensor(np.array([0.2675,0.2565,0.2761])[:, np.newaxis, np.newaxis])
154 | std = std.expand(3, 32, 32).cuda()
155 | losses = AverageMeter('Loss')
156 | prec1_fgsm = 0.0
157 | prec5_fgsm = 0.0
158 | n = 0
159 | model.eval()
160 | attacker = Attack(dataloader=val_loader,
161 | attack_method='fgsm', epsilon=0.031)
162 | for i, (data, target) in enumerate(val_loader):
163 | if torch.cuda.is_available() and args.gpu:
164 | data, target = data.cuda(), target.cuda()
165 | n += target.size(0)
166 | data.requires_grad = False
167 | perturbed_data = attacker.attack_method(model, data, target, args)
168 | perturbed_data.sub_(mean).div_(std)
169 | output_fgsm,_ = model(perturbed_data)
170 | loss_fgsm = F.cross_entropy(output_fgsm, target)
171 | _, pred_fgsm = output_fgsm.topk(5, 1, largest=True, sorted=True)
172 | target_fgsm = target.view(target.size(0),-1).expand_as(pred_fgsm)
173 | correct_fgsm = pred_fgsm.eq(target_fgsm).float()
174 | prec1_fgsm += correct_fgsm[:,:1].sum()
175 | prec5_fgsm += correct_fgsm[:,:5].sum()
176 | losses.update(loss_fgsm.item(), data.size(0))
177 |
178 | top1_fgsm = 100.*(prec1_fgsm/float(n))
179 | top5_fgsm = 100.*(prec5_fgsm/float(n))
180 | print('\n Top1 FGSM:{}'.format(top1_fgsm))
181 | return top1_fgsm
182 |
183 |
184 | def validate_pgd(val_loader, model, eps=0.031, K=7, a=0.01):
185 | if args.dataset == 'CIFAR10':
186 | mean = torch.Tensor(np.array([0.4914, 0.4822, 0.4465])[:, np.newaxis, np.newaxis])
187 | mean = mean.expand(3, 32, 32).cuda()
188 | std = torch.Tensor(np.array([0.2023, 0.1994, 0.2010])[:, np.newaxis, np.newaxis])
189 | std = std.expand(3, 32, 32).cuda()
190 | if args.dataset == 'CIFAR100':
191 | mean = torch.Tensor(np.array([0.5071,0.4867,0.4408])[:, np.newaxis, np.newaxis])
192 | mean = mean.expand(3, 32, 32).cuda()
193 | std = torch.Tensor(np.array([0.2675,0.2565,0.2761])[:, np.newaxis, np.newaxis])
194 | std = std.expand(3, 32, 32).cuda()
195 | losses = AverageMeter('Loss')
196 | prec1_pgd = 0.0
197 | prec5_pgd = 0.0
198 | n = 0
199 | model.eval()
200 | print('Value of K:{}, eps:{}'.format(K, eps))
201 | for i, (data, target) in enumerate(val_loader):
202 | if torch.cuda.is_available() and args.gpu:
203 | data, target = data.cuda(), target.cuda()
204 | n += target.size(0)
205 | orig_input = data.clone()
206 | randn = torch.FloatTensor(data.size()).uniform_(-eps, eps).cuda()
207 | data += randn
208 | data.clamp_(0, 1.0)
209 | for _ in range(K):
210 | invar = Variable(data, requires_grad=True)
211 | in1 = invar - mean
212 | in1.div_(std)
213 | output,_ = model(in1)
214 | ascend_loss = F.cross_entropy(output, target)
215 | ascend_grad = torch.autograd.grad(ascend_loss, invar)[0]
216 | pert = torch.sign(ascend_grad)*a
217 | data += pert.data
218 | data = torch.max(orig_input-eps, data)
219 | data = torch.min(orig_input+eps, data)
220 | data.clamp_(0, 1.0)
221 | data.sub_(mean).div_(std)
222 | with torch.no_grad():
223 | # compute output
224 | output,_ = model(data)
225 | loss_pgd = F.cross_entropy(output, target)
226 |
227 | # measure accuracy and record loss
228 | _, pred_pgd = output.topk(5, 1, largest=True, sorted=True)
229 | target_pgd = target.view(target.size(0),-1).expand_as(pred_pgd)
230 | correct_pgd = pred_pgd.eq(target_pgd).float()
231 | prec1_pgd += correct_pgd[:,:1].sum()
232 | prec5_pgd += correct_pgd[:,:5].sum()
233 | losses.update(loss_pgd.item(), data.size(0))
234 |
235 | top1_pgd = 100.*(prec1_pgd/float(n))
236 | top5_pgd = 100.*(prec5_pgd/float(n))
237 | print('\n Top1 PGD:{}'.format(top1_pgd))
238 | return top1_pgd
239 |
240 |
241 | if __name__ == '__main__':
242 | parser = argparse.ArgumentParser(description='SNN training')
243 | parser.add_argument('--gpu', default=True, type=bool, help='use gpu')
244 | parser.add_argument('--dataset', default='CIFAR10', type=str, help='dataset name', choices=['MNIST','CIFAR10','CIFAR100','IMAGENET', 'TINY_IMAGENET'])
245 | parser.add_argument('--batch_size', default=64, type=int, help='minibatch size')
246 | parser.add_argument('-a','--architecture', default='VGG16', type=str, help='network architecture', choices=['VGG4','VGG5','VGG6','VGG9','VGG11','VGG13','VGG16','VGG19','RESNET12','RESNET20','RESNET34'])
247 | parser.add_argument('--pretrained_snn', default='', type=str, help='pretrained SNN for inference')
248 | parser.add_argument('--test_only', action='store_true', help='perform only inference')
249 | parser.add_argument('--log', action='store_true', help='to print the output on terminal or to log file')
250 | parser.add_argument('--pgd_iter', default=7, type=int, help='number of pgd iterations')
251 | parser.add_argument('--pgd_step', default=0.01, type=float, help='pgd attack step size')
252 | parser.add_argument('--epochs', default=30, type=int, help='number of training epochs')
253 | parser.add_argument('--timesteps', default=20, type=int, help='simulation timesteps')
254 | parser.add_argument('--leak', default=1.0, type=float, help='membrane leak')
255 | parser.add_argument('--default_threshold', default=1.0, type=float, help='intial threshold to train SNN from scratch')
256 | parser.add_argument('--activation', default='Linear', type=str, help='SNN activation function', choices=['Linear'])
257 | parser.add_argument('--optimizer', default='SGD', type=str, help='optimizer for SNN backpropagation', choices=['SGD', 'Adam'])
258 | parser.add_argument('--kernel_size', default=3, type=int, help='filter size for the conv layers')
259 | parser.add_argument('--test_acc_every_batch', action='store_true', help='print acc of every batch during inference')
260 | parser.add_argument('--devices', default='0', type=str, help='list of gpu device(s)')
261 | parser.add_argument('--resume', default='', type=str, help='resume training from this state')
262 | parser.add_argument('--dont_save', action='store_true', help='don\'t save training model during testing')
263 |
264 | args = parser.parse_args()
265 |
266 | os.environ['CUDA_VISIBLE_DEVICES'] = args.devices
267 |
268 | #torch.backends.cudnn.deterministic = True
269 | #torch.backends.cudnn.benchmark = False
270 |
271 | dataset = args.dataset
272 | batch_size = args.batch_size
273 | architecture = args.architecture
274 | pretrained_snn = args.pretrained_snn
275 | epochs = args.epochs
276 | timesteps = args.timesteps
277 | leak = args.leak
278 | default_threshold = args.default_threshold
279 | activation = args.activation
280 | kernel_size = args.kernel_size
281 | test_acc_every_batch= args.test_acc_every_batch
282 | resume = args.resume
283 | start_epoch = 1
284 | max_accuracy = 0.0
285 |
286 | log_file = './logs/'
287 | try:
288 | os.mkdir(log_file)
289 | except OSError:
290 | pass
291 | identifier = 'snn_'+architecture.lower()+'_'+dataset.lower()+'_'+str(timesteps)+'_'+str(datetime.datetime.now())
292 | log_file+=identifier+'.log'
293 |
294 | if args.log:
295 | f = open(log_file, 'w', buffering=1)
296 | else:
297 | f = sys.stdout
298 |
299 | if torch.cuda.is_available() and args.gpu:
300 | torch.set_default_tensor_type('torch.cuda.FloatTensor')
301 |
302 | if dataset == 'CIFAR10':
303 | normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
304 | elif dataset == 'CIFAR100':
305 | normalize = transforms.Normalize((0.5071,0.4867,0.4408), (0.2675,0.2565,0.2761))
306 |
307 | if dataset in ['CIFAR10', 'CIFAR100']:
308 | transform_train = transforms.Compose([
309 | transforms.RandomCrop(32, padding=4),
310 | transforms.RandomHorizontalFlip(),
311 | transforms.ToTensor()
312 | ])
313 | transform_test = transforms.Compose([transforms.ToTensor()])
314 |
315 | if dataset == 'CIFAR10':
316 | trainset = datasets.CIFAR10(root = '../SNN_adversary/cifar_data', train = True, download = True, transform = transform_train)
317 | testset = datasets.CIFAR10(root='../SNN_adversary/cifar_data', train=False, download=True, transform = transform_test)
318 | labels = 10
319 | train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True)
320 | test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False)
321 |
322 | elif dataset == 'CIFAR100':
323 | trainset = datasets.CIFAR100(root = './cifar_data', train = True, download = True, transform = transform_train)
324 | testset = datasets.CIFAR100(root='./cifar_data', train=False, download=True, transform = transform_test)
325 | labels = 100
326 | train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True)
327 | test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False)
328 |
329 | if architecture[0:3].lower() == 'vgg':
330 | model = VGG_SNN(vgg_name = architecture, activation = activation, labels=labels, timesteps=timesteps, leak=leak, default_threshold=default_threshold, dropout=0.2, kernel_size=kernel_size, dataset=dataset)
331 |
332 | if pretrained_snn:
333 |
334 | state = torch.load(pretrained_snn, map_location='cpu')
335 | missing_keys, unexpected_keys = model.load_state_dict(state, strict=False)
336 | else:
337 | print("Please provide an snn file name")
338 |
339 | #model = nn.DataParallel(model)
340 |
341 | if torch.cuda.is_available() and args.gpu:
342 | model.cuda()
343 |
344 | if resume:
345 | f.write('\n Resuming from checkpoint {}'.format(resume))
346 | state = torch.load(resume, map_location='cpu')
347 | missing_keys, unexpected_keys = model.load_state_dict(state, strict=False)
348 | f.write('\n Missing keys : {}, Unexpected Keys: {}'.format(missing_keys, unexpected_keys))
349 | #f.write('\n Info: Accuracy of loaded ANN model: {}'.format(state['accuracy']))
350 |
351 | #epoch = state['epoch']
352 | start_epoch = epoch + 1
353 | #max_accuracy = state['accuracy']
354 | #optimizer.load_state_dict(state['optimizer'])
355 | for param_group in optimizer.param_groups:
356 | learning_rate = param_group['lr']
357 |
358 | f.write('\n Loaded from resume epoch: {}, accuracy: {:.4f} lr: {:.1e}'.format(epoch, max_accuracy, learning_rate))
359 |
360 | test_Acc = [0]
361 | pgd_test_acc = [0]
362 | fgsmtest_acc = [0]
363 | for epoch in range(start_epoch, epochs+1):
364 | start_time = datetime.datetime.now()
365 | top1, test_Acc = test(epoch, test_loader, test_Acc, args)
366 | top1_fgsm = validate_fgsm(test_loader, model, args, eps=0.031)
367 | top1_pgd = validate_pgd(test_loader, model, eps=0.031, K=7, a=0.01)
368 | fgsmtest_acc.append(top1_fgsm)
369 | pgd_test_acc.append(top1_pgd)
370 | #print('Epoch:{}, TestAcc:{}, PGD acc:{}'.format(epoch, top1, top1_pgd))
371 | print('Epoch:{}, TestAcc:{}, FGSM acc:{}, PGD acc:{}'.format(epoch, top1*100, top1_fgsm, top1_pgd))
372 |
373 | #f.write('\n Highest accuracy: {:.4f}'.format(max_accuracy))
374 |
375 |
376 |
377 |
378 |
--------------------------------------------------------------------------------
/snn_free_nonorm_for_bbtest_purpose.py:
--------------------------------------------------------------------------------
1 | #---------------------------------------------------
2 | # Imports
3 | #---------------------------------------------------
4 | from __future__ import print_function
5 | import argparse
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 | import torch.optim as optim
10 | from torchvision import datasets, transforms, models
11 | from torch.utils.data.dataloader import DataLoader
12 | from torch.autograd import Variable
13 | from matplotlib import pyplot as plt
14 | from matplotlib.gridspec import GridSpec
15 | import numpy as np
16 | import datetime
17 | import pdb
18 | from vgg_spiking_nodewise import *
19 | import sys
20 | import os
21 | import shutil
22 | import argparse
23 | from attack_model_spike_cnt import Attack
24 |
25 |
26 | class AverageMeter(object):
27 | """Computes and stores the average and current value"""
28 | def __init__(self, name, fmt=':f'):
29 | self.name = name
30 | self.fmt = fmt
31 | self.reset()
32 |
33 | def reset(self):
34 | self.val = 0
35 | self.avg = 0
36 | self.sum = 0
37 | self.count = 0
38 |
39 | def update(self, val, n=1):
40 | self.val = val
41 | self.sum += val * n
42 | self.count += n
43 | self.avg = self.sum / self.count
44 |
45 | def __str__(self):
46 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
47 | return fmtstr.format(**self.__dict__)
48 |
49 | def accuracy(output, target, topk=(1,)):
50 | """Computes the accuracy over the k top predictions for the specified values of k"""
51 | with torch.no_grad():
52 | maxk = max(topk)
53 | batch_size = target.size(0)
54 |
55 | _, pred = output.topk(maxk, 1, True, True)
56 | pred = pred.t()
57 | correct = pred.eq(target.view(1, -1).expand_as(pred))
58 |
59 | res = []
60 | for k in topk:
61 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
62 | res.append(correct_k.mul_(100.0 / batch_size))
63 | return res
64 |
65 | def test(epoch, test_loader, best_Acc, args):
66 |
67 | losses = AverageMeter('Loss')
68 | top1 = AverageMeter('Acc@1')
69 | avg_spike_cnt = []
70 |
71 | if args.dataset == 'CIFAR10':
72 | mean = torch.Tensor(np.array([0.4914, 0.4822, 0.4465])[:, np.newaxis, np.newaxis])
73 | mean = mean.expand(3, 32, 32).cuda()
74 | std = torch.Tensor(np.array([0.2023, 0.1994, 0.2010])[:, np.newaxis, np.newaxis])
75 | std = std.expand(3, 32, 32).cuda()
76 | if args.dataset == 'CIFAR100':
77 | mean = torch.Tensor(np.array([0.5071,0.4867,0.4408])[:, np.newaxis, np.newaxis])
78 | mean = mean.expand(3, 32, 32).cuda()
79 | std = torch.Tensor(np.array([0.2675,0.2565,0.2761])[:, np.newaxis, np.newaxis])
80 | std = std.expand(3, 32, 32).cuda()
81 |
82 | if args.test_only:
83 | temp1 = []
84 | temp2 = []
85 | for key, value in sorted(model.threshold.items(), key=lambda x: (int(x[0][1:]), (x[1]))):
86 | temp1 = temp1+[round(value.item(),2)]
87 | for key, value in sorted(model.leak.items(), key=lambda x: (int(x[0][1:]), (x[1]))):
88 | temp2 = temp2+[round(value.item(),2)]
89 | f.write('\n Thresholds: {}, leak: {}'.format(temp1, temp2))
90 |
91 | with torch.no_grad():
92 | model.eval()
93 | global max_accuracy
94 |
95 | for batch_idx, (data, target) in enumerate(test_loader):
96 |
97 | if torch.cuda.is_available() and args.gpu:
98 | data, target = data.cuda(), target.cuda()
99 |
100 | data = data - mean
101 | data.div_(std)
102 | output, spike_count = model(data)
103 | loss = F.cross_entropy(output,target)
104 | pred = output.max(1,keepdim=True)[1]
105 | correct = pred.eq(target.data.view_as(pred)).cpu().sum()
106 |
107 | losses.update(loss.item(),data.size(0))
108 | top1.update(correct.item()/data.size(0), data.size(0))
109 |
110 | if test_acc_every_batch:
111 |
112 | f.write('\n Images {}/{} Accuracy: {}/{}({:.4f})'
113 | .format(
114 | test_loader.batch_size*(batch_idx+1),
115 | len(test_loader.dataset),
116 | correct.item(),
117 | data.size(0),
118 | top1.avg*100
119 | )
120 | )
121 |
122 | temp1 = []
123 | temp2 = []
124 | for key, value in sorted(model.threshold.items(), key=lambda x: (int(x[0][1:]), (x[1]))):
125 | temp1 = temp1+[value.item()]
126 | for key, value in sorted(model.leak.items(), key=lambda x: (int(x[0][1:]), (x[1]))):
127 | temp2 = temp2+[value.item()]
128 |
129 | if epoch>5 and top1.avg<0.15:
130 | f.write('\n Quitting as the training is not progressing')
131 | exit(0)
132 |
133 | f.write(' test_loss: {:.4f}, test_acc: {:.4f}, time: {}'
134 | .format(
135 | losses.avg,
136 | top1.avg*100,
137 | datetime.timedelta(seconds=(datetime.datetime.now() - start_time).seconds)
138 | )
139 | )
140 | best_Acc.append(top1.avg*100)
141 | return top1.avg, best_Acc
142 |
143 |
144 | def validate_fgsm_bb(val_loader, model_bb, model, args, eps=0.031):
145 | if args.dataset == 'CIFAR10':
146 | mean = torch.Tensor(np.array([0.4914, 0.4822, 0.4465])[:, np.newaxis, np.newaxis])
147 | mean = mean.expand(3, 32, 32).cuda()
148 | std = torch.Tensor(np.array([0.2023, 0.1994, 0.2010])[:, np.newaxis, np.newaxis])
149 | std = std.expand(3, 32, 32).cuda()
150 | if args.dataset == 'CIFAR100':
151 | mean = torch.Tensor(np.array([0.5071,0.4867,0.4408])[:, np.newaxis, np.newaxis])
152 | mean = mean.expand(3, 32, 32).cuda()
153 | std = torch.Tensor(np.array([0.2675,0.2565,0.2761])[:, np.newaxis, np.newaxis])
154 | std = std.expand(3, 32, 32).cuda()
155 | losses = AverageMeter('Loss')
156 | prec1_fgsm = 0.0
157 | prec5_fgsm = 0.0
158 | prec1_pgd = 0.0
159 | prec5_pgd = 0.0
160 | n = 0
161 | model_bb.eval()
162 | model.eval()
163 | attacker = Attack(dataloader=val_loader,
164 | attack_method='fgsm', epsilon=0.031)
165 | for i, (data, target) in enumerate(val_loader):
166 | if torch.cuda.is_available() and args.gpu:
167 | data, target = data.cuda(), target.cuda()
168 | n += target.size(0)
169 | data.requires_grad = False
170 | #Here generate the perturbed data from model_bb
171 | perturbed_data = attacker.attack_method(model_bb, data, target, args)
172 | perturbed_data.sub_(mean).div_(std)
173 | # Use the model_bb generated perturbed data to evaluate on model
174 | output_fgsm,_ = model(perturbed_data)
175 | loss_fgsm = F.cross_entropy(output_fgsm, target)
176 | _, pred_fgsm = output_fgsm.topk(5, 1, largest=True, sorted=True)
177 | target_fgsm = target.view(target.size(0),-1).expand_as(pred_fgsm)
178 | correct_fgsm = pred_fgsm.eq(target_fgsm).float()
179 | prec1_fgsm += correct_fgsm[:,:1].sum()
180 | prec5_fgsm += correct_fgsm[:,:5].sum()
181 | losses.update(loss_fgsm.item(), data.size(0))
182 |
183 | top1_fgsm = 100.*(prec1_fgsm/float(n))
184 | top5_fgsm = 100.*(prec5_fgsm/float(n))
185 | print('Top1 FGSM:{}'.format(top1_fgsm))
186 | return top1_fgsm
187 |
188 |
189 | def validate_pgd_bb(val_loader, model_bb, model, eps=0.031, K=7, a=0.01):
190 | if args.dataset == 'CIFAR10':
191 | mean = torch.Tensor(np.array([0.4914, 0.4822, 0.4465])[:, np.newaxis, np.newaxis])
192 | mean = mean.expand(3, 32, 32).cuda()
193 | std = torch.Tensor(np.array([0.2023, 0.1994, 0.2010])[:, np.newaxis, np.newaxis])
194 | std = std.expand(3, 32, 32).cuda()
195 | if args.dataset == 'CIFAR100':
196 | mean = torch.Tensor(np.array([0.5071,0.4867,0.4408])[:, np.newaxis, np.newaxis])
197 | mean = mean.expand(3, 32, 32).cuda()
198 | std = torch.Tensor(np.array([0.2675,0.2565,0.2761])[:, np.newaxis, np.newaxis])
199 | std = std.expand(3, 32, 32).cuda()
200 | losses = AverageMeter('Loss')
201 | prec1_pgd = 0.0
202 | prec5_pgd = 0.0
203 | n = 0
204 | model_bb.eval()
205 | model.eval()
206 | print('Value of K:{}, eps:{}'.format(K, eps))
207 | for i, (data, target) in enumerate(val_loader):
208 | if torch.cuda.is_available() and args.gpu:
209 | data, target = data.cuda(), target.cuda()
210 | n += target.size(0)
211 | orig_input = data.clone()
212 | randn = torch.FloatTensor(data.size()).uniform_(-eps, eps).cuda()
213 | data += randn
214 | data.clamp_(0, 1.0)
215 | for _ in range(K):
216 | invar = Variable(data, requires_grad=True)
217 | in1 = invar - mean
218 | in1.div_(std)
219 | #Here generate the output for grad computation from model_bb
220 | output,_ = model_bb(in1)
221 | ascend_loss = F.cross_entropy(output, target)
222 | ascend_grad = torch.autograd.grad(ascend_loss, invar)[0]
223 | pert = torch.sign(ascend_grad)*a
224 | data += pert.data
225 | data = torch.max(orig_input-eps, data)
226 | data = torch.min(orig_input+eps, data)
227 | data.clamp_(0, 1.0)
228 | data.sub_(mean).div_(std)
229 | with torch.no_grad():
230 | # compute output on model
231 | output,_ = model(data)
232 | loss_pgd = F.cross_entropy(output, target)
233 |
234 | # measure accuracy and record loss
235 | _, pred_pgd = output.topk(5, 1, largest=True, sorted=True)
236 | target_pgd = target.view(target.size(0),-1).expand_as(pred_pgd)
237 | correct_pgd = pred_pgd.eq(target_pgd).float()
238 | prec1_pgd += correct_pgd[:,:1].sum()
239 | prec5_pgd += correct_pgd[:,:5].sum()
240 | losses.update(loss_pgd.item(), data.size(0))
241 |
242 | top1_pgd = 100.*(prec1_pgd/float(n))
243 | top5_pgd = 100.*(prec5_pgd/float(n))
244 | print('\n Top1 PGD:{}'.format(top1_pgd))
245 | return top1_pgd
246 |
247 | if __name__ == '__main__':
248 | parser = argparse.ArgumentParser(description='SNN training')
249 | parser.add_argument('--gpu', default=True, type=bool, help='use gpu')
250 | parser.add_argument('--dataset', default='CIFAR10', type=str, help='dataset name', choices=['MNIST','CIFAR10','CIFAR100','IMAGENET', 'TINY_IMAGENET'])
251 | parser.add_argument('--batch_size', default=64, type=int, help='minibatch size')
252 | parser.add_argument('-a','--architecture', default='VGG16', type=str, help='network architecture', choices=['VGG4','VGG5','VGG6','VGG9','VGG11','VGG13','VGG16','VGG19','RESNET12','RESNET20','RESNET34'])
253 | parser.add_argument('--pretrained_snn', default='', type=str, help='pretrained SNN for inference')
254 | parser.add_argument('--pretrained_snn_bb', default='', type=str, help='pretrained SNN for BB inference')
255 | parser.add_argument('--test_only', action='store_true', help='perform only inference')
256 | parser.add_argument('--log', action='store_true', help='to print the output on terminal or to log file')
257 | parser.add_argument('--pgd_iter', default=7, type=int, help='number of pgd iterations')
258 | parser.add_argument('--pgd_step', default=0.01, type=float, help='pgd attack step size')
259 | parser.add_argument('--epochs', default=30, type=int, help='number of training epochs')
260 | parser.add_argument('--timesteps', default=20, type=int, help='simulation timesteps')
261 | parser.add_argument('--leak', default=1.0, type=float, help='membrane leak')
262 | parser.add_argument('--default_threshold', default=1.0, type=float, help='intial threshold to train SNN from scratch')
263 | parser.add_argument('--activation', default='Linear', type=str, help='SNN activation function', choices=['Linear'])
264 | parser.add_argument('--optimizer', default='SGD', type=str, help='optimizer for SNN backpropagation', choices=['SGD', 'Adam'])
265 | parser.add_argument('--kernel_size', default=3, type=int, help='filter size for the conv layers')
266 | parser.add_argument('--test_acc_every_batch', action='store_true', help='print acc of every batch during inference')
267 | parser.add_argument('--devices', default='0', type=str, help='list of gpu device(s)')
268 | parser.add_argument('--resume', default='', type=str, help='resume training from this state')
269 | parser.add_argument('--dont_save', action='store_true', help='don\'t save training model during testing')
270 |
271 | args = parser.parse_args()
272 |
273 | os.environ['CUDA_VISIBLE_DEVICES'] = args.devices
274 |
275 | #torch.backends.cudnn.deterministic = True
276 | #torch.backends.cudnn.benchmark = False
277 |
278 | dataset = args.dataset
279 | batch_size = args.batch_size
280 | architecture = args.architecture
281 | pretrained_snn = args.pretrained_snn
282 | pretrained_snn_bb = args.pretrained_snn_bb
283 | epochs = args.epochs
284 | timesteps = args.timesteps
285 | leak = args.leak
286 | default_threshold = args.default_threshold
287 | activation = args.activation
288 | kernel_size = args.kernel_size
289 | test_acc_every_batch= args.test_acc_every_batch
290 | resume = args.resume
291 | start_epoch = 1
292 | max_accuracy = 0.0
293 |
294 | log_file = './logs/'
295 | try:
296 | os.mkdir(log_file)
297 | except OSError:
298 | pass
299 | identifier = 'snn_'+architecture.lower()+'_'+dataset.lower()+'_'+str(timesteps)+'_'+str(datetime.datetime.now())
300 | log_file+=identifier+'.log'
301 |
302 | if args.log:
303 | f = open(log_file, 'w', buffering=1)
304 | else:
305 | f = sys.stdout
306 |
307 | if torch.cuda.is_available() and args.gpu:
308 | torch.set_default_tensor_type('torch.cuda.FloatTensor')
309 |
310 | if dataset == 'CIFAR10':
311 | normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
312 | elif dataset == 'CIFAR100':
313 | normalize = transforms.Normalize((0.5071,0.4867,0.4408), (0.2675,0.2565,0.2761))
314 |
315 | if dataset in ['CIFAR10', 'CIFAR100']:
316 | transform_train = transforms.Compose([
317 | transforms.RandomCrop(32, padding=4),
318 | transforms.RandomHorizontalFlip(),
319 | transforms.ToTensor()
320 | ])
321 | transform_test = transforms.Compose([transforms.ToTensor()])
322 |
323 | if dataset == 'CIFAR10':
324 | trainset = datasets.CIFAR10(root = '../SNN_adversary/cifar_data', train = True, download = True, transform = transform_train)
325 | testset = datasets.CIFAR10(root='../SNN_adversary/cifar_data', train=False, download=True, transform = transform_test)
326 | labels = 10
327 | train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True)
328 | test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False)
329 |
330 | elif dataset == 'CIFAR100':
331 | trainset = datasets.CIFAR100(root = './cifar_data', train = True, download = True, transform = transform_train)
332 | testset = datasets.CIFAR100(root='./cifar_data', train=False, download=True, transform = transform_test)
333 | labels = 100
334 | train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True)
335 | test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False)
336 |
337 | if architecture[0:3].lower() == 'vgg':
338 | model = VGG_SNN(vgg_name = architecture, activation = activation, labels=labels, timesteps=timesteps, leak=leak, default_threshold=default_threshold, dropout=0.2, kernel_size=kernel_size, dataset=dataset)
339 | model_bb = VGG_SNN(vgg_name = architecture, activation = activation, labels=labels, timesteps=timesteps, leak=leak, default_threshold=default_threshold, dropout=0.2, kernel_size=kernel_size, dataset=dataset)
340 |
341 | if pretrained_snn:
342 | if pretrained_snn_bb:
343 | state_bb = torch.load(pretrained_snn_bb, map_location='cpu')
344 | missing_keys, unexpected_keys = model_bb.load_state_dict(state_bb, strict=False)
345 | f.write('\n For the bb model:Missing keys : {}, Unexpected Keys: {}'.format(missing_keys, unexpected_keys))
346 | state = torch.load(pretrained_snn, map_location='cpu')
347 | missing_keys, unexpected_keys = model.load_state_dict(state, strict=False)
348 | else:
349 | print("Please provide an snn file name")
350 |
351 | #model = nn.DataParallel(model)
352 |
353 | if torch.cuda.is_available() and args.gpu:
354 | model.cuda()
355 | if pretrained_snn_bb:
356 | model_bb.cuda()
357 |
358 | if resume:
359 | f.write('\n Resuming from checkpoint {}'.format(resume))
360 | state = torch.load(resume, map_location='cpu')
361 | missing_keys, unexpected_keys = model.load_state_dict(state, strict=False)
362 | f.write('\n Missing keys : {}, Unexpected Keys: {}'.format(missing_keys, unexpected_keys))
363 | #f.write('\n Info: Accuracy of loaded ANN model: {}'.format(state['accuracy']))
364 |
365 | #epoch = state['epoch']
366 | start_epoch = epoch + 1
367 | #max_accuracy = state['accuracy']
368 | #optimizer.load_state_dict(state['optimizer'])
369 | for param_group in optimizer.param_groups:
370 | learning_rate = param_group['lr']
371 |
372 | f.write('\n Loaded from resume epoch: {}, accuracy: {:.4f} lr: {:.1e}'.format(epoch, max_accuracy, learning_rate))
373 |
374 | test_Acc = [0]
375 | pgd_test_acc = [0]
376 | fgsmtest_acc = [0]
377 | for epoch in range(start_epoch, epochs+1):
378 | start_time = datetime.datetime.now()
379 | top1, test_Acc = test(epoch, test_loader, test_Acc, args)
380 | top1_fgsm = validate_fgsm_bb(test_loader, model_bb, model, args, eps=0.031)
381 | top1_pgd = validate_pgd_bb(test_loader, model_bb, model, eps=0.031, K=7, a=0.01)
382 | fgsmtest_acc.append(top1_fgsm)
383 | pgd_test_acc.append(top1_pgd)
384 | #print('Epoch:{}, TestAcc:{}, PGD acc:{}'.format(epoch, top1, top1_pgd))
385 | print('Epoch:{}, TestAcc:{}, FGSM acc:{}, PGD acc:{}'.format(epoch, top1*100, top1_fgsm, top1_pgd))
386 |
387 | #f.write('\n Highest accuracy: {:.4f}'.format(max_accuracy))
388 |
389 |
390 |
391 |
392 |
--------------------------------------------------------------------------------