├── .gitignore ├── LICENSE ├── README.md ├── functions ├── __init__.py ├── data_loaders.py └── functions.py ├── main.py └── models ├── VGG_models.py └── layers.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled source # 2 | ################### 3 | *.com 4 | *.class 5 | *.dll 6 | *.exe 7 | *.o 8 | *.so 9 | 10 | # Packages # 11 | ############ 12 | # it's better to unpack these files and commit the raw source 13 | # git has its own built in compression methods 14 | *.7z 15 | *.dmg 16 | *.gz 17 | *.iso 18 | *.jar 19 | *.rar 20 | 21 | # Logs and databases # 22 | ###################### 23 | *.log 24 | *.sql 25 | *.sqlite 26 | 27 | # OS generated files # 28 | ###################### 29 | .DS_Store 30 | .DS_Store? 31 | ._* 32 | *.idea* 33 | .Spotlight-V100 34 | .Trashes 35 | *.xml 36 | *.iml 37 | *.pyc 38 | ehthumbs.db 39 | Thumbs.db 40 | 41 | # Ignore ImageNet Pre-trained Models # 42 | ###################################### 43 | *ImageNet/checkpoint/* -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Intelligent Computing Lab at Yale University 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # NDA_SNN 2 | 3 | Pytorch implementation of Neuromorphic Data Augmentation for SNN, Accepted to ECCV 2022. 4 | Paper link: [Neuromorphic Data Augmentation for Training Spiking Neural Networks](https://arxiv.org/pdf/2203.06145.pdf). 5 | 6 | 7 | ## Dataset Preparation 8 | 9 | For CIFAR10-DVS dataset, please refer the Google Drive link below: 10 | 11 | + [Training set](https://drive.google.com/file/d/1pzYnhoUvtcQtxk_Qmy4d2VrhWhy5R-t9/view?usp=sharing) 12 | + [Test set](https://drive.google.com/file/d/1q1k6JJgVH3ZkHWMg2zPtrZak9jRP6ggG/view?usp=sharing) 13 | 14 | For N-Caltech 101, we suggest using [SpikingJelly](https://github.com/fangwei123456/spikingjelly) package to pre-process the data. 15 | Specifically, initialize the `NCaltech101` in SpikingJelly as: 16 | 17 | ```python 18 | from spikingjelly.datasets.n_caltech101 import NCaltech101 19 | dataset = NCaltech101(root='data', data_type='frame', frames_number=10, split_by='time') 20 | ``` 21 | If you can initialize this class, then you will be able to use our provided dataloader in `functions/data_loaders.py` 22 | 23 | 24 | ## Run Experiments 25 | 26 | To run a VGG-11 without NDA on CIFAR10-DVS: 27 | 28 | `python main.py --dset dc10 --amp` 29 | 30 | Here, `--amp` use FP16 training which can accelerate the training stage. 31 | Use `--dset nc101` to change the dataset to NCaltech 101. 32 | 33 | To enable NDA training: 34 | 35 | `python main.py --dset dc10 --amp --nda` 36 | 37 | ### Reference 38 | 39 | If you find our work is interesting, please consider cite us: 40 | 41 | ```bibtex 42 | @article{li2022neuromorphic, 43 | title={Neuromorphic Data Augmentation for Training Spiking Neural Networks}, 44 | author={Li, Yuhang and Kim, Youngeun and Park, Hyoungseob and Geller, Tamar and Panda, Priyadarshini}, 45 | journal={arXiv preprint arXiv:2203.06145}, 46 | year={2022} 47 | } 48 | ``` 49 | 50 | 51 | -------------------------------------------------------------------------------- /functions/__init__.py: -------------------------------------------------------------------------------- 1 | from functions.data_loaders import build_ncaltech, build_dvscifar 2 | from functions.functions import seed_all -------------------------------------------------------------------------------- /functions/data_loaders.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import random 4 | import torchvision.transforms as transforms 5 | from torch.utils.data import Dataset, DataLoader 6 | import warnings 7 | import os 8 | import numpy as np 9 | from os.path import isfile, join 10 | 11 | warnings.filterwarnings('ignore') 12 | 13 | 14 | class Cutout(object): 15 | """Randomly mask out one or more patches from an image. 16 | Args: 17 | n_holes (int): Number of patches to cut out of each image. 18 | length (int): The length (in pixels) of each square patch. 19 | """ 20 | 21 | def __init__(self, length): 22 | self.length = length 23 | 24 | def __call__(self, img): 25 | h = img.size(2) 26 | w = img.size(3) 27 | mask = np.ones((h, w), np.float32) 28 | y = np.random.randint(h) 29 | x = np.random.randint(w) 30 | y1 = np.clip(y - self.length // 2, 0, h) 31 | y2 = np.clip(y + self.length // 2, 0, h) 32 | x1 = np.clip(x - self.length // 2, 0, w) 33 | x2 = np.clip(x + self.length // 2, 0, w) 34 | mask[y1: y2, x1: x2] = 0. 35 | mask = torch.from_numpy(mask) 36 | mask = mask.expand_as(img) 37 | img = img * mask 38 | return img 39 | 40 | 41 | class NCaltech101(Dataset): 42 | def __init__(self, data_path='data/n-caltech/frames_number_10_split_by_number', 43 | data_type='train', transform=False): 44 | 45 | self.filepath = os.path.join(data_path) 46 | self.clslist = os.listdir(self.filepath) 47 | self.clslist.sort() 48 | 49 | self.dvs_filelist = [] 50 | self.targets = [] 51 | self.resize = transforms.Resize(size=(48, 48), interpolation=torchvision.transforms.InterpolationMode.NEAREST) 52 | 53 | for i, cls in enumerate(self.clslist): 54 | # print (i, cls) 55 | file_list = os.listdir(os.path.join(self.filepath, cls)) 56 | num_file = len(file_list) 57 | 58 | cut_idx = int(num_file * 0.9) 59 | train_file_list = file_list[:cut_idx] 60 | test_split_list = file_list[cut_idx:] 61 | for file in file_list: 62 | if data_type == 'train': 63 | if file in train_file_list: 64 | self.dvs_filelist.append(os.path.join(self.filepath, cls, file)) 65 | self.targets.append(i) 66 | else: 67 | if file in test_split_list: 68 | self.dvs_filelist.append(os.path.join(self.filepath, cls, file)) 69 | self.targets.append(i) 70 | 71 | self.data_num = len(self.dvs_filelist) 72 | self.data_type = data_type 73 | if data_type != 'train': 74 | counts = np.unique(np.array(self.targets), return_counts=True)[1] 75 | class_weights = counts.sum() / (counts * len(counts)) 76 | self.class_weights = torch.Tensor(class_weights) 77 | self.classes = range(101) 78 | self.transform = transform 79 | self.rotate = transforms.RandomRotation(degrees=15) 80 | self.shearx = transforms.RandomAffine(degrees=0, shear=(-15, 15)) 81 | 82 | def __getitem__(self, index): 83 | file_pth = self.dvs_filelist[index] 84 | label = self.targets[index] 85 | data = torch.from_numpy(np.load(file_pth)['frames']).float() 86 | data = self.resize(data) 87 | 88 | if self.transform: 89 | 90 | choices = ['roll', 'rotate', 'shear'] 91 | aug = np.random.choice(choices) 92 | if aug == 'roll': 93 | off1 = random.randint(-3, 3) 94 | off2 = random.randint(-3, 3) 95 | data = torch.roll(data, shifts=(off1, off2), dims=(2, 3)) 96 | if aug == 'rotate': 97 | data = self.rotate(data) 98 | if aug == 'shear': 99 | data = self.shearx(data) 100 | 101 | return data, label 102 | 103 | def __len__(self): 104 | return self.data_num 105 | 106 | 107 | def build_ncaltech(transform=False): 108 | train_dataset = NCaltech101(transform=transform) 109 | val_dataset = NCaltech101(data_type='test', transform=False) 110 | 111 | return train_dataset, val_dataset 112 | 113 | 114 | class DVSCifar10(Dataset): 115 | def __init__(self, root, train=True, transform=None, target_transform=None): 116 | self.root = os.path.expanduser(root) 117 | self.transform = transform 118 | self.target_transform = target_transform 119 | self.train = train 120 | self.resize = transforms.Resize(size=(48, 48), interpolation=torchvision.transforms.InterpolationMode.NEAREST) 121 | self.rotate = transforms.RandomRotation(degrees=30) 122 | self.shearx = transforms.RandomAffine(degrees=0, shear=(-30, 30)) 123 | 124 | def __getitem__(self, index): 125 | """ 126 | Args: 127 | index (int): Index 128 | Returns: 129 | tuple: (image, target) where target is index of the target class. 130 | """ 131 | data, target = torch.load(self.root + '/{}.pt'.format(index)) 132 | data = self.resize(data.permute([3, 0, 1, 2])) 133 | 134 | if self.transform: 135 | 136 | choices = ['roll', 'rotate', 'shear'] 137 | aug = np.random.choice(choices) 138 | if aug == 'roll': 139 | off1 = random.randint(-5, 5) 140 | off2 = random.randint(-5, 5) 141 | data = torch.roll(data, shifts=(off1, off2), dims=(2, 3)) 142 | if aug == 'rotate': 143 | data = self.rotate(data) 144 | if aug == 'shear': 145 | data = self.shearx(data) 146 | 147 | return data, target.long().squeeze(-1) 148 | 149 | def __len__(self): 150 | return len(os.listdir(self.root)) 151 | 152 | 153 | def build_dvscifar(path='data/cifar-dvs', transform=False): 154 | train_path = path + '/train' 155 | val_path = path + '/test' 156 | train_dataset = DVSCifar10(root=train_path, transform=transform) 157 | val_dataset = DVSCifar10(root=val_path, transform=False) 158 | 159 | return train_dataset, val_dataset 160 | 161 | 162 | def mixup_criterion(criterion, pred, y_a, y_b, lam): 163 | return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b) 164 | 165 | 166 | def rand_bbox(size, lam): 167 | W = size[3] 168 | H = size[4] 169 | cut_rat = np.sqrt(1. - lam) 170 | cut_w = np.int(W * cut_rat) 171 | cut_h = np.int(H * cut_rat) 172 | 173 | # uniform 174 | cx = np.random.randint(W) 175 | cy = np.random.randint(H) 176 | 177 | bbx1 = np.clip(cx - cut_w // 2, 0, W) 178 | bby1 = np.clip(cy - cut_h // 2, 0, H) 179 | bbx2 = np.clip(cx + cut_w // 2, 0, W) 180 | bby2 = np.clip(cy + cut_h // 2, 0, H) 181 | 182 | return bbx1, bby1, bbx2, bby2 183 | 184 | 185 | def cutmix_data(input, target, alpha=1.0): 186 | lam = np.random.beta(alpha, alpha) 187 | rand_index = torch.randperm(input.size()[0]).cuda() 188 | 189 | target_a = target 190 | target_b = target[rand_index] 191 | 192 | # generate mixed sample 193 | bbx1, bby1, bbx2, bby2 = rand_bbox(input.size(), lam) 194 | input[:, :, :, bbx1:bbx2, bby1:bby2] = input[rand_index, :, :, bbx1:bbx2, bby1:bby2] 195 | # adjust lambda to exactly match pixel ratio 196 | lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (input.size()[-1] * input.size()[-2])) 197 | return input, target_a, target_b, lam 198 | 199 | 200 | if __name__ == '__main__': 201 | choices = ['roll', 'rotate', 'shear'] 202 | aug = np.random.choice(choices) 203 | print(aug) 204 | 205 | 206 | -------------------------------------------------------------------------------- /functions/functions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import random 4 | import os 5 | import numpy as np 6 | import logging 7 | 8 | 9 | def seed_all(seed=1029): 10 | random.seed(seed) 11 | os.environ['PYTHONHASHSEED'] = str(seed) 12 | np.random.seed(seed) 13 | torch.manual_seed(seed) 14 | torch.cuda.manual_seed(seed) 15 | torch.cuda.manual_seed_all(seed) # if you are using multi-GPU. 16 | torch.backends.cudnn.benchmark = False 17 | torch.backends.cudnn.deterministic = True 18 | 19 | 20 | def get_logger(filename, verbosity=1, name=None): 21 | level_dict = {0: logging.DEBUG, 1: logging.INFO, 2: logging.WARNING} 22 | formatter = logging.Formatter( 23 | "[%(asctime)s][%(filename)s][line:%(lineno)d][%(levelname)s] %(message)s" 24 | ) 25 | logger = logging.getLogger(name) 26 | logger.setLevel(level_dict[verbosity]) 27 | 28 | fh = logging.FileHandler(filename, "w") 29 | fh.setFormatter(formatter) 30 | logger.addHandler(fh) 31 | 32 | sh = logging.StreamHandler() 33 | sh.setFormatter(formatter) 34 | logger.addHandler(sh) 35 | 36 | return logger -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import shutil 3 | import os 4 | import time 5 | import torch 6 | import logging as logger 7 | import torch.nn as nn 8 | from torch import autocast 9 | from torch.cuda.amp import GradScaler 10 | from models.VGG_models import vgg11 11 | from functions import seed_all, build_ncaltech, build_dvscifar 12 | 13 | parser = argparse.ArgumentParser(description='PyTorch Neuromorphic Data Augmentation') 14 | parser.add_argument('-j', '--workers', default=8, type=int, metavar='N', 15 | help='number of data loading workers (default: 10)') 16 | parser.add_argument('--epochs', default=200, type=int, metavar='N', 17 | help='number of total epochs to run') 18 | parser.add_argument('--dset', default='nc101', type=str, metavar='N', choices=['nc101', 'dc10'], 19 | help='dataset') 20 | parser.add_argument('--model', default='vgg11', type=str, metavar='N', choices=[ 'vgg11'], 21 | help='neural network architecture') 22 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', 23 | help='manual epoch number (useful on restarts)') 24 | parser.add_argument('-b', '--batch_size', default=256, type=int, metavar='N', 25 | help='mini-batch size (default: 256), this is the total ' 26 | 'batch size of all GPUs on the current node when ' 27 | 'using Data Parallel or Distributed Data Parallel') 28 | parser.add_argument('--lr', '--learning_rate', default=0.001, type=float, metavar='LR', help='initial learning rate', 29 | dest='lr') 30 | parser.add_argument('--seed', default=1000, type=int, 31 | help='seed for initializing training. ') 32 | parser.add_argument('-T', '--time', default=10, type=int, metavar='N', 33 | help='snn simulation time (default: 2)') 34 | parser.add_argument('--amp', action='store_true', 35 | help='if use amp training.') 36 | parser.add_argument('--nda', action='store_true', 37 | help='if use neuromorphic data augmentation.') 38 | args = parser.parse_args() 39 | 40 | 41 | def train(model, device, train_loader, criterion, optimizer, epoch, scaler, args): 42 | running_loss = 0 43 | model.train() 44 | M = len(train_loader) 45 | total = 0 46 | correct = 0 47 | s_time = time.time() 48 | for i, (images, labels) in enumerate(train_loader): 49 | optimizer.zero_grad() 50 | labels = labels.to(device) 51 | images = images.to(device) 52 | 53 | if args.amp: 54 | with autocast(device_type='cuda', dtype=torch.float16): 55 | outputs = model(images) 56 | mean_out = outputs.mean(1) 57 | loss = criterion(mean_out, labels) 58 | scaler.scale(loss.mean()).backward() 59 | scaler.step(optimizer) 60 | scaler.update() 61 | else: 62 | outputs = model(images) 63 | mean_out = outputs.mean(1) 64 | loss = criterion(mean_out, labels) 65 | loss.mean().backward() 66 | optimizer.step() 67 | 68 | running_loss += loss.item() 69 | total += float(labels.size(0)) 70 | _, predicted = mean_out.cpu().max(1) 71 | correct += float(predicted.eq(labels.cpu()).sum().item()) 72 | e_time = time.time() 73 | return running_loss / M, 100 * correct / total, (e_time-s_time)/60 74 | 75 | 76 | @torch.no_grad() 77 | def test(model, test_loader, device): 78 | correct = 0 79 | total = 0 80 | model.eval() 81 | for batch_idx, (inputs, targets) in enumerate(test_loader): 82 | inputs = inputs.to(device) 83 | outputs = model(inputs) 84 | mean_out = outputs.mean(1) 85 | _, predicted = mean_out.cpu().max(1) 86 | total += float(targets.size(0)) 87 | correct += float(predicted.eq(targets).sum().item()) 88 | 89 | correct = torch.tensor([correct]).cuda() 90 | total = torch.tensor([total]).cuda() 91 | final_acc = 100 * correct / total 92 | return final_acc.item() 93 | 94 | 95 | if __name__ == '__main__': 96 | 97 | seed_all(args.seed) 98 | 99 | if args.dset == 'nc101': 100 | train_dataset, val_dataset = build_ncaltech(transform=args.nda) 101 | num_cls = 101 102 | in_c = 2 103 | elif args.dset == 'dc10': 104 | train_dataset, val_dataset = build_dvscifar(transform=args.nda) 105 | num_cls = 10 106 | in_c = 2 107 | else: 108 | raise NotImplementedError 109 | 110 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, 111 | num_workers=args.workers, pin_memory=True) 112 | test_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, 113 | num_workers=args.workers, pin_memory=True) 114 | 115 | if args.model == 'vgg11': 116 | model = vgg11(in_c=in_c, num_classes=num_cls) 117 | else: 118 | raise NotImplementedError 119 | 120 | model.T = args.time 121 | model.cuda() 122 | device = next(model.parameters()).device 123 | 124 | scaler = GradScaler() if args.amp else None 125 | 126 | criterion = nn.CrossEntropyLoss().to(device) 127 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr/256 * args.batch_size, weight_decay=1e-4) 128 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, eta_min=0, T_max=args.epochs) 129 | print('start training!') 130 | for epoch in range(args.epochs): 131 | 132 | loss, acc, t_diff = train(model, device, train_loader, criterion, optimizer, epoch, scaler, args) 133 | print('Epoch:[{}/{}]\t loss={:.5f}\t acc={:.3f},\t time elapsed: {}'.format(epoch, args.epochs, loss, acc, 134 | t_diff)) 135 | scheduler.step() 136 | facc = test(model, test_loader, device) 137 | print('Epoch:[{}/{}]\t Test acc={:.3f}'.format(epoch, args.epochs, facc)) 138 | -------------------------------------------------------------------------------- /models/VGG_models.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Modified from https://github.com/pytorch/vision.git 3 | ''' 4 | import math 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.init as init 9 | from models.layers import * 10 | 11 | __all__ = [ 12 | 'VGG', 'vgg11', 'vgg13', 'vgg16', 13 | ] 14 | 15 | 16 | class VGG(nn.Module): 17 | ''' 18 | VGG model 19 | ''' 20 | 21 | def __init__(self, cfg, num_classes=10, batch_norm=True, in_c=3, **lif_parameters): 22 | super(VGG, self).__init__() 23 | 24 | self.features, out_c = make_layers(cfg, batch_norm, in_c, **lif_parameters) 25 | self.avgpool = SeqToANNContainer(nn.AdaptiveAvgPool2d((1, 1))) 26 | self.classifier = nn.Sequential( 27 | SeqToANNContainer(nn.Linear(out_c, num_classes)), 28 | ) 29 | for m in self.modules(): 30 | if isinstance(m, nn.Conv2d): 31 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 32 | m.weight.data.normal_(0, math.sqrt(2. / n)) 33 | m.bias.data.zero_() 34 | 35 | self.add_dim = lambda x: add_dimention(x, self.T) 36 | 37 | def forward(self, x): 38 | x = self.add_dim(x) if len(x.shape) == 4 else x 39 | x = self.features(x) 40 | x = self.avgpool(x) 41 | x = torch.flatten(x, 1) if len(x.shape) == 4 else torch.flatten(x, 2) 42 | x = self.classifier(x) 43 | return x 44 | 45 | 46 | def make_layers(cfg, batch_norm=False, in_c=3, **lif_parameters): 47 | layers = [] 48 | in_channels = in_c 49 | for v in cfg: 50 | if v == 'M': 51 | layers += [SpikeModule(nn.AvgPool2d(kernel_size=2, stride=2))] 52 | else: 53 | conv2d = SpikeModule(nn.Conv2d(in_channels, v, kernel_size=3, padding=1)) 54 | 55 | lif = LIFSpike(**lif_parameters) 56 | 57 | if batch_norm: 58 | bn = tdBatchNorm(v) 59 | layers += [conv2d, bn, lif] 60 | else: 61 | layers += [conv2d, lif] 62 | 63 | in_channels = v 64 | return nn.Sequential(*layers), in_channels 65 | 66 | 67 | cfg = { 68 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 512, 512], 69 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 512, 512], 70 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 512, 512, 512], 71 | } 72 | 73 | 74 | def vgg11(*args, **kwargs): 75 | """VGG 11-layer model (configuration "A") with batch normalization""" 76 | return VGG(cfg['A'], *args, **kwargs) 77 | 78 | 79 | def vgg13(*args, **kwargs): 80 | """VGG 13-layer model (configuration "B") with batch normalization""" 81 | return VGG(cfg['B'], *args, **kwargs) 82 | 83 | 84 | def vgg16(*args, **kwargs): 85 | """VGG 16-layer model (configuration "D") with batch normalization""" 86 | return VGG(cfg['D'], *args, **kwargs) 87 | 88 | 89 | if __name__ == '__main__': 90 | model = vgg16(num_classes=10, width_mult=1) 91 | print(model) 92 | x = torch.rand(2, 3, 32, 32) 93 | y = model(x) 94 | y.sum().backward() 95 | -------------------------------------------------------------------------------- /models/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class SeqToANNContainer(nn.Module): 6 | def __init__(self, *args): 7 | super().__init__() 8 | if len(args) == 1: 9 | self.module = args[0] 10 | else: 11 | self.module = nn.Sequential(*args) 12 | 13 | def forward(self, x_seq: torch.Tensor): 14 | y_shape = [x_seq.shape[0], x_seq.shape[1]] 15 | y_seq = self.module(x_seq.flatten(0, 1).contiguous()) 16 | y_shape.extend(y_seq.shape[1:]) 17 | return y_seq.view(y_shape) 18 | 19 | 20 | class SpikeModule(nn.Module): 21 | 22 | def __init__(self, module): 23 | super().__init__() 24 | self.ann_module = module 25 | 26 | def forward(self, x): 27 | B, T, *spatial_dims = x.shape 28 | out = self.ann_module(x.reshape(B * T, *spatial_dims)) 29 | BT, *spatial_dims = out.shape 30 | out = out.view(B, T, *spatial_dims).contiguous() 31 | return out 32 | 33 | 34 | def fire_function(gamma): 35 | class ZIF(torch.autograd.Function): 36 | @staticmethod 37 | def forward(ctx, input): 38 | out = (input >= 0).float() 39 | ctx.save_for_backward(input) 40 | return out 41 | 42 | @staticmethod 43 | def backward(ctx, grad_output): 44 | (input, ) = ctx.saved_tensors 45 | grad_input = grad_output.clone() 46 | tmp = (input.abs() < gamma/2).float() / gamma 47 | grad_input = grad_input * tmp 48 | return grad_input, None 49 | 50 | return ZIF.apply 51 | 52 | 53 | class LIFSpike(nn.Module): 54 | def __init__(self, thresh=0.5, tau=0.25, gamma=1.0): 55 | super(LIFSpike, self).__init__() 56 | self.thresh = thresh 57 | self.tau = tau 58 | self.gamma = gamma 59 | 60 | def forward(self, x): 61 | mem = torch.zeros_like(x[:, 0]) 62 | spikes = [] 63 | T = x.shape[1] 64 | for t in range(T): 65 | mem = mem * self.tau + x[:, t, ...] 66 | spike = fire_function(self.gamma)(mem - self.thresh) 67 | mem = (1 - spike) * mem 68 | spikes.append(spike) 69 | return torch.stack(spikes, dim=1) 70 | 71 | 72 | def add_dimention(x, T): 73 | x.unsqueeze_(1) 74 | x = x.repeat(1, T, 1, 1, 1) 75 | return x 76 | 77 | 78 | class tdBatchNorm(nn.BatchNorm2d): 79 | def __init__(self, channel): 80 | super(tdBatchNorm, self).__init__(channel) 81 | # according to tdBN paper, the initialized weight is changed to alpha*Vth 82 | self.weight.data.mul_(0.5) 83 | 84 | def forward(self, x): 85 | B, T, *spatial_dims = x.shape 86 | out = super().forward(x.reshape(B * T, *spatial_dims)) 87 | BT, *spatial_dims = out.shape 88 | out = out.view(B, T, *spatial_dims).contiguous() 89 | return out 90 | 91 | 92 | # x = torch.randn(1, 2, 3, 4, 5) 93 | # B, T, *spatial_dims = x.shape 94 | # print(spatial_dims) --------------------------------------------------------------------------------