├── CIFAR100_VGG16.py ├── Conversion_error.jpg ├── LICENSE ├── README.md ├── converted_CIFAR100_vgg.py └── utils.py /CIFAR100_VGG16.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torchvision.transforms as transforms 5 | import torchvision.datasets as datasets 6 | import numpy as np 7 | import time 8 | import os 9 | import random 10 | from utils import Cutout, CIFAR10Policy, evaluate_accuracy 11 | 12 | 13 | def seed_all(seed=1000): 14 | random.seed(seed) 15 | np.random.seed(seed) 16 | torch.manual_seed(seed) 17 | torch.cuda.manual_seed(seed) 18 | torch.cuda.manual_seed_all(seed) 19 | 20 | 21 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 22 | max_act = [] 23 | 24 | 25 | def hook(module, input, output): 26 | sort, _ = torch.sort(output.detach().view(-1).cpu()) 27 | max_act.append(sort[int(sort.shape[0] * 0.99) - 1]) 28 | 29 | 30 | class CNN(nn.Module): 31 | def __init__(self): 32 | super(CNN, self).__init__() 33 | hooks = [] 34 | cnn = nn.Sequential( 35 | nn.Conv2d(3, 64, 3, 1, 1), nn.BatchNorm2d(64), nn.ReLU(inplace=True), nn.Conv2d(64, 64, 3, 1, 1), nn.BatchNorm2d(64), nn.ReLU(inplace=True), 36 | nn.MaxPool2d(2, 2), 37 | nn.Conv2d(64, 128, 3, 1, 1), nn.BatchNorm2d(128), nn.ReLU(inplace=True), nn.Conv2d(128, 128, 3, 1, 1), nn.BatchNorm2d(128), nn.ReLU(inplace=True), 38 | nn.MaxPool2d(2, 2), 39 | nn.Conv2d(128, 256, 3, 1, 1), nn.BatchNorm2d(256), nn.ReLU(inplace=True), nn.Conv2d(256, 256, 3, 1, 1), nn.BatchNorm2d(256), nn.ReLU(inplace=True), nn.Conv2d(256, 256, 3, 1, 1), nn.BatchNorm2d(256), nn.ReLU(inplace=True), 40 | nn.MaxPool2d(2, 2), 41 | nn.Conv2d(256, 512, 3, 1, 1), nn.BatchNorm2d(512), nn.ReLU(inplace=True), nn.Conv2d(512, 512, 3, 1, 1), nn.BatchNorm2d(512), nn.ReLU(inplace=True), nn.Conv2d(512, 512, 3, 1, 1), nn.BatchNorm2d(512), nn.ReLU(inplace=True), 42 | nn.MaxPool2d(2, 2), 43 | nn.Conv2d(512, 512, 3, 1, 1), nn.BatchNorm2d(512), nn.ReLU(inplace=True), nn.Conv2d(512, 512, 3, 1, 1), nn.BatchNorm2d(512), nn.ReLU(inplace=True), nn.Conv2d(512, 512, 3, 1, 1), nn.BatchNorm2d(512), nn.ReLU(inplace=True), 44 | nn.MaxPool2d(2, 2)) 45 | 46 | self.conv = cnn 47 | self.fc = nn.Linear(512, 100, bias=True) 48 | 49 | for i in range(len(self.conv)): 50 | hooks.append(self.conv[i].register_forward_hook(hook)) 51 | hooks.append(self.fc.register_forward_hook(hook)) 52 | self.hooks = hooks 53 | 54 | def forward(self, input): 55 | conv = self.conv(input) 56 | x = conv.view(conv.shape[0], -1) 57 | output = self.fc(x) 58 | return output 59 | 60 | 61 | def train(net, train_iter, test_iter, optimizer, scheduler, device, num_epochs, losstype='mse'): 62 | best = 0 63 | net = net.to(device) 64 | print("training on ", device) 65 | if losstype == 'mse': 66 | loss = torch.nn.MSELoss() 67 | else: 68 | loss = torch.nn.CrossEntropyLoss(label_smoothing=0.1) 69 | losses = [] 70 | 71 | for epoch in range(num_epochs): 72 | for param_group in optimizer.param_groups: 73 | learning_rate = param_group['lr'] 74 | 75 | losss = [] 76 | train_l_sum, train_acc_sum, n, batch_count, start = 0.0, 0.0, 0, 0, time.time() 77 | for X, y in train_iter: 78 | X = X.to(device) 79 | y = y.to(device) 80 | y_hat = net(X) 81 | label = y 82 | if losstype == 'mse': 83 | label = F.one_hot(y, 10).float() 84 | l = loss(y_hat, label) 85 | losss.append(l.cpu().item()) 86 | optimizer.zero_grad() 87 | l.backward() 88 | optimizer.step() 89 | train_l_sum += l.cpu().item() 90 | train_acc_sum += (y_hat.argmax(dim=1) == y).sum().cpu().item() 91 | n += y.shape[0] 92 | batch_count += 1 93 | scheduler.step() 94 | test_acc = evaluate_accuracy(test_iter, net) 95 | losses.append(np.mean(losss)) 96 | print('epoch %d, lr %.6f, loss %.6f, train acc %.6f, test acc %.6f, time %.1f sec' 97 | % (epoch + 1, learning_rate, train_l_sum / batch_count, train_acc_sum / n, test_acc, time.time() - start)) 98 | 99 | if test_acc > best: 100 | best = test_acc 101 | torch.save(net.state_dict(), 'saved_model/CIFAR100_VGG16_max.pth') 102 | 103 | 104 | if __name__ == '__main__': 105 | seed_all(42) 106 | batch_size = 128 107 | normalize = transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)) 108 | transform_train = transforms.Compose([transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), 109 | CIFAR10Policy(), 110 | transforms.ToTensor(), 111 | Cutout(n_holes=1, length=16), 112 | normalize]) 113 | transform_test = transforms.Compose([transforms.ToTensor(), normalize]) 114 | cifar100_train = datasets.CIFAR100(root='./data/', train=True, download=False, transform=transform_train) 115 | cifar100_test = datasets.CIFAR100(root='./data/', train=False, download=False, transform=transform_test) 116 | train_iter = torch.utils.data.DataLoader(cifar100_train, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True) 117 | test_iter = torch.utils.data.DataLoader(cifar100_test, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True) 118 | 119 | lr, num_epochs = 0.1, 300 120 | net = CNN() 121 | [net.hooks[i].remove() for i in range(len(net.hooks))] 122 | optimizer = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4) 123 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, eta_min=0, T_max=num_epochs) 124 | train(net, train_iter, test_iter, optimizer, scheduler, device, num_epochs, losstype='crossentropy') 125 | 126 | net.load_state_dict(torch.load("./saved_model/CIFAR100_VGG16_max.pth")) 127 | net = net.to(device) 128 | acc = evaluate_accuracy(test_iter, net, device) 129 | print(acc) -------------------------------------------------------------------------------- /Conversion_error.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Brain-Cog-Lab/Conversion_Burst/8f68e341b2f92b23c9cae64891def91845649938/Conversion_error.jpg -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Brain-Inspired-Cognitive-Engine 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Efficient and Accurate Conversion of Spiking Neural Network with Burst Spikes 2 | This repository contains code from our paper titled "Efficient and Accurate Conversion of Spiking Neural Network with Burst Spikes" published in IJCAI, 2022. 3 | 4 | Conversion_error 5 | 6 | 7 | 8 | ## Files 9 | 10 | - `CIFAR100_VGG16.py`: train an ANN 11 | - `converted_CIFAR100_vgg.py`: converted the trained ANN. Including getting the max activation values, fusing the `Conv` and `BN` layers, doing weight normalization. 12 | - `utils.py`: some tricks for data augmentation. 13 | 14 | 15 | 16 | ## Requirements 17 | 18 | - numpy 19 | - tqdm 20 | - copy 21 | - pytorch >= 1.10.0 22 | - torchvision 23 | 24 | 25 | 26 | ## Run 27 | 28 | Firstly, train an ANN 29 | 30 | ```bash 31 | python CIFAR100_VGG16.py 32 | ``` 33 | 34 | Then, modify the model path in `converted_CIFAR100_vgg.py` and run 35 | 36 | ```bash 37 | python converted_CIFAR100_vgg.py 38 | ``` 39 | 40 | 41 | 42 | 43 | 44 | ## Citation 45 | 46 | If you use this code in your work, please cite the following paper, please cite it using 47 | 48 | ``` 49 | @article{li2022efficient, 50 | title={Efficient and Accurate Conversion of Spiking Neural Network with Burst Spikes}, 51 | author={Yang Li and Yi Zeng}, 52 | journal={arXiv preprint arXiv:2204.13271}, 53 | year={2022}, 54 | } 55 | ``` 56 | 57 | 58 | 59 | -------------------------------------------------------------------------------- /converted_CIFAR100_vgg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torchvision.transforms as transforms 6 | import torchvision.datasets as datasets 7 | from tqdm import tqdm 8 | from copy import deepcopy 9 | import matplotlib.pyplot as plt 10 | import numpy as np 11 | import time 12 | 13 | from CIFAR100_VGG16 import evaluate_accuracy 14 | 15 | 16 | max_act = [] 17 | gamma = 2 18 | 19 | 20 | def hook(module, input, output): 21 | ''' 22 | use hook to easily get the maximum of each layers based on one training batch 23 | ''' 24 | out = output.detach() 25 | out[out>1] /= gamma 26 | sort, _ = torch.sort(out.view(-1).cpu()) 27 | max_act.append(sort[int(sort.shape[0] * 0.999) - 1]) 28 | 29 | 30 | class CNN(nn.Module): 31 | def __init__(self): 32 | super(CNN, self).__init__() 33 | hooks = [] 34 | cnn = nn.Sequential( 35 | nn.Conv2d(3, 64, 3, 1, 1), nn.BatchNorm2d(64), nn.ReLU(inplace=True), nn.Conv2d(64, 64, 3, 1, 1), nn.BatchNorm2d(64), nn.ReLU(inplace=True), 36 | nn.MaxPool2d(2, 2), 37 | nn.Conv2d(64, 128, 3, 1, 1), nn.BatchNorm2d(128), nn.ReLU(inplace=True), nn.Conv2d(128, 128, 3, 1, 1), nn.BatchNorm2d(128), nn.ReLU(inplace=True), 38 | nn.MaxPool2d(2, 2), 39 | nn.Conv2d(128, 256, 3, 1, 1), nn.BatchNorm2d(256), nn.ReLU(inplace=True), nn.Conv2d(256, 256, 3, 1, 1), nn.BatchNorm2d(256), nn.ReLU(inplace=True), nn.Conv2d(256, 256, 3, 1, 1), nn.BatchNorm2d(256), nn.ReLU(inplace=True), 40 | nn.MaxPool2d(2, 2), 41 | nn.Conv2d(256, 512, 3, 1, 1), nn.BatchNorm2d(512), nn.ReLU(inplace=True), nn.Conv2d(512, 512, 3, 1, 1), nn.BatchNorm2d(512), nn.ReLU(inplace=True), nn.Conv2d(512, 512, 3, 1, 1), nn.BatchNorm2d(512), nn.ReLU(inplace=True), 42 | nn.MaxPool2d(2, 2), 43 | nn.Conv2d(512, 512, 3, 1, 1), nn.BatchNorm2d(512), nn.ReLU(inplace=True), nn.Conv2d(512, 512, 3, 1, 1), nn.BatchNorm2d(512), nn.ReLU(inplace=True), nn.Conv2d(512, 512, 3, 1, 1), nn.BatchNorm2d(512), nn.ReLU(inplace=True), 44 | nn.MaxPool2d(2, 2)) 45 | 46 | self.conv = cnn 47 | self.fc = nn.Linear(512, 100, bias=True) 48 | 49 | for i in range(len(self.conv)): 50 | hooks.append(self.conv[i].register_forward_hook(hook)) 51 | hooks.append(self.fc.register_forward_hook(hook)) 52 | self.hooks = hooks 53 | 54 | def forward(self, input): 55 | conv = self.conv(input) 56 | x = conv.view(conv.shape[0], -1) 57 | output = self.fc(x) 58 | return output 59 | 60 | 61 | #=========== Do weight Norm, replace ReLU with SNode =============# 62 | class SNode(nn.Module): 63 | def __init__(self, smode=True, gamma=5): 64 | super(SNode, self).__init__() 65 | self.smode = smode 66 | self.mem = 0 67 | self.spike = 0 68 | self.sum = 0 69 | self.threshold = 1.0 70 | self.opration = nn.ReLU(True) 71 | self.rsum = [] 72 | self.summem = 0 73 | self.rmem = [] 74 | self.gamma = gamma 75 | 76 | def forward(self, x): 77 | if not self.smode: 78 | out = self.opration(x) 79 | else: 80 | self.mem = self.mem + x 81 | 82 | self.spike = (self.mem / self.threshold).floor().clamp(min=0, max=self.gamma) 83 | self.mem = self.mem - self.spike 84 | out = self.spike 85 | return out 86 | 87 | 88 | class SMaxPool(nn.Module): 89 | ''' 90 | use lateral_ini to make output equal to the real value 91 | ''' 92 | def __init__(self, smode=True, lateral_inhi=False): 93 | super(SMaxPool, self).__init__() 94 | self.smode = smode 95 | self.lateral_inhi = lateral_inhi 96 | self.sumspike = None 97 | self.opration = nn.MaxPool2d(kernel_size=2, stride=2) 98 | self.sum = 0 99 | self.input = 0 100 | 101 | def forward(self, x): 102 | if not self.smode: 103 | out = self.opration(x) 104 | elif not self.lateral_inhi: 105 | self.sumspike += x 106 | single = self.opration(self.sumspike * 1000) 107 | sum_plus_spike = self.opration(x + self.sumspike * 1000) 108 | out = sum_plus_spike - single 109 | else: 110 | self.sumspike += x 111 | out = self.opration(self.sumspike) 112 | self.sumspike -= F.interpolate(out, scale_factor=2, mode='nearest') 113 | return out 114 | 115 | 116 | def fuse_norm_replace(m, max_activation, last_max, smode=True, gamma=5, data_norm=True, lateral_inhi=False): 117 | ''' 118 | merge conv and bn, then do data_norm 119 | :param m: model 120 | :param max_activation: the max_activation values on one training batch 121 | :param last_max: the last max 122 | :param smode: choose to use spike 123 | :param data_norm: 124 | :param lateral_inhi: 125 | :return: snn 126 | ''' 127 | global index 128 | children = list(m.named_children()) 129 | c, cn = None, None 130 | 131 | for i, (name, child) in enumerate(children): 132 | ind = index 133 | if isinstance(child, nn.Linear): 134 | if data_norm: 135 | child.weight.data /= max_activation[index] / max_activation[index-2] 136 | child.bias.data /= max_activation[index] 137 | last_max = max_activation[index] 138 | elif isinstance(child, nn.BatchNorm2d): 139 | bc = fuse(c, child) 140 | m._modules[cn] = bc 141 | m._modules[name] = torch.nn.Identity() 142 | if data_norm: 143 | m._modules[cn].weight.data /= max_activation[index] / last_max 144 | m._modules[cn].bias.data /= max_activation[index] 145 | last_max = max_activation[index] 146 | c = None 147 | elif isinstance(child, nn.Conv2d): 148 | c = child 149 | cn = name 150 | elif isinstance(child, nn.ReLU): 151 | m._modules[name] = SNode(smode=smode, gamma=gamma) 152 | if not data_norm: 153 | m._modules[name].threshold = max_activation[index] 154 | last_max = max_activation[index] 155 | elif isinstance(child, nn.MaxPool2d): 156 | m._modules[name] = SMaxPool(smode=smode, lateral_inhi=lateral_inhi) 157 | elif isinstance(child, nn.AvgPool2d): 158 | pass 159 | else: 160 | fuse_norm_replace(child, max_activation, last_max, smode, gamma, data_norm, lateral_inhi) 161 | index -= 1 162 | index += 1 163 | 164 | 165 | def fuse(conv, bn): 166 | ''' 167 | fuse the conv and bn layer 168 | ''' 169 | w = conv.weight 170 | mean, var_sqrt, beta, gamma = bn.running_mean, torch.sqrt(bn.running_var + bn.eps), bn.weight, bn.bias 171 | b = conv.bias if conv.bias is not None else mean.new_zeros(mean.shape) 172 | 173 | w = w * (beta / var_sqrt).reshape([conv.out_channels, 1, 1, 1]) 174 | b = (b - mean)/var_sqrt * beta + gamma 175 | fused_conv = nn.Conv2d(conv.in_channels, conv.out_channels, conv.kernel_size, conv.stride, conv.padding, bias=True) 176 | fused_conv.weight = nn.Parameter(w) 177 | fused_conv.bias = nn.Parameter(b) 178 | return fused_conv 179 | 180 | 181 | def clean_mem_spike(m): 182 | ''' 183 | when change batch, you should clean the mem and spike of last batch 184 | :param m: snn 185 | :return: 186 | ''' 187 | children = list(m.named_children()) 188 | for name, child in children: 189 | if isinstance(child, SNode): 190 | child.mem = 0 191 | child.spike = 0 192 | elif isinstance(child, SMaxPool): 193 | child.sumspike = 0 194 | else: 195 | clean_mem_spike(child) 196 | 197 | 198 | def evaluate_snn(test_iter, snn, net, device=None, duration=50, plot=False, linetype=None): 199 | linetype = '-' if linetype==None else linetype 200 | accs = [] 201 | acc_sum, n = 0.0, 0 202 | snn.eval() 203 | 204 | for test_x, test_y in tqdm(test_iter): 205 | test_x = test_x.to(device) 206 | test_y = test_y.to(device) 207 | n = test_y.shape[0] 208 | out = 0 209 | with torch.no_grad(): 210 | clean_mem_spike(snn) 211 | acc = [] 212 | for t in range(duration): 213 | start = time.time() 214 | out += snn(test_x) 215 | result = torch.max(out, 1).indices 216 | result = result.to(device) 217 | acc_sum = (result == test_y).float().sum().item() 218 | acc.append(acc_sum / n) 219 | accs.append(np.array(acc)) 220 | 221 | accs = np.array(accs).mean(axis=0) 222 | 223 | print(max(accs)) 224 | if plot: 225 | plt.plot(list(range(len(accs))), accs, linetype) 226 | plt.ylabel('Accuracy') 227 | plt.xlabel('Time Step') 228 | # plt.show() 229 | plt.savefig('./result.jpg') 230 | 231 | 232 | if __name__ == '__main__': 233 | global index 234 | device = torch.device("cuda:1") if torch.cuda.is_available() else 'cpu' 235 | 236 | batch_size = 128 237 | normalize = torchvision.transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)) 238 | transform_train = transforms.Compose( 239 | [transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize]) 240 | transform_test = transforms.Compose([transforms.ToTensor(), normalize]) 241 | cifar100_train = datasets.CIFAR100(root='./data/', train=True, download=False, transform=transform_train) 242 | cifar100_test = datasets.CIFAR100(root='./data/', train=False, download=False, transform=transform_test) 243 | train_iter = torch.utils.data.DataLoader(cifar100_train, batch_size=200, shuffle=True, num_workers=0) # 1024 244 | test_iter = torch.utils.data.DataLoader(cifar100_test, batch_size=batch_size, shuffle=False, num_workers=0) 245 | 246 | # the result of ANN 247 | net = CNN() 248 | net1 = deepcopy(net) 249 | [net1.hooks[i].remove() for i in range(len(net1.hooks))] 250 | net1.load_state_dict(torch.load("./saved_model/CIFAR100_VGG16_max.pth", map_location=torch.device(device))) 251 | net1 = net1.to(device) 252 | acc = evaluate_accuracy(test_iter, net1, device) 253 | print("acc on ann is : {:.4f}".format(acc)) 254 | 255 | # get max activation on one training batch 256 | net2 = deepcopy(net) 257 | net2.load_state_dict(torch.load("./saved_model/CIFAR100_VGG16_max.pth", map_location=torch.device(device))) 258 | net2 = net2.to(device) 259 | _ = evaluate_accuracy(train_iter, net2, device, only_onebatch=True) 260 | # print(len(max_act)) 261 | [net2.hooks[i].remove() for i in range(len(net2.hooks))] 262 | 263 | # data_norm 264 | net3 = deepcopy(net2) 265 | index = 0 266 | fuse_norm_replace(net3, max_act, last_max=1.0, smode=False, data_norm=True) 267 | 268 | index = 0 269 | fuse_norm_replace(net2, max_act, last_max=1.0, smode=True, gamma=gamma, data_norm=True, lateral_inhi=False) 270 | evaluate_snn(test_iter, net2, net3, device=device, duration=256, plot=True, linetype=None) 271 | 272 | 273 | 274 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from PIL import Image, ImageEnhance, ImageOps 2 | import numpy as np 3 | import random 4 | import os 5 | from tqdm import tqdm 6 | 7 | import torch 8 | import torchvision.transforms as transforms 9 | import torchvision.datasets as datasets 10 | 11 | 12 | class Cutout(object): 13 | """Randomly mask out one or more patches from an image. 14 | reference https://github.com/yhhhli/SNN_Calibration 15 | Args: 16 | n_holes (int): Number of patches to cut out of each image. 17 | length (int): The length (in pixels) of each square patch. 18 | """ 19 | 20 | def __init__(self, n_holes, length): 21 | self.n_holes = n_holes 22 | self.length = length 23 | 24 | def __call__(self, img): 25 | """ 26 | Args: 27 | img (Tensor): Tensor image of size (C, H, W). 28 | Returns: 29 | Tensor: Image with n_holes of dimension length x length cut out of it. 30 | """ 31 | h = img.size(1) 32 | w = img.size(2) 33 | 34 | mask = np.ones((h, w), np.float32) 35 | 36 | for n in range(self.n_holes): 37 | y = np.random.randint(h) 38 | x = np.random.randint(w) 39 | 40 | y1 = np.clip(y - self.length // 2, 0, h) 41 | y2 = np.clip(y + self.length // 2, 0, h) 42 | x1 = np.clip(x - self.length // 2, 0, w) 43 | x2 = np.clip(x + self.length // 2, 0, w) 44 | 45 | mask[y1: y2, x1: x2] = 0. 46 | 47 | mask = torch.from_numpy(mask) 48 | mask = mask.expand_as(img) 49 | img = img * mask 50 | 51 | return img 52 | 53 | 54 | class CIFAR10Policy(object): 55 | """ Randomly choose one of the best 25 Sub-policies on CIFAR10. 56 | reference https://github.com/yhhhli/SNN_Calibration 57 | Example: 58 | >>> policy = CIFAR10Policy() 59 | >>> transformed = policy(image) 60 | 61 | Example as a PyTorch Transform: 62 | >>> transform=transforms.Compose([ 63 | >>> transforms.Resize(256), 64 | >>> CIFAR10Policy(), 65 | >>> transforms.ToTensor()]) 66 | """ 67 | 68 | def __init__(self, fillcolor=(128, 128, 128)): 69 | self.policies = [ 70 | SubPolicy(0.1, "invert", 7, 0.2, "contrast", 6, fillcolor), 71 | SubPolicy(0.7, "rotate", 2, 0.3, "translateX", 9, fillcolor), 72 | SubPolicy(0.8, "sharpness", 1, 0.9, "sharpness", 3, fillcolor), 73 | SubPolicy(0.5, "shearY", 8, 0.7, "translateY", 9, fillcolor), 74 | SubPolicy(0.5, "autocontrast", 8, 0.9, "equalize", 2, fillcolor), 75 | 76 | SubPolicy(0.2, "shearY", 7, 0.3, "posterize", 7, fillcolor), 77 | SubPolicy(0.4, "color", 3, 0.6, "brightness", 7, fillcolor), 78 | SubPolicy(0.3, "sharpness", 9, 0.7, "brightness", 9, fillcolor), 79 | SubPolicy(0.6, "equalize", 5, 0.5, "equalize", 1, fillcolor), 80 | SubPolicy(0.6, "contrast", 7, 0.6, "sharpness", 5, fillcolor), 81 | 82 | SubPolicy(0.7, "color", 7, 0.5, "translateX", 8, fillcolor), 83 | SubPolicy(0.3, "equalize", 7, 0.4, "autocontrast", 8, fillcolor), 84 | SubPolicy(0.4, "translateY", 3, 0.2, "sharpness", 6, fillcolor), 85 | SubPolicy(0.9, "brightness", 6, 0.2, "color", 8, fillcolor), 86 | SubPolicy(0.5, "solarize", 2, 0.0, "invert", 3, fillcolor), 87 | 88 | SubPolicy(0.2, "equalize", 0, 0.6, "autocontrast", 0, fillcolor), 89 | SubPolicy(0.2, "equalize", 8, 0.8, "equalize", 4, fillcolor), 90 | SubPolicy(0.9, "color", 9, 0.6, "equalize", 6, fillcolor), 91 | SubPolicy(0.8, "autocontrast", 4, 0.2, "solarize", 8, fillcolor), 92 | SubPolicy(0.1, "brightness", 3, 0.7, "color", 0, fillcolor), 93 | 94 | SubPolicy(0.4, "solarize", 5, 0.9, "autocontrast", 3, fillcolor), 95 | SubPolicy(0.9, "translateY", 9, 0.7, "translateY", 9, fillcolor), 96 | SubPolicy(0.9, "autocontrast", 2, 0.8, "solarize", 3, fillcolor), 97 | SubPolicy(0.8, "equalize", 8, 0.1, "invert", 3, fillcolor), 98 | SubPolicy(0.7, "translateY", 9, 0.9, "autocontrast", 1, fillcolor) 99 | ] 100 | 101 | def __call__(self, img): 102 | policy_idx = random.randint(0, len(self.policies) - 1) 103 | return self.policies[policy_idx](img) 104 | 105 | def __repr__(self): 106 | return "AutoAugment CIFAR10 Policy" 107 | 108 | 109 | class SVHNPolicy(object): 110 | """ Randomly choose one of the best 25 Sub-policies on SVHN. 111 | reference https://github.com/yhhhli/SNN_Calibration 112 | Example: 113 | >>> policy = SVHNPolicy() 114 | >>> transformed = policy(image) 115 | 116 | Example as a PyTorch Transform: 117 | >>> transform=transforms.Compose([ 118 | >>> transforms.Resize(256), 119 | >>> SVHNPolicy(), 120 | >>> transforms.ToTensor()]) 121 | """ 122 | 123 | def __init__(self, fillcolor=(128, 128, 128)): 124 | self.policies = [ 125 | SubPolicy(0.9, "shearX", 4, 0.2, "invert", 3, fillcolor), 126 | SubPolicy(0.9, "shearY", 8, 0.7, "invert", 5, fillcolor), 127 | SubPolicy(0.6, "equalize", 5, 0.6, "solarize", 6, fillcolor), 128 | SubPolicy(0.9, "invert", 3, 0.6, "equalize", 3, fillcolor), 129 | SubPolicy(0.6, "equalize", 1, 0.9, "rotate", 3, fillcolor), 130 | 131 | SubPolicy(0.9, "shearX", 4, 0.8, "autocontrast", 3, fillcolor), 132 | SubPolicy(0.9, "shearY", 8, 0.4, "invert", 5, fillcolor), 133 | SubPolicy(0.9, "shearY", 5, 0.2, "solarize", 6, fillcolor), 134 | SubPolicy(0.9, "invert", 6, 0.8, "autocontrast", 1, fillcolor), 135 | SubPolicy(0.6, "equalize", 3, 0.9, "rotate", 3, fillcolor), 136 | 137 | SubPolicy(0.9, "shearX", 4, 0.3, "solarize", 3, fillcolor), 138 | SubPolicy(0.8, "shearY", 8, 0.7, "invert", 4, fillcolor), 139 | SubPolicy(0.9, "equalize", 5, 0.6, "translateY", 6, fillcolor), 140 | SubPolicy(0.9, "invert", 4, 0.6, "equalize", 7, fillcolor), 141 | SubPolicy(0.3, "contrast", 3, 0.8, "rotate", 4, fillcolor), 142 | 143 | SubPolicy(0.8, "invert", 5, 0.0, "translateY", 2, fillcolor), 144 | SubPolicy(0.7, "shearY", 6, 0.4, "solarize", 8, fillcolor), 145 | SubPolicy(0.6, "invert", 4, 0.8, "rotate", 4, fillcolor), 146 | SubPolicy(0.3, "shearY", 7, 0.9, "translateX", 3, fillcolor), 147 | SubPolicy(0.1, "shearX", 6, 0.6, "invert", 5, fillcolor), 148 | 149 | SubPolicy(0.7, "solarize", 2, 0.6, "translateY", 7, fillcolor), 150 | SubPolicy(0.8, "shearY", 4, 0.8, "invert", 8, fillcolor), 151 | SubPolicy(0.7, "shearX", 9, 0.8, "translateY", 3, fillcolor), 152 | SubPolicy(0.8, "shearY", 5, 0.7, "autocontrast", 3, fillcolor), 153 | SubPolicy(0.7, "shearX", 2, 0.1, "invert", 5, fillcolor) 154 | ] 155 | 156 | def __call__(self, img): 157 | policy_idx = random.randint(0, len(self.policies) - 1) 158 | return self.policies[policy_idx](img) 159 | 160 | def __repr__(self): 161 | return "AutoAugment SVHN Policy" 162 | 163 | 164 | class SubPolicy(object): 165 | def __init__(self, p1, operation1, magnitude_idx1, p2, operation2, magnitude_idx2, fillcolor=(128, 128, 128)): 166 | ranges = { 167 | "shearX": np.linspace(0, 0.3, 10), 168 | "shearY": np.linspace(0, 0.3, 10), 169 | "translateX": np.linspace(0, 150 / 331, 10), 170 | "translateY": np.linspace(0, 150 / 331, 10), 171 | "rotate": np.linspace(0, 30, 10), 172 | "color": np.linspace(0.0, 0.9, 10), 173 | "posterize": np.round(np.linspace(8, 4, 10), 0).astype(np.int), 174 | "solarize": np.linspace(256, 0, 10), 175 | "contrast": np.linspace(0.0, 0.9, 10), 176 | "sharpness": np.linspace(0.0, 0.9, 10), 177 | "brightness": np.linspace(0.0, 0.9, 10), 178 | "autocontrast": [0] * 10, 179 | "equalize": [0] * 10, 180 | "invert": [0] * 10 181 | } 182 | 183 | def rotate_with_fill(img, magnitude): 184 | rot = img.convert("RGBA").rotate(magnitude) 185 | return Image.composite(rot, Image.new("RGBA", rot.size, (128,) * 4), rot).convert(img.mode) 186 | 187 | func = { 188 | "shearX": lambda img, magnitude: img.transform( 189 | img.size, Image.AFFINE, (1, magnitude * 190 | random.choice([-1, 1]), 0, 0, 1, 0), 191 | Image.BICUBIC, fillcolor=fillcolor), 192 | "shearY": lambda img, magnitude: img.transform( 193 | img.size, Image.AFFINE, (1, 0, 0, magnitude * 194 | random.choice([-1, 1]), 1, 0), 195 | Image.BICUBIC, fillcolor=fillcolor), 196 | "translateX": lambda img, magnitude: img.transform( 197 | img.size, Image.AFFINE, (1, 0, magnitude * 198 | img.size[0] * random.choice([-1, 1]), 0, 1, 0), 199 | fillcolor=fillcolor), 200 | "translateY": lambda img, magnitude: img.transform( 201 | img.size, Image.AFFINE, (1, 0, 0, 0, 1, magnitude * 202 | img.size[1] * random.choice([-1, 1])), 203 | fillcolor=fillcolor), 204 | "rotate": lambda img, magnitude: rotate_with_fill(img, magnitude), 205 | # "rotate": lambda img, magnitude: img.rotate(magnitude * random.choice([-1, 1])), 206 | "color": lambda img, magnitude: ImageEnhance.Color(img).enhance(1 + magnitude * random.choice([-1, 1])), 207 | "posterize": lambda img, magnitude: ImageOps.posterize(img, magnitude), 208 | "solarize": lambda img, magnitude: ImageOps.solarize(img, magnitude), 209 | "contrast": lambda img, magnitude: ImageEnhance.Contrast(img).enhance( 210 | 1 + magnitude * random.choice([-1, 1])), 211 | "sharpness": lambda img, magnitude: ImageEnhance.Sharpness(img).enhance( 212 | 1 + magnitude * random.choice([-1, 1])), 213 | "brightness": lambda img, magnitude: ImageEnhance.Brightness(img).enhance( 214 | 1 + magnitude * random.choice([-1, 1])), 215 | "autocontrast": lambda img, magnitude: ImageOps.autocontrast(img), 216 | "equalize": lambda img, magnitude: ImageOps.equalize(img), 217 | "invert": lambda img, magnitude: ImageOps.invert(img) 218 | } 219 | 220 | # self.name = "{}_{:.2f}_and_{}_{:.2f}".format( 221 | # operation1, ranges[operation1][magnitude_idx1], 222 | # operation2, ranges[operation2][magnitude_idx2]) 223 | self.p1 = p1 224 | self.operation1 = func[operation1] 225 | self.magnitude1 = ranges[operation1][magnitude_idx1] 226 | self.p2 = p2 227 | self.operation2 = func[operation2] 228 | self.magnitude2 = ranges[operation2][magnitude_idx2] 229 | 230 | def __call__(self, img): 231 | if random.random() < self.p1: 232 | img = self.operation1(img, self.magnitude1) 233 | if random.random() < self.p2: 234 | img = self.operation2(img, self.magnitude2) 235 | return img 236 | 237 | 238 | def evaluate_accuracy(data_iter, net, device=None, only_onebatch=False, ind=None): 239 | if device is None and isinstance(net, torch.nn.Module): 240 | device = list(net.parameters())[0].device 241 | acc_sum, n = 0.0, 0 242 | with torch.no_grad(): 243 | for i, (X, y) in enumerate(tqdm(data_iter)): 244 | net.eval() 245 | acc_sum += (net(X.to(device)).argmax(dim=1) == y.to(device)).float().sum().cpu().item() 246 | net.train() 247 | n += y.shape[0] 248 | 249 | if only_onebatch: break 250 | if i == ind: break 251 | return acc_sum / n 252 | 253 | 254 | def load_imagenet(root='/data/raid/floyed/ILSVRC2012', batch_size=128): 255 | ''' 256 | load imagenet 2012 257 | we use images in train/ for training, and use images in val/ for testing 258 | https://github.com/pytorch/examples/tree/master/imagenet 259 | ''' 260 | IMAGENET_PATH = root 261 | traindir = os.path.join(IMAGENET_PATH, 'train') 262 | valdir = os.path.join(IMAGENET_PATH, 'val') 263 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 264 | std=[0.229, 0.224, 0.225]) 265 | 266 | train_dataset = datasets.ImageFolder( 267 | traindir, 268 | transforms.Compose([ 269 | transforms.RandomResizedCrop(224), 270 | transforms.RandomHorizontalFlip(), 271 | transforms.ToTensor(), 272 | normalize])) 273 | 274 | val_dataset = datasets.ImageFolder( 275 | valdir, 276 | transforms.Compose([ 277 | transforms.Resize(256), 278 | transforms.CenterCrop(224), 279 | transforms.ToTensor(), 280 | normalize])) 281 | 282 | train_loader = torch.utils.data.DataLoader( 283 | train_dataset, 284 | batch_size=batch_size, shuffle=False, 285 | num_workers=4, pin_memory=True) 286 | 287 | val_loader = torch.utils.data.DataLoader( 288 | val_dataset, 289 | batch_size=batch_size, shuffle=False, 290 | num_workers=4, pin_memory=True) 291 | 292 | return train_loader, val_loader, train_dataset, val_dataset --------------------------------------------------------------------------------