├── save └── .gitkeep ├── data ├── cifar │ └── .gitkeep ├── mnist │ └── .gitkeep ├── __init__.py └── README.md ├── requirements.txt ├── models ├── __init__.py ├── test.py ├── Update.py ├── vgg.py ├── resnet.py ├── vgg_spiking_bntt.py ├── Fed.py └── vgg_spiking_bntt_activity.py ├── utils ├── __init__.py ├── options.py └── sampling.py ├── test_cifar10.sh ├── test_cifar100.sh ├── .gitignore ├── README.md ├── LICENSE └── main_fed.py /save/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/cifar/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/mnist/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.7.1 2 | torchvision==0.8.2 3 | pysnn -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @python: 3.6 4 | 5 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @python: 3.6 4 | 5 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @python: 3.6 4 | 5 | -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | # Data 2 | 3 | MNIST & CIFAR-10 datasets will be downloaded automatically by the torchvision package. 4 | -------------------------------------------------------------------------------- /test_cifar10.sh: -------------------------------------------------------------------------------- 1 | python main_fed.py --snn --dataset CIFAR10 --num_classes 10 --model VGG9 --optimizer SGD --bs 32 --local_bs 32 --lr 0.1 --lr_reduce 5 --epochs 100 --local_ep 2 --eval_every 1 --num_users 10 --frac 0.2 --iid --gpu 0 --timesteps 20 --straggler_prob 0.0 --grad_noise_stdev 0.0 --result_dir test 2 | -------------------------------------------------------------------------------- /test_cifar100.sh: -------------------------------------------------------------------------------- 1 | python main_fed.py --snn --dataset CIFAR100 --num_classes 100 --model VGG9 --optimizer SGD --bs 32 --local_bs 32 --lr 0.1 --lr_reduce 5 --epochs 100 --local_ep 2 --eval_every 1 --num_users 10 --frac 0.2 --iid --gpu 0 --timesteps 20 --straggler_prob 0.0 --grad_noise_stdev 0.0 --result_dir test_cifar100 --num_channels 3 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # pycharm 2 | .idea/* 3 | 4 | # documents 5 | *.csv 6 | .xls 7 | .xlsx 8 | .pdf 9 | .json 10 | 11 | # macOS 12 | .DS_Store 13 | 14 | # Byte-compiled / optimized / DLL files 15 | __pycache__/ 16 | *.py[cod] 17 | *$py.class 18 | 19 | # C extensions 20 | *.so 21 | 22 | # Distribution / packaging 23 | .Python 24 | env/ 25 | build/ 26 | develop-eggs/ 27 | dist/ 28 | downloads/ 29 | eggs/ 30 | .eggs/ 31 | lib/ 32 | lib64/ 33 | parts/ 34 | sdist/ 35 | var/ 36 | wheels/ 37 | *.egg-info/ 38 | .installed.cfg 39 | *.egg 40 | 41 | # PyInstaller 42 | # Usually these files are written by a python script from a template 43 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 44 | *.manifest 45 | *.spec 46 | 47 | # virtualenv 48 | .venv 49 | venv/ 50 | ENV/ 51 | 52 | CIFAR* 53 | cifar* 54 | *results* 55 | test/ 56 | ddd20/*dataset*/* 57 | ddd20/data 58 | nmnist/data 59 | *experiments*/ 60 | *rog_script* -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Federated Learning with Spiking Neural Networks 2 | 3 | This repo contains the source code for the paper "Federated Learning with Spiking Neural Networks" (https://arxiv.org/abs/2106.06579). 4 | 5 | ## Requirements 6 | python>=3.7 7 | pytorch>=1.7.1 8 | 9 | ## Run 10 | 11 | For example, to train a federated SNN model with 10 clients and 2 clients participating in each round: 12 | > python main_fed.py --snn --dataset CIFAR10 --num_classes 10 --model VGG9 --optimizer SGD --bs 32 --local_bs 32 --lr 0.1 --lr_reduce 5 --epochs 100 --local_ep 2 --eval_every 1 --num_users 10 --frac 0.2 --iid --gpu 0 --timesteps 20 --result_dir test 13 | 14 | Other options can be found by running 15 | > pythin main_fed.py --help 16 | 17 | Sample scripts are provided at `test_cifar10.sh` and `test_cifar100.sh`. 18 | 19 | ## Ackonwledgements 20 | Initial Code adopted from https://github.com/shaoxiongji/federated-learning 21 | 22 | Code for SNN training adopted from https://github.com/Intelligent-Computing-Lab-Yale/BNTT-Batch-Normalization-Through-Time 23 | 24 | 25 | 26 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Yeshwanth Venkatesha 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /models/test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @python: 3.6 4 | 5 | import torch 6 | from torch import nn 7 | import torch.nn.functional as F 8 | from torch.utils.data import DataLoader 9 | import sys 10 | import os 11 | 12 | 13 | def test_img(net_g, datatest, args): 14 | net_g.eval() 15 | # testing 16 | test_loss = 0 17 | correct = 0 18 | data_loader = DataLoader(datatest, batch_size=args.bs) 19 | l = len(data_loader) 20 | for idx, (data, target) in enumerate(data_loader): 21 | if args.gpu != -1: 22 | data, target = data.cuda(), target.cuda() 23 | log_probs = net_g(data) 24 | # sum up batch loss 25 | test_loss += F.cross_entropy(log_probs, target, reduction='sum').item() 26 | # get the index of the max log-probability 27 | y_pred = log_probs.data.max(1, keepdim=True)[1] 28 | correct += y_pred.eq(target.data.view_as(y_pred)).long().cpu().sum() 29 | 30 | test_loss /= len(data_loader.dataset) 31 | accuracy = 100.00 * correct / len(data_loader.dataset) 32 | if args.verbose: 33 | print('\nTest set: Average loss: {:.4f} \nAccuracy: {}/{} ({:.2f}%)\n'.format( 34 | test_loss, correct, len(data_loader.dataset), accuracy)) 35 | return accuracy.item(), test_loss 36 | 37 | def comp_activity(net_g, dataset, args): 38 | net_g.eval() 39 | # testing 40 | data_loader = DataLoader(dataset, batch_size=args.bs) 41 | l = len(data_loader) 42 | for idx, (data, target) in enumerate(data_loader): 43 | if args.gpu != -1: 44 | data, target = data.cuda(), target.cuda() 45 | activity = torch.zeros(net_g(data, count_active_layers = True)) 46 | break 47 | batch_count = 0 48 | for idx, (data, target) in enumerate(data_loader): 49 | if args.gpu != -1: 50 | data, target = data.cuda(), target.cuda() 51 | activity += torch.tensor(net_g(data, report_activity = True)) 52 | # sum up batch loss 53 | batch_count += 1 54 | activity = activity/batch_count 55 | 56 | return activity -------------------------------------------------------------------------------- /models/Update.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Python version: 3.6 4 | 5 | import torch 6 | from torch import nn, autograd 7 | from torch.utils.data import DataLoader, Dataset 8 | import numpy as np 9 | import random 10 | from sklearn import metrics 11 | import sys 12 | import os 13 | 14 | 15 | class DatasetSplit(Dataset): 16 | def __init__(self, dataset, idxs): 17 | self.dataset = dataset 18 | self.idxs = list(idxs) 19 | 20 | def __len__(self): 21 | return len(self.idxs) 22 | 23 | def __getitem__(self, item): 24 | image, label = self.dataset[self.idxs[item]] 25 | return image, label 26 | 27 | class LocalUpdate(object): 28 | def __init__(self, args, dataset=None, idxs=None): 29 | self.args = args 30 | self.loss_func = nn.CrossEntropyLoss() 31 | self.ldr_train = DataLoader(DatasetSplit(dataset, idxs), batch_size=self.args.local_bs, shuffle=True, drop_last=True) 32 | 33 | def train(self, net): 34 | net.train() 35 | # train and update 36 | if self.args.optimizer == "SGD": 37 | optimizer = torch.optim.SGD(net.parameters(), lr=self.args.lr, momentum=self.args.momentum, weight_decay = self.args.weight_decay) 38 | elif self.args.optimizer == "Adam": 39 | optimizer = torch.optim.Adam(net.parameters(), lr = self.args.lr, weight_decay = self.args.weight_decay, amsgrad = True) 40 | else: 41 | print("Invalid optimizer") 42 | 43 | epoch_loss = [] 44 | for iter in range(self.args.local_ep): 45 | batch_loss = [] 46 | for batch_idx, (images, labels) in enumerate(self.ldr_train): 47 | images, labels = images.to(self.args.device), labels.to(self.args.device) 48 | net.zero_grad() 49 | log_probs = net(images) 50 | loss = self.loss_func(log_probs, labels) 51 | loss.backward() 52 | optimizer.step() 53 | if self.args.verbose and batch_idx % 10 == 0: 54 | print('Update Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 55 | iter, batch_idx * len(images), len(self.ldr_train.dataset), 56 | 100. * batch_idx / len(self.ldr_train), loss.item())) 57 | batch_loss.append(loss.item()) 58 | if self.args.verbose and (batch_idx + 1) % self.args.train_acc_batches == 0: 59 | thresholds = [] 60 | for value in net.module.threshold.values(): 61 | thresholds = thresholds + [round(value.item(), 2)] 62 | print('Epoch: {}, batch {}, threshold {}, leak {}, timesteps {}'.format(iter, batch_idx + 1, thresholds, net.module.leak.item(), net.module.timesteps)) 63 | epoch_loss.append(sum(batch_loss)/len(batch_loss)) 64 | return net.state_dict(), sum(epoch_loss) / len(epoch_loss) -------------------------------------------------------------------------------- /utils/options.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Python version: 3.6 4 | 5 | import argparse 6 | 7 | def args_parser(): 8 | parser = argparse.ArgumentParser() 9 | # federated arguments 10 | parser.add_argument('--epochs', type=int, default=10, help="rounds of training") 11 | parser.add_argument('--num_users', type=int, default=100, help="number of users: K") 12 | parser.add_argument('--frac', type=float, default=0.1, help="the fraction of clients: C") 13 | parser.add_argument('--local_ep', type=int, default=5, help="the number of local epochs: E") 14 | parser.add_argument('--local_bs', type=int, default=16, help="local batch size: B") 15 | parser.add_argument('--bs', type=int, default=16, help="test batch size") 16 | parser.add_argument('--lr', type=float, default=1e-4, help="learning rate") 17 | parser.add_argument('--lr_interval', default='0.33 0.66', type=str, help='intervals at which to reduce lr, expressed as %%age of total epochs') 18 | 19 | parser.add_argument('--lr_reduce', default=10, type=int, help='reduction factor for learning rate') 20 | parser.add_argument('--timesteps', default=25, type=int, help='simulation timesteps') 21 | parser.add_argument('--leak', default=1.0, type=float, help='membrane leak') 22 | parser.add_argument('--scaling_factor', default=0.7, type=float, help='scaling factor for thresholds at reduced timesteps') 23 | parser.add_argument('--default_threshold', default=1.0, type=float, help='intial threshold to train SNN from scratch') 24 | parser.add_argument('--activation', default='Linear', type=str, help='SNN activation function', choices=['Linear', 'STDB']) 25 | parser.add_argument('--alpha', default=0.3, type=float, help='parameter alpha for STDB') 26 | parser.add_argument('--beta', default=0.01, type=float, help='parameter beta for STDB') 27 | parser.add_argument('--snn_kernel_size', default=3, type=int, help='filter size for the conv layers') 28 | parser.add_argument('--optimizer', default='SGD', type=str, help='optimizer for SNN backpropagation', choices=['SGD', 'Adam']) 29 | parser.add_argument('--weight_decay', default=1e-4, type=float, help='weight decay parameter for the optimizer') 30 | parser.add_argument('--dropout', default=0.3, type=float, help='dropout percentage for conv layers') 31 | 32 | parser.add_argument('--momentum', type=float, default=0.9, help="SGD momentum (default: 0.5)") 33 | parser.add_argument('--split', type=str, default='user', help="train-test split type, user or sample") 34 | 35 | # model arguments 36 | parser.add_argument('--model', type=str, default='mlp', help='model name') 37 | parser.add_argument('--kernel_num', type=int, default=9, help='number of each kind of kernel') 38 | parser.add_argument('--kernel_sizes', type=str, default='3,4,5', 39 | help='comma-separated kernel size to use for convolution') 40 | parser.add_argument('--norm', type=str, default='batch_norm', help="batch_norm, layer_norm, or None") 41 | parser.add_argument('--num_filters', type=int, default=32, help="number of filters for conv nets") 42 | parser.add_argument('--max_pool', type=str, default='True', 43 | help="Whether use max pooling rather than strided convolutions") 44 | 45 | # other arguments 46 | parser.add_argument('--dataset', type=str, default='mnist', help="name of dataset") 47 | parser.add_argument('--iid', action='store_true', help='whether i.i.d or not') 48 | parser.add_argument('--num_classes', type=int, default=10, help="number of classes") 49 | parser.add_argument('--num_channels', type=int, default=3, help="number of channels of imges") 50 | parser.add_argument('--gpu', type=int, default=0, help="GPU ID, -1 for CPU") 51 | parser.add_argument('--stopping_rounds', type=int, default=10, help='rounds of early stopping') 52 | parser.add_argument('--verbose', action='store_true', help='verbose print') 53 | parser.add_argument('--seed', type=int, default=1, help='random seed (default: 1)') 54 | parser.add_argument('--eval_every', type=int, default=10, help='Frequency of model evaluation') 55 | parser.add_argument('--pretrained_model', type=str, default=None, help="Path for the pre-trained mode if any") 56 | parser.add_argument('--result_dir', type=str, default="results", help="Directory to store results") 57 | parser.add_argument('--snn', action='store_true', help="Whether to train SNN or ANN") 58 | parser.add_argument('--train_acc_batches', default=200, type=int, help='print training progress after this many batches') 59 | parser.add_argument('--straggler_prob', type=float, default=0.0, help="straggler probability") 60 | parser.add_argument('--grad_noise_stdev', type=float, default=0.0, help="Noise level for gradients") 61 | parser.add_argument('--dvs', action='store_true', help="Whether the input data is DVS") 62 | parser.add_argument('--modality', type=str, default='aps', help="aps or dvs for the type of data to work on DDD20") 63 | args = parser.parse_args() 64 | return args 65 | -------------------------------------------------------------------------------- /models/vgg.py: -------------------------------------------------------------------------------- 1 | ############################# 2 | # @author: Nitin Rathi # 3 | ############################# 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import pdb 8 | import math 9 | 10 | 11 | cfg = { 12 | 'VGG5' : [64, 'A', 128, 128, 'A'], 13 | 'VGG9': [64, 'A', 64, 128, 'A', 128, 256, 'A', 256, 'A', 256], 14 | 'VGG11': [64, 'A', 128, 256, 'A', 512, 512, 'A', 512, 'A', 512, 512], 15 | 'VGG13': [64, 64, 'A', 128, 128, 'A', 256, 256, 'A', 512, 512, 512, 'A', 512], 16 | 'VGG16': [64, 64, 'A', 128, 128, 'A', 256, 256, 256, 'A', 512, 512, 512, 'A', 512, 512, 512], 17 | 'VGG19': [64, 64, 'A', 128, 128, 'A', 256, 256, 256, 256, 'A', 512, 512, 512, 512, 'A', 512, 512, 512, 512] 18 | } 19 | 20 | 21 | class VGG(nn.Module): 22 | def __init__(self, vgg_name='VGG16', labels=10, dataset = 'CIFAR10', kernel_size=3, dropout=0.2): 23 | super(VGG, self).__init__() 24 | 25 | self.dataset = dataset 26 | self.kernel_size = kernel_size 27 | self.dropout = dropout 28 | self.features = self._make_layers(cfg[vgg_name]) 29 | if vgg_name == 'VGG5' and dataset == 'MNIST': 30 | self.classifier = nn.Sequential( 31 | nn.Linear(128*7*7, 4096, bias=False), 32 | nn.ReLU(inplace=True), 33 | nn.Dropout(0.5), 34 | nn.Linear(4096, 4096, bias=False), 35 | nn.ReLU(inplace=True), 36 | nn.Dropout(0.5), 37 | nn.Linear(4096, labels, bias=False) 38 | ) 39 | elif vgg_name!='VGG5' and dataset =='MNIST': 40 | self.classifier = nn.Sequential( 41 | nn.Linear(512*1*1, 4096, bias=False), 42 | nn.ReLU(inplace=True), 43 | nn.Dropout(0.5), 44 | nn.Linear(4096, 4096, bias=False), 45 | nn.ReLU(inplace=True), 46 | nn.Dropout(0.5), 47 | nn.Linear(4096, labels, bias=False) 48 | ) 49 | elif vgg_name == 'VGG5' and dataset == 'DDD20': 50 | self.classifier = nn.Sequential( 51 | nn.Linear(128*10*10, 4096, bias=False), 52 | nn.ReLU(inplace=True), 53 | nn.Dropout(0.5), 54 | nn.Linear(4096, 4096, bias=False), 55 | nn.ReLU(inplace=True), 56 | nn.Dropout(0.5), 57 | nn.Linear(4096, labels, bias=False) 58 | ) 59 | elif vgg_name!='VGG5' and dataset =='DDD20': 60 | self.classifier = nn.Sequential( 61 | nn.Linear(256*5*5, 512, bias=False), 62 | nn.ReLU(inplace=True), 63 | nn.Dropout(0.5), 64 | nn.Linear(512, 256, bias=False), 65 | nn.ReLU(inplace=True), 66 | nn.Dropout(0.5), 67 | nn.Linear(256, labels, bias=False) 68 | ) 69 | elif vgg_name == 'VGG5' and dataset!= 'MNIST': 70 | self.classifier = nn.Sequential( 71 | nn.Linear(128*8*8, 4096, bias=False), # 128*8*8 is more consistent with the input dimensions. 512*4*4 is misleading 72 | nn.ReLU(inplace=True), 73 | nn.Dropout(0.5), 74 | nn.Linear(4096, 4096, bias=False), 75 | nn.ReLU(inplace=True), 76 | nn.Dropout(0.5), 77 | nn.Linear(4096, labels, bias=False) 78 | ) 79 | elif vgg_name!='VGG5' and dataset!='MNIST': 80 | self.classifier = nn.Sequential( 81 | nn.Linear(256*2*2, 4096, bias=False), 82 | nn.ReLU(inplace=True), 83 | nn.Dropout(0.5), 84 | nn.Linear(4096, 4096, bias=False), 85 | nn.ReLU(inplace=True), 86 | nn.Dropout(0.5), 87 | nn.Linear(4096, labels, bias=False) 88 | ) 89 | 90 | self._initialize_weights2() 91 | 92 | 93 | def forward(self, x): 94 | out = self.features(x) 95 | out = out.view(out.size(0), -1) 96 | out = self.classifier(out) 97 | return out 98 | 99 | def _initialize_weights2(self): 100 | for m in self.modules(): 101 | 102 | if isinstance(m, nn.Conv2d): 103 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 104 | m.weight.data.normal_(0, math.sqrt(2. / n)) 105 | if m.bias is not None: 106 | m.bias.data.zero_() 107 | elif isinstance(m, nn.Linear): 108 | n = m.weight.size(1) 109 | m.weight.data.normal_(0, 0.01) 110 | if m.bias is not None: 111 | m.bias.data.zero_() 112 | 113 | def _make_layers(self, cfg): 114 | layers = [] 115 | 116 | if self.dataset == 'MNIST': 117 | in_channels = 1 118 | elif self.dataset == 'DDD20': 119 | in_channels = 1 120 | else: 121 | in_channels = 3 122 | 123 | for x in cfg: 124 | stride = 1 125 | 126 | if x == 'A': 127 | layers.pop() 128 | layers += [nn.AvgPool2d(kernel_size=2, stride=2)] 129 | else: 130 | layers += [nn.Conv2d(in_channels, x, kernel_size=self.kernel_size, padding=(self.kernel_size-1)//2, stride=stride, bias=False), 131 | nn.BatchNorm2d(x), 132 | nn.ReLU(inplace=True) 133 | ] 134 | layers += [nn.Dropout(self.dropout)] 135 | in_channels = x 136 | 137 | 138 | return nn.Sequential(*layers) 139 | 140 | def test(): 141 | for a in cfg.keys(): 142 | #if a=='VGG5': 143 | # continue 144 | net = VGG(a) 145 | x = torch.randn(2,3,32,32) 146 | y = net(x) 147 | print(y.size()) 148 | # For VGG5 change the linear layer in self. classifier from '512*2*2' to '512*4*4' 149 | # net = VGG('VGG5') 150 | # x = torch.randn(2,3,32,32) 151 | # y = net(x) 152 | # print(y.size()) 153 | if __name__ == '__main__': 154 | test() 155 | -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | def initialize_weights(module): 9 | if isinstance(module, nn.Conv2d): 10 | nn.init.kaiming_normal_(module.weight.data, mode='fan_out') 11 | elif isinstance(module, nn.BatchNorm2d): 12 | module.weight.data.fill_(1) 13 | module.bias.data.zero_() 14 | elif isinstance(module, nn.Linear): 15 | module.bias.data.zero_() 16 | 17 | 18 | class BasicBlock(nn.Module): 19 | expansion = 1 20 | 21 | def __init__(self, in_channels, out_channels, stride): 22 | super(BasicBlock, self).__init__() 23 | 24 | self.conv1 = nn.Conv2d( 25 | in_channels, 26 | out_channels, 27 | kernel_size=3, 28 | stride=stride, # downsample with first conv 29 | padding=1, 30 | bias=False) 31 | self.bn1 = nn.BatchNorm2d(out_channels) 32 | self.conv2 = nn.Conv2d( 33 | out_channels, 34 | out_channels, 35 | kernel_size=3, 36 | stride=1, 37 | padding=1, 38 | bias=False) 39 | self.bn2 = nn.BatchNorm2d(out_channels) 40 | 41 | self.shortcut = nn.Sequential() 42 | if in_channels != out_channels: 43 | self.shortcut.add_module( 44 | 'conv', 45 | nn.Conv2d( 46 | in_channels, 47 | out_channels, 48 | kernel_size=1, 49 | stride=stride, # downsample 50 | padding=0, 51 | bias=False)) 52 | self.shortcut.add_module('bn', nn.BatchNorm2d(out_channels)) # BN 53 | 54 | def forward(self, x): 55 | y = F.relu(self.bn1(self.conv1(x)), inplace=True) 56 | y = self.bn2(self.conv2(y)) 57 | y += self.shortcut(x) 58 | y = F.relu(y, inplace=True) # apply ReLU after addition 59 | return y 60 | 61 | 62 | class BottleneckBlock(nn.Module): 63 | expansion = 4 64 | 65 | def __init__(self, in_channels, out_channels, stride): 66 | super(BottleneckBlock, self).__init__() 67 | 68 | bottleneck_channels = out_channels // self.expansion 69 | 70 | self.conv1 = nn.Conv2d( 71 | in_channels, 72 | bottleneck_channels, 73 | kernel_size=1, 74 | stride=1, 75 | padding=0, 76 | bias=False) 77 | self.bn1 = nn.BatchNorm2d(bottleneck_channels) 78 | 79 | self.conv2 = nn.Conv2d( 80 | bottleneck_channels, 81 | bottleneck_channels, 82 | kernel_size=3, 83 | stride=stride, # downsample with 3x3 conv 84 | padding=1, 85 | bias=False) 86 | self.bn2 = nn.BatchNorm2d(bottleneck_channels) 87 | 88 | self.conv3 = nn.Conv2d( 89 | bottleneck_channels, 90 | out_channels, 91 | kernel_size=1, 92 | stride=1, 93 | padding=0, 94 | bias=False) 95 | self.bn3 = nn.BatchNorm2d(out_channels) 96 | 97 | self.shortcut = nn.Sequential() # identity 98 | if in_channels != out_channels: 99 | self.shortcut.add_module( 100 | 'conv', 101 | nn.Conv2d( 102 | in_channels, 103 | out_channels, 104 | kernel_size=1, 105 | stride=stride, # downsample 106 | padding=0, 107 | bias=False)) 108 | self.shortcut.add_module('bn', nn.BatchNorm2d(out_channels)) # BN 109 | 110 | def forward(self, x): 111 | y = F.relu(self.bn1(self.conv1(x)), inplace=True) 112 | y = F.relu(self.bn2(self.conv2(y)), inplace=True) 113 | y = self.bn3(self.conv3(y)) # not apply ReLU 114 | y += self.shortcut(x) 115 | y = F.relu(y, inplace=True) # apply ReLU after addition 116 | return y 117 | 118 | 119 | class Network(nn.Module): 120 | def __init__(self, num_cls = 1): 121 | super(Network, self).__init__() 122 | 123 | # input_shape = config['input_shape'] 124 | input_shape = (1, 1, 40, 40) 125 | n_classes = num_cls 126 | 127 | # base_channels = config['base_channels'] 128 | base_channels = 16 129 | # block_type = config['block_type'] 130 | block_type = 'basic' 131 | # depth = config['depth'] 132 | depth = 20 133 | 134 | assert block_type in ['basic', 'bottleneck'] 135 | if block_type == 'basic': 136 | block = BasicBlock 137 | n_blocks_per_stage = (depth - 2) // 6 138 | assert n_blocks_per_stage * 6 + 2 == depth 139 | else: 140 | block = BottleneckBlock 141 | n_blocks_per_stage = (depth - 2) // 9 142 | assert n_blocks_per_stage * 9 + 2 == depth 143 | 144 | n_channels = [ 145 | base_channels, base_channels * 2 * block.expansion, 146 | base_channels * 4 * block.expansion 147 | ] 148 | 149 | self.conv = nn.Conv2d( 150 | input_shape[1], 151 | n_channels[0], 152 | kernel_size=3, 153 | stride=1, 154 | padding=1, 155 | bias=False) 156 | self.bn = nn.BatchNorm2d(base_channels) 157 | 158 | self.stage1 = self._make_stage( 159 | n_channels[0], n_channels[0], n_blocks_per_stage, block, stride=1) 160 | self.stage2 = self._make_stage( 161 | n_channels[0], n_channels[1], n_blocks_per_stage, block, stride=2) 162 | self.stage3 = self._make_stage( 163 | n_channels[1], n_channels[2], n_blocks_per_stage, block, stride=2) 164 | 165 | # compute conv feature size 166 | with torch.no_grad(): 167 | self.feature_size = self._forward_conv( 168 | torch.zeros(*input_shape)).view(-1).shape[0] 169 | 170 | self.fc = nn.Linear(self.feature_size, n_classes) 171 | 172 | # initialize weights 173 | self.apply(initialize_weights) 174 | 175 | def _make_stage(self, in_channels, out_channels, n_blocks, block, stride): 176 | stage = nn.Sequential() 177 | for index in range(n_blocks): 178 | block_name = 'block{}'.format(index + 1) 179 | if index == 0: 180 | stage.add_module( 181 | block_name, block( 182 | in_channels, out_channels, stride=stride)) 183 | else: 184 | stage.add_module(block_name, 185 | block(out_channels, out_channels, stride=1)) 186 | return stage 187 | 188 | def _forward_conv(self, x): 189 | x = F.relu(self.bn(self.conv(x)), inplace=True) 190 | x = self.stage1(x) 191 | x = self.stage2(x) 192 | x = self.stage3(x) 193 | x = F.adaptive_avg_pool2d(x, output_size=1) 194 | return x 195 | 196 | def forward(self, x): 197 | x = self._forward_conv(x) 198 | x = x.view(x.size(0), -1) 199 | x = self.fc(x) 200 | return x 201 | -------------------------------------------------------------------------------- /utils/sampling.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Python version: 3.6 4 | 5 | 6 | import numpy as np 7 | from torchvision import datasets, transforms 8 | 9 | def mnist_iid(dataset, num_users): 10 | """ 11 | Sample I.I.D. client data from MNIST dataset 12 | :param dataset: 13 | :param num_users: 14 | :return: dict of image index 15 | """ 16 | num_items = int(len(dataset)/num_users) 17 | dict_users, all_idxs = {}, [i for i in range(len(dataset))] 18 | for i in range(num_users): 19 | dict_users[i] = set(np.random.choice(all_idxs, num_items, replace=False)) 20 | all_idxs = list(set(all_idxs) - dict_users[i]) 21 | return dict_users 22 | 23 | 24 | def mnist_noniid(dataset, num_users): 25 | """ 26 | Sample non-I.I.D client data from MNIST dataset 27 | :param dataset: 28 | :param num_users: 29 | :return: 30 | """ 31 | total_imgs = dataset.train_labels.shape[0] 32 | num_shards = 200 33 | num_imgs = int(total_imgs / num_shards) 34 | idx_shard = [i for i in range(num_shards)] 35 | dict_users = {i: np.array([], dtype='int64') for i in range(num_users)} 36 | idxs = np.arange(num_shards*num_imgs) 37 | labels = dataset.train_labels 38 | labels = labels[0:num_shards*num_imgs] 39 | 40 | # sort labels 41 | idxs_labels = np.vstack((idxs, labels)) 42 | idxs_labels = idxs_labels[:,idxs_labels[1,:].argsort()] 43 | idxs = idxs_labels[0,:] 44 | 45 | # divide and assign 46 | for i in range(num_users): 47 | rand_set = set(np.random.choice(idx_shard, 2, replace=False)) 48 | idx_shard = list(set(idx_shard) - rand_set) 49 | for rand in rand_set: 50 | dict_users[i] = np.concatenate((dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0) 51 | return dict_users 52 | 53 | 54 | def cifar_iid(dataset, num_users): 55 | """ 56 | Sample I.I.D. client data from CIFAR10 dataset 57 | :param dataset: 58 | :param num_users: 59 | :return: dict of image index 60 | """ 61 | num_items = int(len(dataset)/num_users) 62 | dict_users, all_idxs = {}, [i for i in range(len(dataset))] 63 | for i in range(num_users): 64 | dict_users[i] = set(np.random.choice(all_idxs, num_items, replace=False)) 65 | all_idxs = list(set(all_idxs) - dict_users[i]) 66 | return dict_users 67 | 68 | def cifar_non_iid(dataset, num_classes, num_users, alpha = 0.5): 69 | N = len(dataset) 70 | min_size = 0 71 | print("Dataset size:", N) 72 | 73 | dict_users = {} 74 | while min_size < 10: 75 | idx_batch = [[] for _ in range(num_users)] 76 | for k in range(num_classes): 77 | idx_k = np.where(np.asarray(dataset.targets) == k)[0] 78 | np.random.shuffle(idx_k) 79 | proportions = np.random.dirichlet(np.repeat(alpha, num_users)) 80 | ## Balance 81 | proportions = np.array([p*(len(idx_j) 0] = 1.0 17 | return out 18 | 19 | @staticmethod 20 | def backward(ctx, grad_output): 21 | input, = ctx.saved_tensors 22 | grad_input = grad_output.clone() 23 | grad = grad_input * 0.3 * F.threshold(1.0 - torch.abs(input), 0, 0) 24 | return grad 25 | 26 | 27 | def PoissonGen(inp, rescale_fac=2.0): 28 | rand_inp = torch.rand_like(inp).cuda() 29 | return torch.mul(torch.le(rand_inp * rescale_fac, torch.abs(inp)).float(), torch.sign(inp)) 30 | 31 | 32 | 33 | 34 | 35 | 36 | class SNN_VGG9_BNTT(nn.Module): 37 | def __init__(self, timesteps=20, leak_mem=0.95, img_size=32, num_cls=10): 38 | super(SNN_VGG9_BNTT, self).__init__() 39 | 40 | self.img_size = img_size 41 | self.num_cls = num_cls 42 | self.timesteps = timesteps 43 | self.spike_fn = Surrogate_BP_Function.apply 44 | self.leak_mem = leak_mem 45 | self.batch_num = self.timesteps 46 | 47 | print (">>>>>>>>>>>>>>>>>>> VGG 9 >>>>>>>>>>>>>>>>>>>>>>") 48 | print ("***** time step per batchnorm".format(self.batch_num)) 49 | print (">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>") 50 | 51 | affine_flag = True 52 | bias_flag = False 53 | 54 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=bias_flag) 55 | self.bntt1 = nn.ModuleList([nn.BatchNorm2d(64, eps=1e-4, momentum=0.1, affine=affine_flag) for i in range(self.batch_num)]) 56 | self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=bias_flag) 57 | self.bntt2 = nn.ModuleList([nn.BatchNorm2d(64, eps=1e-4, momentum=0.1, affine=affine_flag) for i in range(self.batch_num)]) 58 | self.pool1 = nn.AvgPool2d(kernel_size=2) 59 | 60 | self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=bias_flag) 61 | self.bntt3 = nn.ModuleList([nn.BatchNorm2d(128, eps=1e-4, momentum=0.1, affine=affine_flag) for i in range(self.batch_num)]) 62 | self.conv4 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=bias_flag) 63 | self.bntt4 = nn.ModuleList([nn.BatchNorm2d(128, eps=1e-4, momentum=0.1, affine=affine_flag) for i in range(self.batch_num)]) 64 | self.pool2 = nn.AvgPool2d(kernel_size=2) 65 | 66 | self.conv5 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1, bias=bias_flag) 67 | self.bntt5 = nn.ModuleList([nn.BatchNorm2d(256, eps=1e-4, momentum=0.1, affine=affine_flag) for i in range(self.batch_num)]) 68 | self.conv6 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=bias_flag) 69 | self.bntt6 = nn.ModuleList([nn.BatchNorm2d(256, eps=1e-4, momentum=0.1, affine=affine_flag) for i in range(self.batch_num)]) 70 | self.conv7 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=bias_flag) 71 | self.bntt7 = nn.ModuleList([nn.BatchNorm2d(256, eps=1e-4, momentum=0.1, affine=affine_flag) for i in range(self.batch_num)]) 72 | self.pool3 = nn.AvgPool2d(kernel_size=2) 73 | 74 | 75 | self.fc1 = nn.Linear((self.img_size//8)*(self.img_size//8)*256, 1024, bias=bias_flag) 76 | self.bntt_fc = nn.ModuleList([nn.BatchNorm1d(1024, eps=1e-4, momentum=0.1, affine=affine_flag) for i in range(self.batch_num)]) 77 | self.fc2 = nn.Linear(1024, self.num_cls, bias=bias_flag) 78 | 79 | self.conv_list = [self.conv1, self.conv2, self.conv3, self.conv4, self.conv5, self.conv6, self.conv7] 80 | self.bntt_list = [self.bntt1, self.bntt2, self.bntt3, self.bntt4, self.bntt5, self.bntt6, self.bntt7, self.bntt_fc] 81 | self.pool_list = [False, self.pool1, False, self.pool2, False, False, self.pool3] 82 | 83 | # Turn off bias of BNTT 84 | for bn_list in self.bntt_list: 85 | for bn_temp in bn_list: 86 | bn_temp.bias = None 87 | 88 | 89 | # Initialize the firing thresholds of all the layers 90 | for m in self.modules(): 91 | if (isinstance(m, nn.Conv2d)): 92 | m.threshold = 1.0 93 | torch.nn.init.xavier_uniform_(m.weight, gain=2) 94 | elif (isinstance(m, nn.Linear)): 95 | m.threshold = 1.0 96 | torch.nn.init.xavier_uniform_(m.weight, gain=2) 97 | 98 | 99 | 100 | 101 | def forward(self, inp): 102 | 103 | batch_size = inp.size(0) 104 | mem_conv1 = torch.zeros(batch_size, 64, self.img_size, self.img_size).cuda() 105 | mem_conv2 = torch.zeros(batch_size, 64, self.img_size, self.img_size).cuda() 106 | mem_conv3 = torch.zeros(batch_size, 128, self.img_size//2, self.img_size//2).cuda() 107 | mem_conv4 = torch.zeros(batch_size, 128, self.img_size//2, self.img_size//2).cuda() 108 | mem_conv5 = torch.zeros(batch_size, 256, self.img_size//4, self.img_size//4).cuda() 109 | mem_conv6 = torch.zeros(batch_size, 256, self.img_size//4, self.img_size//4).cuda() 110 | mem_conv7 = torch.zeros(batch_size, 256, self.img_size//4, self.img_size//4).cuda() 111 | mem_conv_list = [mem_conv1, mem_conv2, mem_conv3, mem_conv4, mem_conv5, mem_conv6, mem_conv7] 112 | 113 | mem_fc1 = torch.zeros(batch_size, 1024).cuda() 114 | mem_fc2 = torch.zeros(batch_size, self.num_cls).cuda() 115 | 116 | 117 | 118 | for t in range(self.timesteps): 119 | 120 | spike_inp = PoissonGen(inp) 121 | out_prev = spike_inp 122 | 123 | for i in range(len(self.conv_list)): 124 | mem_conv_list[i] = self.leak_mem * mem_conv_list[i] + self.bntt_list[i][t](self.conv_list[i](out_prev)) 125 | mem_thr = (mem_conv_list[i] / self.conv_list[i].threshold) - 1.0 126 | out = self.spike_fn(mem_thr) 127 | rst = torch.zeros_like(mem_conv_list[i]).cuda() 128 | rst[mem_thr > 0] = self.conv_list[i].threshold 129 | mem_conv_list[i] = mem_conv_list[i] - rst 130 | out_prev = out.clone() 131 | 132 | 133 | if self.pool_list[i] is not False: 134 | out = self.pool_list[i](out_prev) 135 | out_prev = out.clone() 136 | 137 | 138 | out_prev = out_prev.reshape(batch_size, -1) 139 | 140 | mem_fc1 = self.leak_mem * mem_fc1 + self.bntt_fc[t](self.fc1(out_prev)) 141 | mem_thr = (mem_fc1 / self.fc1.threshold) - 1.0 142 | out = self.spike_fn(mem_thr) 143 | rst = torch.zeros_like(mem_fc1).cuda() 144 | rst[mem_thr > 0] = self.fc1.threshold 145 | mem_fc1 = mem_fc1 - rst 146 | out_prev = out.clone() 147 | 148 | # accumulate voltage in the last layer 149 | mem_fc2 = mem_fc2 + self.fc2(out_prev) 150 | 151 | out_voltage = mem_fc2 / self.timesteps 152 | 153 | 154 | return out_voltage 155 | 156 | 157 | class SNN_VGG11_BNTT(nn.Module): 158 | def __init__(self, timesteps=20, leak_mem=0.95, img_size=32, num_cls=10): 159 | super(SNN_VGG11_BNTT, self).__init__() 160 | 161 | self.img_size = img_size 162 | self.num_cls = num_cls 163 | self.timesteps = timesteps 164 | self.spike_fn = Surrogate_BP_Function.apply 165 | self.leak_mem = leak_mem 166 | self.batch_num = self.timesteps 167 | 168 | print (">>>>>>>>>>>>>>>>> VGG11 >>>>>>>>>>>>>>>>>>>>>>>") 169 | print ("***** time step per batchnorm".format(self.batch_num)) 170 | print (">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>") 171 | 172 | affine_flag = True 173 | bias_flag = False 174 | 175 | 176 | 177 | 178 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=bias_flag) 179 | self.bntt1 = nn.ModuleList([nn.BatchNorm2d(64, eps=1e-4, momentum=0.1, affine=affine_flag) for i in range(self.batch_num)]) 180 | self.pool1 = nn.AvgPool2d(kernel_size=2) 181 | 182 | self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=bias_flag) 183 | self.bntt2 = nn.ModuleList([nn.BatchNorm2d(128, eps=1e-4, momentum=0.1, affine=affine_flag) for i in range(self.batch_num)]) 184 | self.pool2 = nn.AvgPool2d(kernel_size=2) 185 | 186 | self.conv3 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1, bias=bias_flag) 187 | self.bntt3 = nn.ModuleList([nn.BatchNorm2d(256, eps=1e-4, momentum=0.1, affine=affine_flag) for i in range(self.batch_num)]) 188 | self.conv4 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=bias_flag) 189 | self.bntt4 = nn.ModuleList([nn.BatchNorm2d(256, eps=1e-4, momentum=0.1, affine=affine_flag) for i in range(self.batch_num)]) 190 | self.pool3 = nn.AvgPool2d(kernel_size=2) 191 | 192 | self.conv5 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1, bias=bias_flag) 193 | self.bntt5 = nn.ModuleList([nn.BatchNorm2d(512, eps=1e-4, momentum=0.1, affine=affine_flag) for i in range(self.batch_num)]) 194 | self.conv6 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=bias_flag) 195 | self.bntt6 = nn.ModuleList([nn.BatchNorm2d(512, eps=1e-4, momentum=0.1, affine=affine_flag) for i in range(self.batch_num)]) 196 | self.pool4 = nn.AvgPool2d(kernel_size=2) 197 | 198 | self.conv7 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=bias_flag) 199 | self.bntt7 = nn.ModuleList([nn.BatchNorm2d(512, eps=1e-4, momentum=0.1, affine=affine_flag) for i in range(self.batch_num)]) 200 | self.conv8 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=bias_flag) 201 | self.bntt8 = nn.ModuleList([nn.BatchNorm2d(512, eps=1e-4, momentum=0.1, affine=affine_flag) for i in range(self.batch_num)]) 202 | self.pool5 = nn.AdaptiveAvgPool2d((1,1)) 203 | 204 | 205 | self.fc1 = nn.Linear(512, 4096, bias=bias_flag) 206 | self.bntt_fc = nn.ModuleList([nn.BatchNorm1d(4096, eps=1e-4, momentum=0.1, affine=affine_flag) for i in range(self.batch_num)]) 207 | self.fc2 = nn.Linear(4096, self.num_cls, bias=bias_flag) 208 | 209 | self.conv_list = [self.conv1, self.conv2, self.conv3, self.conv4, self.conv5, self.conv6, self.conv7, self.conv8] 210 | self.bntt_list = [self.bntt1, self.bntt2, self.bntt3, self.bntt4, self.bntt5, self.bntt6, self.bntt7, self.bntt8, self.bntt_fc] 211 | self.pool_list = [self.pool1, self.pool2, False, self.pool3, False, self.pool4, False, self.pool5] 212 | 213 | # Turn off bias of BNTT 214 | for bn_list in self.bntt_list: 215 | for bn_temp in bn_list: 216 | bn_temp.bias = None 217 | 218 | 219 | # Initialize the firing thresholds of all the layers 220 | for m in self.modules(): 221 | if (isinstance(m, nn.Conv2d)): 222 | m.threshold = 1.0 223 | torch.nn.init.xavier_uniform_(m.weight, gain=2) 224 | elif (isinstance(m, nn.Linear)): 225 | m.threshold = 1.0 226 | torch.nn.init.xavier_uniform_(m.weight, gain=2) 227 | 228 | 229 | 230 | 231 | def forward(self, inp): 232 | 233 | batch_size = inp.size(0) 234 | mem_conv1 = torch.zeros(batch_size, 64, self.img_size, self.img_size).cuda() 235 | mem_conv2 = torch.zeros(batch_size, 128, self.img_size // 2, self.img_size // 2).cuda() 236 | mem_conv3 = torch.zeros(batch_size, 256, self.img_size // 4, self.img_size // 4).cuda() 237 | mem_conv4 = torch.zeros(batch_size, 256, self.img_size // 4, self.img_size // 4).cuda() 238 | mem_conv5 = torch.zeros(batch_size, 512, self.img_size // 8, self.img_size // 8).cuda() 239 | mem_conv6 = torch.zeros(batch_size, 512, self.img_size // 8, self.img_size // 8).cuda() 240 | mem_conv7 = torch.zeros(batch_size, 512, self.img_size // 16, self.img_size // 16).cuda() 241 | mem_conv8 = torch.zeros(batch_size, 512, self.img_size // 16, self.img_size // 16).cuda() 242 | mem_conv_list = [mem_conv1, mem_conv2, mem_conv3, mem_conv4, mem_conv5, mem_conv6, mem_conv7, mem_conv8] 243 | 244 | mem_fc1 = torch.zeros(batch_size, 4096).cuda() 245 | mem_fc2 = torch.zeros(batch_size, self.num_cls).cuda() 246 | 247 | 248 | 249 | for t in range(self.timesteps): 250 | 251 | spike_inp = PoissonGen(inp) 252 | out_prev = spike_inp 253 | 254 | for i in range(len(self.conv_list)): 255 | mem_conv_list[i] = self.leak_mem * mem_conv_list[i] + self.bntt_list[i][t](self.conv_list[i](out_prev)) 256 | mem_thr = (mem_conv_list[i] / self.conv_list[i].threshold) - 1.0 257 | out = self.spike_fn(mem_thr) 258 | rst = torch.zeros_like(mem_conv_list[i]).cuda() 259 | rst[mem_thr > 0] = self.conv_list[i].threshold 260 | mem_conv_list[i] = mem_conv_list[i] - rst 261 | out_prev = out.clone() 262 | 263 | 264 | if self.pool_list[i] is not False: 265 | out = self.pool_list[i](out_prev) 266 | out_prev = out.clone() 267 | 268 | 269 | out_prev = out_prev.reshape(batch_size, -1) 270 | 271 | mem_fc1 = self.leak_mem * mem_fc1 + self.bntt_fc[t](self.fc1(out_prev)) 272 | mem_thr = (mem_fc1 / self.fc1.threshold) - 1.0 273 | out = self.spike_fn(mem_thr) 274 | rst = torch.zeros_like(mem_fc1).cuda() 275 | rst[mem_thr > 0] = self.fc1.threshold 276 | mem_fc1 = mem_fc1 - rst 277 | out_prev = out.clone() 278 | 279 | # accumulate voltage in the last layer 280 | mem_fc2 = mem_fc2 + self.fc2(out_prev) 281 | 282 | 283 | out_voltage = mem_fc2 / self.timesteps 284 | 285 | return out_voltage -------------------------------------------------------------------------------- /models/Fed.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Python version: 3.6 4 | 5 | import copy 6 | import torch 7 | import random 8 | from torch import nn 9 | from typing import Union 10 | 11 | def percentile(t: torch.tensor, q: float) -> Union[int, float]: 12 | """ 13 | Return the ``q``-th percentile of the flattened input tensor's data. 14 | CAUTION: 15 | * Needs PyTorch >= 1.1.0, as ``torch.kthvalue()`` is used. 16 | * Values are not interpolated, which corresponds to 17 | ``numpy.percentile(..., interpolation="nearest")``. 18 | :param t: Input tensor. 19 | :param q: Percentile to compute, which must be between 0 and 100 inclusive. 20 | :return: Resulting value (scalar). 21 | """ 22 | # Note that ``kthvalue()`` works one-based, i.e. the first sorted value 23 | # indeed corresponds to k=1, not k=0! Use float(q) instead of q directly, 24 | # so that ``round()`` returns an integer, even if q is a np.float32. 25 | k = 1 + round(.01 * float(q) * (t.numel() - 1)) 26 | result = t.view(-1).kthvalue(k).values.item() 27 | return result 28 | 29 | def model_diff(w, w_init): 30 | diff = 0 31 | for k in w_init.keys(): 32 | if not ("num_batches_tracked" in k): 33 | diff += torch.linalg.norm(w[k] - w_init[k])/(1 + torch.linalg.norm(w_init[k])) 34 | return diff 35 | 36 | def model_deviation(w_locals, w_init): 37 | model_deviation_list = [] 38 | print("Num clients:",len(w_locals)) 39 | for w in w_locals: 40 | model_deviation_list.append(model_diff(w, w_init).item()) 41 | return model_deviation_list 42 | 43 | class FedLearn(object): 44 | def __init__(self, args): 45 | self.args = args 46 | 47 | def FedAvg(self, w, w_init = None): 48 | non_stragglers = [1]*len(w) 49 | for i in range(1, len(w)): 50 | epsilon = random.uniform(0, 1) 51 | if epsilon < self.args.straggler_prob: 52 | non_stragglers[i] = 0 53 | w_avg = copy.deepcopy(w[0]) 54 | for k in w_avg.keys(): 55 | if w_init: 56 | w_avg[k] = w_avg[k].cpu() + torch.mean(torch.abs(w_init[k].cpu() - w[0][k].cpu())*1.0) * torch.randn(w[0][k].size()) * self.args.grad_noise_stdev # Scale the noise by mean of the absolute value of the model updates 57 | else: 58 | w_avg[k] = w_avg[k].cpu() + torch.randn(w[0][k].size()) * self.args.grad_noise_stdev # Add gaussian noise to the model updates 59 | for k in w_avg.keys(): 60 | for i in range(1, len(w)): 61 | if non_stragglers[i] == 1: 62 | if w_init: 63 | w_avg[k] = w_avg[k].cpu() + w[i][k].cpu() + torch.mean(torch.abs(w_init[k].cpu() - w[i][k].cpu())*1.0) * torch.randn(w[i][k].size()) * self.args.grad_noise_stdev # Scale the noise by mean of the absolute value of the model updates 64 | # 1.0 is to convert into float 65 | else: 66 | w_avg[k] = w_avg[k].cpu() + w[i][k].cpu() + torch.randn(w[i][k].size()) * self.args.grad_noise_stdev # Add gaussian noise to the model updates 67 | w_avg[k] = torch.div(w_avg[k], sum(non_stragglers)) 68 | return w_avg 69 | 70 | def FedAvgSparse(self, w_init, delta_w_locals, th_basis = "magnitude", pruning_type = "uniform", sparsity = 0, activity = None, activity_multiplier = 1, activity_mask = None): 71 | # th_basis -> on what basis the threshold is calculated - magnitude or activity 72 | # pruningy_type -> uniform or dynamic: uniform will have equal sparsity among all layers. dynamic has different sparsity for different layers based on activity. 73 | # Only magnitude based uniform is applicable for ANNs 74 | delta_w_avg = {} 75 | w_avg = {} 76 | sparse_delta_w_locals = [] 77 | for i in range(0, len(delta_w_locals)): 78 | sparse_delta_w = {} 79 | sparse_delta_w_locals.append(sparse_delta_w) 80 | for k in w_init.keys(): 81 | # Threshold Calculation 82 | if th_basis == "magnitude" and pruning_type == "uniform": 83 | th = percentile(torch.abs(delta_w_locals[0][k]), sparsity) 84 | th = torch.FloatTensor([th]).cuda() 85 | mask = torch.abs(delta_w_locals[0][k]) > th.expand_as(w_init[k]) 86 | elif th_basis == "magnitude" and pruning_type == "dynamic": 87 | if activity is None: 88 | print("Layer activity not available. Dynamic sparsity not possible") 89 | if "features" in k: 90 | idx = int(k.split(sep='.')[2]) 91 | layer_activity = activity[idx] 92 | prev = idx 93 | elif "classifier" in k: 94 | idx = int(k.split(sep='.')[2]) + prev + 3 95 | if idx in activity.keys(): 96 | layer_activity = activity[idx] 97 | else: 98 | layer_activity = sum(activity) / len(activity) 99 | else: 100 | print("Unknown Layer!") 101 | s = 100*(1 - layer_activity/activity_multiplier) 102 | print("sparsity", s) 103 | th = percentile(torch.abs(delta_w_locals[0][k]), s) 104 | print("Threshold", th) 105 | th = torch.FloatTensor([th]).cuda() 106 | mask = torch.abs(delta_w_locals[0][k]) > th.expand_as(w_init[k]) 107 | elif th_basis == "activity" and pruning_type == "uniform": 108 | if activity_mask is None: 109 | print("Activity mask is not available. Activity based pruning not possible") 110 | if "features" in k: 111 | idx = int(k.split(sep='.')[2]) 112 | layer_activity_mask = activity_mask[idx] 113 | prev = idx 114 | elif "classifier" in k: 115 | idx = int(k.split(sep='.')[2]) + prev + 3 116 | if idx in activity_mask.keys(): 117 | layer_activity_mask = activity_mask[idx] 118 | else: 119 | layer_activity_mask = torch.tensor(1) 120 | else: 121 | print("Unknown Layer!") 122 | if layer_activity_mask.shape == torch.Size([]): 123 | th = percentile(torch.abs(delta_w_locals[0][k]), sparsity) 124 | th = torch.FloatTensor([th]).cuda() 125 | mask = torch.abs(delta_w_locals[0][k]) > th.expand_as(w_init[k]) 126 | else: 127 | th = percentile(layer_activity_mask, sparsity) 128 | th = torch.FloatTensor([th]).cuda() 129 | mask = layer_activity_mask > th.expand_as(w_init[k]) 130 | elif th_basis == "activity" and pruning_type == "dynamic": 131 | if activity is None: 132 | print("Layer activity not available. Dynamic sparsity not possible") 133 | if activity_mask is None: 134 | print("Activity mask is not available. Activity based pruning not possible") 135 | if "features" in k: 136 | idx = int(k.split(sep='.')[2]) 137 | layer_activity = activity[idx] 138 | layer_activity_mask = activity_mask[idx] 139 | prev = idx 140 | elif "classifier" in k: 141 | idx = int(k.split(sep='.')[2]) + prev + 3 142 | if idx in activity.keys(): 143 | layer_activity = activity[idx] 144 | else: 145 | layer_activity = sum(activity) / len(activity) 146 | if idx in activity_mask.keys(): 147 | layer_activity_mask = activity_mask[idx] 148 | else: 149 | layer_activity_mask = torch.tensor(1) 150 | else: 151 | print("Unknown Layer!") 152 | s = 100*(1 - layer_activity/activity_multiplier) 153 | if layer_activity_mask.shape == torch.Size([]): 154 | th = percentile(torch.abs(delta_w_locals[0][k]), s) 155 | th = torch.FloatTensor([th]).cuda() 156 | mask = torch.abs(delta_w_locals[0][k]) > th.expand_as(w_init[k]) 157 | else: 158 | th = percentile(layer_activity_mask, s) 159 | th = torch.FloatTensor([th]).cuda() 160 | mask = layer_activity_mask > th.expand_as(w_init[k]) 161 | else: 162 | print("Unknown threshold basis or pruning_type. Available options: th_basis - magnitude or activity, pruning_type - uniform or dynamic") 163 | sparse_delta_w_locals[0][k] = delta_w_locals[0][k] * mask 164 | delta_w_avg[k] = (delta_w_locals[0][k] * mask) 165 | for k in w_init.keys(): 166 | for i in range(1, len(delta_w_locals)): 167 | # Threshold Calculation 168 | if th_basis == "magnitude" and pruning_type == "uniform": 169 | th = percentile(torch.abs(delta_w_locals[i][k]), sparsity) 170 | th = torch.FloatTensor([th]).cuda() 171 | mask = torch.abs(delta_w_locals[i][k]) > th.expand_as(w_init[k]) 172 | elif th_basis == "magnitude" and pruning_type == "dynamic": 173 | if activity is None: 174 | print("Layer activity not available. Dynamic sparsity not possible") 175 | if "features" in k: 176 | idx = int(k.split(sep='.')[2]) 177 | layer_activity = activity[idx] 178 | prev = idx 179 | elif "classifier" in k: 180 | idx = int(k.split(sep='.')[2]) + prev + 3 181 | if idx in activity.keys(): 182 | layer_activity = activity[idx] 183 | else: 184 | layer_activity = sum(activity) / len(activity) 185 | else: 186 | print("Unknown Layer!") 187 | s = 100*(1 - layer_activity/activity_multiplier) 188 | print("sparsity", s) 189 | th = percentile(torch.abs(delta_w_locals[i][k]), s) 190 | print("Threshold", th) 191 | th = torch.FloatTensor([th]).cuda() 192 | mask = torch.abs(delta_w_locals[i][k]) > th.expand_as(w_init[k]) 193 | elif th_basis == "activity" and pruning_type == "uniform": 194 | if activity_mask is None: 195 | print("Activity mask is not available. Activity based pruning not possible") 196 | if "features" in k: 197 | idx = int(k.split(sep='.')[2]) 198 | layer_activity_mask = activity_mask[idx] 199 | prev = idx 200 | elif "classifier" in k: 201 | idx = int(k.split(sep='.')[2]) + prev + 3 202 | if idx in activity_mask.keys(): 203 | layer_activity_mask = activity_mask[idx] 204 | else: 205 | layer_activity_mask = torch.tensor(1) 206 | else: 207 | print("Unknown Layer!") 208 | if layer_activity_mask.shape == torch.Size([]): 209 | th = percentile(torch.abs(delta_w_locals[i][k]), sparsity) 210 | th = torch.FloatTensor([th]).cuda() 211 | mask = torch.abs(delta_w_locals[i][k]) > th.expand_as(w_init[k]) 212 | else: 213 | th = percentile(layer_activity_mask, sparsity) 214 | th = torch.FloatTensor([th]).cuda() 215 | mask = layer_activity_mask > th.expand_as(w_init[k]) 216 | elif th_basis == "activity" and pruning_type == "dynamic": 217 | if activity is None: 218 | print("Layer activity not available. Dynamic sparsity not possible") 219 | if activity_mask is None: 220 | print("Activity mask is not available. Activity based pruning not possible") 221 | if "features" in k: 222 | idx = int(k.split(sep='.')[2]) 223 | layer_activity = activity[idx] 224 | layer_activity_mask = activity_mask[idx] 225 | prev = idx 226 | elif "classifier" in k: 227 | idx = int(k.split(sep='.')[2]) + prev + 3 228 | if idx in activity.keys(): 229 | layer_activity = activity[idx] 230 | else: 231 | layer_activity = sum(activity) / len(activity) 232 | if idx in activity_mask.keys(): 233 | layer_activity_mask = activity_mask[idx] 234 | else: 235 | layer_activity_mask = torch.tensor(1) 236 | else: 237 | print("Unknown Layer!") 238 | s = 100*(1 - layer_activity/activity_multiplier) 239 | if layer_activity_mask.shape == torch.Size([]): 240 | th = percentile(torch.abs(delta_w_locals[i][k]), s) 241 | th = torch.FloatTensor([th]).cuda() 242 | mask = torch.abs(delta_w_locals[i][k]) > th.expand_as(w_init[k]) 243 | else: 244 | th = percentile(layer_activity_mask, s) 245 | th = torch.FloatTensor([th]).cuda() 246 | mask = layer_activity_mask > th.expand_as(w_init[k]) 247 | else: 248 | print("Unknown threshold basis or pruning_type. Available options: th_basis - magnitude or activity, pruning_type - uniform or dynamic") 249 | sparse_delta_w_locals[i][k] = delta_w_locals[i][k] * mask 250 | delta_w_avg[k] = (delta_w_locals[i][k] * mask) 251 | delta_w_avg[k] = torch.div(delta_w_avg[k], len(delta_w_avg)) 252 | w_avg[k] = w_init[k] + delta_w_avg[k] 253 | return w_avg, delta_w_avg, sparse_delta_w_locals 254 | 255 | def count_gradients(self, delta_w_locals, sparse_delta_w_locals): 256 | num_grads = [] 257 | nz_grads = [] 258 | for i in range(0, len(delta_w_locals)): 259 | num_grads.append(0) 260 | nz_grads.append(0) 261 | for k in delta_w_locals[0].keys(): 262 | for i in range(len(delta_w_locals)): 263 | num_grads[i] += delta_w_locals[i][k].numel() 264 | nz_grads[i] += torch.nonzero(sparse_delta_w_locals[i][k]).size(0) 265 | return num_grads, nz_grads -------------------------------------------------------------------------------- /models/vgg_spiking_bntt_activity.py: -------------------------------------------------------------------------------- 1 | # File : model_cifar10.py 2 | # Descr: Define SNN models for the CIFAR10 dataset 3 | # Date : March 22, 2019 4 | 5 | # -------------------------------------------------- 6 | # Imports 7 | # -------------------------------------------------- 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import math 12 | import sys 13 | import numpy as np 14 | import numpy.linalg as LA 15 | from torch.autograd import Variable 16 | 17 | 18 | # -------------------------------------------------- 19 | # Spiking neuron with fast-sigmoid surrogate gradient 20 | # This class is replicated from: 21 | # https://github.com/fzenke/spytorch/blob/master/notebooks/SpyTorchTutorial2.ipynb 22 | # -------------------------------------------------- 23 | class SuperSpike(torch.autograd.Function): 24 | """ 25 | Here we implement our spiking nonlinearity which also implements 26 | the surrogate gradient. By subclassing torch.autograd.Function, 27 | we will be able to use all of PyTorch's autograd functionality. 28 | Here we use the normalized negative part of a fast sigmoid as 29 | was done in Zenke & Ganguli (2018). 30 | """ 31 | scale = 100.0 # Controls the steepness of the fast-sigmoid surrogate gradient 32 | 33 | @staticmethod 34 | def forward(ctx, input): 35 | """ 36 | In the forward pass, we compute a step function of the input Tensor and 37 | return it. ctx is a context object that we use to stash information which 38 | we need to later backpropagate our error signals. To achieve this we use 39 | the ctx.save_for_backward method. 40 | """ 41 | ctx.save_for_backward(input) 42 | out = torch.zeros_like(input).cuda() 43 | out[input > 0] = 1.0 44 | return out 45 | 46 | @staticmethod 47 | def backward(ctx, grad_output): 48 | """ 49 | In the backward pass, we receive a Tensor we need to compute 50 | the surrogate gradient of the loss with respect to the input. 51 | Here we use the normalized negative part of a fast sigmoid 52 | as was done in Zenke & Ganguli (2018). 53 | """ 54 | input, = ctx.saved_tensors 55 | grad_input = grad_output.clone() 56 | grad = grad_input / (SuperSpike.scale * torch.abs(input) + 1.0) ** 2 57 | return grad 58 | 59 | 60 | # -------------------------------------------------- 61 | # Spiking neuron with piecewise-linear surrogate gradient 62 | # -------------------------------------------------- 63 | class LinearSpike(torch.autograd.Function): 64 | """ 65 | Here we implement our spiking nonlinearity which also implements 66 | the surrogate gradient. By subclassing torch.autograd.Function, 67 | we will be able to use all of PyTorch's autograd functionality. 68 | Here we use the piecewise-linear surrogate gradient as was done 69 | in Bellec et al. (2018). 70 | """ 71 | gamma = 0.3 # Controls the dampening of the piecewise-linear surrogate gradient 72 | 73 | @staticmethod 74 | def forward(ctx, input): 75 | """ 76 | In the forward pass, we compute a step function of the input Tensor and 77 | return it. ctx is a context object that we use to stash information which 78 | we need to later backpropagate our error signals. To achieve this we use 79 | the ctx.save_for_backward method. 80 | """ 81 | ctx.save_for_backward(input) 82 | out = torch.zeros_like(input).cuda() 83 | out[input > 0] = 1.0 84 | return out 85 | 86 | @staticmethod 87 | def backward(ctx, grad_output): 88 | """ 89 | In the backward pass, we receive a Tensor we need to compute 90 | the surrogate gradient of the loss with respect to the input. 91 | Here we use the piecewise-linear surrogate gradient as was 92 | done in Bellec et al. (2018). 93 | """ 94 | input, = ctx.saved_tensors 95 | grad_input = grad_output.clone() 96 | grad = grad_input * LinearSpike.gamma * F.threshold(1.0 - torch.abs(input), 0, 0) 97 | return grad 98 | 99 | 100 | # -------------------------------------------------- 101 | # Spiking neuron with exponential surrogate gradient 102 | # -------------------------------------------------- 103 | class ExpSpike(torch.autograd.Function): 104 | """ 105 | Here we implement our spiking nonlinearity which also implements 106 | the surrogate gradient. By subclassing torch.autograd.Function, 107 | we will be able to use all of PyTorch's autograd functionality. 108 | Here we use the exponential surrogate gradient as was done in 109 | Shrestha et al. (2018). 110 | """ 111 | alpha = 1.0 # Controls the magnitude of the exponential surrogate gradient 112 | beta = 10.0 # Controls the steepness of the exponential surrogate gradient 113 | 114 | @staticmethod 115 | def forward(ctx, input): 116 | """ 117 | In the forward pass, we compute a step function of the input Tensor and 118 | return it. ctx is a context object that we use to stash information which 119 | we need to later backpropagate our error signals. To achieve this we use 120 | the ctx.save_for_backward method. 121 | """ 122 | ctx.save_for_backward(input) 123 | out = torch.zeros_like(input).cuda() 124 | out[input > 0] = 1.0 125 | return out 126 | 127 | @staticmethod 128 | def backward(ctx, grad_output): 129 | """ 130 | In the backward pass, we receive a Tensor we need to compute 131 | the surrogate gradient of the loss with respect to the input. 132 | Here we use the exponential surrogate gradient as was done 133 | in Shrestha et al. (2018). 134 | """ 135 | input, = ctx.saved_tensors 136 | grad_input = grad_output.clone() 137 | grad = grad_input * ExpSpike.alpha * torch.exp(-ExpSpike.beta * torch.abs(input)) 138 | return grad 139 | 140 | 141 | # -------------------------------------------------- 142 | # Spiking neuron with pass-through surrogate gradient 143 | # -------------------------------------------------- 144 | class PassThruSpike(torch.autograd.Function): 145 | """ 146 | Here we implement our spiking nonlinearity which also implements 147 | the surrogate gradient. By subclassing torch.autograd.Function, 148 | we will be able to use all of PyTorch's autograd functionality. 149 | Here we use the pass-through surrogate gradient. 150 | """ 151 | 152 | @staticmethod 153 | def forward(ctx, input): 154 | """ 155 | In the forward pass, we compute a step function of the input Tensor and 156 | return it. For this spiking nonlinearity, the context object ctx does not 157 | stash input information since it is not used for backpropagation. 158 | """ 159 | # ctx.save_for_backward(input) 160 | out = torch.zeros_like(input).cuda() 161 | out[input > 0] = 1.0 162 | return out 163 | 164 | @staticmethod 165 | def backward(ctx, grad_output): 166 | """ 167 | In the backward pass, we receive a Tensor we need to compute 168 | the surrogate gradient of the loss with respect to the input. 169 | Here we use the pass-through surrogate gradient. 170 | """ 171 | # input, = ctx.saved_tensors 172 | grad_input = grad_output.clone() 173 | return grad_input 174 | 175 | 176 | # Overwrite the naive spike function by differentiable spiking nonlinearity which implements a surrogate gradient 177 | def init_spike_fn(grad_type): 178 | if (grad_type == 'FastSigm'): 179 | spike_fn = SuperSpike.apply 180 | elif (grad_type == 'Linear'): 181 | spike_fn = LinearSpike.apply 182 | elif (grad_type == 'Exp'): 183 | spike_fn = ExpSpike.apply 184 | elif (grad_type == 'PassThru'): 185 | spike_fn = PassThruSpike.apply 186 | else: 187 | sys.exit("Unknown gradient type '{}'".format(grad_type)) 188 | return spike_fn 189 | 190 | 191 | # -------------------------------------------------- 192 | # Poisson spike generator 193 | # Positive spike is generated (i.e. 1 is returned) if rand()<=abs(input) and sign(input)= 1 194 | # Negative spike is generated (i.e. -1 is returned) if rand()<=abs(input) and sign(input)=-1 195 | # -------------------------------------------------- 196 | def PoissonGen(inp, rescale_fac=2.0): 197 | rand_inp = torch.rand_like(inp).cuda() 198 | return torch.mul(torch.le(rand_inp * rescale_fac, torch.abs(inp)).float(), torch.sign(inp)) 199 | 200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 | class SNN_VGG9_TBN(nn.Module): 208 | def __init__(self, dt=0.001, t_end=0.100, inp_rate=100, grad_type='Linear', thresh_init_wnorm=False, 209 | leak_mem=0.99, img_size=32, inp_maps=3, c1_maps=64, c2_maps=64, ksize=3, fc0_size=200, 210 | num_cls=1000, drop_rate=0.5, use_max_out_over_time=False, timesteps = 20): 211 | super(SNN_VGG9_TBN, self).__init__() 212 | 213 | # ConvSNN architecture parameters 214 | self.img_size = img_size 215 | self.inp_maps = inp_maps 216 | self.c1_maps = 64 217 | self.c1_dim = self.img_size 218 | self.c2_maps = 128 219 | self.c3_maps = 256 220 | self.c4_maps = 512 221 | 222 | 223 | self.ksize = ksize 224 | self.fc0_size = fc0_size 225 | self.num_cls = num_cls 226 | 227 | # ConvSNN simulation parameters 228 | self.dt = dt 229 | self.t_end = t_end 230 | # self.num_steps = int(self.t_end / self.dt) 231 | self.num_steps = timesteps 232 | self.inp_rate = inp_rate 233 | self.inp_rescale_fac = 1.0 / (self.dt * self.inp_rate) 234 | self.grad_type = grad_type 235 | self.grad_type_pool = 'PassThru' 236 | self.thresh_init_wnorm = thresh_init_wnorm 237 | self.leak_mem = leak_mem#0.95 # leak_mem 238 | self.drop_rate = drop_rate 239 | self.lnorm_ord = 2 240 | self.scale_thresh = 1.0 241 | self.use_max_out_over_time = use_max_out_over_time 242 | 243 | self.dropout_layer = nn.Dropout2d(p=0.2) 244 | 245 | self.one_stamp = 1 246 | self.batch_num = self.num_steps // self.one_stamp 247 | 248 | print (">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>") 249 | print ("***** time step per batchnorm", self.batch_num) 250 | print (">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>") 251 | affine_flag = True 252 | bias_flag = False 253 | # Instantiate the ConvSNN layers 254 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=bias_flag) 255 | self.bn1_list = nn.ModuleList([nn.BatchNorm2d(64, eps=1e-4, momentum=0.1, affine=affine_flag) for i in range(self.batch_num)]) 256 | self.conv1_1 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=bias_flag) 257 | self.bn1_1_list = nn.ModuleList([nn.BatchNorm2d(64, eps=1e-4, momentum=0.1, affine=affine_flag) for i in range(self.batch_num)]) 258 | self.pool1 = nn.AvgPool2d(kernel_size=2) # Default stride = kernel_size 259 | 260 | self.conv2 = nn.Conv2d(64, 128, kernel_size=self.ksize, stride=1, padding=1, bias=bias_flag) 261 | self.bn2_list = nn.ModuleList([nn.BatchNorm2d(128, eps=1e-4, momentum=0.1, affine=affine_flag) for i in range(self.batch_num)]) 262 | self.conv3 = nn.Conv2d(128, 128, kernel_size=self.ksize, stride=1, padding=1, bias=bias_flag) 263 | self.bn3_list = nn.ModuleList([nn.BatchNorm2d(128, eps=1e-4, momentum=0.1, affine=affine_flag) for i in range(self.batch_num)]) 264 | self.pool2 = nn.AvgPool2d(kernel_size=2) # Default stride = kernel_size 265 | 266 | self.conv4 = nn.Conv2d(128, 256, kernel_size=self.ksize, stride=1, padding=1, bias=bias_flag) 267 | self.bn4_list = nn.ModuleList([nn.BatchNorm2d(256, eps=1e-4, momentum=0.1, affine=affine_flag) for i in range(self.batch_num)]) 268 | self.conv5 = nn.Conv2d(256, 256, kernel_size=self.ksize, stride=1, padding=1, bias=bias_flag) 269 | self.bn5_list = nn.ModuleList([nn.BatchNorm2d(256, eps=1e-4, momentum=0.1, affine=affine_flag) for i in range(self.batch_num)]) 270 | self.conv6 = nn.Conv2d(256, 256, kernel_size=self.ksize, stride=1, padding=1, bias=bias_flag) 271 | self.bn6_list = nn.ModuleList([nn.BatchNorm2d(256, eps=1e-4, momentum=0.1, affine=affine_flag) for i in range(self.batch_num)]) 272 | self.pool3 = nn.AvgPool2d(kernel_size=2) # Default stride = kernel_size 273 | 274 | 275 | 276 | self.drop = nn.Dropout(p=0.2) 277 | 278 | self.fc1 = nn.Linear((self.img_size//8)*(self.img_size //8)*256, 2*2*256, bias=bias_flag) 279 | self.bnfc_list = nn.ModuleList([nn.BatchNorm1d( 2*2*256, eps=1e-4, momentum=0.1, affine=affine_flag) for i in range(self.batch_num)]) 280 | 281 | self.fc2 = nn.Linear(2*2*256, self.num_cls, bias=bias_flag) 282 | 283 | batchnormlist = [self.bn1_list, self.bn1_1_list, self.bn2_list, self.bn3_list, self.bn4_list, self.bn5_list, 284 | self.bn6_list, self.bnfc_list] 285 | 286 | #TODO turn off bias of batchnorm 287 | for bnlist in batchnormlist: 288 | for bnbn in bnlist: 289 | bnbn.bias = None 290 | 291 | 292 | # Initialize the firing thresholds of all the layers 293 | for m in self.modules(): 294 | if (isinstance(m, nn.Conv2d)): 295 | m.threshold = 1.0 296 | # torch.nn.init.kaiming_normal_(m.weight,a=1) 297 | torch.nn.init.xavier_uniform_(m.weight, gain=2) 298 | 299 | elif (isinstance(m, nn.AvgPool2d)): 300 | m.threshold = 0.75 301 | elif (isinstance(m, nn.Linear)): 302 | m.threshold = 1.0 303 | # torch.nn.init.kaiming_normal_(m.weight,a=1) 304 | torch.nn.init.xavier_uniform_(m.weight, gain=2) 305 | 306 | 307 | # Instantiate differentiable spiking nonlinearity 308 | self.spike_fn = init_spike_fn(self.grad_type) 309 | self.spike_pool = init_spike_fn(self.grad_type_pool) 310 | 311 | def fc_init(self): 312 | torch.nn.init.xavier_uniform_(self.fc1.weight) 313 | 314 | torch.nn.init.xavier_uniform_(self.fc2.weight) 315 | 316 | 317 | def forward(self, inp, count_active_layers = False, report_activity = False): 318 | 319 | active_layer_count = 9 320 | if count_active_layers == True: 321 | return active_layer_count 322 | activity = torch.zeros(active_layer_count).cuda() 323 | 324 | # avg_spike_time = [] 325 | # Initialize the neuronal membrane potentials and dropout masks 326 | batch_size = inp.size(0) 327 | mem_conv1 = torch.zeros(batch_size, 64, self.img_size, self.img_size).cuda() 328 | mem_conv1_1 = torch.zeros(batch_size, 64, self.img_size, self.img_size).cuda() 329 | mem_conv2 = torch.zeros(batch_size, 128, self.img_size//2, self.img_size//2).cuda() 330 | mem_conv3 = torch.zeros(batch_size, 128, self.img_size//2, self.img_size//2).cuda() 331 | mem_conv4 = torch.zeros(batch_size, 256, self.img_size//4, self.img_size//4).cuda() 332 | mem_conv5 = torch.zeros(batch_size, 256, self.img_size//4, self.img_size//4).cuda() 333 | mem_conv6 = torch.zeros(batch_size, 256, self.img_size//4, self.img_size//4).cuda() 334 | 335 | 336 | 337 | mem_fc1 = torch.zeros(batch_size, 1024).cuda() 338 | mem_fc2 = torch.zeros(batch_size, self.num_cls).cuda() 339 | 340 | fc_dropout_mask = self.drop(torch.ones([batch_size, 1024]).cuda()) 341 | 342 | for t in range(self.num_steps): 343 | spike_inp = PoissonGen(inp) 344 | activity[0] += torch.count_nonzero(spike_inp.detach())/torch.numel(spike_inp.detach()) 345 | # # spike_x = torch.cat([spike_inp]*22,1)[:,:64,:,:] 346 | # out_prev = spike_inp 347 | # 348 | # Compute the conv1 outputs 349 | mem_thr = (mem_conv1/self.conv1.threshold) - 1.0 350 | out = self.spike_fn(mem_thr) 351 | rst = torch.zeros_like(mem_conv1).cuda() 352 | rst[mem_thr>0] = self.conv1.threshold 353 | mem_conv1 = (self.leak_mem*mem_conv1 + self.bn1_list[int(t/self.one_stamp)](self.conv1(spike_inp)) -rst) 354 | out_prev = out.clone() 355 | 356 | activity[1] += torch.count_nonzero(out_prev.detach())/torch.numel(out_prev.detach()) 357 | # Compute the conv1_1 outputs 358 | mem_thr = (mem_conv1_1 / self.conv1_1.threshold) - 1.0 359 | out = self.spike_fn(mem_thr) 360 | rst = torch.zeros_like(mem_conv1_1).cuda() 361 | rst[mem_thr > 0] = self.conv1_1.threshold 362 | mem_conv1_1 = (self.leak_mem * mem_conv1_1 + self.bn1_1_list[int(t/self.one_stamp)](self.conv1_1(out_prev)) - rst) 363 | out_prev = out.clone() 364 | 365 | 366 | # Compute the avgpool1 outputs 367 | out = self.pool1(out_prev) 368 | out_prev = out.clone() 369 | 370 | # mem_thr = (mem_pool1 / self.pool1.threshold) - 1.0 371 | # out = self.spike_pool(mem_thr) 372 | # rst = torch.zeros_like(mem_pool1).cuda() 373 | # rst[mem_thr > 0] = self.pool1.threshold 374 | # mem_pool1 = mem_pool1 + self.pool1(out_prev) - rst 375 | # out_prev = out.clone() 376 | 377 | 378 | activity[2] += torch.count_nonzero(out_prev.detach())/torch.numel(out_prev.detach()) 379 | # Compute the conv2 outputs 380 | mem_thr = (mem_conv2/self.conv2.threshold) - 1.0 381 | out = self.spike_fn(mem_thr) 382 | rst = torch.zeros_like(mem_conv2).cuda() 383 | rst[mem_thr>0] = self.conv2.threshold 384 | mem_conv2 = (self.leak_mem*mem_conv2 + self.bn2_list[int(t/self.one_stamp)](self.conv2(out_prev)) -rst) 385 | out_prev = out.clone() 386 | 387 | 388 | activity[3] += torch.count_nonzero(out_prev.detach())/torch.numel(out_prev.detach()) 389 | # Compute the conv3 outputs 390 | mem_thr = (mem_conv3 / self.conv3.threshold) - 1.0 391 | out = self.spike_fn(mem_thr) 392 | rst = torch.zeros_like(mem_conv3).cuda() 393 | rst[mem_thr > 0] = self.conv3.threshold 394 | mem_conv3 = (self.leak_mem * mem_conv3 + self.bn3_list[int(t/self.one_stamp)](self.conv3(out_prev)) - rst) 395 | out_prev = out.clone() 396 | 397 | # Compute the avgpool2 outputs 398 | out = self.pool2(out_prev) 399 | out_prev = out.clone() 400 | # mem_thr = (mem_pool2 / self.pool2.threshold) - 1.0 401 | # out = self.spike_pool(mem_thr) 402 | # rst = torch.zeros_like(mem_pool2).cuda() 403 | # rst[mem_thr > 0] = self.pool2.threshold 404 | # mem_pool2 = mem_pool2 + self.pool2(out_prev) - rst 405 | # out_prev = out.clone() 406 | 407 | 408 | activity[4] += torch.count_nonzero(out_prev.detach())/torch.numel(out_prev.detach()) 409 | # Compute the conv4 outputs 410 | mem_thr = (mem_conv4 / self.conv4.threshold) - 1.0 411 | out = self.spike_fn(mem_thr) 412 | rst = torch.zeros_like(mem_conv4).cuda() 413 | rst[mem_thr > 0] = self.conv4.threshold 414 | mem_conv4 = (self.leak_mem * mem_conv4 + self.bn4_list[int(t/self.one_stamp)](self.conv4(out_prev)) - rst) 415 | out_prev = out.clone() 416 | 417 | 418 | activity[5] += torch.count_nonzero(out_prev.detach())/torch.numel(out_prev.detach()) 419 | # Compute the conv5 outputs 420 | mem_thr = (mem_conv5 / self.conv5.threshold) - 1.0 421 | out = self.spike_fn(mem_thr) 422 | rst = torch.zeros_like(mem_conv5).cuda() 423 | rst[mem_thr > 0] = self.conv5.threshold 424 | mem_conv5 = (self.leak_mem * mem_conv5 + self.bn5_list[int(t/self.one_stamp)](self.conv5(out_prev)) - rst) 425 | out_prev = out.clone() 426 | 427 | 428 | activity[6] += torch.count_nonzero(out_prev.detach())/torch.numel(out_prev.detach()) 429 | # Compute the conv6 outputs 430 | mem_thr = (mem_conv6 / self.conv6.threshold) - 1.0 431 | out = self.spike_fn(mem_thr) 432 | rst = torch.zeros_like(mem_conv6).cuda() 433 | rst[mem_thr > 0] = self.conv6.threshold 434 | mem_conv6 = (self.leak_mem * mem_conv6 + self.bn6_list[int(t/self.one_stamp)](self.conv6(out_prev)) - rst) 435 | out_prev = out.clone() 436 | 437 | # Compute the avgpool3 outputs 438 | out = self.pool3(out_prev) 439 | out_prev = out.clone() 440 | # mem_thr = (mem_pool3 / self.pool3.threshold) - 1.0 441 | # out = self.spike_pool(mem_thr) 442 | # rst = torch.zeros_like(mem_pool3).cuda() 443 | # rst[mem_thr > 0] = self.pool3.threshold 444 | # mem_pool3 = mem_pool3 + self.pool3(out_prev) - rst 445 | # out_prev = out.clone() 446 | 447 | 448 | activity[7] += torch.count_nonzero(out_prev.detach())/torch.numel(out_prev.detach()) 449 | out_prev = out_prev.reshape(batch_size, -1) 450 | # compute fc1 451 | mem_thr = (mem_fc1 / self.fc1.threshold) - 1.0 452 | out = self.spike_fn(mem_thr) 453 | rst = torch.zeros_like(mem_fc1).cuda() 454 | rst[mem_thr > 0] = self.fc1.threshold 455 | mem_fc1 = (self.leak_mem * mem_fc1 + self.bnfc_list[int(t/self.one_stamp)](self.fc1(out_prev)) - rst) 456 | # mem_fc1 = (self.leak_mem * mem_fc1 + (self.fc1(out_prev)) - rst) 457 | 458 | out_prev = out.clone() 459 | 460 | # out_prev = fc_dropout_mask *out_prev 461 | 462 | # # TODO last spike expectation 463 | # avg_spike = out_prev.sum(1).sum(0) / out_prev.size(1) / out_prev.size(0) 464 | # avg_spike_time.append(float(avg_spike.cpu().data.numpy())) 465 | 466 | activity[8] += torch.count_nonzero(out_prev.detach())/torch.numel(out_prev.detach()) 467 | # compute fc1 468 | mem_fc2 = (1 * mem_fc2 + self.fc2(out_prev)) 469 | 470 | print(activity) 471 | activity = [x / self.num_steps for x in activity] 472 | print(activity) 473 | if report_activity: 474 | return activity 475 | 476 | out_voltage = mem_fc2 477 | out_voltage = (out_voltage) / self.num_steps 478 | 479 | 480 | return out_voltage 481 | # ---------------------------------------------- 482 | 483 | 484 | 485 | 486 | 487 | 488 | 489 | class SNN_VGG16_TBN(nn.Module): 490 | def __init__(self, dt=0.001, t_end=0.100, inp_rate=100, grad_type='Linear', thresh_init_wnorm=False, 491 | leak_mem=0.99, img_size=32, inp_maps=3, c1_maps=64, c2_maps=64, ksize=3, fc0_size=200, 492 | num_cls=1000, drop_rate=0.5, use_max_out_over_time=False, timesteps = 20): 493 | super(SNN_VGG16_TBN, self).__init__() 494 | 495 | # ConvSNN architecture parameters 496 | self.img_size = img_size 497 | self.inp_maps = inp_maps 498 | self.c1_maps = 64 499 | self.c1_dim = self.img_size 500 | self.c2_maps = 128 501 | self.c3_maps = 256 502 | self.c4_maps = 512 503 | 504 | 505 | self.ksize = ksize 506 | self.fc0_size = fc0_size 507 | self.num_cls = num_cls 508 | 509 | # ConvSNN simulation parameters 510 | self.dt = dt 511 | self.t_end = t_end 512 | # self.num_steps = int(self.t_end / self.dt) 513 | self.num_steps = timesteps 514 | self.inp_rate = inp_rate 515 | self.inp_rescale_fac = 1.0 / (self.dt * self.inp_rate) 516 | self.grad_type = grad_type 517 | self.grad_type_pool = 'PassThru' 518 | self.thresh_init_wnorm = thresh_init_wnorm 519 | self.leak_mem = 0.99 # leak_mem 520 | self.drop_rate = drop_rate 521 | self.lnorm_ord = 2 522 | self.scale_thresh = 1.0 523 | self.use_max_out_over_time = use_max_out_over_time 524 | 525 | self.dropout_layer = nn.Dropout2d(p=0.2) 526 | 527 | self.one_stamp = 1 528 | self.batch_num = self.num_steps // self.one_stamp 529 | 530 | print (">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>") 531 | print ("***** time step per batchnorm", self.batch_num) 532 | print (">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>") 533 | 534 | affine_flag = True 535 | 536 | 537 | # Instantiate the ConvSNN layers 538 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 539 | self.bn1_list = nn.ModuleList([nn.BatchNorm2d(64, eps=1e-4, momentum=0.1, affine=affine_flag) for i in range(self.batch_num)]) 540 | self.conv1_1 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False) 541 | self.bn1_1_list = nn.ModuleList([nn.BatchNorm2d(64, eps=1e-4, momentum=0.1, affine=affine_flag) for i in range(self.batch_num)]) 542 | self.pool1 = nn.AvgPool2d(kernel_size=2) # Default stride = kernel_size 543 | 544 | self.conv2 = nn.Conv2d(64, 128, kernel_size=self.ksize, stride=1, padding=1, bias=False) 545 | self.bn2_list = nn.ModuleList([nn.BatchNorm2d(128, eps=1e-4, momentum=0.1, affine=affine_flag) for i in range(self.batch_num)]) 546 | self.conv3 = nn.Conv2d(128, 128, kernel_size=self.ksize, stride=1, padding=1, bias=False) 547 | self.bn3_list = nn.ModuleList([nn.BatchNorm2d(128, eps=1e-4, momentum=0.1, affine=affine_flag) for i in range(self.batch_num)]) 548 | self.pool2 = nn.AvgPool2d(kernel_size=2) # Default stride = kernel_size 549 | 550 | self.conv4 = nn.Conv2d(128, 256, kernel_size=self.ksize, stride=1, padding=1, bias=False) 551 | self.bn4_list = nn.ModuleList([nn.BatchNorm2d(256, eps=1e-4, momentum=0.1, affine=affine_flag) for i in range(self.batch_num)]) 552 | self.conv5 = nn.Conv2d(256, 256, kernel_size=self.ksize, stride=1, padding=1, bias=False) 553 | self.bn5_list = nn.ModuleList([nn.BatchNorm2d(256, eps=1e-4, momentum=0.1, affine=affine_flag) for i in range(self.batch_num)]) 554 | self.conv6 = nn.Conv2d(256, 256, kernel_size=self.ksize, stride=1, padding=1, bias=False) 555 | self.bn6_list = nn.ModuleList([nn.BatchNorm2d(256, eps=1e-4, momentum=0.1, affine=affine_flag) for i in range(self.batch_num)]) 556 | self.pool3 = nn.AvgPool2d(kernel_size=2) 557 | 558 | self.conv7 = nn.Conv2d(256, 512, kernel_size=self.ksize, stride=1, padding=1, bias=False) 559 | self.bn7_list = nn.ModuleList([nn.BatchNorm2d(512, eps=1e-4, momentum=0.1, affine=affine_flag) for i in range(self.batch_num)]) 560 | self.conv8 = nn.Conv2d(512, 512, kernel_size=self.ksize, stride=1, padding=1, bias=False) 561 | self.bn8_list = nn.ModuleList([nn.BatchNorm2d(512, eps=1e-4, momentum=0.1, affine=affine_flag) for i in range(self.batch_num)]) 562 | self.conv9 = nn.Conv2d(512, 512, kernel_size=self.ksize, stride=1, padding=1, bias=False) 563 | self.bn9_list = nn.ModuleList([nn.BatchNorm2d(512, eps=1e-4, momentum=0.1, affine=affine_flag) for i in range(self.batch_num)]) 564 | self.pool4 = nn.AvgPool2d(kernel_size=2) 565 | 566 | self.conv10 = nn.Conv2d(512, 512, kernel_size=self.ksize, stride=1, padding=1, bias=False) 567 | self.bn10_list = nn.ModuleList([nn.BatchNorm2d(512, eps=1e-4, momentum=0.1, affine=affine_flag) for i in range(self.batch_num)]) 568 | self.conv11 = nn.Conv2d(512, 512, kernel_size=self.ksize, stride=1, padding=1, bias=False) 569 | self.bn11_list = nn.ModuleList([nn.BatchNorm2d(512, eps=1e-4, momentum=0.1, affine=affine_flag) for i in range(self.batch_num)]) 570 | self.conv12 = nn.Conv2d(512, 512, kernel_size=self.ksize, stride=1, padding=1, bias=False) 571 | self.bn12_list = nn.ModuleList([nn.BatchNorm2d(512, eps=1e-4, momentum=0.1, affine=affine_flag) for i in range(self.batch_num)]) 572 | self.pool5 = nn.AdaptiveAvgPool2d((1,1)) 573 | 574 | 575 | 576 | self.fc1 = nn.Linear(512, 4096, bias=False) 577 | self.bnfc_list = nn.ModuleList([nn.BatchNorm1d( 4096, eps=1e-4, momentum=0.1, affine=affine_flag) for i in range(self.batch_num)]) 578 | 579 | self.fc2 = nn.Linear(4096, self.num_cls, bias=False) 580 | 581 | 582 | batchnormlist = [self.bn1_list, self.bn1_1_list, self.bn2_list, self.bn3_list, self.bn4_list, self.bn5_list, 583 | self.bn6_list, self.bn7_list,self.bn8_list,self.bn9_list,self.bn10_list, self.bn11_list,self.bn12_list, self.bnfc_list] 584 | 585 | # TODO turn off bias of batchnorm 586 | for bnlist in batchnormlist: 587 | for bnbn in bnlist: 588 | bnbn.bias = None 589 | 590 | # Initialize the firing thresholds of all the layers 591 | for m in self.modules(): 592 | if (isinstance(m, nn.Conv2d)): 593 | m.threshold = 1.0 594 | torch.nn.init.xavier_uniform_(m.weight, gain=2) 595 | elif (isinstance(m, nn.AvgPool2d)): 596 | m.threshold = 0.75 597 | elif (isinstance(m, nn.Linear)): 598 | m.threshold = 1.0 599 | torch.nn.init.xavier_uniform_(m.weight, gain=2) 600 | if (self.thresh_init_wnorm): 601 | lnorm = LA.norm(m.weight.data, self.lnorm_ord) 602 | thresh_init = lnorm * self.scale_thresh 603 | m.threshold = torch.from_numpy(np.array([thresh_init])).float().cuda() 604 | print('Wl{}norm: {:.2f}; Threshold: {:.2f}\n'.format(self.lnorm_ord, lnorm, m.threshold[0])) 605 | 606 | # Instantiate differentiable spiking nonlinearity 607 | self.spike_fn = init_spike_fn(self.grad_type) 608 | self.spike_pool = init_spike_fn(self.grad_type_pool) 609 | 610 | def fc_init(self): 611 | torch.nn.init.xavier_uniform_(self.fc1.weight) 612 | 613 | torch.nn.init.xavier_uniform_(self.fc2.weight) 614 | 615 | 616 | def forward(self, inp): 617 | # avg_spike_time = [] 618 | # Initialize the neuronal membrane potentials and dropout masks 619 | batch_size = inp.size(0) 620 | mem_conv1 = torch.zeros(batch_size, 64, self.img_size, self.img_size).cuda() 621 | mem_conv1_1 = torch.zeros(batch_size, 64, self.img_size, self.img_size).cuda() 622 | 623 | mem_conv2 = torch.zeros(batch_size, 128, self.img_size//2, self.img_size//2).cuda() 624 | mem_conv3 = torch.zeros(batch_size, 128, self.img_size//2, self.img_size//2).cuda() 625 | 626 | mem_conv4 = torch.zeros(batch_size, 256, self.img_size//4, self.img_size//4).cuda() 627 | mem_conv5 = torch.zeros(batch_size, 256, self.img_size//4, self.img_size//4).cuda() 628 | mem_conv6 = torch.zeros(batch_size, 256, self.img_size//4, self.img_size//4).cuda() 629 | 630 | mem_conv7 = torch.zeros(batch_size, 512, self.img_size//8, self.img_size//8).cuda() 631 | mem_conv8 = torch.zeros(batch_size, 512, self.img_size//8, self.img_size//8).cuda() 632 | mem_conv9 = torch.zeros(batch_size, 512, self.img_size//8, self.img_size//8).cuda() 633 | 634 | mem_conv10 = torch.zeros(batch_size, 512, self.img_size//16, self.img_size//16).cuda() 635 | mem_conv11 = torch.zeros(batch_size, 512, self.img_size//16, self.img_size//16).cuda() 636 | mem_conv12 = torch.zeros(batch_size, 512, self.img_size//16, self.img_size//16).cuda() 637 | 638 | 639 | 640 | 641 | 642 | mem_fc1 = torch.zeros(batch_size, 4096).cuda() 643 | mem_fc2 = torch.zeros(batch_size, self.num_cls).cuda() 644 | 645 | 646 | for t in range(self.num_steps): 647 | spike_inp = PoissonGen(inp) 648 | # spike_x = torch.cat([spike_inp]*22,1)[:,:64,:,:] 649 | out_prev = spike_inp 650 | 651 | # Compute the conv1 outputs 652 | mem_thr = (mem_conv1/self.conv1.threshold) - 1.0 653 | out = self.spike_fn(mem_thr) 654 | rst = torch.zeros_like(mem_conv1).cuda() 655 | rst[mem_thr>0] = self.conv1.threshold 656 | mem_conv1 = (self.leak_mem*mem_conv1 + self.bn1_list[int(t/self.one_stamp)](self.conv1(out_prev)) -rst) 657 | out_prev = out.clone() 658 | 659 | # Compute the conv1_1 outputs 660 | mem_thr = (mem_conv1_1 / self.conv1_1.threshold) - 1.0 661 | out = self.spike_fn(mem_thr) 662 | rst = torch.zeros_like(mem_conv1_1).cuda() 663 | rst[mem_thr > 0] = self.conv1_1.threshold 664 | mem_conv1_1 = (self.leak_mem * mem_conv1_1 + self.bn1_1_list[int(t/self.one_stamp)](self.conv1_1(out_prev)) - rst) 665 | out_prev = out.clone() 666 | 667 | 668 | # Compute the avgpool1 outputs 669 | out = self.pool1(out_prev) 670 | out_prev = out.clone() 671 | 672 | 673 | 674 | # Compute the conv2 outputs 675 | mem_thr = (mem_conv2/self.conv2.threshold) - 1.0 676 | out = self.spike_fn(mem_thr) 677 | rst = torch.zeros_like(mem_conv2).cuda() 678 | rst[mem_thr>0] = self.conv2.threshold 679 | mem_conv2 = (self.leak_mem*mem_conv2 + self.bn2_list[int(t/self.one_stamp)](self.conv2(out_prev)) -rst) 680 | out_prev = out.clone() 681 | 682 | # Compute the conv3 outputs 683 | mem_thr = (mem_conv3 / self.conv3.threshold) - 1.0 684 | out = self.spike_fn(mem_thr) 685 | rst = torch.zeros_like(mem_conv3).cuda() 686 | rst[mem_thr > 0] = self.conv3.threshold 687 | mem_conv3 = (self.leak_mem * mem_conv3 + self.bn3_list[int(t/self.one_stamp)](self.conv3(out_prev)) - rst) 688 | out_prev = out.clone() 689 | 690 | # Compute the avgpool2 outputs 691 | out = self.pool2(out_prev) 692 | out_prev = out.clone() 693 | 694 | # Compute the conv4 outputs 695 | mem_thr = (mem_conv4 / self.conv4.threshold) - 1.0 696 | out = self.spike_fn(mem_thr) 697 | rst = torch.zeros_like(mem_conv4).cuda() 698 | rst[mem_thr > 0] = self.conv4.threshold 699 | mem_conv4 = (self.leak_mem * mem_conv4 + self.bn4_list[int(t/self.one_stamp)](self.conv4(out_prev)) - rst) 700 | out_prev = out.clone() 701 | 702 | # Compute the conv5 outputs 703 | mem_thr = (mem_conv5 / self.conv5.threshold) - 1.0 704 | out = self.spike_fn(mem_thr) 705 | rst = torch.zeros_like(mem_conv5).cuda() 706 | rst[mem_thr > 0] = self.conv5.threshold 707 | mem_conv5 = (self.leak_mem * mem_conv5 + self.bn5_list[int(t/self.one_stamp)](self.conv5(out_prev)) - rst) 708 | out_prev = out.clone() 709 | 710 | # Compute the conv6 outputs 711 | mem_thr = (mem_conv6 / self.conv6.threshold) - 1.0 712 | out = self.spike_fn(mem_thr) 713 | rst = torch.zeros_like(mem_conv6).cuda() 714 | rst[mem_thr > 0] = self.conv6.threshold 715 | mem_conv6 = (self.leak_mem * mem_conv6 + self.bn6_list[int(t/self.one_stamp)](self.conv6(out_prev)) - rst) 716 | out_prev = out.clone() 717 | 718 | # Compute the avgpool3 outputs 719 | out = self.pool3(out_prev) 720 | out_prev = out.clone() 721 | 722 | 723 | 724 | # Compute the conv7 outputs 725 | mem_thr = (mem_conv7 / self.conv7.threshold) - 1.0 726 | out = self.spike_fn(mem_thr) 727 | rst = torch.zeros_like(mem_conv7).cuda() 728 | rst[mem_thr > 0] = self.conv7.threshold 729 | mem_conv7 = (self.leak_mem * mem_conv7 + self.bn7_list[int(t / self.one_stamp)](self.conv7(out_prev)) - rst) 730 | out_prev = out.clone() 731 | 732 | # Compute the conv8 outputs 733 | mem_thr = (mem_conv8 / self.conv8.threshold) - 1.0 734 | out = self.spike_fn(mem_thr) 735 | rst = torch.zeros_like(mem_conv8).cuda() 736 | rst[mem_thr > 0] = self.conv8.threshold 737 | mem_conv8 = (self.leak_mem * mem_conv8 + self.bn8_list[int(t / self.one_stamp)](self.conv8(out_prev)) - rst) 738 | out_prev = out.clone() 739 | 740 | # Compute the conv9 outputs 741 | mem_thr = (mem_conv9 / self.conv9.threshold) - 1.0 742 | out = self.spike_fn(mem_thr) 743 | rst = torch.zeros_like(mem_conv9).cuda() 744 | rst[mem_thr > 0] = self.conv9.threshold 745 | mem_conv9 = (self.leak_mem * mem_conv9 + self.bn9_list[int(t / self.one_stamp)](self.conv9(out_prev)) - rst) 746 | out_prev = out.clone() 747 | 748 | # Compute the avgpool4 outputs 749 | out = self.pool4(out_prev) 750 | out_prev = out.clone() 751 | 752 | 753 | 754 | # Compute the conv10 outputs 755 | mem_thr = (mem_conv10 / self.conv10.threshold) - 1.0 756 | out = self.spike_fn(mem_thr) 757 | rst = torch.zeros_like(mem_conv10).cuda() 758 | rst[mem_thr > 0] = self.conv10.threshold 759 | mem_conv10 = (self.leak_mem * mem_conv10 + self.bn10_list[int(t / self.one_stamp)](self.conv10(out_prev)) - rst) 760 | out_prev = out.clone() 761 | 762 | # Compute the conv11 outputs 763 | mem_thr = (mem_conv11 / self.conv11.threshold) - 1.0 764 | out = self.spike_fn(mem_thr) 765 | rst = torch.zeros_like(mem_conv11).cuda() 766 | rst[mem_thr > 0] = self.conv11.threshold 767 | mem_conv11 = (self.leak_mem * mem_conv11 + self.bn11_list[int(t / self.one_stamp)](self.conv11(out_prev)) - rst) 768 | out_prev = out.clone() 769 | 770 | # Compute the conv12 outputs 771 | mem_thr = (mem_conv12 / self.conv12.threshold) - 1.0 772 | out = self.spike_fn(mem_thr) 773 | rst = torch.zeros_like(mem_conv12).cuda() 774 | rst[mem_thr > 0] = self.conv12.threshold 775 | mem_conv12 = (self.leak_mem * mem_conv12 + self.bn12_list[int(t / self.one_stamp)](self.conv12(out_prev)) - rst) 776 | out_prev = out.clone() 777 | 778 | # Compute the avgpool5 outputs 779 | out = self.pool5(out_prev) 780 | out_prev = out.clone() 781 | 782 | out_prev = out_prev.reshape(batch_size, -1) 783 | 784 | # compute fc1 785 | mem_thr = (mem_fc1 / self.fc1.threshold) - 1.0 786 | out = self.spike_fn(mem_thr) 787 | rst = torch.zeros_like(mem_fc1).cuda() 788 | rst[mem_thr > 0] = self.fc1.threshold 789 | mem_fc1 = (self.leak_mem * mem_fc1 + self.bnfc_list[int(t/self.one_stamp)](self.fc1(out_prev)) - rst) 790 | out_prev = out.clone() 791 | 792 | # # TODO last spike expectation 793 | # avg_spike = out_prev.sum(1).sum(0) / out_prev.size(1) / out_prev.size(0) 794 | # avg_spike_time.append(float(avg_spike.cpu().data.numpy())) 795 | 796 | # compute fc1 797 | mem_fc2 = (1 * mem_fc2 + self.fc2(out_prev)) 798 | 799 | 800 | 801 | out_voltage = mem_fc2 802 | out_voltage = (out_voltage) / self.num_steps 803 | 804 | 805 | return out_voltage 806 | # ---------------------------------------------- 807 | 808 | 809 | # -------------------------------------------------- 810 | # Define a class for recording the SNN train/test loss. 811 | # This class is replicated from Chankyu Lee's SNN-backprop code. 812 | # -------------------------------------------------- 813 | class AverageMeter(object): 814 | """ 815 | Computes and stores the average and current value 816 | """ 817 | 818 | def __init__(self): 819 | self.reset() 820 | 821 | def reset(self): 822 | self.val = 0 823 | self.avg = 0 824 | self.sum = 0 825 | self.count = 0 826 | 827 | def update(self, val, n=1): 828 | self.val = val 829 | self.sum += val * n 830 | self.count += n 831 | self.avg = self.sum / self.count 832 | 833 | 834 | 835 | 836 | 837 | 838 | class SNN_VGG11_TBN(nn.Module): 839 | def __init__(self, dt=0.001, t_end=0.100, inp_rate=100, grad_type='Linear', thresh_init_wnorm=False, 840 | leak_mem=0.99, img_size=32, inp_maps=3, c1_maps=64, c2_maps=64, ksize=3, fc0_size=200, 841 | num_cls=1000, drop_rate=0.5, use_max_out_over_time=False): 842 | super(SNN_VGG11_TBN, self).__init__() 843 | 844 | # ConvSNN architecture parameters 845 | self.img_size = img_size 846 | self.inp_maps = inp_maps 847 | self.c1_maps = 64 848 | self.c1_dim = self.img_size 849 | self.c2_maps = 128 850 | self.c3_maps = 256 851 | self.c4_maps = 512 852 | 853 | 854 | self.ksize = ksize 855 | self.fc0_size = fc0_size 856 | self.num_cls = num_cls 857 | 858 | # ConvSNN simulation parameters 859 | self.dt = dt 860 | self.t_end = t_end 861 | self.num_steps = int(self.t_end / self.dt) 862 | self.inp_rate = inp_rate 863 | self.inp_rescale_fac = 1.0 / (self.dt * self.inp_rate) 864 | self.grad_type = grad_type 865 | self.grad_type_pool = 'PassThru' 866 | self.thresh_init_wnorm = thresh_init_wnorm 867 | self.leak_mem = leak_mem#0.95 # leak_mem 868 | self.drop_rate = drop_rate 869 | self.lnorm_ord = 2 870 | self.scale_thresh = 1.0 871 | self.use_max_out_over_time = use_max_out_over_time 872 | 873 | self.dropout_layer = nn.Dropout2d(p=0.2) 874 | 875 | self.one_stamp = 1 876 | self.batch_num = self.num_steps // self.one_stamp 877 | 878 | print (">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>") 879 | print ("***** time step per batchnorm", self.batch_num) 880 | print (">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>") 881 | affine_flag = True 882 | bias_flag = False 883 | # Instantiate the ConvSNN layers 884 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=bias_flag) 885 | self.bn1_list = nn.ModuleList([nn.BatchNorm2d(64, eps=1e-4, momentum=0.1, affine=affine_flag) for i in range(self.batch_num)]) 886 | self.pool1 = nn.AvgPool2d(kernel_size=2) # Default stride = kernel_size 887 | 888 | self.conv2 = nn.Conv2d(64, 128, kernel_size=self.ksize, stride=1, padding=1, bias=bias_flag) 889 | self.bn2_list = nn.ModuleList([nn.BatchNorm2d(128, eps=1e-4, momentum=0.1, affine=affine_flag) for i in range(self.batch_num)]) 890 | self.pool2 = nn.AvgPool2d(kernel_size=2) # Default stride = kernel_size 891 | 892 | 893 | self.conv3 = nn.Conv2d(128, 256, kernel_size=self.ksize, stride=1, padding=1, bias=bias_flag) 894 | self.bn3_list = nn.ModuleList([nn.BatchNorm2d(256, eps=1e-4, momentum=0.1, affine=affine_flag) for i in range(self.batch_num)]) 895 | self.conv4 = nn.Conv2d(256, 256, kernel_size=self.ksize, stride=1, padding=1, bias=bias_flag) 896 | self.bn4_list = nn.ModuleList([nn.BatchNorm2d(256, eps=1e-4, momentum=0.1, affine=affine_flag) for i in range(self.batch_num)]) 897 | self.pool3 = nn.AvgPool2d(kernel_size=2) # Default stride = kernel_size 898 | 899 | self.conv5 = nn.Conv2d(256, 512, kernel_size=self.ksize, stride=1, padding=1, bias=bias_flag) 900 | self.bn5_list = nn.ModuleList([nn.BatchNorm2d(512, eps=1e-4, momentum=0.1, affine=affine_flag) for i in range(self.batch_num)]) 901 | self.conv6 = nn.Conv2d(512, 512, kernel_size=self.ksize, stride=1, padding=1, bias=bias_flag) 902 | self.bn6_list = nn.ModuleList([nn.BatchNorm2d(512, eps=1e-4, momentum=0.1, affine=affine_flag) for i in range(self.batch_num)]) 903 | self.pool4 = nn.AvgPool2d(kernel_size=2) # Default stride = kernel_size 904 | 905 | self.conv7 = nn.Conv2d(512, 512, kernel_size=self.ksize, stride=1, padding=1, bias=bias_flag) 906 | self.bn7_list = nn.ModuleList([nn.BatchNorm2d(512, eps=1e-4, momentum=0.1, affine=affine_flag) for i in range(self.batch_num)]) 907 | self.conv8 = nn.Conv2d(512, 512, kernel_size=self.ksize, stride=1, padding=1, bias=bias_flag) 908 | self.bn8_list = nn.ModuleList([nn.BatchNorm2d(512, eps=1e-4, momentum=0.1, affine=affine_flag) for i in range(self.batch_num)]) 909 | self.pool5 = nn.AvgPool2d(kernel_size=2) # Default stride = kernel_size 910 | 911 | self.drop = nn.Dropout(p=0.2) 912 | 913 | self.avg_pool = nn.AdaptiveAvgPool2d((1,1)) 914 | 915 | # self.fc1 = nn.Linear((self.img_size//32)*(self.img_size //32)*512, 4*2*2*256, bias=bias_flag) 916 | self.fc1 = nn.Linear(512, 4*2*2*256, bias=bias_flag) 917 | 918 | self.bnfc_list = nn.ModuleList([nn.BatchNorm1d( 4*2*2*256, eps=1e-4, momentum=0.1, affine=affine_flag) for i in range(self.batch_num)]) 919 | 920 | self.fc2 = nn.Linear(4*2*2*256, self.num_cls, bias=bias_flag) 921 | 922 | batchnormlist = [self.bn1_list, self.bn2_list, self.bn3_list, self.bn4_list, self.bn5_list, 923 | self.bn6_list, self.bn7_list,self.bn8_list,self.bnfc_list] 924 | 925 | #TODO turn off bias of batchnorm 926 | for bnlist in batchnormlist: 927 | for bnbn in bnlist: 928 | bnbn.bias = None 929 | 930 | 931 | # Initialize the firing thresholds of all the layers 932 | for m in self.modules(): 933 | if (isinstance(m, nn.Conv2d)): 934 | m.threshold = 1.0 935 | torch.nn.init.xavier_normal_(m.weight,gain=2) 936 | elif (isinstance(m, nn.AvgPool2d)): 937 | m.threshold = 0.75 938 | elif (isinstance(m, nn.Linear)): 939 | m.threshold = 1.0 940 | torch.nn.init.xavier_normal_(m.weight,gain=2) 941 | 942 | 943 | # Instantiate differentiable spiking nonlinearity 944 | self.spike_fn = init_spike_fn(self.grad_type) 945 | self.spike_pool = init_spike_fn(self.grad_type_pool) 946 | 947 | def fc_init(self): 948 | torch.nn.init.xavier_uniform_(self.fc1.weight) 949 | 950 | torch.nn.init.xavier_uniform_(self.fc2.weight) 951 | 952 | 953 | def forward(self, inp): 954 | 955 | 956 | # avg_spike_time = [] 957 | # Initialize the neuronal membrane potentials and dropout masks 958 | batch_size = inp.size(0) 959 | mem_conv1 = torch.zeros(batch_size, 64, self.img_size, self.img_size).cuda() 960 | mem_conv2 = torch.zeros(batch_size, 128, self.img_size//2, self.img_size//2).cuda() 961 | mem_conv3 = torch.zeros(batch_size, 256, self.img_size//4, self.img_size//4).cuda() 962 | mem_conv4 = torch.zeros(batch_size, 256, self.img_size//4, self.img_size//4).cuda() 963 | mem_conv5 = torch.zeros(batch_size, 512, self.img_size//8, self.img_size//8).cuda() 964 | mem_conv6 = torch.zeros(batch_size, 512, self.img_size//8, self.img_size//8).cuda() 965 | mem_conv7 = torch.zeros(batch_size, 512, self.img_size // 16, self.img_size // 16).cuda() 966 | mem_conv8 = torch.zeros(batch_size, 512, self.img_size // 16, self.img_size // 16).cuda() 967 | 968 | 969 | mem_fc1 = torch.zeros(batch_size, 4*1024).cuda() 970 | mem_fc2 = torch.zeros(batch_size, self.num_cls).cuda() 971 | 972 | fc_dropout_mask = self.drop(torch.ones([batch_size, 1024]).cuda()) 973 | 974 | for t in range(self.num_steps): 975 | spike_inp = PoissonGen(inp) 976 | # spike_x = torch.cat([spike_inp]*22,1)[:,:64,:,:] 977 | out_prev = spike_inp 978 | 979 | # Compute the conv1 outputs 980 | mem_thr = (mem_conv1/self.conv1.threshold) - 1.0 981 | out = self.spike_fn(mem_thr) 982 | rst = torch.zeros_like(mem_conv1).cuda() 983 | rst[mem_thr>0] = self.conv1.threshold 984 | mem_conv1 = (self.leak_mem*mem_conv1 + self.bn1_list[int(t/self.one_stamp)](self.conv1(out_prev)) -rst) 985 | out_prev = out.clone() 986 | 987 | # Compute the avgpool1 outputs 988 | out = self.pool1(out_prev) 989 | out_prev = out.clone() 990 | 991 | 992 | # Compute the conv2 outputs 993 | mem_thr = (mem_conv2/self.conv2.threshold) - 1.0 994 | out = self.spike_fn(mem_thr) 995 | rst = torch.zeros_like(mem_conv2).cuda() 996 | rst[mem_thr>0] = self.conv2.threshold 997 | mem_conv2 = (self.leak_mem*mem_conv2 + self.bn2_list[int(t/self.one_stamp)](self.conv2(out_prev)) -rst) 998 | out_prev = out.clone() 999 | 1000 | # Compute the avgpool2 outputs 1001 | out = self.pool2(out_prev) 1002 | out_prev = out.clone() 1003 | 1004 | 1005 | 1006 | # Compute the conv3 outputs 1007 | mem_thr = (mem_conv3 / self.conv3.threshold) - 1.0 1008 | out = self.spike_fn(mem_thr) 1009 | rst = torch.zeros_like(mem_conv3).cuda() 1010 | rst[mem_thr > 0] = self.conv3.threshold 1011 | mem_conv3 = (self.leak_mem * mem_conv3 + self.bn3_list[int(t/self.one_stamp)](self.conv3(out_prev)) - rst) 1012 | out_prev = out.clone() 1013 | 1014 | # Compute the conv4 outputs 1015 | mem_thr = (mem_conv4 / self.conv4.threshold) - 1.0 1016 | out = self.spike_fn(mem_thr) 1017 | rst = torch.zeros_like(mem_conv4).cuda() 1018 | rst[mem_thr > 0] = self.conv4.threshold 1019 | mem_conv4 = (self.leak_mem * mem_conv4 + self.bn4_list[int(t/self.one_stamp)](self.conv4(out_prev)) - rst) 1020 | out_prev = out.clone() 1021 | 1022 | # Compute the avgpool3 outputs 1023 | out = self.pool3(out_prev) 1024 | out_prev = out.clone() 1025 | 1026 | 1027 | 1028 | # Compute the conv5 outputs 1029 | mem_thr = (mem_conv5 / self.conv5.threshold) - 1.0 1030 | out = self.spike_fn(mem_thr) 1031 | rst = torch.zeros_like(mem_conv5).cuda() 1032 | rst[mem_thr > 0] = self.conv5.threshold 1033 | mem_conv5 = (self.leak_mem * mem_conv5 + self.bn5_list[int(t/self.one_stamp)](self.conv5(out_prev)) - rst) 1034 | out_prev = out.clone() 1035 | 1036 | # Compute the conv6 outputs 1037 | mem_thr = (mem_conv6 / self.conv6.threshold) - 1.0 1038 | out = self.spike_fn(mem_thr) 1039 | rst = torch.zeros_like(mem_conv6).cuda() 1040 | rst[mem_thr > 0] = self.conv6.threshold 1041 | mem_conv6 = (self.leak_mem * mem_conv6 + self.bn6_list[int(t/self.one_stamp)](self.conv6(out_prev)) - rst) 1042 | out_prev = out.clone() 1043 | 1044 | # Compute the avgpool4 outputs 1045 | out = self.pool4(out_prev) 1046 | out_prev = out.clone() 1047 | 1048 | 1049 | 1050 | # Compute the conv7 outputs 1051 | mem_thr = (mem_conv7 / self.conv7.threshold) - 1.0 1052 | out = self.spike_fn(mem_thr) 1053 | rst = torch.zeros_like(mem_conv7).cuda() 1054 | rst[mem_thr > 0] = self.conv7.threshold 1055 | mem_conv7 = (self.leak_mem * mem_conv7 + self.bn7_list[int(t / self.one_stamp)](self.conv7(out_prev)) - rst) 1056 | out_prev = out.clone() 1057 | 1058 | # Compute the conv8 outputs 1059 | mem_thr = (mem_conv8 / self.conv8.threshold) - 1.0 1060 | out = self.spike_fn(mem_thr) 1061 | rst = torch.zeros_like(mem_conv8).cuda() 1062 | rst[mem_thr > 0] = self.conv8.threshold 1063 | mem_conv8 = (self.leak_mem * mem_conv8 + self.bn8_list[int(t / self.one_stamp)](self.conv8(out_prev)) - rst) 1064 | out_prev = out.clone() 1065 | 1066 | # Compute the avgpool5 outputs 1067 | out = self.avg_pool(out_prev) 1068 | out_prev = out.clone() 1069 | 1070 | 1071 | 1072 | 1073 | out_prev = out_prev.reshape(batch_size, -1) 1074 | 1075 | 1076 | # compute fc1 1077 | mem_thr = (mem_fc1 / self.fc1.threshold) - 1.0 1078 | out = self.spike_fn(mem_thr) 1079 | rst = torch.zeros_like(mem_fc1).cuda() 1080 | rst[mem_thr > 0] = self.fc1.threshold 1081 | mem_fc1 = (self.leak_mem * mem_fc1 + self.bnfc_list[int(t/self.one_stamp)](self.fc1(out_prev)) - rst) 1082 | 1083 | out_prev = out.clone() 1084 | 1085 | # compute fc1 1086 | mem_fc2 = (1 * mem_fc2 + self.fc2(out_prev)) 1087 | 1088 | 1089 | out_voltage = mem_fc2 1090 | out_voltage = (out_voltage) / self.num_steps 1091 | 1092 | 1093 | return out_voltage 1094 | # ---------------------------------------------- 1095 | 1096 | 1097 | --------------------------------------------------------------------------------